diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
commit | 5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch) | |
tree | 2b65d4f48b92bf7ec846b3efd5d5259244bc234a /config-model | |
parent | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff) |
Add tensor value type
Diffstat (limited to 'config-model')
4 files changed, 8 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index a0f35dbefe6..75b3af47954 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -191,7 +191,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement else { // default dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString(); } - return Optional.of(new TensorType.Builder().mapped(dimension).build()); + + // TODO: Determine the type of the weighted set/vector and use that as value type + return Optional.of(new TensorType.Builder(TensorType.Value.DOUBLE).mapped(dimension).build()); } /** Binds the given list of formal arguments to their actual values */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index f197e2dfe6d..e12cc60b041 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -453,10 +453,9 @@ public class ConvertedModel { */ // TODO: determine when this is not necessary! private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); + if (after.equals(before)) return node; + + TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); for (TensorType.Dimension dimension : before.dimensions()) { if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { typeBuilder.indexed(dimension.name(), 1); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java index 5c96635fd8f..80440ac8eb4 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java @@ -144,7 +144,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); - exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'"); + exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(x)'. Dimension 'x' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])"); RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " constants {\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java index 2fcf5809ea5..f53ca15635f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java @@ -39,7 +39,7 @@ public class TensorFieldTestCase { @Test public void requireThatIllegalTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); - exception.expectMessage("Field type: Illegal tensor type spec: Failed parsing element 'invalid' in type spec 'tensor(invalid)'"); + exception.expectMessage("Field type: Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(invalid)'. Dimension 'invalid' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])"); SearchBuilder.createFromString(getSd("field f1 type tensor(invalid) { indexing: attribute }")); } |