diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
commit | 5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch) | |
tree | 2b65d4f48b92bf7ec846b3efd5d5259244bc234a | |
parent | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff) |
Add tensor value type
40 files changed, 338 insertions, 249 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index a0f35dbefe6..75b3af47954 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -191,7 +191,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement else { // default dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString(); } - return Optional.of(new TensorType.Builder().mapped(dimension).build()); + + // TODO: Determine the type of the weighted set/vector and use that as value type + return Optional.of(new TensorType.Builder(TensorType.Value.DOUBLE).mapped(dimension).build()); } /** Binds the given list of formal arguments to their actual values */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index f197e2dfe6d..e12cc60b041 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -453,10 +453,9 @@ public class ConvertedModel { */ // TODO: determine when this is not necessary! private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); + if (after.equals(before)) return node; + + TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); for (TensorType.Dimension dimension : before.dimensions()) { if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { typeBuilder.indexed(dimension.name(), 1); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java index 5c96635fd8f..80440ac8eb4 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java @@ -144,7 +144,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); - exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'"); + exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(x)'. Dimension 'x' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])"); RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " constants {\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java index 2fcf5809ea5..f53ca15635f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java @@ -39,7 +39,7 @@ public class TensorFieldTestCase { @Test public void requireThatIllegalTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); - exception.expectMessage("Field type: Illegal tensor type spec: Failed parsing element 'invalid' in type spec 'tensor(invalid)'"); + exception.expectMessage("Field type: Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(invalid)'. Dimension 'invalid' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])"); SearchBuilder.createFromString(getSd("field f1 type tensor(invalid) { indexing: attribute }")); } diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java index 8eaf4cc08cb..c05c3589a30 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java @@ -77,7 +77,7 @@ public class QueryProfileTypeTestCase { type.addField(new FieldDescription("myBoolean", FieldType.fromString("boolean", registry)), registry); type.addField(new FieldDescription("ranking.features.query(myTensor1)", FieldType.fromString("tensor(a{},b{})", registry)), registry); type.addField(new FieldDescription("ranking.features.query(myTensor2)", FieldType.fromString("tensor(x[2],y[2])", registry)), registry); - type.addField(new FieldDescription("ranking.features.query(myTensor3)", FieldType.fromString("tensor(x{})",registry)), registry); + type.addField(new FieldDescription("ranking.features.query(myTensor3)", FieldType.fromString("tensor<float>(x{})",registry)), registry); type.addField(new FieldDescription("myQuery", FieldType.fromString("query", registry)), registry); type.addField(new FieldDescription("myQueryProfile", FieldType.fromString("query-profile", registry),"qp"), registry); } @@ -136,7 +136,7 @@ public class QueryProfileTypeTestCase { assertEquals(true, properties.get("myBoolean")); assertEquals(Tensor.from(tensorString1), properties.get("ranking.features.query(myTensor1)")); assertEquals(Tensor.from("tensor(x[2],y[2])", tensorString2), properties.get("ranking.features.query(myTensor2)")); - assertEquals(Tensor.from("tensor(x{})", tensorString3), properties.get("ranking.features.query(myTensor3)")); + assertEquals(Tensor.from("tensor<float>(x{})", tensorString3), properties.get("ranking.features.query(myTensor3)")); // TODO: assertEquals(..., cprofile.get("myQuery")); assertEquals("value1", properties.get("myQueryProfile.anyString")); assertEquals("value1", properties.get("QP.anyString")); diff --git a/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java index 3fa7f1ee47e..b5c4166e4de 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.search.yql; import static org.junit.Assert.*; +import com.yahoo.search.query.QueryTree; import org.apache.http.client.utils.URIBuilder; import org.junit.After; import org.junit.Before; @@ -29,20 +30,20 @@ public class UserInputTestCase { @Before public void setUp() throws Exception { - searchChain = new Chain<Searcher>(new MinimalQueryInserter()); + searchChain = new Chain<>(new MinimalQueryInserter()); context = Execution.Context.createContextStub(null); execution = new Execution(searchChain, context); } @After - public void tearDown() throws Exception { + public void tearDown() { searchChain = null; context = null; execution = null; } @Test - public final void testSimpleUserInput() { + public void testSimpleUserInput() { { URIBuilder builder = searchUri(); builder.setParameter("yql", @@ -70,7 +71,7 @@ public class UserInputTestCase { } @Test - public final void testRawUserInput() { + public void testRawUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"grammar\": \"raw\"}]userInput(\"nal le\");"); @@ -79,7 +80,7 @@ public class UserInputTestCase { } @Test - public final void testSegmentedUserInput() { + public void testSegmentedUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"grammar\": \"segment\"}]userInput(\"nal le\");"); @@ -88,7 +89,7 @@ public class UserInputTestCase { } @Test - public final void testSegmentedNoiseUserInput() { + public void testSegmentedNoiseUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"grammar\": \"segment\"}]userInput(\"^^^^^^^^\");"); @@ -97,7 +98,7 @@ public class UserInputTestCase { } @Test - public final void testCustomDefaultIndexUserInput() { + public void testCustomDefaultIndexUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"defaultIndex\": \"glompf\"}]userInput(\"nalle\");"); @@ -106,7 +107,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputStemming() { + public void testAnnotatedUserInputStemming() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"stem\": false}]userInput(\"nalle\");"); @@ -117,7 +118,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputUnrankedTerms() { + public void testAnnotatedUserInputUnrankedTerms() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"ranked\": false}]userInput(\"nalle\");"); @@ -128,7 +129,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputFiltersTerms() { + public void testAnnotatedUserInputFiltersTerms() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"filter\": true}]userInput(\"nalle\");"); @@ -139,7 +140,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputCaseNormalization() { + public void testAnnotatedUserInputCaseNormalization() { URIBuilder builder = searchUri(); builder.setParameter( "yql", @@ -151,7 +152,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputAccentRemoval() { + public void testAnnotatedUserInputAccentRemoval() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"accentDrop\": false}]userInput(\"nalle\");"); @@ -162,7 +163,7 @@ public class UserInputTestCase { } @Test - public final void testAnnotatedUserInputPositionData() { + public void testAnnotatedUserInputPositionData() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where [{\"usePositionData\": false}]userInput(\"nalle\");"); @@ -173,7 +174,7 @@ public class UserInputTestCase { } @Test - public final void testQueryPropertiesAsStringArguments() { + public void testQueryPropertiesAsStringArguments() { URIBuilder builder = searchUri(); builder.setParameter("nalle", "bamse"); builder.setParameter("meta", "syntactic"); @@ -197,7 +198,7 @@ public class UserInputTestCase { } @Test - public final void testEmptyUserInput() { + public void testEmptyUserInput() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where userInput(\"\");"); @@ -205,7 +206,7 @@ public class UserInputTestCase { } @Test - public final void testEmptyUserInputFromQueryProperty() { + public void testEmptyUserInputFromQueryProperty() { URIBuilder builder = searchUri(); builder.setParameter("foo", ""); builder.setParameter("yql", @@ -214,7 +215,7 @@ public class UserInputTestCase { } @Test - public final void testEmptyQueryProperty() { + public void testEmptyQueryProperty() { URIBuilder builder = searchUri(); builder.setParameter("foo", ""); builder.setParameter("yql", "select * from sources * where bar contains \"a\" and nonEmpty(foo contains @foo);"); @@ -222,7 +223,7 @@ public class UserInputTestCase { } @Test - public final void testEmptyQueryPropertyInsideExpression() { + public void testEmptyQueryPropertyInsideExpression() { URIBuilder builder = searchUri(); builder.setParameter("foo", ""); builder.setParameter("yql", @@ -231,7 +232,7 @@ public class UserInputTestCase { } @Test - public final void testCompositeWithoutArguments() { + public void testCompositeWithoutArguments() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where bar contains \"a\" and foo contains phrase();"); searchAndAssertNoErrors(builder); @@ -241,7 +242,7 @@ public class UserInputTestCase { } @Test - public final void testAnnoyingPlacementOfNonEmpty() { + public void testAnnoyingPlacementOfNonEmpty() { URIBuilder builder = searchUri(); builder.setParameter("yql", "select * from sources * where bar contains \"a\" and foo contains nonEmpty(phrase(\"a\", \"b\"));"); @@ -254,7 +255,7 @@ public class UserInputTestCase { } @Test - public final void testAllowEmptyUserInput() { + public void testAllowEmptyUserInput() { URIBuilder builder = searchUri(); builder.setParameter("foo", ""); builder.setParameter("yql", "select * from sources * where [{\"allowEmpty\": true}]userInput(@foo);"); @@ -262,7 +263,7 @@ public class UserInputTestCase { } @Test - public final void testAllowEmptyNullFromQueryParsing() { + public void testAllowEmptyNullFromQueryParsing() { URIBuilder builder = searchUri(); builder.setParameter("foo", ",,,,,,,,"); builder.setParameter("yql", "select * from sources * where [{\"allowEmpty\": true}]userInput(@foo);"); @@ -270,7 +271,7 @@ public class UserInputTestCase { } @Test - public final void testDisallowEmptyNullFromQueryParsing() { + public void testDisallowEmptyNullFromQueryParsing() { URIBuilder builder = searchUri(); builder.setParameter("foo", ",,,,,,,,"); builder.setParameter("yql", "select * from sources * where userInput(@foo);"); diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index 2773f9d31da..435c8fcdc65 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -38,7 +38,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { * Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions). */ public static TensorType convertDimensionsToMapped(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(type.valueType()); type.dimensions().stream().forEach(dim -> builder.mapped(dim.name())); return builder.build(); } diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java index 335cda8e133..981120af145 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java @@ -97,7 +97,7 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> { } public static TensorType extractSparseDimensions(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(type.valueType()); type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name())); return builder.build(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index c4acfeb3235..9c8f6238731 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -29,9 +29,17 @@ public class OrderedTensorType { private final long[] innerSizesVespa; private final int[] dimensionMap; - private OrderedTensorType(List<TensorType.Dimension> dimensions) { + private OrderedTensorType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { this.dimensions = Collections.unmodifiableList(dimensions); - this.type = new TensorType.Builder(dimensions).build(); + this.type = new TensorType.Builder(valueType, dimensions).build(); + this.innerSizesOriginal = new long[dimensions.size()]; + this.innerSizesVespa = new long[dimensions.size()]; + this.dimensionMap = createDimensionMap(); + } + + private OrderedTensorType(TensorType type) { + this.dimensions = type.dimensions(); + this.type = type; this.innerSizesOriginal = new long[dimensions.size()]; this.innerSizesVespa = new long[dimensions.size()]; this.dimensionMap = createDimensionMap(); @@ -136,11 +144,11 @@ public class OrderedTensorType { renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); } } - return new OrderedTensorType(renamedDimensions); + return new OrderedTensorType(type.valueType(), renamedDimensions); } public OrderedTensorType rename(String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.valueType()); for (int i = 0; i < dimensions.size(); ++ i) { String dimensionName = dimensionPrefix + i; Optional<Long> dimSize = dimensions.get(i).size(); @@ -154,7 +162,7 @@ public class OrderedTensorType { } public static OrderedTensorType standardType(OrderedTensorType type) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.type().valueType()); for (int i = 0; i < type.dimensions().size(); ++ i) { TensorType.Dimension dim = type.dimensions().get(i); String dimensionName = "d" + i; @@ -193,18 +201,18 @@ public class OrderedTensorType { * where dimensions are listed in the order of this rather than the natural order of their names. */ public static OrderedTensorType fromSpec(String typeSpec) { - return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); + return new OrderedTensorType(TensorType.fromSpec(typeSpec)); } - public static OrderedTensorType fromDimensionList(List<Long> dims) { - return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... + public static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions) { + return fromDimensionList(valueType, dimensions, "d"); // standard naming convention: d0, d1, ... } - private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < dims.size(); ++ i) { + private static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions, String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueType); + for (int i = 0; i < dimensions.size(); ++ i) { String dimensionName = dimensionPrefix + i; - Long dimSize = dims.get(i); + Long dimSize = dimensions.get(i); if (dimSize >= 0) { builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); } else { @@ -216,9 +224,15 @@ public class OrderedTensorType { public static class Builder { + private final TensorType.Value valueType; private final List<TensorType.Dimension> dimensions; public Builder() { + this(TensorType.Value.DOUBLE); + } + + public Builder(TensorType.Value valueType) { + this.valueType = valueType; this.dimensions = new ArrayList<>(); } @@ -228,7 +242,7 @@ public class OrderedTensorType { } public OrderedTensorType build() { - return new OrderedTensorType(dimensions); + return new OrderedTensorType(valueType, dimensions); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index dd2add973e4..5cc1defc010 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -16,8 +16,10 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.NoOp; import ai.vespa.rankingexpression.importer.operations.Reshape; import ai.vespa.rankingexpression.importer.operations.Shape; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; +import onnx.Onnx.TensorProto.DataType; import java.util.List; import java.util.stream.Collectors; @@ -114,7 +116,8 @@ class GraphImporter { } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); - OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList()); + OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(toValueType(tensorProto.getDataType()), + tensorProto.getDimsList()); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); @@ -133,6 +136,25 @@ class GraphImporter { return operation; } + private static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { Onnx.ValueInfoProto value = getArgumentTensor(name, graph); Onnx.TensorProto tensor = getConstantTensor(name, graph); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index f251a14213b..79b399f2c6f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -36,7 +36,7 @@ class TypeConverter { private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TensorType.Value.DOUBLE); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 1a564661ccb..7ae50a0549d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -21,20 +21,15 @@ public class ConcatV2 extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { - return null; - } + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null; IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input - if (!concatDimOp.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a constant."); - } + if ( ! concatDimOp.getConstantValue().isPresent()) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a constant."); + Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); - if (concatDimTensor.type().rank() != 0) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a scalar."); - } + if (concatDimTensor.type().rank() != 0) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a scalar."); OrderedTensorType aType = inputs.get(0).type().get(); concatDimensionIndex = (int)concatDimTensor.asDouble(); @@ -42,10 +37,9 @@ public class ConcatV2 extends IntermediateOperation { for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType bType = inputs.get(i).type().get(); - if (bType.rank() != aType.rank()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "inputs must have save rank."); - } + if (bType.rank() != aType.rank()) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Inputs must have the same rank."); + for (int j = 0; j < aType.rank(); ++j) { long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); long dimSizeB = bType.dimensions().get(j).size().orElse(-1L); @@ -58,7 +52,7 @@ public class ConcatV2 extends IntermediateOperation { } } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); int dimensionIndex = 0; for (TensorType.Dimension dimension : aType.dimensions()) { if (dimensionIndex == concatDimensionIndex) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 8ae6d81b8d4..c64b9ded601 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -27,20 +27,15 @@ public class ExpandDims extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; IntermediateOperation axisOperation = inputs().get(1); if (!axisOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis must be a constant."); + throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis argument must be a scalar."); - } + if (axis.type().rank() != 0) + throw new IllegalArgumentException("ExpandDims in " + name + ": Axis argument must be a scalar."); OrderedTensorType inputType = inputs.get(0).type().get(); int dimensionToInsert = (int)axis.asDouble(); @@ -48,7 +43,7 @@ public class ExpandDims extends IntermediateOperation { dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { @@ -66,12 +61,10 @@ public class ExpandDims extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputFunctionsPresent(2)) return null; // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); for (String name : expandDimensions) { typeBuilder.indexed(name, 1); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 3b77f9527ca..0ee54f839bc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -9,6 +9,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; @@ -17,6 +18,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.function.Function; +import java.util.stream.Collectors; /** * Wraps an imported operation node and produces the respective Vespa tensor @@ -161,6 +163,19 @@ public abstract class IntermediateOperation { } /** + * Returns the largest value type among the input value types. + * This should only be called after it has been verified that input types are available. + * + * @throws IllegalArgumentException if a type cannot be uniquely determined + * @throws RuntimeException if called when input types are not available + */ + TensorType.Value resultValueType() { + return TensorType.Value.largestOf(inputs.stream() + .map(input -> input.type().get().type().valueType()) + .collect(Collectors.toList())); + } + + /** * A method signature input and output has the form name:index. * This returns the name part without the index. */ diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index fed95e13bb7..c2d75153586 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -22,13 +22,12 @@ public class Join extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType a = largestInput().type().get(); OrderedTensorType b = smallestInput().type().get(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); int sizeDifference = a.rank() - b.rank(); for (int i = 0; i < a.rank(); ++i) { TensorType.Dimension aDim = a.dimensions().get(i); @@ -52,12 +51,8 @@ public class Join extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; IntermediateOperation a = largestInput(); IntermediateOperation b = smallestInput(); @@ -92,9 +87,8 @@ public class Join extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } + if ( ! allInputTypesPresent(2)) return; + OrderedTensorType a = largestInput().type().get(); OrderedTensorType b = smallestInput().type().get(); int sizeDifference = a.rank() - b.rank(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 1dbfd6e40dc..9a76662529d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -17,10 +17,9 @@ public class MatMul extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + if ( ! allInputTypesPresent(2)) return null; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); return typeBuilder.build(); @@ -28,9 +27,8 @@ public class MatMul extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType aType = inputs.get(0).type().get(); OrderedTensorType bType = inputs.get(1).type().get(); if (aType.type().rank() < 2 || bType.type().rank() < 2) @@ -48,9 +46,8 @@ public class MatMul extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } + if ( ! allInputTypesPresent(2)) return; + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); @@ -69,4 +66,5 @@ public class MatMul extends IntermediateOperation { renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index 4be220db9d5..d8e9950c61f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -32,13 +32,11 @@ public class Mean extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + IntermediateOperation reductionIndices = inputs.get(1); - if (!reductionIndices.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Mean in " + name + ": " + - "reduction indices must be a constant."); + if ( ! reductionIndices.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Mean in " + name + ": Reduction indices must be a constant."); } Tensor indices = reductionIndices.getConstantValue().get().asTensor(); reduceDimensions = new ArrayList<>(); @@ -59,14 +57,14 @@ public class Mean extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + + TensorFunction inputFunction = inputs.get(0).function().get(); TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); for (String name : reduceDimensions) { typeBuilder.indexed(name, 1); } @@ -99,9 +97,9 @@ public class Mean extends IntermediateOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { - if (!reduceDimensions.contains(dimension.name())) { + if ( ! reduceDimensions.contains(dimension.name())) { builder.add(dimension); } else if (keepDimensions) { builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 18f3cc1cc39..4a0fe236c9f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -32,18 +32,16 @@ public class Reshape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + IntermediateOperation newShape = inputs.get(1); - if (!newShape.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Reshape in " + name + ": " + - "shape input must be a constant."); - } + if ( ! newShape.getConstantValue().isPresent()) + throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant."); + Tensor shape = newShape.getConstantValue().get().asTensor(); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); int dimensionIndex = 0; for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); @@ -61,12 +59,9 @@ public class Reshape extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); TensorFunction inputFunction = inputs.get(0).function().get(); return reshape(inputFunction, inputType.type(), type.type()); @@ -80,9 +75,8 @@ public class Reshape extends IntermediateOperation { } public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) { + if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); - } // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, // then use the dimension order of the new shape to roll back into a tensor. @@ -96,20 +90,17 @@ public class Reshape extends IntermediateOperation { TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); Generate transformTensor = new Generate(transformationType, - new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - - TensorFunction outputFunction = new Reduce( - new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); + new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - return outputFunction; + return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); } private static ExpressionNode unrollTensorExpression(TensorType type) { - if (type.rank() == 0) { + if (type.rank() == 0) return new ConstantNode(DoubleValue.zero); - } + List<ExpressionNode> children = new ArrayList<>(); List<ArithmeticOperator> operators = new ArrayList<>(); int size = 1; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index 361729a8c14..79f3012c327 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -19,11 +19,10 @@ public class Shape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } + if ( ! allInputTypesPresent(1)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); - return new OrderedTensorType.Builder() + return new OrderedTensorType.Builder(resultValueType()) .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) .build(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 2eeefcbe8a2..52d40144f61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -25,9 +25,8 @@ public class Squeeze extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } + if ( ! allInputTypesPresent(1)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); squeezeDimensions = new ArrayList<>(); @@ -51,9 +50,8 @@ public class Squeeze extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(1)) { - return null; - } + if ( ! allInputFunctionsPresent(1)) return null; + TensorFunction inputFunction = inputs.get(0).function().get(); return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); } @@ -73,7 +71,7 @@ public class Squeeze extends IntermediateOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if ( ! squeezeDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index 6c92ffa6055..a4fe38cce95 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import org.tensorflow.DataType; import org.tensorflow.framework.TensorProto; import java.nio.ByteBuffer; @@ -27,7 +28,7 @@ public class TensorConverter { } private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix); + TensorType type = toVespaTensorType(tfTensor, dimensionPrefix); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); for (int i = 0; i < values.size(); i++) @@ -53,10 +54,10 @@ public class TensorConverter { return builder.build(); } - private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) { - TensorType.Builder b = new TensorType.Builder(); + private static TensorType toVespaTensorType(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType())); int dimensionIndex = 0; - for (long dimensionSize : shape) { + for (long dimensionSize : tfTensor.shape()) { if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); } @@ -85,7 +86,7 @@ public class TensorConverter { case INT64: return new LongValues(tfTensor); } throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); + tfTensor.dataType() + " to a Vespa tensor"); } private static Values readValuesOf(TensorProto tensorProto) { @@ -107,6 +108,21 @@ public class TensorConverter { throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); } + /** TensorFlow has two different DataType classes. This must be kept in sync with TypeConverter.toValueType */ + static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + /** Allows reading values from buffers of various numeric types as bytes */ private static abstract class Values { abstract double get(int i); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index 63a605ce97a..3e825026b0e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -5,6 +5,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.DataType; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.TensorShapeProto; @@ -22,7 +23,7 @@ class TypeConverter { if (shape != null) { if (shape.getDimCount() != type.rank()) { throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + - "does not match Vespa shape"); + "does not match Vespa shape"); } for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) { int vespaIndex = type.dimensionMap(tensorFlowIndex); @@ -30,7 +31,7 @@ class TypeConverter { TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + - "does not match Vespa dimensions"); + "does not match Vespa dimensions"); } } } @@ -38,16 +39,24 @@ class TypeConverter { private static TensorShapeProto tensorFlowShape(NodeDef node) { AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); - if (attrValueList == null) { + if (attrValueList == null) throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "does not exist"); - } - if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { + "does not exist"); + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "is not of expected type"); - } - List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); - return shapeList.get(0); // support multiple outputs? + "is not of expected type"); + + return attrValueList.getList().getShape(0); // support multiple outputs? + } + + private static DataType tensorFlowValueType(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("dtypes"); + if (attrValueList == null) + return DataType.DT_DOUBLE; // default. This will usually (always?) be used. TODO: How can we do better? + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) + return DataType.DT_DOUBLE; // default + + return attrValueList.getList().getType(0); // support multiple outputs? } static OrderedTensorType fromTensorFlowType(NodeDef node) { @@ -55,8 +64,8 @@ class TypeConverter { } private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); TensorShapeProto shape = tensorFlowShape(node); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node))); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); @@ -69,4 +78,26 @@ class TypeConverter { return builder.build(); } + /** TensorFlow has two different DataType classes. This must be kept in sync with TensorConverter.toValueType */ + static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case DT_FLOAT: return TensorType.Value.FLOAT; + case DT_DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case DT_BOOL: return TensorType.Value.FLOAT; + case DT_BFLOAT16: return TensorType.Value.FLOAT; + case DT_HALF: return TensorType.Value.FLOAT; + case DT_INT8: return TensorType.Value.FLOAT; + case DT_INT16: return TensorType.Value.FLOAT; + case DT_INT32: return TensorType.Value.FLOAT; + case DT_INT64: return TensorType.Value.DOUBLE; + case DT_UINT8: return TensorType.Value.FLOAT; + case DT_UINT16: return TensorType.Value.FLOAT; + case DT_UINT32: return TensorType.Value.FLOAT; + case DT_UINT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java index afe699d6e05..61f332327be 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java @@ -13,9 +13,10 @@ public class OrderedTensorTypeTestCase { @Test public void testToFromSpec() { String spec = "tensor(b[],c{},a[3])"; + String orderedSpec = "tensor(a[3],b[],c{})"; OrderedTensorType type = OrderedTensorType.fromSpec(spec); - assertEquals(spec, type.toString()); - assertEquals("tensor(a[3],b[],c{})", type.type().toString()); + assertEquals(orderedSpec, type.toString()); + assertEquals(orderedSpec, type.type().toString()); } } diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 79c633b9617..b8c51f4e33d 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -886,6 +886,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", "public final com.yahoo.tensor.TensorType tensorTypeArgument()", + "public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()", "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)", "public final java.lang.String tensorFunctionName()", "public final com.yahoo.searchlib.rankingexpression.rule.Function unaryFunctionName()", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 2f173ad0266..c83de4ced0a 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -598,9 +598,12 @@ Reduce.Aggregator tensorReduceAggregator() : TensorType tensorTypeArgument() : { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder; + TensorType.Value valueType; } { + valueType = optionalTensorValueTypeParameter() + { builder = new TensorType.Builder(valueType); } <LBRACE> ( tensorTypeDimension(builder) ) ? ( <COMMA> tensorTypeDimension(builder) ) * @@ -608,6 +611,15 @@ TensorType tensorTypeArgument() : { return builder.build(); } } +TensorType.Value optionalTensorValueTypeParameter() : +{ + String valueType = "double"; +} +{ + ( <LT> valueType = identifier() <GT> )? + { return TensorTypeParser.toValueType(valueType); } +} + // NOTE: Only indexed bound dimensions are parsed currently, as that is what we need void tensorTypeDimension(TensorType.Builder builder) : { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index f2122bb5da9..f7e38862883 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -238,6 +238,8 @@ public class EvaluationTestCase { "{{x:0}:1}", "{}", "{{y:0,z:0}:1}"); tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:3 }", "tensor(x{}):{ {x:1}:5 }"); + tester.assertEvaluates("tensor<float>(x{}):{}", + "tensor0 * tensor1", "{ {x:0}:3 }", "tensor<float>(x{}):{ {x:1}:5 }"); tester.assertEvaluates("{ {x:0}:15 }", "tensor0 * tensor1", "{ {x:0}:3 }", "{ {x:0}:5 }"); tester.assertEvaluates("{ {x:0,y:0}:15 }", diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java index ba0db4de5e1..488930a8eb9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java @@ -40,7 +40,7 @@ public class EvaluationTester { int argumentIndex = 0; for (String argumentString : tensorArgumentStrings) { Tensor argument; - if (argumentString.startsWith("tensor(")) // explicitly decided type + if (argumentString.startsWith("tensor")) // explicitly decided type argument = Tensor.from(argumentString); else // use mappedTensors+dimensions in tensor to decide type argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 239efa0f89c..b071566ae31 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -947,7 +947,7 @@ "public java.lang.String toString()", "public boolean equals(java.lang.Object)", "public long denseSubspaceSize()", - "public static com.yahoo.tensor.TensorType createPartialType(java.util.List)" + "public static com.yahoo.tensor.TensorType createPartialType(com.yahoo.tensor.TensorType$Value, java.util.List)" ], "fields": [] }, @@ -1162,11 +1162,11 @@ ], "methods": [ "public void <init>()", - "public void <init>(com.yahoo.tensor.TensorType$ValueType)", + "public void <init>(com.yahoo.tensor.TensorType$Value)", "public varargs void <init>(com.yahoo.tensor.TensorType[])", - "public varargs void <init>(com.yahoo.tensor.TensorType$ValueType, com.yahoo.tensor.TensorType[])", + "public varargs void <init>(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType[])", "public void <init>(java.lang.Iterable)", - "public void <init>(com.yahoo.tensor.TensorType$ValueType, java.lang.Iterable)", + "public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)", "public int rank()", "public com.yahoo.tensor.TensorType$Builder set(com.yahoo.tensor.TensorType$Dimension)", "public com.yahoo.tensor.TensorType$Builder indexed(java.lang.String, long)", @@ -1270,7 +1270,7 @@ ], "fields": [] }, - "com.yahoo.tensor.TensorType$ValueType": { + "com.yahoo.tensor.TensorType$Value": { "superClass": "java.lang.Enum", "interfaces": [], "attributes": [ @@ -1279,12 +1279,14 @@ "enum" ], "methods": [ - "public static com.yahoo.tensor.TensorType$ValueType[] values()", - "public static com.yahoo.tensor.TensorType$ValueType valueOf(java.lang.String)" + "public static com.yahoo.tensor.TensorType$Value[] values()", + "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)", + "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)", + "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)" ], "fields": [ - "public static final enum com.yahoo.tensor.TensorType$ValueType DOUBLE", - "public static final enum com.yahoo.tensor.TensorType$ValueType FLOAT" + "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", + "public static final enum com.yahoo.tensor.TensorType$Value FLOAT" ] }, "com.yahoo.tensor.TensorType": { @@ -1294,9 +1296,8 @@ "public" ], "methods": [ - "public final com.yahoo.tensor.TensorType$ValueType valueType()", - "public final com.yahoo.tensor.TensorType valueType(com.yahoo.tensor.TensorType$ValueType)", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", + "public com.yahoo.tensor.TensorType$Value valueType()", "public int rank()", "public java.util.List dimensions()", "public java.util.Set dimensionNames()", @@ -1325,7 +1326,7 @@ "methods": [ "public void <init>()", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", - "public static java.util.List dimensionsFromSpec(java.lang.String)" + "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 08878edeb83..c06cb2a0986 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -319,7 +319,7 @@ public class MixedTensor implements Tensor { } public TensorType createBoundType() { - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(type().valueType()); for (int i = 0; i < type.dimensions().size(); ++i) { TensorType.Dimension dimension = type.dimensions().get(i); if (!dimension.isIndexed()) { @@ -355,8 +355,8 @@ public class MixedTensor implements Tensor { this.type = type; this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); this.indexedDimensions = type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()); - this.sparseType = createPartialType(mappedDimensions); - this.denseType = createPartialType(indexedDimensions); + this.sparseType = createPartialType(type.valueType(), mappedDimensions); + this.denseType = createPartialType(type.valueType(), indexedDimensions); } public long indexOf(TensorAddress address) { @@ -476,8 +476,8 @@ public class MixedTensor implements Tensor { } - public static TensorType createPartialType(List<TensorType.Dimension> dimensions) { - TensorType.Builder builder = new TensorType.Builder(); + public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { + TensorType.Builder builder = new TensorType.Builder(valueType); for (TensorType.Dimension dimension : dimensions) { builder.set(dimension); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 998f3170aa0..45a9992c9ad 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -18,7 +18,7 @@ class TensorParser { TensorType typeFromString = TensorTypeParser.fromSpec(typeString); if (type.isPresent() && ! type.get().equals(typeFromString)) throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + - "passed type " + type); + "passed type " + type.get()); return tensorFromValueString(valueString, typeFromString); } else if (tensorString.startsWith("{")) { @@ -48,7 +48,7 @@ class TensorParser { addressBody = addressBody.substring(1); // remove key start if (addressBody.isEmpty()) return TensorType.empty; // Empty key - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE); for (String elementString : addressBody.split(",")) { String[] pair = elementString.split(":"); if (pair.length != 2) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index bded55405c0..5bd44cbc327 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -25,8 +25,29 @@ import java.util.stream.Collectors; public class TensorType { /** The permissible cell value types. Default is double. */ - // Types added here must also be added to TensorTypeParser.parseValueTypeSpec - public enum Value { DOUBLE, FLOAT}; + public enum Value { + + // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below + DOUBLE, FLOAT; + + public static Value largestOf(List<Value> values) { + if (values.isEmpty()) return Value.DOUBLE; // Default + Value largest = null; + for (Value value : values) { + if (largest == null) + largest = value; + else + largest = largestOf(largest, value); + } + return largest; + } + + public static Value largestOf(Value value1, Value value2) { + if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE; + return FLOAT; + } + + }; /** The empty tensor type - which is the same as a double */ public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList()); @@ -170,7 +191,7 @@ public class TensorType { if (this.equals(other)) return Optional.of(this); // shortcut if (this.dimensions.size() != other.dimensions.size()) return Optional.empty(); - Builder b = new Builder(); + Builder b = new Builder(TensorType.Value.largestOf(valueType, other.valueType)); for (int i = 0; i < dimensions.size(); i++) { Dimension thisDim = this.dimensions().get(i); Dimension otherDim = other.dimensions().get(i); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index a5733f1cc4c..d5f77be0dd0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -13,6 +13,7 @@ import java.util.regex.Pattern; * Class for parsing a tensor type spec. * * @author geirst + * @author bratseth */ public class TensorTypeParser { @@ -54,17 +55,24 @@ public class TensorTypeParser { return new TensorType.Builder(valueType, dimensions).build(); } + public static TensorType.Value toValueType(String valueTypeString) { + switch (valueTypeString) { + case "double" : return TensorType.Value.DOUBLE; + case "float" : return TensorType.Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) { if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">")) throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>")); - String valueType = valueTypeSpec.substring(1, valueTypeSpec.length() - 1); - switch (valueType) { - case "double" : return TensorType.Value.DOUBLE; - case "float" : return TensorType.Value.FLOAT; - default : throw formatException(fullSpecString, - "Value type must be either 'double' or 'float'" + - " but was '" + valueType + "'"); + try { + return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); + } + catch (IllegalArgumentException e) { + throw formatException(fullSpecString, e.getMessage()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 91ab4f9d046..a0a257bb909 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction { if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType()) + .indexed(dimensionName, 1) + .build()) + .cell(1,0) + .build(); return tensor.multiply(unitTensor); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 62ee471fcf4..062e0d92e80 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -386,13 +386,12 @@ public class Join extends PrimitiveTensorFunction { return true; } - /** - * Returns common dimension of a and b as a new tensor type - */ + /** Returns common dimension of a and b as a new tensor type */ private static TensorType commonDimensions(Tensor a, Tensor b) { - TensorType.Builder typeBuilder = new TensorType.Builder(); TensorType aType = a.type(); TensorType bType = b.type(); + TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(), + bType.valueType())); for (int i = 0; i < aType.dimensions().size(); ++i) { TensorType.Dimension aDim = aType.dimensions().get(i); for (int j = 0; j < bType.dimensions().size(); ++j) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 54d7710c9dc..017dc3920e6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction { } public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { - if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder b = new TensorType.Builder(); + TensorType.Builder b = new TensorType.Builder(inputType.valueType()); + if (reduceDimensions.isEmpty()) return b.build(); // means reduce all for (TensorType.Dimension dimension : inputType.dimensions()) { if ( ! reduceDimensions.contains(dimension.name())) b.dimension(dimension); @@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction { } private static TensorType type(TensorType argumentType, List<String> dimensions) { - if (dimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(argumentType.valueType()); + if (dimensions.isEmpty()) return builder.build(); // means reduce all for (TensorType.Dimension dimension : argumentType.dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index b268e33b418..db950e6c8b9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -268,7 +268,8 @@ public class ReduceJoin extends CompositeTensorFunction { } private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(), + b.type().valueType())); for (TensorType.Dimension aDim : a.type().dimensions()) { for (TensorType.Dimension bDim : b.type().dimensions()) { if (aDim.name().equals(bDim.name())) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index e18af235d59..5694684956e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -75,7 +75,7 @@ public class Rename extends PrimitiveTensorFunction { } private TensorType type(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(type.valueType()); for (TensorType.Dimension dimension : type.dimensions()) builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index acaeb3ef5ba..284dfea2141 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -78,7 +78,7 @@ class MixedBinaryFormat implements BinaryFormat { TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + - " cannot be assigned to type " + type); + " cannot be assigned to type " + type); } else { type = decodeType(buffer); @@ -103,7 +103,7 @@ class MixedBinaryFormat implements BinaryFormat { private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) { List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); - TensorType sparseType = MixedTensor.createPartialType(sparseDimensions); + TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions); long denseSubspaceSize = builder.denseSubspaceSize(); int numBlocks = 1; diff --git a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java index 9602bdb8d94..f6fed9d33ed 100644 --- a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java @@ -69,16 +69,6 @@ public class BoundingBoxParserTestCase { all1234(parser); } - /** - * Tests various legal inputs and print the output - */ - @Test - public void testPrint() { - String here = "n=63.418417 E=10.433033 S=37.7 W=-122.02"; - parser = new BoundingBoxParser(here); - System.out.println(here+" -> "+parser); - } - @Test public void testGeoPlanetExample() { /* example XML: diff --git a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java index e8ceab44c78..7cf4bddaa01 100644 --- a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java @@ -57,7 +57,6 @@ public class BinaryFormatTestCase { @Test public void testZigZagConversion() { - System.out.println("test zigzag conversion"); assertThat(encode_zigzag(0), is((long)0)); assertThat(decode_zigzag(encode_zigzag(0)), is(0L)); @@ -88,7 +87,6 @@ public class BinaryFormatTestCase { @Test public void testDoubleConversion() { - System.out.println("test double conversion"); assertThat(encode_double(0.0), is(0L)); assertThat(decode_double(encode_double(0.0)), is(0.0)); @@ -116,7 +114,6 @@ public class BinaryFormatTestCase { @Test public void testTypeAndMetaMangling() { - System.out.println("test type and meta mangling"); for (byte type = 0; type < TYPE_LIMIT; ++type) { for (int meta = 0; meta < META_LIMIT; ++meta) { byte mangled = encode_type_and_meta(type, meta); @@ -126,10 +123,8 @@ public class BinaryFormatTestCase { } } - // was testCmprUlong @Test - public void testCmprLong() { - System.out.println("test compressed long"); + public void testCompressedLong() { { long value = 0; byte[] wanted = { 0 }; @@ -217,11 +212,8 @@ public class BinaryFormatTestCase { // testWriteBytes -> buffered IO test // testReadByte -> buffered IO test // testReadBytes -> buffered IO test - @Test - public void testTypeAndSize() { - System.out.println("test type and size conversion"); - + public void testTypeAndSizeConversion() { for (byte type = 0; type < TYPE_LIMIT; ++type) { for (long size = 0; size < 500; ++size) { BufferedOutput expect = new BufferedOutput(); @@ -271,8 +263,7 @@ public class BinaryFormatTestCase { } @Test - public void testTypeAndBytes() { - System.out.println("test encoding and decoding of type and bytes"); + public void testEncodingAndDecodingOfTypeAndBytes() { for (byte type = 0; type < TYPE_LIMIT; ++type) { for (int n = 0; n < MAX_NUM_SIZE; ++n) { for (int pre = 0; (pre == 0) || (pre < n); ++pre) { @@ -307,9 +298,7 @@ public class BinaryFormatTestCase { } @Test - public void testEmpty() { - System.out.println("test encoding empty slime"); - + public void testEncodingEmptySlime() { Slime slime = new Slime(); BufferedOutput expect = new BufferedOutput(); expect.put((byte)0); // num symbols @@ -321,8 +310,7 @@ public class BinaryFormatTestCase { } @Test - public void testBasic() { - System.out.println("test encoding slime holding a single basic value"); + public void testEncodingSlimeHoldingASingleBasicValue() { { Slime slime = new Slime(); slime.setBool(false); @@ -427,8 +415,7 @@ public class BinaryFormatTestCase { } @Test - public void testArray() { - System.out.println("test encoding slime holding an array of various basic values"); + public void testEncodingSlimeArray() { Slime slime = new Slime(); Cursor c = slime.setArray(); byte[] data = { 'd', 'a', 't', 'a' }; @@ -452,8 +439,7 @@ public class BinaryFormatTestCase { } @Test - public void testObject() { - System.out.println("test encoding slime holding an object of various basic values"); + public void testEncodingSlimeObject() { Slime slime = new Slime(); Cursor c = slime.setObject(); byte[] data = { 'd', 'a', 't', 'a' }; @@ -478,8 +464,7 @@ public class BinaryFormatTestCase { } @Test - public void testNesting() { - System.out.println("test encoding slime holding a more complex structure"); + public void testEncodingComplexSlimeStructure() { Slime slime = new Slime(); Cursor c1 = slime.setObject(); c1.setLong("bar", 10); @@ -503,8 +488,7 @@ public class BinaryFormatTestCase { } @Test - public void testSymbolReuse() { - System.out.println("test encoding slime reusing symbols"); + public void testEncodingSlimeReusingSymbols() { Slime slime = new Slime(); Cursor c1 = slime.setArray(); { @@ -533,8 +517,7 @@ public class BinaryFormatTestCase { } @Test - public void testOptionalDecodeOrder() { - System.out.println("test decoding slime with different symbol order"); + public void testDecodingSlimeWithDifferentSymbolOrder() { byte[] data = { 5, // num symbols 1, 'd', 1, 'e', 1, 'f', 1, 'b', 1, 'c', // symbol table @@ -564,4 +547,5 @@ public class BinaryFormatTestCase { assertThat(c.field("f").asData(), is(expd)); assertThat(c.entry(5).valid(), is(false)); // not ARRAY } + } |