summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
commit5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch)
tree2b65d4f48b92bf7ec846b3efd5d5259244bc234a /config-model/src/main/java
parent6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff)
Add tensor value type
Diffstat (limited to 'config-model/src/main/java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java7
2 files changed, 6 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index a0f35dbefe6..75b3af47954 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -191,7 +191,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
else { // default
dimension = ((ReferenceNode)arg0).reference().arguments().expressions().get(0).toString();
}
- return Optional.of(new TensorType.Builder().mapped(dimension).build());
+
+ // TODO: Determine the type of the weighted set/vector and use that as value type
+ return Optional.of(new TensorType.Builder(TensorType.Value.DOUBLE).mapped(dimension).build());
}
/** Binds the given list of formal arguments to their actual values */
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index f197e2dfe6d..e12cc60b041 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -453,10 +453,9 @@ public class ConvertedModel {
*/
// TODO: determine when this is not necessary!
private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ if (after.equals(before)) return node;
+
+ TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
for (TensorType.Dimension dimension : before.dimensions()) {
if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
typeBuilder.indexed(dimension.name(), 1);