diff options
author | Jon Bratseth <bratseth@oath.com> | 2019-04-04 08:49:06 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-04-04 08:49:06 +0200 |
commit | f6e6d076a237bc5d08cf618a8caa2a683b3c78e6 (patch) | |
tree | 001e4d544e7e8354c163a44f8eb5c8d4e4e497d9 /config-model | |
parent | c8c842d622eb744504fe0b7b15044602b85ec0ee (diff) | |
parent | 8c23296c0feb1c418706f847c7b78ae926180859 (diff) |
Merge pull request #9003 from vespa-engine/bratseth/tensor-value-type
Bratseth/tensor value type
Diffstat (limited to 'config-model')
4 files changed, 8 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index a0f35dbefe6..75b3af47954 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -191,7 +191,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement else { // default dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString(); } - return Optional.of(new TensorType.Builder().mapped(dimension).build()); + + // TODO: Determine the type of the weighted set/vector and use that as value type + return Optional.of(new TensorType.Builder(TensorType.Value.DOUBLE).mapped(dimension).build()); } /** Binds the given list of formal arguments to their actual values */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index f197e2dfe6d..e12cc60b041 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -453,10 +453,9 @@ public class ConvertedModel { */ // TODO: determine when this is not necessary! private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); + if (after.equals(before)) return node; + + TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); for (TensorType.Dimension dimension : before.dimensions()) { if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { typeBuilder.indexed(dimension.name(), 1); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java index 5c96635fd8f..80440ac8eb4 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java @@ -144,7 +144,7 @@ public class RankingExpressionWithTensorTestCase { @Test public void requireThatInvalidTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); - exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: Failed parsing element 'x' in type spec 'tensor(x)'"); + exception.expectMessage("For constant tensor 'my_tensor' in rank profile 'my_profile': Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(x)'. Dimension 'x' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])"); RankProfileSearchFixture f = new RankProfileSearchFixture( " rank-profile my_profile {\n" + " constants {\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java index 2fcf5809ea5..f53ca15635f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorFieldTestCase.java @@ -39,7 +39,7 @@ public class TensorFieldTestCase { @Test public void requireThatIllegalTensorTypeSpecThrowsException() throws ParseException { exception.expect(IllegalArgumentException.class); - exception.expectMessage("Field type: Illegal tensor type spec: Failed parsing element 'invalid' in type spec 'tensor(invalid)'"); + exception.expectMessage("Field type: Illegal tensor type spec: A tensor type spec must be on the form tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was 'tensor(invalid)'. Dimension 'invalid' is on the wrong format. Examples: tensor(x[]), tensor<float>(name{}, x[10])"); SearchBuilder.createFromString(getSd("field f1 type tensor(invalid) { indexing: attribute }")); } |