summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/ValidateFieldTypes.java6
-rw-r--r--document/src/main/java/com/yahoo/document/TensorDataType.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java43
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java4
5 files changed, 68 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/ValidateFieldTypes.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/ValidateFieldTypes.java
index d9b93d11a52..8e0f7a7f340 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/ValidateFieldTypes.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/ValidateFieldTypes.java
@@ -15,7 +15,7 @@ import java.util.HashMap;
import java.util.Map;
/**
- * This Processor checks to make sure all fields with the same name have the same {@link DataType}. This check
+ * This Processor makes sure all fields with the same name have the same {@link DataType}. This check
* explicitly disregards whether a field is an index field, an attribute or a summary field. This is a requirement if we
* hope to move to a model where index fields, attributes and summary fields share a common field class.
*
@@ -50,7 +50,7 @@ public class ValidateFieldTypes extends Processor {
if (seenType == null) {
seenFields.put(fieldName, fieldType);
} else if ( ! compatibleTypes(seenType, fieldType)) {
- throw newProcessException(searchName, fieldName, "Incompatible types. Expected " +
+ throw newProcessException(searchName, fieldName, "Incompatible types. Expected " +
seenType.getName() + " for " + fieldDesc +
" '" + fieldName + "', got " + fieldType.getName() + ".");
}
@@ -69,5 +69,5 @@ public class ValidateFieldTypes extends Processor {
}
return seenType.equals(fieldType);
}
-
+
}
diff --git a/document/src/main/java/com/yahoo/document/TensorDataType.java b/document/src/main/java/com/yahoo/document/TensorDataType.java
index 50e9cf0f60f..b21461597bf 100644
--- a/document/src/main/java/com/yahoo/document/TensorDataType.java
+++ b/document/src/main/java/com/yahoo/document/TensorDataType.java
@@ -42,7 +42,7 @@ public class TensorDataType extends DataType {
if (value == null) return false;
if ( ! TensorFieldValue.class.isAssignableFrom(value.getClass())) return false;
TensorFieldValue tensorValue = (TensorFieldValue)value;
- return tensorType.isAssignableTo(tensorValue.getDataType().getTensorType());
+ return tensorType.isConvertibleTo(tensorValue.getDataType().getTensorType());
}
/** Returns the type of the tensor this field can hold */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 5b98a1b4fb5..8ff9774fc7d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -82,6 +82,21 @@ public class TensorType {
* i.e if the given type is a generalization of this type.
*/
public boolean isAssignableTo(TensorType generalization) {
+ return isConvertibleOrAssignableTo(generalization, false);
+ }
+
+ /**
+ * Returns whether this type can be converted to the given type.
+ * This is true if this type isAssignableTo the given type or
+ * if it is not assignable only because it has a shorter dimension length
+ * than the given type in some shared dimension(s), as it can then be
+ * converted to the given type by zero padding.
+ */
+ public boolean isConvertibleTo(TensorType generalization) {
+ return isConvertibleOrAssignableTo(generalization, true);
+ }
+
+ private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible) {
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);
@@ -90,7 +105,12 @@ public class TensorType {
if ( ! thisDimension.name().equals(generalizationDimension.name())) return false;
if (generalizationDimension.size().isPresent()) {
if ( ! thisDimension.size().isPresent()) return false;
- if (thisDimension.size().get() > generalizationDimension.size().get() ) return false;
+ if (convertible) {
+ if (thisDimension.size().get() > generalizationDimension.size().get()) return false;
+ }
+ else { // assignable
+ if (!thisDimension.size().get().equals(generalizationDimension.size().get())) return false;
+ }
}
}
return true;
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index ef973d03ccb..eef0b090fd1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -5,11 +5,14 @@ import org.junit.Test;
import static org.hamcrest.Matchers.containsString;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
/**
* @author geirst
+ * @author bratseth
*/
public class TensorTypeTestCase {
@@ -61,6 +64,30 @@ public class TensorTypeTestCase {
assertIllegalTensorType("tensor(x{10})", "Failed parsing element 'x{10}' in type spec 'tensor(x{10})'");
}
+ @Test
+ public void testAssignableTo() {
+ assertIsAssignableTo("tensor(x[])", "tensor(x[])");
+ assertUnassignableTo("tensor(x[])", "tensor(y[])");
+ assertIsAssignableTo("tensor(x[10])", "tensor(x[])");
+ assertUnassignableTo("tensor(x[])", "tensor(x[10])");
+ assertUnassignableTo("tensor(x[10])", "tensor(x[5])");
+ assertUnassignableTo("tensor(x[5])", "tensor(x[10])");
+ assertUnassignableTo("tensor(x{})", "tensor(x[])");
+ assertIsAssignableTo("tensor(x{},y[10])", "tensor(x{},y[])");
+ }
+
+ @Test
+ public void testConvertibleTo() {
+ assertIsConvertibleTo("tensor(x[])", "tensor(x[])");
+ assertUnconvertibleTo("tensor(x[])", "tensor(y[])");
+ assertIsConvertibleTo("tensor(x[10])", "tensor(x[])");
+ assertUnconvertibleTo("tensor(x[])", "tensor(x[10])");
+ assertUnconvertibleTo("tensor(x[10])", "tensor(x[5])");
+ assertIsConvertibleTo("tensor(x[5])", "tensor(x[10])"); // Different from assignable
+ assertUnconvertibleTo("tensor(x{})", "tensor(x[])");
+ assertIsConvertibleTo("tensor(x{},y[10])", "tensor(x{},y[])");
+ }
+
private static void assertTensorType(String typeSpec) {
assertTensorType(typeSpec, typeSpec);
}
@@ -78,4 +105,20 @@ public class TensorTypeTestCase {
}
}
+ private void assertIsAssignableTo(String specificType, String generalType) {
+ assertTrue(TensorType.fromSpec(specificType).isAssignableTo(TensorType.fromSpec(generalType)));
+ }
+
+ private void assertUnassignableTo(String specificType, String generalType) {
+ assertFalse(TensorType.fromSpec(specificType).isAssignableTo(TensorType.fromSpec(generalType)));
+ }
+
+ private void assertIsConvertibleTo(String specificType, String generalType) {
+ assertTrue(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType)));
+ }
+
+ private void assertUnconvertibleTo(String specificType, String generalType) {
+ assertFalse(TensorType.fromSpec(specificType).isConvertibleTo(TensorType.fromSpec(generalType)));
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 15a872e439f..4a975b83ec0 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -1,16 +1,13 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
-import com.google.common.collect.Sets;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import org.junit.Ignore;
import org.junit.Test;
import java.util.Arrays;
import java.util.Optional;
-import java.util.Set;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;
@@ -34,7 +31,6 @@ public class DenseBinaryFormatTestCase {
@Test
public void testSerializationToSeparateType() {
assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[],y[])"));
- assertSerialization(Tensor.from("tensor(x[1],y[1]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[2],y[2])"));
try {
assertSerialization(Tensor.from("tensor(x[2],y[2]):{{x:0,y:0}:2.0}"), TensorType.fromSpec("tensor(x[1],y[1])"));
fail("Expected exception");