diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-10 14:47:44 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-01-10 14:47:44 +0100 |
commit | 7d48aa76c6c89851bf5d99109e41d2b485bc87ab (patch) | |
tree | 9a624a94ca3d5e8071c90f6c9a8cbb5bb4c55e44 | |
parent | 8c6329d755c778850bba7c1c1ed69eafebba8863 (diff) |
Maintain TensorType in documents
35 files changed, 229 insertions, 152 deletions
diff --git a/config-model/src/main/java/com/yahoo/documentmodel/VespaDocumentType.java b/config-model/src/main/java/com/yahoo/documentmodel/VespaDocumentType.java index 793a5fcff6c..09aa55f776b 100644 --- a/config-model/src/main/java/com/yahoo/documentmodel/VespaDocumentType.java +++ b/config-model/src/main/java/com/yahoo/documentmodel/VespaDocumentType.java @@ -32,7 +32,6 @@ public class VespaDocumentType { vespa.add(PositionDataType.INSTANCE); vespa.add(DataType.URI); vespa.add(DataType.PREDICATE); - vespa.add(DataType.TENSOR); return vespa; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java index 950ec791368..5856caeb692 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/Attribute.java @@ -25,10 +25,6 @@ public final class Attribute implements Cloneable, Serializable { private Type type; private CollectionType collectionType; - /** True if only the enum information should be read from this attribute - * (i.e. the actual values are not relevant, only which documents have the - * same values) Used for collapsing and unique. - */ private boolean removeIfZero = false; private boolean createIfNonExistent = false; private boolean enableBitVectors = false; @@ -224,31 +220,22 @@ public final class Attribute implements Cloneable, Serializable { } /** Converts to the right field type from an attribute type */ - public static DataType convertAttrType(Type attrType) { - if (attrType== Type.STRING) { - return DataType.STRING; - } else if (attrType== Type.INTEGER) { - return DataType.INT; - } else if (attrType== Type.LONG) { - return DataType.LONG; - } else if (attrType== Type.FLOAT) { - return DataType.FLOAT; - } else if (attrType== Type.DOUBLE) { - return DataType.DOUBLE; - } else if (attrType == Type.BYTE) { - return DataType.BYTE; - } else if (attrType == Type.PREDICATE) { - return DataType.PREDICATE; - } else if (attrType == Type.TENSOR) { - return DataType.TENSOR; - } else { - throw new IllegalArgumentException("Don't know which attribute type to " + - "convert " + attrType + " to"); + private DataType toDataType(Type attributeType) { + switch (attributeType) { + case STRING : return DataType.STRING; + case INTEGER: return DataType.INT; + case LONG: return DataType.LONG; + case FLOAT: return DataType.FLOAT; + case DOUBLE: return DataType.DOUBLE; + case BYTE: return DataType.BYTE; + case PREDICATE: return DataType.PREDICATE; + case TENSOR: DataType.getTensor(tensorType.orElseThrow(IllegalStateException::new)); + default: throw new IllegalArgumentException("Unknown attribute type " + attributeType); } } public DataType getDataType() { - DataType dataType = Attribute.convertAttrType(type); + DataType dataType = toDataType(type); if (collectionType.equals(Attribute.CollectionType.ARRAY)) { return DataType.getArray(dataType); } else if (collectionType.equals(Attribute.CollectionType.WEIGHTEDSET)) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java index ae16f6cfed8..f618c55adfc 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/TensorFieldProcessor.java @@ -23,13 +23,11 @@ public class TensorFieldProcessor extends Processor { @Override public void process() { for (SDField field : search.allFieldsList()) { - if (field.getDataType() == DataType.TENSOR) { - warnUseOfTensorFieldAsAttribute(field); - validateIndexingScripsForTensorField(field); - validateAttributeSettingForTensorField(field); - } else { - validateDataTypeForField(field); - } + if ( ! (field.getDataType() instanceof TensorDataType)) continue; + + warnUseOfTensorFieldAsAttribute(field); + validateIndexingScripsForTensorField(field); + validateAttributeSettingForTensorField(field); } } @@ -55,9 +53,4 @@ public class TensorFieldProcessor extends Processor { } } - private void validateDataTypeForField(SDField field) { - if (field.getDataType().getPrimitiveType() == DataType.TENSOR) { - fail(search, field, "A field with collection type of tensor is not supported. Use simple type 'tensor' instead."); - } - } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/SearchDataTypeValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/SearchDataTypeValidator.java index c03fb0617b8..1d39d1e6928 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/SearchDataTypeValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/SearchDataTypeValidator.java @@ -43,7 +43,7 @@ public class SearchDataTypeValidator extends Validator { for (Field field : doc.fieldSet()) { DataType fieldType = field.getDataType(); disallowIndexingOfMaps(cluster, def, field); - if (!validateDataType(fieldType)) { + if ( ! isSupportedInSearchClusters(fieldType)) { throw new IllegalArgumentException("Field type '" + fieldType.getName() + "' is illegal for search " + "clusters (field '" + field.getName() + "' in definition '" + def.getName() + "' for cluster '" + cluster.getClusterName() + "')."); @@ -51,28 +51,30 @@ public class SearchDataTypeValidator extends Validator { } } - private boolean validateDataType(DataType dataType) { - if (dataType instanceof ArrayDataType || - dataType instanceof WeightedSetDataType) - { - return validateDataType(((CollectionDataType)dataType).getNestedType()); + private boolean isSupportedInSearchClusters(DataType dataType) { + if (dataType instanceof ArrayDataType || dataType instanceof WeightedSetDataType) { + return isSupportedInSearchClusters(((CollectionDataType)dataType).getNestedType()); } - if (dataType instanceof StructDataType) { + else if (dataType instanceof StructDataType) { return true; // Struct will work for summary TODO maybe check individual fields } - if (dataType instanceof MapDataType) { + else if (dataType instanceof MapDataType) { return true; // Maps will work for summary, see disallowIndexingOfMaps() } - return dataType.equals(DataType.INT) || - dataType.equals(DataType.FLOAT) || - dataType.equals(DataType.STRING) || - dataType.equals(DataType.RAW) || - dataType.equals(DataType.LONG) || - dataType.equals(DataType.DOUBLE) || - dataType.equals(DataType.URI) || - dataType.equals(DataType.BYTE) || - dataType.equals(DataType.PREDICATE) || - dataType.equals(DataType.TENSOR); + else if (dataType instanceof TensorDataType) { + return true; + } + else { + return dataType.equals(DataType.INT) || + dataType.equals(DataType.FLOAT) || + dataType.equals(DataType.STRING) || + dataType.equals(DataType.RAW) || + dataType.equals(DataType.LONG) || + dataType.equals(DataType.DOUBLE) || + dataType.equals(DataType.URI) || + dataType.equals(DataType.BYTE) || + dataType.equals(DataType.PREDICATE); + } } private void disallowIndexingOfMaps(AbstractSearchCluster cluster, SearchDefinition def, Field field) { diff --git a/document/src/main/java/com/yahoo/document/ArrayDataType.java b/document/src/main/java/com/yahoo/document/ArrayDataType.java index 640bd94bd1c..4b24bb1ae00 100644 --- a/document/src/main/java/com/yahoo/document/ArrayDataType.java +++ b/document/src/main/java/com/yahoo/document/ArrayDataType.java @@ -8,9 +8,10 @@ import java.util.ArrayList; import java.util.List; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class ArrayDataType extends CollectionDataType { + // The global class identifier shared with C++. public static int classId = registerClass(Ids.document + 54, ArrayDataType.class); diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java index 19726ac7b99..3fb1e75262e 100644 --- a/document/src/main/java/com/yahoo/document/DataType.java +++ b/document/src/main/java/com/yahoo/document/DataType.java @@ -14,6 +14,7 @@ import com.yahoo.document.datatypes.Raw; import com.yahoo.document.datatypes.StringFieldValue; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.datatypes.UriFieldValue; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.objects.Identifiable; import com.yahoo.vespa.objects.Ids; import com.yahoo.vespa.objects.ObjectVisitor; @@ -50,14 +51,13 @@ public abstract class DataType extends Identifiable implements Serializable, Com public final static PrimitiveDataType URI = new PrimitiveDataType("uri", 10, UriFieldValue.class, new UriFieldValue.Factory()); public final static NumericDataType BYTE = new NumericDataType("byte", 16, ByteFieldValue.class, ByteFieldValue.getFactory()); public final static PrimitiveDataType PREDICATE = new PrimitiveDataType("predicate", 20, PredicateFieldValue.class, PredicateFieldValue.getFactory()); - public final static PrimitiveDataType TENSOR = new PrimitiveDataType("tensor", 21, TensorFieldValue.class, TensorFieldValue.getFactory()); - // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference + // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor // Tags are converted to weightedset<string> when reading the search definition TODO: Remove it public final static WeightedSetDataType TAG = new WeightedSetDataType(DataType.STRING, true, true); public static int lastPredefinedDataTypeId() { - return 21; + return 20; } /** Set to true when this type is registered in a type manager. From that time we should refuse changes. */ @@ -196,6 +196,11 @@ public abstract class DataType extends Identifiable implements Serializable, Com return new WeightedSetDataType(type, createIfNonExistent, removeIfZero); } + /** Returns the given tensor type as a DataType */ + public static TensorDataType getTensor(TensorType type) { + return new TensorDataType(type); + } + public String getName() { return name; } diff --git a/document/src/main/java/com/yahoo/document/Document.java b/document/src/main/java/com/yahoo/document/Document.java index 34c952e1cec..e1d912b4e51 100644 --- a/document/src/main/java/com/yahoo/document/Document.java +++ b/document/src/main/java/com/yahoo/document/Document.java @@ -26,7 +26,7 @@ import java.util.Map; * be removed soon. * * @author bratseth - * @author <a href="einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class Document extends StructuredFieldValue { @@ -267,7 +267,9 @@ public class Document extends StructuredFieldValue { } /** Returns true if the argument is a document which has the same set of values */ + @Override public boolean equals(Object o) { + if (o == this) return true; if (!(o instanceof Document)) return false; Document other = (Document) o; return (super.equals(o) && docId.equals(other.docId) && @@ -394,4 +396,5 @@ public class Document extends StructuredFieldValue { comp = body.compareTo(otherValue.body); return comp; } + } diff --git a/document/src/main/java/com/yahoo/document/Field.java b/document/src/main/java/com/yahoo/document/Field.java index 86543916b42..b80d92e0e4d 100644 --- a/document/src/main/java/com/yahoo/document/Field.java +++ b/document/src/main/java/com/yahoo/document/Field.java @@ -13,7 +13,7 @@ import java.io.Serializable; * A name and type. Fields are contained in document types to describe their fields, * but is also used to represent name/type pairs which are not part of document types. * - * @author <a href="mailto:thomasg@yahoo-inc.com">Thomas Gundersen</a> + * @author Thomas Gundersen * @author bratseth */ public class Field extends FieldBase implements FieldSet, Comparable, Serializable { diff --git a/document/src/main/java/com/yahoo/document/PrimitiveDataType.java b/document/src/main/java/com/yahoo/document/PrimitiveDataType.java index a0024bb0497..23bf4b43ccf 100644 --- a/document/src/main/java/com/yahoo/document/PrimitiveDataType.java +++ b/document/src/main/java/com/yahoo/document/PrimitiveDataType.java @@ -8,9 +8,10 @@ import com.yahoo.vespa.objects.ObjectVisitor; import java.util.Objects; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class PrimitiveDataType extends DataType { + public static abstract class Factory { public abstract FieldValue create(); } diff --git a/document/src/main/java/com/yahoo/document/TensorDataType.java b/document/src/main/java/com/yahoo/document/TensorDataType.java new file mode 100644 index 00000000000..46d5633f9f7 --- /dev/null +++ b/document/src/main/java/com/yahoo/document/TensorDataType.java @@ -0,0 +1,48 @@ +package com.yahoo.document; + +import com.yahoo.document.datatypes.FieldValue; +import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.objects.Ids; + +/** + * A DataType containing a tensor type + * + * @author bratseth + */ +public class TensorDataType extends DataType { + + private final TensorType tensorType; + + // The global class identifier shared with C++. + public static int classId = registerClass(Ids.document + 59, TensorDataType.class); + + public TensorDataType(TensorType tensorType) { + super(tensorType.toString(), 0); + this.tensorType = tensorType; + setId(getName().toLowerCase().hashCode()); + } + + public TensorDataType clone() { + return (TensorDataType)super.clone(); + } + + @Override + public FieldValue createFieldValue() { + return new TensorFieldValue(tensorType); + } + + @Override + public Class<? extends TensorFieldValue> getValueClass() { + return TensorFieldValue.class; + } + + @Override + public boolean isValueCompatible(FieldValue value) { + return value != null && TensorFieldValue.class.isAssignableFrom(value.getClass()); + } + + /** Returns the type of the tensor this field can hold */ + public TensorType getTensorType() { return tensorType; } + +} diff --git a/document/src/main/java/com/yahoo/document/datatypes/Array.java b/document/src/main/java/com/yahoo/document/datatypes/Array.java index 66cc472de69..09fb8b71db1 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/Array.java +++ b/document/src/main/java/com/yahoo/document/datatypes/Array.java @@ -16,7 +16,7 @@ import java.util.*; /** * FieldValue which encapsulates a Array value * - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public final class Array<T extends FieldValue> extends CollectionFieldValue<T> implements List<T> { diff --git a/document/src/main/java/com/yahoo/document/datatypes/Struct.java b/document/src/main/java/com/yahoo/document/datatypes/Struct.java index 5a01dc33aa1..b5920cb4758 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/Struct.java +++ b/document/src/main/java/com/yahoo/document/datatypes/Struct.java @@ -215,7 +215,6 @@ public class Struct extends StructuredFieldValue { if (!super.equals(o)) return false; Struct struct = (Struct) o; - return values.equals(struct.values); } @@ -388,4 +387,5 @@ public class Struct extends StructuredFieldValue { } return fieldType.cast(fieldValue); } + } diff --git a/document/src/main/java/com/yahoo/document/datatypes/StructuredFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/StructuredFieldValue.java index b4585a2188d..7957d9b812f 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/StructuredFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/StructuredFieldValue.java @@ -10,7 +10,7 @@ import java.util.List; import java.util.Map; /** - * @author <a href="mailto:humbe@yahoo-inc.com">Håkon Humberset</a> + * @author Håkon Humberset */ public abstract class StructuredFieldValue extends CompositeFieldValue { diff --git a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java index bee29478219..9d8e9a83b5e 100644 --- a/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java +++ b/document/src/main/java/com/yahoo/document/datatypes/TensorFieldValue.java @@ -3,11 +3,12 @@ package com.yahoo.document.datatypes; import com.yahoo.document.DataType; import com.yahoo.document.Field; -import com.yahoo.document.PrimitiveDataType; +import com.yahoo.document.TensorDataType; import com.yahoo.document.serialization.FieldReader; import com.yahoo.document.serialization.FieldWriter; import com.yahoo.document.serialization.XmlStream; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.util.Objects; import java.util.Optional; @@ -20,12 +21,18 @@ import java.util.Optional; public class TensorFieldValue extends FieldValue { private Optional<Tensor> tensor; + + private final TensorDataType dataType; - public TensorFieldValue() { - tensor = Optional.empty(); + /** Create an empty tensor field value */ + public TensorFieldValue(TensorType type) { + this.dataType = new TensorDataType(type); + this.tensor = Optional.empty(); } + /** Create a tensor field value containing the given tensor */ public TensorFieldValue(Tensor tensor) { + this.dataType = new TensorDataType(tensor.type()); this.tensor = Optional.of(tensor); } @@ -34,8 +41,8 @@ public class TensorFieldValue extends FieldValue { } @Override - public DataType getDataType() { - return DataType.TENSOR; + public TensorDataType getDataType() { + return dataType; } @Override @@ -51,16 +58,23 @@ public class TensorFieldValue extends FieldValue { @Override public void assign(Object o) { if (o == null) { - tensor = Optional.empty(); + assignTensor(Optional.empty()); } else if (o instanceof Tensor) { - tensor = Optional.of((Tensor)o); + assignTensor(Optional.of((Tensor)o)); } else if (o instanceof TensorFieldValue) { - tensor = ((TensorFieldValue)o).getTensor(); + assignTensor(((TensorFieldValue)o).getTensor()); } else { throw new IllegalArgumentException("Expected class '" + getClass().getName() + "', got '" + - o.getClass().getName() + "'."); + o.getClass().getName() + "'."); } } + + public void assignTensor(Optional<Tensor> tensor) { + if (tensor.isPresent() && ! dataType.getTensorType().isAssignableTo(tensor.get().type())) + throw new IllegalArgumentException("Type mismatch: Cannot assign tensor of type " + tensor.get().type() + + " to field of type " + dataType.getTensorType()); + this.tensor = tensor; + } @Override public void serialize(Field field, FieldWriter writer) { @@ -74,27 +88,14 @@ public class TensorFieldValue extends FieldValue { @Override public boolean equals(Object o) { - if (this == o) { - return true; - } - if (!(o instanceof TensorFieldValue)) { - return false; - } - TensorFieldValue rhs = (TensorFieldValue)o; - if (!Objects.equals(tensor, rhs.tensor)) { - return false; - } + if (this == o) return true; + if ( ! (o instanceof TensorFieldValue)) return false; + + TensorFieldValue other = (TensorFieldValue)o; + if ( ! dataType.getTensorType().equals(other.dataType.getTensorType())) return false; + if ( ! tensor.equals(other.tensor)) return false; return true; } - public static PrimitiveDataType.Factory getFactory() { - return new PrimitiveDataType.Factory() { - - @Override - public FieldValue create() { - return new TensorFieldValue(); - } - }; - } } diff --git a/document/src/main/java/com/yahoo/document/serialization/DocumentReader.java b/document/src/main/java/com/yahoo/document/serialization/DocumentReader.java index 52a62caf296..5f1b227790b 100644 --- a/document/src/main/java/com/yahoo/document/serialization/DocumentReader.java +++ b/document/src/main/java/com/yahoo/document/serialization/DocumentReader.java @@ -9,7 +9,7 @@ import com.yahoo.document.DocumentTypeManager; /** * This interface is used to implement custom deserialization of document updates. * - * @author <a href="mailto:ravishar@yahoo-inc.com">Ravi Sharma</a> + * @author Ravi Sharma * @author baldersheim */ public interface DocumentReader { diff --git a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java index 9e764aae798..6e9495b1437 100644 --- a/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java +++ b/document/src/main/java/com/yahoo/document/serialization/VespaDocumentDeserializer42.java @@ -115,8 +115,8 @@ public class VespaDocumentDeserializer42 extends VespaDocumentSerializer42 imple // Verify that we have correct version version = getShort(null); if (version < 6 || version > Document.SERIALIZED_VERSION) { - throw new DeserializationException( - "Unknown version " + version + ", expected " + Document.SERIALIZED_VERSION + "."); + throw new DeserializationException("Unknown version " + version + ", expected " + + Document.SERIALIZED_VERSION + "."); } int dataLength = 0; @@ -278,7 +278,7 @@ public class VespaDocumentDeserializer42 extends VespaDocumentSerializer42 imple int encodedTensorLength = buf.getInt1_4Bytes(); if (encodedTensorLength > 0) { byte[] encodedTensor = getBytes(null, encodedTensorLength); - value.assign(TypedBinaryFormat.decode(encodedTensor)); + value.assign(TypedBinaryFormat.decode(value.getDataType().getTensorType(), encodedTensor)); } else { value.clear(); } @@ -328,7 +328,7 @@ public class VespaDocumentDeserializer42 extends VespaDocumentSerializer42 imple fieldIdsAndLengths.add(new Tuple2<>(getInt1_4Bytes(null), getInt2_4_8Bytes(null))); } - //save a reference to the big buffer we're reading from: + // save a reference to the big buffer we're reading from: GrowableByteBuffer bigBuf = buf; if (version < 7) { diff --git a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java index 413d1581e58..02773c7dad0 100644 --- a/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java +++ b/document/src/test/java/com/yahoo/document/DocumentUpdateTestCase.java @@ -10,6 +10,7 @@ import com.yahoo.document.update.FieldUpdate; import com.yahoo.document.update.ValueUpdate; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.io.FileOutputStream; import java.io.IOException; @@ -43,6 +44,7 @@ public class DocumentUpdateTestCase extends junit.framework.TestCase { private final String documentId = "doc:something:foooo"; private final String tensorField = "tensorfield"; + private final TensorType tensorType = new TensorType.Builder().mapped("x").build(); private Document createDocument() { return new Document(docMan.getDocumentType("foobar"), new DocumentId(documentId)); @@ -60,7 +62,7 @@ public class DocumentUpdateTestCase extends junit.framework.TestCase { DataType stringwset = DataType.getWeightedSet(DataType.STRING); docType.addField(new Field("strwset", stringwset)); - docType.addField(new Field(tensorField, DataType.TENSOR)); + docType.addField(new Field(tensorField, new TensorDataType(tensorType))); docMan.register(docType); docType2 = new DocumentType("otherdoctype"); @@ -625,7 +627,7 @@ public class DocumentUpdateTestCase extends junit.framework.TestCase { private DocumentUpdate createTensorAssignUpdate() { DocumentUpdate result = new DocumentUpdate(docType, new DocumentId(documentId)); result.addFieldUpdate(FieldUpdate.createAssign(docType.getField(tensorField), - createTensorFieldValue("{{x:0}:2.0}"))); + createTensorFieldValue("{{x:0}:2.0}"))); return result; } diff --git a/document/src/test/java/com/yahoo/document/datatypes/TensorFieldValueTestCase.java b/document/src/test/java/com/yahoo/document/datatypes/TensorFieldValueTestCase.java index 80386141968..c94c917d2ca 100644 --- a/document/src/test/java/com/yahoo/document/datatypes/TensorFieldValueTestCase.java +++ b/document/src/test/java/com/yahoo/document/datatypes/TensorFieldValueTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.document.datatypes; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertFalse; @@ -13,8 +14,8 @@ import static org.junit.Assert.assertTrue; */ public class TensorFieldValueTestCase { - private static TensorFieldValue createFieldValue(String tensor) { - return new TensorFieldValue(Tensor.from(tensor)); + private static TensorFieldValue createFieldValue(String tensorString) { + return new TensorFieldValue(Tensor.from(tensorString)); } @Test @@ -23,20 +24,27 @@ public class TensorFieldValueTestCase { } @Test + public void requireThatDifferentTensorTypesWithEmptyValuesAreNotEqual() { + TensorFieldValue field1 = new TensorFieldValue(new TensorType.Builder().mapped("x").build()); + TensorFieldValue field2 = new TensorFieldValue(new TensorType.Builder().indexed("y").build()); + assertFalse(field1.equals(field2)); + } + + @Test public void requireThatDifferentTensorValuesAreNotEqual() { - TensorFieldValue lhs = createFieldValue("{{x:0}:2.0}"); - TensorFieldValue rhs = createFieldValue("{{x:0}:3.0}"); - assertFalse(lhs.equals(rhs)); - assertFalse(lhs.equals(new TensorFieldValue())); + TensorFieldValue field1 = createFieldValue("{{x:0}:2.0}"); + TensorFieldValue field2 = createFieldValue("{{x:0}:3.0}"); + assertFalse(field1.equals(field2)); + assertFalse(field1.equals(new TensorFieldValue(TensorType.empty))); } @Test public void requireThatSameTensorValueIsEqual() { Tensor tensor = Tensor.from("{{x:0}:2.0}"); - TensorFieldValue lhs = new TensorFieldValue(tensor); - TensorFieldValue rhs = new TensorFieldValue(tensor); - assertTrue(lhs.equals(lhs)); - assertTrue(lhs.equals(rhs)); - assertTrue(lhs.equals(createFieldValue("{{x:0}:2.0}"))); + TensorFieldValue field1 = new TensorFieldValue(tensor); + TensorFieldValue field2 = new TensorFieldValue(tensor); + assertTrue(field1.equals(field1)); + assertTrue(field1.equals(field2)); + assertTrue(field1.equals(createFieldValue("{{x:0}:2.0}"))); } } diff --git a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java index ffec7927ab3..a0f993fd2fc 100644 --- a/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java +++ b/document/src/test/java/com/yahoo/document/json/DocumentUpdateJsonSerializerTest.java @@ -10,7 +10,9 @@ import com.yahoo.document.DocumentUpdate; import com.yahoo.document.Field; import com.yahoo.document.MapDataType; import com.yahoo.document.PositionDataType; +import com.yahoo.document.TensorDataType; import com.yahoo.document.WeightedSetDataType; +import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; import org.junit.Test; @@ -27,6 +29,7 @@ import static com.yahoo.test.json.JsonTestHelper.inputJson; */ public class DocumentUpdateJsonSerializerTest { + final static TensorType tensorType = new TensorType.Builder().mapped("x").mapped("y").build(); final static DocumentTypeManager types = new DocumentTypeManager(); final static JsonFactory parserFactory = new JsonFactory(); final static DocumentType docType = new DocumentType("doctype"); @@ -39,7 +42,7 @@ public class DocumentUpdateJsonSerializerTest { docType.addField(new Field("float_field", DataType.FLOAT)); docType.addField(new Field("double_field", DataType.DOUBLE)); docType.addField(new Field("byte_field", DataType.BYTE)); - docType.addField(new Field("tensor_field", DataType.TENSOR)); + docType.addField(new Field("tensor_field", new TensorDataType(tensorType))); docType.addField(new Field("predicate_field", DataType.PREDICATE)); docType.addField(new Field("raw_field", DataType.RAW)); docType.addField(new Field("int_array", new ArrayDataType(DataType.INT))); diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index 6c46f743332..e61ed6cf541 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -19,6 +19,7 @@ import com.yahoo.document.Field; import com.yahoo.document.MapDataType; import com.yahoo.document.PositionDataType; import com.yahoo.document.StructDataType; +import com.yahoo.document.TensorDataType; import com.yahoo.document.WeightedSetDataType; import com.yahoo.document.datatypes.Array; import com.yahoo.document.datatypes.FieldValue; @@ -38,6 +39,7 @@ import com.yahoo.document.update.FieldUpdate; import com.yahoo.document.update.MapValueUpdate; import com.yahoo.document.update.ValueUpdate; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; import org.apache.commons.codec.binary.Base64; import org.junit.After; @@ -66,11 +68,12 @@ import static org.junit.Assert.*; /** * Basic test of JSON streams to Vespa document instances. * - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> + * @author Steinar Knutsen */ public class JsonReaderTestCase { - DocumentTypeManager types; - JsonFactory parserFactory; + + private DocumentTypeManager types; + private JsonFactory parserFactory; @Rule public ExpectedException exception = ExpectedException.none(); @@ -133,7 +136,8 @@ public class JsonReaderTestCase { } { DocumentType x = new DocumentType("testtensor"); - x.addField(new Field("tensorfield", DataType.TENSOR)); + TensorType tensorType = new TensorType.Builder().mapped("x").mapped("y").build(); + x.addField(new Field("tensorfield", new TensorDataType(tensorType))); types.registerDocumentType(x); } { @@ -1082,16 +1086,6 @@ public class JsonReaderTestCase { } @Test - public void testParsingOfTensorWithSingleCellWithoutValue() { - assertTensorField("{{x:a}:0.0}", - createPutWithTensor("{ " - + " \"cells\": [ " - + " { \"address\": { \"x\": \"a\" } } " - + " ]" - + "}")); - } - - @Test public void testAssignUpdateOfEmptyTensor() { assertTensorAssignUpdate("{}", createAssignUpdateWithTensor("{}")); } diff --git a/document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java index 171676be694..99f7deb7bd7 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonWriterTestCase.java @@ -19,8 +19,10 @@ import com.yahoo.document.Field; import com.yahoo.document.MapDataType; import com.yahoo.document.PositionDataType; import com.yahoo.document.StructDataType; +import com.yahoo.document.TensorDataType; import com.yahoo.document.WeightedSetDataType; import com.yahoo.document.datatypes.TensorFieldValue; +import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; import org.apache.commons.codec.binary.Base64; import org.junit.After; @@ -40,7 +42,7 @@ import static org.junit.Assert.assertSame; /** * Functional tests for com.yahoo.document.json.JsonWriter. * - * @author <a href="mailto:steinar@yahoo-inc.com">Steinar Knutsen</a> + * @author Steinar Knutsen */ public class JsonWriterTestCase { @@ -115,7 +117,8 @@ public class JsonWriterTestCase { } { DocumentType x = new DocumentType("testtensor"); - x.addField(new Field("tensorfield", DataType.TENSOR)); + TensorType tensorType = new TensorType.Builder().mapped("x").mapped("y").build(); + x.addField(new Field("tensorfield", new TensorDataType(tensorType))); types.registerDocumentType(x); } } @@ -310,8 +313,7 @@ public class JsonWriterTestCase { @Test public void testWritingOfEmptyTensor() throws IOException { - assertTensorRoundTripEquality("{}", - "{ \"cells\": [] }"); + assertTensorRoundTripEquality("{}", "{ \"cells\": [] }"); } @Test @@ -344,10 +346,11 @@ public class JsonWriterTestCase { @Test public void testWritingOfTensorFieldValueWithoutTensor() throws IOException { - DocumentType tensorType = types.getDocumentType("testtensor"); + DocumentType documentTypeWithTensor = types.getDocumentType("testtensor"); String docId = "id:unittest:testtensor::0"; - Document doc = new Document(tensorType, docId); - doc.setFieldValue(tensorType.getField("tensorfield"), new TensorFieldValue()); + Document doc = new Document(documentTypeWithTensor, docId); + Field tensorField = documentTypeWithTensor.getField("tensorfield"); + doc.setFieldValue(tensorField, new TensorFieldValue(((TensorDataType)tensorField.getDataType()).getTensorType())); assertEqualJson(asDocument(docId, "{ \"tensorfield\": {} }"), JsonWriter.toByteArray(doc)); } diff --git a/document/src/test/java/com/yahoo/document/serialization/SerializationTestUtils.java b/document/src/test/java/com/yahoo/document/serialization/SerializationTestUtils.java index f3987085e32..7e3fabc30fb 100644 --- a/document/src/test/java/com/yahoo/document/serialization/SerializationTestUtils.java +++ b/document/src/test/java/com/yahoo/document/serialization/SerializationTestUtils.java @@ -2,6 +2,7 @@ package com.yahoo.document.serialization; import com.yahoo.document.Document; +import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.io.GrowableByteBuffer; import java.io.IOException; diff --git a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java index 22cb35ae937..ae61bb3cf6f 100644 --- a/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java +++ b/document/src/test/java/com/yahoo/document/serialization/TensorFieldValueSerializationTestCase.java @@ -4,8 +4,10 @@ package com.yahoo.document.serialization; import com.yahoo.document.DataType; import com.yahoo.document.Document; import com.yahoo.document.DocumentType; +import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import org.junit.Test; import java.io.IOException; @@ -19,30 +21,31 @@ import static org.junit.Assert.assertEquals; */ public class TensorFieldValueSerializationTestCase { + private final static TensorType tensorType = new TensorType.Builder().mapped("dimX").mapped("dimY").build(); private final static String TENSOR_FIELD = "my_tensor"; private final static String TENSOR_FILES = "src/test/resources/tensor/"; - private final static TestDocumentFactory docFactory = - new TestDocumentFactory(createDocType(), "id:test:my_type::foo"); + private final static TestDocumentFactory docFactory = new TestDocumentFactory(createDocType(), + "id:test:my_type::foo"); private static DocumentType createDocType() { DocumentType type = new DocumentType("my_type"); - type.addField(TENSOR_FIELD, DataType.TENSOR); + type.addField(TENSOR_FIELD, new TensorDataType(tensorType)); return type; } @Test public void requireThatTensorFieldValueIsSerializedAndDeserialized() { - assertSerialization(new TensorFieldValue()); - assertSerialization(createTensor("{}")); - assertSerialization(createTensor("{{dimX:a,dimY:bb}:2.0,{dimX:ccc,dimY:dddd}:3.0,{dimX:e,dimY:ff}:5.0}")); + assertSerialization(new TensorFieldValue(tensorType)); + assertSerialization(createTensor(tensorType, "{}")); + assertSerialization(createTensor(tensorType, "{{dimX:a,dimY:bb}:2.0,{dimX:ccc,dimY:dddd}:3.0,{dimX:e,dimY:ff}:5.0}")); } @Test public void requireThatSerializationMatchesCpp() throws IOException { - assertSerializationMatchesCpp("non_existing_tensor", new TensorFieldValue()); - assertSerializationMatchesCpp("empty_tensor", createTensor("{}")); + assertSerializationMatchesCpp("non_existing_tensor", new TensorFieldValue(tensorType)); + assertSerializationMatchesCpp("empty_tensor", createTensor(tensorType, "{}")); assertSerializationMatchesCpp("multi_cell_tensor", - createTensor("{{dimX:a,dimY:bb}:2.0,{dimX:ccc,dimY:dddd}:3.0,{dimX:e,dimY:ff}:5.0}")); + createTensor(tensorType, "{{dimX:a,dimY:bb}:2.0,{dimX:ccc,dimY:dddd}:3.0,{dimX:e,dimY:ff}:5.0}")); } private static void assertSerialization(TensorFieldValue tensor) { @@ -60,8 +63,8 @@ public class TensorFieldValueSerializationTestCase { SerializationTestUtils.assertSerializationMatchesCpp(TENSOR_FILES, fileName, document, docFactory); } - private static TensorFieldValue createTensor(String tensor) { - return new TensorFieldValue(Tensor.from(tensor)); + private static TensorFieldValue createTensor(TensorType type, String tensorCellString) { + return new TensorFieldValue(Tensor.from(type, tensorCellString)); } } diff --git a/document/src/test/java/com/yahoo/vespaxmlparser/VespaXmlFieldReaderTestCase.java b/document/src/test/java/com/yahoo/vespaxmlparser/VespaXmlFieldReaderTestCase.java index 3cfbcac5b62..3dc6ebd1403 100644 --- a/document/src/test/java/com/yahoo/vespaxmlparser/VespaXmlFieldReaderTestCase.java +++ b/document/src/test/java/com/yahoo/vespaxmlparser/VespaXmlFieldReaderTestCase.java @@ -12,6 +12,7 @@ import com.yahoo.document.predicate.FeatureRange; import com.yahoo.document.predicate.FeatureSet; import com.yahoo.document.predicate.Predicate; import com.yahoo.document.serialization.DeserializationException; +import com.yahoo.tensor.TensorType; import org.apache.commons.codec.binary.Base64; import org.junit.Test; @@ -75,8 +76,8 @@ public class VespaXmlFieldReaderTestCase { @Test public void requireThatPutsForTensorFieldsAreNotSupported() throws Exception { - assertThrows(new Field("my_tensor", DataType.TENSOR), "", - "Field 'my_tensor': XML input for fields of type TENSOR is not supported. Please use JSON input instead."); + assertThrows(new Field("my_tensor", new TensorDataType(TensorType.empty)), "", + "Field 'my_tensor': XML input for fields of type TENSOR is not supported. Please use JSON input instead."); } private class MockedReaderFixture { diff --git a/document/src/test/java/com/yahoo/vespaxmlparser/VespaXmlUpdateReaderTestCase.java b/document/src/test/java/com/yahoo/vespaxmlparser/VespaXmlUpdateReaderTestCase.java index 8730265c80d..8a5fabde9ea 100644 --- a/document/src/test/java/com/yahoo/vespaxmlparser/VespaXmlUpdateReaderTestCase.java +++ b/document/src/test/java/com/yahoo/vespaxmlparser/VespaXmlUpdateReaderTestCase.java @@ -7,7 +7,9 @@ import com.yahoo.document.DocumentTypeManager; import com.yahoo.document.DocumentUpdate; import com.yahoo.document.Field; import com.yahoo.document.StructDataType; +import com.yahoo.document.TensorDataType; import com.yahoo.document.serialization.DeserializationException; +import com.yahoo.tensor.TensorType; import org.junit.Ignore; import org.junit.Test; @@ -215,8 +217,8 @@ public class VespaXmlUpdateReaderTestCase { @Test public void requireThatUpdatesForTensorFieldsAreNotSupported() throws Exception { - assertThrows(new Field("my_tensor", DataType.TENSOR), "<assign field='my_tensor'></assign>", - "Field 'my_tensor': XML input for fields of type TENSOR is not supported. Please use JSON input instead."); + assertThrows(new Field("my_tensor", new TensorDataType(TensorType.empty)), "<assign field='my_tensor'></assign>", + "Field 'my_tensor': XML input for fields of type TENSOR is not supported. Please use JSON input instead."); } private static void assertThrows(Field field, String fieldXml, String expected) throws Exception { diff --git a/document/src/test/resources/tensor/empty_tensor__cpp b/document/src/test/resources/tensor/empty_tensor__cpp Binary files differindex 365182b14eb..7915a6842e0 100644 --- a/document/src/test/resources/tensor/empty_tensor__cpp +++ b/document/src/test/resources/tensor/empty_tensor__cpp diff --git a/document/src/test/resources/tensor/empty_tensor__java b/document/src/test/resources/tensor/empty_tensor__java Binary files differindex 365182b14eb..7915a6842e0 100644 --- a/document/src/test/resources/tensor/empty_tensor__java +++ b/document/src/test/resources/tensor/empty_tensor__java diff --git a/document/src/test/resources/tensor/multi_cell_tensor__cpp b/document/src/test/resources/tensor/multi_cell_tensor__cpp Binary files differindex c0b2b3a165a..d34080b6b99 100644 --- a/document/src/test/resources/tensor/multi_cell_tensor__cpp +++ b/document/src/test/resources/tensor/multi_cell_tensor__cpp diff --git a/document/src/test/resources/tensor/multi_cell_tensor__java b/document/src/test/resources/tensor/multi_cell_tensor__java Binary files differindex d923fc10559..d34080b6b99 100644 --- a/document/src/test/resources/tensor/multi_cell_tensor__java +++ b/document/src/test/resources/tensor/multi_cell_tensor__java diff --git a/document/src/test/resources/tensor/non_existing_tensor__cpp b/document/src/test/resources/tensor/non_existing_tensor__cpp Binary files differindex 08cbcac6dd3..2c5746e19f3 100644 --- a/document/src/test/resources/tensor/non_existing_tensor__cpp +++ b/document/src/test/resources/tensor/non_existing_tensor__cpp diff --git a/document/src/test/resources/tensor/non_existing_tensor__java b/document/src/test/resources/tensor/non_existing_tensor__java Binary files differindex 08cbcac6dd3..2c5746e19f3 100644 --- a/document/src/test/resources/tensor/non_existing_tensor__java +++ b/document/src/test/resources/tensor/non_existing_tensor__java diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 82f36972a47..e58a08c8d31 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -77,6 +77,25 @@ public class TensorType { return Optional.empty(); } + /** + * Returns whether a tensor of the given type can be assigned to this type, + * i.e of this type is a generalization of the given type. + */ + public boolean isAssignableTo(TensorType other) { + if (other.dimensions().size() != this.dimensions().size()) return false; + for (int i = 0; i < other.dimensions().size(); i++) { + Dimension thisDimension = this.dimensions().get(i); + Dimension otherDimension = other.dimensions().get(i); + if (thisDimension.isIndexed() != other.isIndexed()) return false; + if ( ! thisDimension.name().equals(otherDimension.name())) return false; + if (thisDimension.size().isPresent()) { + if ( ! otherDimension.size().isPresent()) return false; + if (otherDimension.size().get() > thisDimension.size().get() ) return false; + } + } + return true; + } + @Override public String toString() { return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 5a45f20b6d8..5bb93b9da83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization; import com.google.common.annotations.Beta; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * Class used by clients for serializing a Tensor object into binary format or diff --git a/vespajlib/src/main/java/com/yahoo/vespa/objects/Identifiable.java b/vespajlib/src/main/java/com/yahoo/vespa/objects/Identifiable.java index e0edc6f4e64..d303a69a68d 100644 --- a/vespajlib/src/main/java/com/yahoo/vespa/objects/Identifiable.java +++ b/vespajlib/src/main/java/com/yahoo/vespa/objects/Identifiable.java @@ -16,7 +16,7 @@ import java.util.HashMap; * methods. * * @author baldersheim - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public class Identifiable extends Selectable implements Cloneable { @@ -177,7 +177,7 @@ public class Identifiable extends Selectable implements Cloneable { * * @param id The class identifier to register with. * @param spec The class to register. - * @return The identifier argument. + * @return the identifier argument. */ protected static int registerClass(int id, Class<? extends Identifiable> spec) { if (registry == null) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index ad908101329..8f96edf7dd8 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -32,7 +32,7 @@ public class SparseBinaryFormatTestCase { private static void assertSerialization(Tensor tensor) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(encodedTensor); + Tensor decodedTensor = TypedBinaryFormat.decode(null, encodedTensor); assertEquals(tensor, decodedTensor); } |