diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
commit | 5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch) | |
tree | 2b65d4f48b92bf7ec846b3efd5d5259244bc234a /config-model/src/main/java | |
parent | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff) |
Add tensor value type
Diffstat (limited to 'config-model/src/main/java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java | 4 | ||||
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 7 |
2 files changed, 6 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index a0f35dbefe6..75b3af47954 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -191,7 +191,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement else { // default dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString(); } - return Optional.of(new TensorType.Builder().mapped(dimension).build()); + + // TODO: Determine the type of the weighted set/vector and use that as value type + return Optional.of(new TensorType.Builder(TensorType.Value.DOUBLE).mapped(dimension).build()); } /** Binds the given list of formal arguments to their actual values */ diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index f197e2dfe6d..e12cc60b041 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -453,10 +453,9 @@ public class ConvertedModel { */ // TODO: determine when this is not necessary! private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); + if (after.equals(before)) return node; + + TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); for (TensorType.Dimension dimension : before.dimensions()) { if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { typeBuilder.indexed(dimension.name(), 1); |