summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java7
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java2
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java4
-rw-r--r--container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java47
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java2
-rw-r--r--document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java40
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java24
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java26
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java21
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java22
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java41
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java26
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java53
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java5
-rw-r--r--searchlib/abi-spec.json1
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj14
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java2
-rw-r--r--vespajlib/abi-spec.json25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java38
40 files changed, 338 insertions, 249 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index a0f35dbefe6..75b3af47954 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -191,7 +191,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
else { // default
dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString();
}
- return Optional.of(new TensorType.Builder().mapped(dimension).build());
+
+ // TODO: Determine the type of the weighted set/vector and use that as value type
+ return Optional.of(new TensorType.Builder(TensorType.Value.DOUBLE).mapped(dimension).build());
}
/** Binds the given list of formal arguments to their actual values */
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index f197e2dfe6d..e12cc60b041 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -453,10 +453,9 @@ public class ConvertedModel {
*/
// TODO: determine when this is not necessary!
private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ if (after.equals(before)) return node;
+
+ TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
for (TensorType.Dimension dimension : before.dimensions()) {
if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
typeBuilder.indexed(dimension.name(), 1);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
index 5c96635fd8f..80440ac8eb4 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
@@ -144,7 +144,7 @@ public class RankingExpressionWithTensorTestCase {
@Test
public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'");
+ exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(x)'. Dimension 'x' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])");
RankProfileSearchFixture f = new RankProfileSearchFixture(
" rank-profile my_profile {\n" +
" constants {\n" +
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java
index 2fcf5809ea5..f53ca15635f 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java
@@ -39,7 +39,7 @@ public class TensorFieldTestCase {
@Test
public void requireThatIllegalTensorTypeSpecThrowsException() throws ParseException {
exception.expect(IllegalArgumentException.class);
- exception.expectMessage("Field type: Illegal tensor type spec: Failed parsing element 'invalid' in type spec 'tensor(invalid)'");
+ exception.expectMessage("Field type: Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(invalid)'. Dimension 'invalid' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])");
SearchBuilder.createFromString(getSd("field f1 type tensor(invalid) { indexing: attribute }"));
}
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
index 8eaf4cc08cb..c05c3589a30 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java
@@ -77,7 +77,7 @@ public class QueryProfileTypeTestCase {
type.addField(new FieldDescription("myBoolean", FieldType.fromString("boolean", registry)), registry);
type.addField(new FieldDescription("ranking.features.query(myTensor1)", FieldType.fromString("tensor(a{},b{})", registry)), registry);
type.addField(new FieldDescription("ranking.features.query(myTensor2)", FieldType.fromString("tensor(x[2],y[2])", registry)), registry);
- type.addField(new FieldDescription("ranking.features.query(myTensor3)", FieldType.fromString("tensor(x{})",registry)), registry);
+ type.addField(new FieldDescription("ranking.features.query(myTensor3)", FieldType.fromString("tensor<float>(x{})",registry)), registry);
type.addField(new FieldDescription("myQuery", FieldType.fromString("query", registry)), registry);
type.addField(new FieldDescription("myQueryProfile", FieldType.fromString("query-profile", registry),"qp"), registry);
}
@@ -136,7 +136,7 @@ public class QueryProfileTypeTestCase {
assertEquals(true, properties.get("myBoolean"));
assertEquals(Tensor.from(tensorString1), properties.get("ranking.features.query(myTensor1)"));
assertEquals(Tensor.from("tensor(x[2],y[2])", tensorString2), properties.get("ranking.features.query(myTensor2)"));
- assertEquals(Tensor.from("tensor(x{})", tensorString3), properties.get("ranking.features.query(myTensor3)"));
+ assertEquals(Tensor.from("tensor<float>(x{})", tensorString3), properties.get("ranking.features.query(myTensor3)"));
// TODO: assertEquals(..., cprofile.get("myQuery"));
assertEquals("value1", properties.get("myQueryProfile.anyString"));
assertEquals("value1", properties.get("QP.anyString"));
diff --git a/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java
index 3fa7f1ee47e..b5c4166e4de 100644
--- a/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/yql/UserInputTestCase.java
@@ -3,6 +3,7 @@ package com.yahoo.search.yql;
import static org.junit.Assert.*;
+import com.yahoo.search.query.QueryTree;
import org.apache.http.client.utils.URIBuilder;
import org.junit.After;
import org.junit.Before;
@@ -29,20 +30,20 @@ public class UserInputTestCase {
@Before
public void setUp() throws Exception {
- searchChain = new Chain<Searcher>(new MinimalQueryInserter());
+ searchChain = new Chain<>(new MinimalQueryInserter());
context = Execution.Context.createContextStub(null);
execution = new Execution(searchChain, context);
}
@After
- public void tearDown() throws Exception {
+ public void tearDown() {
searchChain = null;
context = null;
execution = null;
}
@Test
- public final void testSimpleUserInput() {
+ public void testSimpleUserInput() {
{
URIBuilder builder = searchUri();
builder.setParameter("yql",
@@ -70,7 +71,7 @@ public class UserInputTestCase {
}
@Test
- public final void testRawUserInput() {
+ public void testRawUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"grammar\": \"raw\"}]userInput(\"nal le\");");
@@ -79,7 +80,7 @@ public class UserInputTestCase {
}
@Test
- public final void testSegmentedUserInput() {
+ public void testSegmentedUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"grammar\": \"segment\"}]userInput(\"nal le\");");
@@ -88,7 +89,7 @@ public class UserInputTestCase {
}
@Test
- public final void testSegmentedNoiseUserInput() {
+ public void testSegmentedNoiseUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"grammar\": \"segment\"}]userInput(\"^^^^^^^^\");");
@@ -97,7 +98,7 @@ public class UserInputTestCase {
}
@Test
- public final void testCustomDefaultIndexUserInput() {
+ public void testCustomDefaultIndexUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"defaultIndex\": \"glompf\"}]userInput(\"nalle\");");
@@ -106,7 +107,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputStemming() {
+ public void testAnnotatedUserInputStemming() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"stem\": false}]userInput(\"nalle\");");
@@ -117,7 +118,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputUnrankedTerms() {
+ public void testAnnotatedUserInputUnrankedTerms() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"ranked\": false}]userInput(\"nalle\");");
@@ -128,7 +129,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputFiltersTerms() {
+ public void testAnnotatedUserInputFiltersTerms() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"filter\": true}]userInput(\"nalle\");");
@@ -139,7 +140,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputCaseNormalization() {
+ public void testAnnotatedUserInputCaseNormalization() {
URIBuilder builder = searchUri();
builder.setParameter(
"yql",
@@ -151,7 +152,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputAccentRemoval() {
+ public void testAnnotatedUserInputAccentRemoval() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"accentDrop\": false}]userInput(\"nalle\");");
@@ -162,7 +163,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnotatedUserInputPositionData() {
+ public void testAnnotatedUserInputPositionData() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where [{\"usePositionData\": false}]userInput(\"nalle\");");
@@ -173,7 +174,7 @@ public class UserInputTestCase {
}
@Test
- public final void testQueryPropertiesAsStringArguments() {
+ public void testQueryPropertiesAsStringArguments() {
URIBuilder builder = searchUri();
builder.setParameter("nalle", "bamse");
builder.setParameter("meta", "syntactic");
@@ -197,7 +198,7 @@ public class UserInputTestCase {
}
@Test
- public final void testEmptyUserInput() {
+ public void testEmptyUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where userInput(\"\");");
@@ -205,7 +206,7 @@ public class UserInputTestCase {
}
@Test
- public final void testEmptyUserInputFromQueryProperty() {
+ public void testEmptyUserInputFromQueryProperty() {
URIBuilder builder = searchUri();
builder.setParameter("foo", "");
builder.setParameter("yql",
@@ -214,7 +215,7 @@ public class UserInputTestCase {
}
@Test
- public final void testEmptyQueryProperty() {
+ public void testEmptyQueryProperty() {
URIBuilder builder = searchUri();
builder.setParameter("foo", "");
builder.setParameter("yql", "select * from sources * where bar contains \"a\" and nonEmpty(foo contains @foo);");
@@ -222,7 +223,7 @@ public class UserInputTestCase {
}
@Test
- public final void testEmptyQueryPropertyInsideExpression() {
+ public void testEmptyQueryPropertyInsideExpression() {
URIBuilder builder = searchUri();
builder.setParameter("foo", "");
builder.setParameter("yql",
@@ -231,7 +232,7 @@ public class UserInputTestCase {
}
@Test
- public final void testCompositeWithoutArguments() {
+ public void testCompositeWithoutArguments() {
URIBuilder builder = searchUri();
builder.setParameter("yql", "select * from sources * where bar contains \"a\" and foo contains phrase();");
searchAndAssertNoErrors(builder);
@@ -241,7 +242,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAnnoyingPlacementOfNonEmpty() {
+ public void testAnnoyingPlacementOfNonEmpty() {
URIBuilder builder = searchUri();
builder.setParameter("yql",
"select * from sources * where bar contains \"a\" and foo contains nonEmpty(phrase(\"a\", \"b\"));");
@@ -254,7 +255,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAllowEmptyUserInput() {
+ public void testAllowEmptyUserInput() {
URIBuilder builder = searchUri();
builder.setParameter("foo", "");
builder.setParameter("yql", "select * from sources * where [{\"allowEmpty\": true}]userInput(@foo);");
@@ -262,7 +263,7 @@ public class UserInputTestCase {
}
@Test
- public final void testAllowEmptyNullFromQueryParsing() {
+ public void testAllowEmptyNullFromQueryParsing() {
URIBuilder builder = searchUri();
builder.setParameter("foo", ",,,,,,,,");
builder.setParameter("yql", "select * from sources * where [{\"allowEmpty\": true}]userInput(@foo);");
@@ -270,7 +271,7 @@ public class UserInputTestCase {
}
@Test
- public final void testDisallowEmptyNullFromQueryParsing() {
+ public void testDisallowEmptyNullFromQueryParsing() {
URIBuilder builder = searchUri();
builder.setParameter("foo", ",,,,,,,,");
builder.setParameter("yql", "select * from sources * where userInput(@foo);");
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
index 2773f9d31da..435c8fcdc65 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java
@@ -38,7 +38,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> {
* Converts the given tensor type to a type that is compatible for being used in this update (has only mapped dimensions).
*/
public static TensorType convertDimensionsToMapped(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
type.dimensions().stream().forEach(dim -> builder.mapped(dim.name()));
return builder.build();
}
diff --git a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
index 335cda8e133..981120af145 100644
--- a/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
+++ b/document/src/main/java/com/yahoo/document/update/TensorRemoveUpdate.java
@@ -97,7 +97,7 @@ public class TensorRemoveUpdate extends ValueUpdate<TensorFieldValue> {
}
public static TensorType extractSparseDimensions(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
return builder.build();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
index c4acfeb3235..9c8f6238731 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -29,9 +29,17 @@ public class OrderedTensorType {
private final long[] innerSizesVespa;
private final int[] dimensionMap;
- private OrderedTensorType(List<TensorType.Dimension> dimensions) {
+ private OrderedTensorType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
this.dimensions = Collections.unmodifiableList(dimensions);
- this.type = new TensorType.Builder(dimensions).build();
+ this.type = new TensorType.Builder(valueType, dimensions).build();
+ this.innerSizesOriginal = new long[dimensions.size()];
+ this.innerSizesVespa = new long[dimensions.size()];
+ this.dimensionMap = createDimensionMap();
+ }
+
+ private OrderedTensorType(TensorType type) {
+ this.dimensions = type.dimensions();
+ this.type = type;
this.innerSizesOriginal = new long[dimensions.size()];
this.innerSizesVespa = new long[dimensions.size()];
this.dimensionMap = createDimensionMap();
@@ -136,11 +144,11 @@ public class OrderedTensorType {
renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
}
}
- return new OrderedTensorType(renamedDimensions);
+ return new OrderedTensorType(type.valueType(), renamedDimensions);
}
public OrderedTensorType rename(String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.valueType());
for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
Optional<Long> dimSize = dimensions.get(i).size();
@@ -154,7 +162,7 @@ public class OrderedTensorType {
}
public static OrderedTensorType standardType(OrderedTensorType type) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.type().valueType());
for (int i = 0; i < type.dimensions().size(); ++ i) {
TensorType.Dimension dim = type.dimensions().get(i);
String dimensionName = "d" + i;
@@ -193,18 +201,18 @@ public class OrderedTensorType {
* where dimensions are listed in the order of this rather than the natural order of their names.
*/
public static OrderedTensorType fromSpec(String typeSpec) {
- return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ return new OrderedTensorType(TensorType.fromSpec(typeSpec));
}
- public static OrderedTensorType fromDimensionList(List<Long> dims) {
- return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
+ public static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions) {
+ return fromDimensionList(valueType, dimensions, "d"); // standard naming convention: d0, d1, ...
}
- private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < dims.size(); ++ i) {
+ private static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueType);
+ for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
- Long dimSize = dims.get(i);
+ Long dimSize = dimensions.get(i);
if (dimSize >= 0) {
builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
} else {
@@ -216,9 +224,15 @@ public class OrderedTensorType {
public static class Builder {
+ private final TensorType.Value valueType;
private final List<TensorType.Dimension> dimensions;
public Builder() {
+ this(TensorType.Value.DOUBLE);
+ }
+
+ public Builder(TensorType.Value valueType) {
+ this.valueType = valueType;
this.dimensions = new ArrayList<>();
}
@@ -228,7 +242,7 @@ public class OrderedTensorType {
}
public OrderedTensorType build() {
- return new OrderedTensorType(dimensions);
+ return new OrderedTensorType(valueType, dimensions);
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index dd2add973e4..5cc1defc010 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -16,8 +16,10 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import onnx.Onnx.TensorProto.DataType;
import java.util.List;
import java.util.stream.Collectors;
@@ -114,7 +116,8 @@ class GraphImporter {
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
- OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
+ OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(toValueType(tensorProto.getDataType()),
+ tensorProto.getDimsList());
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
@@ -133,6 +136,25 @@ class GraphImporter {
return operation;
}
+ private static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
Onnx.TensorProto tensor = getConstantTensor(name, graph);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index f251a14213b..79b399f2c6f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -36,7 +36,7 @@ class TypeConverter {
private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TensorType.Value.DOUBLE);
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index 1a564661ccb..7ae50a0549d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -21,20 +21,15 @@ public class ConcatV2 extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
- return null;
- }
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null;
IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
- if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a constant.");
- }
+ if ( ! concatDimOp.getConstantValue().isPresent())
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a constant.");
+
Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
- if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a scalar.");
- }
+ if (concatDimTensor.type().rank() != 0)
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a scalar.");
OrderedTensorType aType = inputs.get(0).type().get();
concatDimensionIndex = (int)concatDimTensor.asDouble();
@@ -42,10 +37,9 @@ public class ConcatV2 extends IntermediateOperation {
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
- if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "inputs must have save rank.");
- }
+ if (bType.rank() != aType.rank())
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Inputs must have the same rank.");
+
for (int j = 0; j < aType.rank(); ++j) {
long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
@@ -58,7 +52,7 @@ public class ConcatV2 extends IntermediateOperation {
}
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
if (dimensionIndex == concatDimensionIndex) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 8ae6d81b8d4..c64b9ded601 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -27,20 +27,15 @@ public class ExpandDims extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis must be a constant.");
+ throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
- if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis argument must be a scalar.");
- }
+ if (axis.type().rank() != 0)
+ throw new IllegalArgumentException("ExpandDims in " + name + ": Axis argument must be a scalar.");
OrderedTensorType inputType = inputs.get(0).type().get();
int dimensionToInsert = (int)axis.asDouble();
@@ -48,7 +43,7 @@ public class ExpandDims extends IntermediateOperation {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
@@ -66,12 +61,10 @@ public class ExpandDims extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputFunctionsPresent(2)) return null;
// multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
for (String name : expandDimensions) {
typeBuilder.indexed(name, 1);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 3b77f9527ca..0ee54f839bc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -9,6 +9,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
@@ -17,6 +18,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
+import java.util.stream.Collectors;
/**
* Wraps an imported operation node and produces the respective Vespa tensor
@@ -161,6 +163,19 @@ public abstract class IntermediateOperation {
}
/**
+ * Returns the largest value type among the input value types.
+ * This should only be called after it has been verified that input types are available.
+ *
+ * @throws IllegalArgumentException if a type cannot be uniquely determined
+ * @throws RuntimeException if called when input types are not available
+ */
+ TensorType.Value resultValueType() {
+ return TensorType.Value.largestOf(inputs.stream()
+ .map(input -> input.type().get().type().valueType())
+ .collect(Collectors.toList()));
+ }
+
+ /**
* A method signature input and output has the form name:index.
* This returns the name part without the index.
*/
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index fed95e13bb7..c2d75153586 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -22,13 +22,12 @@ public class Join extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
OrderedTensorType a = largestInput().type().get();
OrderedTensorType b = smallestInput().type().get();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
int sizeDifference = a.rank() - b.rank();
for (int i = 0; i < a.rank(); ++i) {
TensorType.Dimension aDim = a.dimensions().get(i);
@@ -52,12 +51,8 @@ public class Join extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
IntermediateOperation a = largestInput();
IntermediateOperation b = smallestInput();
@@ -92,9 +87,8 @@ public class Join extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
+ if ( ! allInputTypesPresent(2)) return;
+
OrderedTensorType a = largestInput().type().get();
OrderedTensorType b = smallestInput().type().get();
int sizeDifference = a.rank() - b.rank();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 1dbfd6e40dc..9a76662529d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -17,10 +17,9 @@ public class MatMul extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ if ( ! allInputTypesPresent(2)) return null;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
return typeBuilder.build();
@@ -28,9 +27,8 @@ public class MatMul extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
OrderedTensorType aType = inputs.get(0).type().get();
OrderedTensorType bType = inputs.get(1).type().get();
if (aType.type().rank() < 2 || bType.type().rank() < 2)
@@ -48,9 +46,8 @@ public class MatMul extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
+ if ( ! allInputTypesPresent(2)) return;
+
List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
@@ -69,4 +66,5 @@ public class MatMul extends IntermediateOperation {
renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
}
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
index 4be220db9d5..d8e9950c61f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
@@ -32,13 +32,11 @@ public class Mean extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
IntermediateOperation reductionIndices = inputs.get(1);
- if (!reductionIndices.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Mean in " + name + ": " +
- "reduction indices must be a constant.");
+ if ( ! reductionIndices.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Mean in " + name + ": Reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
reduceDimensions = new ArrayList<>();
@@ -59,14 +57,14 @@ public class Mean extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
+
TensorFunction inputFunction = inputs.get(0).function().get();
TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
for (String name : reduceDimensions) {
typeBuilder.indexed(name, 1);
}
@@ -99,9 +97,9 @@ public class Mean extends IntermediateOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
- if (!reduceDimensions.contains(dimension.name())) {
+ if ( ! reduceDimensions.contains(dimension.name())) {
builder.add(dimension);
} else if (keepDimensions) {
builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 18f3cc1cc39..4a0fe236c9f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -32,18 +32,16 @@ public class Reshape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
IntermediateOperation newShape = inputs.get(1);
- if (!newShape.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Reshape in " + name + ": " +
- "shape input must be a constant.");
- }
+ if ( ! newShape.getConstantValue().isPresent())
+ throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant.");
+
Tensor shape = newShape.getConstantValue().get().asTensor();
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
int dimensionIndex = 0;
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
@@ -61,12 +59,9 @@ public class Reshape extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
TensorFunction inputFunction = inputs.get(0).function().get();
return reshape(inputFunction, inputType.type(), type.type());
@@ -80,9 +75,8 @@ public class Reshape extends IntermediateOperation {
}
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
+ if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType)))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
- }
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
@@ -96,20 +90,17 @@ public class Reshape extends IntermediateOperation {
TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
Generate transformTensor = new Generate(transformationType,
- new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
-
- TensorFunction outputFunction = new Reduce(
- new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+ new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
- return outputFunction;
+ return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
}
private static ExpressionNode unrollTensorExpression(TensorType type) {
- if (type.rank() == 0) {
+ if (type.rank() == 0)
return new ConstantNode(DoubleValue.zero);
- }
+
List<ExpressionNode> children = new ArrayList<>();
List<ArithmeticOperator> operators = new ArrayList<>();
int size = 1;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
index 361729a8c14..79f3012c327 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
@@ -19,11 +19,10 @@ public class Shape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
+ if ( ! allInputTypesPresent(1)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
- return new OrderedTensorType.Builder()
+ return new OrderedTensorType.Builder(resultValueType())
.add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
.build();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index 2eeefcbe8a2..52d40144f61 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -25,9 +25,8 @@ public class Squeeze extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
+ if ( ! allInputTypesPresent(1)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
@@ -51,9 +50,8 @@ public class Squeeze extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(1)) {
- return null;
- }
+ if ( ! allInputFunctionsPresent(1)) return null;
+
TensorFunction inputFunction = inputs.get(0).function().get();
return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
}
@@ -73,7 +71,7 @@ public class Squeeze extends IntermediateOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if ( ! squeezeDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index 6c92ffa6055..a4fe38cce95 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import org.tensorflow.DataType;
import org.tensorflow.framework.TensorProto;
import java.nio.ByteBuffer;
@@ -27,7 +28,7 @@ public class TensorConverter {
}
private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
+ TensorType type = toVespaTensorType(tfTensor, dimensionPrefix);
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
for (int i = 0; i < values.size(); i++)
@@ -53,10 +54,10 @@ public class TensorConverter {
return builder.build();
}
- private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) {
- TensorType.Builder b = new TensorType.Builder();
+ private static TensorType toVespaTensorType(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType()));
int dimensionIndex = 0;
- for (long dimensionSize : shape) {
+ for (long dimensionSize : tfTensor.shape()) {
if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
}
@@ -85,7 +86,7 @@ public class TensorConverter {
case INT64: return new LongValues(tfTensor);
}
throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
+ tfTensor.dataType() + " to a Vespa tensor");
}
private static Values readValuesOf(TensorProto tensorProto) {
@@ -107,6 +108,21 @@ public class TensorConverter {
throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
}
+ /** TensorFlow has two different DataType classes. This must be kept in sync with TypeConverter.toValueType */
+ static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
/** Allows reading values from buffers of various numeric types as bytes */
private static abstract class Values {
abstract double get(int i);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index 63a605ce97a..3e825026b0e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -5,6 +5,7 @@ package ai.vespa.rankingexpression.importer.tensorflow;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.DataType;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;
@@ -22,7 +23,7 @@ class TypeConverter {
if (shape != null) {
if (shape.getDimCount() != type.rank()) {
throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
+ "does not match Vespa shape");
}
for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
int vespaIndex = type.dimensionMap(tensorFlowIndex);
@@ -30,7 +31,7 @@ class TypeConverter {
TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
+ "does not match Vespa dimensions");
}
}
}
@@ -38,16 +39,24 @@ class TypeConverter {
private static TensorShapeProto tensorFlowShape(NodeDef node) {
AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null) {
+ if (attrValueList == null)
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- }
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
+ "does not exist");
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST)
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
- }
- List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
- return shapeList.get(0); // support multiple outputs?
+ "is not of expected type");
+
+ return attrValueList.getList().getShape(0); // support multiple outputs?
+ }
+
+ private static DataType tensorFlowValueType(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("dtypes");
+ if (attrValueList == null)
+ return DataType.DT_DOUBLE; // default. This will usually (always?) be used. TODO: How can we do better?
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST)
+ return DataType.DT_DOUBLE; // default
+
+ return attrValueList.getList().getType(0); // support multiple outputs?
}
static OrderedTensorType fromTensorFlowType(NodeDef node) {
@@ -55,8 +64,8 @@ class TypeConverter {
}
private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
TensorShapeProto shape = tensorFlowShape(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node)));
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
@@ -69,4 +78,26 @@ class TypeConverter {
return builder.build();
}
+ /** TensorFlow has two different DataType classes. This must be kept in sync with TensorConverter.toValueType */
+ static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case DT_FLOAT: return TensorType.Value.FLOAT;
+ case DT_DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case DT_BOOL: return TensorType.Value.FLOAT;
+ case DT_BFLOAT16: return TensorType.Value.FLOAT;
+ case DT_HALF: return TensorType.Value.FLOAT;
+ case DT_INT8: return TensorType.Value.FLOAT;
+ case DT_INT16: return TensorType.Value.FLOAT;
+ case DT_INT32: return TensorType.Value.FLOAT;
+ case DT_INT64: return TensorType.Value.DOUBLE;
+ case DT_UINT8: return TensorType.Value.FLOAT;
+ case DT_UINT16: return TensorType.Value.FLOAT;
+ case DT_UINT32: return TensorType.Value.FLOAT;
+ case DT_UINT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
index afe699d6e05..61f332327be 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
@@ -13,9 +13,10 @@ public class OrderedTensorTypeTestCase {
@Test
public void testToFromSpec() {
String spec = "tensor(b[],c{},a[3])";
+ String orderedSpec = "tensor(a[3],b[],c{})";
OrderedTensorType type = OrderedTensorType.fromSpec(spec);
- assertEquals(spec, type.toString());
- assertEquals("tensor(a[3],b[],c{})", type.type().toString());
+ assertEquals(orderedSpec, type.toString());
+ assertEquals(orderedSpec, type.type().toString());
}
}
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 79c633b9617..b8c51f4e33d 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -886,6 +886,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
"public final com.yahoo.tensor.TensorType tensorTypeArgument()",
+ "public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()",
"public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)",
"public final java.lang.String tensorFunctionName()",
"public final com.yahoo.searchlib.rankingexpression.rule.Function unaryFunctionName()",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 2f173ad0266..c83de4ced0a 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -598,9 +598,12 @@ Reduce.Aggregator tensorReduceAggregator() :
TensorType tensorTypeArgument() :
{
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder;
+ TensorType.Value valueType;
}
{
+ valueType = optionalTensorValueTypeParameter()
+ { builder = new TensorType.Builder(valueType); }
<LBRACE>
( tensorTypeDimension(builder) ) ?
( <COMMA> tensorTypeDimension(builder) ) *
@@ -608,6 +611,15 @@ TensorType tensorTypeArgument() :
{ return builder.build(); }
}
+TensorType.Value optionalTensorValueTypeParameter() :
+{
+ String valueType = "double";
+}
+{
+ ( <LT> valueType = identifier() <GT> )?
+ { return TensorTypeParser.toValueType(valueType); }
+}
+
// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need
void tensorTypeDimension(TensorType.Builder builder) :
{
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index f2122bb5da9..f7e38862883 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -238,6 +238,8 @@ public class EvaluationTestCase {
"{{x:0}:1}", "{}", "{{y:0,z:0}:1}");
tester.assertEvaluates("tensor(x{}):{}",
"tensor0 * tensor1", "{ {x:0}:3 }", "tensor(x{}):{ {x:1}:5 }");
+ tester.assertEvaluates("tensor<float>(x{}):{}",
+ "tensor0 * tensor1", "{ {x:0}:3 }", "tensor<float>(x{}):{ {x:1}:5 }");
tester.assertEvaluates("{ {x:0}:15 }",
"tensor0 * tensor1", "{ {x:0}:3 }", "{ {x:0}:5 }");
tester.assertEvaluates("{ {x:0,y:0}:15 }",
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index ba0db4de5e1..488930a8eb9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -40,7 +40,7 @@ public class EvaluationTester {
int argumentIndex = 0;
for (String argumentString : tensorArgumentStrings) {
Tensor argument;
- if (argumentString.startsWith("tensor(")) // explicitly decided type
+ if (argumentString.startsWith("tensor")) // explicitly decided type
argument = Tensor.from(argumentString);
else // use mappedTensors+dimensions in tensor to decide type
argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString);
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 239efa0f89c..b071566ae31 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -947,7 +947,7 @@
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
"public long denseSubspaceSize()",
- "public static com.yahoo.tensor.TensorType createPartialType(java.util.List)"
+ "public static com.yahoo.tensor.TensorType createPartialType(com.yahoo.tensor.TensorType$Value, java.util.List)"
],
"fields": []
},
@@ -1162,11 +1162,11 @@
],
"methods": [
"public void <init>()",
- "public void <init>(com.yahoo.tensor.TensorType$ValueType)",
+ "public void <init>(com.yahoo.tensor.TensorType$Value)",
"public varargs void <init>(com.yahoo.tensor.TensorType[])",
- "public varargs void <init>(com.yahoo.tensor.TensorType$ValueType, com.yahoo.tensor.TensorType[])",
+ "public varargs void <init>(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType[])",
"public void <init>(java.lang.Iterable)",
- "public void <init>(com.yahoo.tensor.TensorType$ValueType, java.lang.Iterable)",
+ "public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)",
"public int rank()",
"public com.yahoo.tensor.TensorType$Builder set(com.yahoo.tensor.TensorType$Dimension)",
"public com.yahoo.tensor.TensorType$Builder indexed(java.lang.String, long)",
@@ -1270,7 +1270,7 @@
],
"fields": []
},
- "com.yahoo.tensor.TensorType$ValueType": {
+ "com.yahoo.tensor.TensorType$Value": {
"superClass": "java.lang.Enum",
"interfaces": [],
"attributes": [
@@ -1279,12 +1279,14 @@
"enum"
],
"methods": [
- "public static com.yahoo.tensor.TensorType$ValueType[] values()",
- "public static com.yahoo.tensor.TensorType$ValueType valueOf(java.lang.String)"
+ "public static com.yahoo.tensor.TensorType$Value[] values()",
+ "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)",
+ "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)",
+ "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)"
],
"fields": [
- "public static final enum com.yahoo.tensor.TensorType$ValueType DOUBLE",
- "public static final enum com.yahoo.tensor.TensorType$ValueType FLOAT"
+ "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE",
+ "public static final enum com.yahoo.tensor.TensorType$Value FLOAT"
]
},
"com.yahoo.tensor.TensorType": {
@@ -1294,9 +1296,8 @@
"public"
],
"methods": [
- "public final com.yahoo.tensor.TensorType$ValueType valueType()",
- "public final com.yahoo.tensor.TensorType valueType(com.yahoo.tensor.TensorType$ValueType)",
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
+ "public com.yahoo.tensor.TensorType$Value valueType()",
"public int rank()",
"public java.util.List dimensions()",
"public java.util.Set dimensionNames()",
@@ -1325,7 +1326,7 @@
"methods": [
"public void <init>()",
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
- "public static java.util.List dimensionsFromSpec(java.lang.String)"
+ "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 08878edeb83..c06cb2a0986 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -319,7 +319,7 @@ public class MixedTensor implements Tensor {
}
public TensorType createBoundType() {
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(type().valueType());
for (int i = 0; i < type.dimensions().size(); ++i) {
TensorType.Dimension dimension = type.dimensions().get(i);
if (!dimension.isIndexed()) {
@@ -355,8 +355,8 @@ public class MixedTensor implements Tensor {
this.type = type;
this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
this.indexedDimensions = type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList());
- this.sparseType = createPartialType(mappedDimensions);
- this.denseType = createPartialType(indexedDimensions);
+ this.sparseType = createPartialType(type.valueType(), mappedDimensions);
+ this.denseType = createPartialType(type.valueType(), indexedDimensions);
}
public long indexOf(TensorAddress address) {
@@ -476,8 +476,8 @@ public class MixedTensor implements Tensor {
}
- public static TensorType createPartialType(List<TensorType.Dimension> dimensions) {
- TensorType.Builder builder = new TensorType.Builder();
+ public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
+ TensorType.Builder builder = new TensorType.Builder(valueType);
for (TensorType.Dimension dimension : dimensions) {
builder.set(dimension);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 998f3170aa0..45a9992c9ad 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -18,7 +18,7 @@ class TensorParser {
TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
if (type.isPresent() && ! type.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
- "passed type " + type);
+ "passed type " + type.get());
return tensorFromValueString(valueString, typeFromString);
}
else if (tensorString.startsWith("{")) {
@@ -48,7 +48,7 @@ class TensorParser {
addressBody = addressBody.substring(1); // remove key start
if (addressBody.isEmpty()) return TensorType.empty; // Empty key
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE);
for (String elementString : addressBody.split(",")) {
String[] pair = elementString.split(":");
if (pair.length != 2)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index bded55405c0..5bd44cbc327 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -25,8 +25,29 @@ import java.util.stream.Collectors;
public class TensorType {
/** The permissible cell value types. Default is double. */
- // Types added here must also be added to TensorTypeParser.parseValueTypeSpec
- public enum Value { DOUBLE, FLOAT};
+ public enum Value {
+
+ // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
+ DOUBLE, FLOAT;
+
+ public static Value largestOf(List<Value> values) {
+ if (values.isEmpty()) return Value.DOUBLE; // Default
+ Value largest = null;
+ for (Value value : values) {
+ if (largest == null)
+ largest = value;
+ else
+ largest = largestOf(largest, value);
+ }
+ return largest;
+ }
+
+ public static Value largestOf(Value value1, Value value2) {
+ if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE;
+ return FLOAT;
+ }
+
+ };
/** The empty tensor type - which is the same as a double */
public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList());
@@ -170,7 +191,7 @@ public class TensorType {
if (this.equals(other)) return Optional.of(this); // shortcut
if (this.dimensions.size() != other.dimensions.size()) return Optional.empty();
- Builder b = new Builder();
+ Builder b = new Builder(TensorType.Value.largestOf(valueType, other.valueType));
for (int i = 0; i < dimensions.size(); i++) {
Dimension thisDim = this.dimensions().get(i);
Dimension otherDim = other.dimensions().get(i);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index a5733f1cc4c..d5f77be0dd0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -13,6 +13,7 @@ import java.util.regex.Pattern;
* Class for parsing a tensor type spec.
*
* @author geirst
+ * @author bratseth
*/
public class TensorTypeParser {
@@ -54,17 +55,24 @@ public class TensorTypeParser {
return new TensorType.Builder(valueType, dimensions).build();
}
+ public static TensorType.Value toValueType(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
- String valueType = valueTypeSpec.substring(1, valueTypeSpec.length() - 1);
- switch (valueType) {
- case "double" : return TensorType.Value.DOUBLE;
- case "float" : return TensorType.Value.FLOAT;
- default : throw formatException(fullSpecString,
- "Value type must be either 'double' or 'float'" +
- " but was '" + valueType + "'");
+ try {
+ return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ }
+ catch (IllegalArgumentException e) {
+ throw formatException(fullSpecString, e.getMessage());
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 91ab4f9d046..a0a257bb909 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction {
if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType())
+ .indexed(dimensionName, 1)
+ .build())
+ .cell(1,0)
+ .build();
return tensor.multiply(unitTensor);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 62ee471fcf4..062e0d92e80 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -386,13 +386,12 @@ public class Join extends PrimitiveTensorFunction {
return true;
}
- /**
- * Returns common dimension of a and b as a new tensor type
- */
+ /** Returns common dimension of a and b as a new tensor type */
private static TensorType commonDimensions(Tensor a, Tensor b) {
- TensorType.Builder typeBuilder = new TensorType.Builder();
TensorType aType = a.type();
TensorType bType = b.type();
+ TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(),
+ bType.valueType()));
for (int i = 0; i < aType.dimensions().size(); ++i) {
TensorType.Dimension aDim = aType.dimensions().get(i);
for (int j = 0; j < bType.dimensions().size(); ++j) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 54d7710c9dc..017dc3920e6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
- if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder b = new TensorType.Builder();
+ TensorType.Builder b = new TensorType.Builder(inputType.valueType());
+ if (reduceDimensions.isEmpty()) return b.build(); // means reduce all
for (TensorType.Dimension dimension : inputType.dimensions()) {
if ( ! reduceDimensions.contains(dimension.name()))
b.dimension(dimension);
@@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static TensorType type(TensorType argumentType, List<String> dimensions) {
- if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(argumentType.valueType());
+ if (dimensions.isEmpty()) return builder.build(); // means reduce all
for (TensorType.Dimension dimension : argumentType.dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index b268e33b418..db950e6c8b9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -268,7 +268,8 @@ public class ReduceJoin extends CompositeTensorFunction {
}
private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(),
+ b.type().valueType()));
for (TensorType.Dimension aDim : a.type().dimensions()) {
for (TensorType.Dimension bDim : b.type().dimensions()) {
if (aDim.name().equals(bDim.name())) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index e18af235d59..5694684956e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -75,7 +75,7 @@ public class Rename extends PrimitiveTensorFunction {
}
private TensorType type(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
for (TensorType.Dimension dimension : type.dimensions())
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index acaeb3ef5ba..284dfea2141 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -78,7 +78,7 @@ class MixedBinaryFormat implements BinaryFormat {
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
- " cannot be assigned to type " + type);
+ " cannot be assigned to type " + type);
}
else {
type = decodeType(buffer);
@@ -103,7 +103,7 @@ class MixedBinaryFormat implements BinaryFormat {
private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
- TensorType sparseType = MixedTensor.createPartialType(sparseDimensions);
+ TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions);
long denseSubspaceSize = builder.denseSubspaceSize();
int numBlocks = 1;
diff --git a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
index 9602bdb8d94..f6fed9d33ed 100644
--- a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
@@ -69,16 +69,6 @@ public class BoundingBoxParserTestCase {
all1234(parser);
}
- /**
- * Tests various legal inputs and print the output
- */
- @Test
- public void testPrint() {
- String here = "n=63.418417 E=10.433033 S=37.7 W=-122.02";
- parser = new BoundingBoxParser(here);
- System.out.println(here+" -> "+parser);
- }
-
@Test
public void testGeoPlanetExample() {
/* example XML:
diff --git a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
index e8ceab44c78..7cf4bddaa01 100644
--- a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
@@ -57,7 +57,6 @@ public class BinaryFormatTestCase {
@Test
public void testZigZagConversion() {
- System.out.println("test zigzag conversion");
assertThat(encode_zigzag(0), is((long)0));
assertThat(decode_zigzag(encode_zigzag(0)), is(0L));
@@ -88,7 +87,6 @@ public class BinaryFormatTestCase {
@Test
public void testDoubleConversion() {
- System.out.println("test double conversion");
assertThat(encode_double(0.0), is(0L));
assertThat(decode_double(encode_double(0.0)), is(0.0));
@@ -116,7 +114,6 @@ public class BinaryFormatTestCase {
@Test
public void testTypeAndMetaMangling() {
- System.out.println("test type and meta mangling");
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (int meta = 0; meta < META_LIMIT; ++meta) {
byte mangled = encode_type_and_meta(type, meta);
@@ -126,10 +123,8 @@ public class BinaryFormatTestCase {
}
}
- // was testCmprUlong
@Test
- public void testCmprLong() {
- System.out.println("test compressed long");
+ public void testCompressedLong() {
{
long value = 0;
byte[] wanted = { 0 };
@@ -217,11 +212,8 @@ public class BinaryFormatTestCase {
// testWriteBytes -> buffered IO test
// testReadByte -> buffered IO test
// testReadBytes -> buffered IO test
-
@Test
- public void testTypeAndSize() {
- System.out.println("test type and size conversion");
-
+ public void testTypeAndSizeConversion() {
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (long size = 0; size < 500; ++size) {
BufferedOutput expect = new BufferedOutput();
@@ -271,8 +263,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testTypeAndBytes() {
- System.out.println("test encoding and decoding of type and bytes");
+ public void testEncodingAndDecodingOfTypeAndBytes() {
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (int n = 0; n < MAX_NUM_SIZE; ++n) {
for (int pre = 0; (pre == 0) || (pre < n); ++pre) {
@@ -307,9 +298,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testEmpty() {
- System.out.println("test encoding empty slime");
-
+ public void testEncodingEmptySlime() {
Slime slime = new Slime();
BufferedOutput expect = new BufferedOutput();
expect.put((byte)0); // num symbols
@@ -321,8 +310,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testBasic() {
- System.out.println("test encoding slime holding a single basic value");
+ public void testEncodingSlimeHoldingASingleBasicValue() {
{
Slime slime = new Slime();
slime.setBool(false);
@@ -427,8 +415,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testArray() {
- System.out.println("test encoding slime holding an array of various basic values");
+ public void testEncodingSlimeArray() {
Slime slime = new Slime();
Cursor c = slime.setArray();
byte[] data = { 'd', 'a', 't', 'a' };
@@ -452,8 +439,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testObject() {
- System.out.println("test encoding slime holding an object of various basic values");
+ public void testEncodingSlimeObject() {
Slime slime = new Slime();
Cursor c = slime.setObject();
byte[] data = { 'd', 'a', 't', 'a' };
@@ -478,8 +464,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testNesting() {
- System.out.println("test encoding slime holding a more complex structure");
+ public void testEncodingComplexSlimeStructure() {
Slime slime = new Slime();
Cursor c1 = slime.setObject();
c1.setLong("bar", 10);
@@ -503,8 +488,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testSymbolReuse() {
- System.out.println("test encoding slime reusing symbols");
+ public void testEncodingSlimeReusingSymbols() {
Slime slime = new Slime();
Cursor c1 = slime.setArray();
{
@@ -533,8 +517,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testOptionalDecodeOrder() {
- System.out.println("test decoding slime with different symbol order");
+ public void testDecodingSlimeWithDifferentSymbolOrder() {
byte[] data = {
5, // num symbols
1, 'd', 1, 'e', 1, 'f', 1, 'b', 1, 'c', // symbol table
@@ -564,4 +547,5 @@ public class BinaryFormatTestCase {
assertThat(c.field("f").asData(), is(expd));
assertThat(c.entry(5).valid(), is(false)); // not ARRAY
}
+
}