aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-09-24 12:41:07 +0200
committerLester Solbakken <lesters@oath.com>2019-09-24 12:41:07 +0200
commite1723e3e7c40997ecc099ab0ccfd4f6b8ba3e221 (patch)
tree833c17c59ad261b7c7b842dad92965d642a3c05e /config-model/src/main/java/com/yahoo/vespa/model/ml
parentd082531b8c6244de5bc99ed887f706be3a1084df (diff)
Avoid renaming batch reduced dimensions
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java18
1 files changed, 18 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 877b1ac72a9..4acb47df179 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -409,6 +409,24 @@ public class ConvertedModel {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
+ // Modify any renames in expression to disregard batch dimension
+ else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
+ TensorFunction childFunction = (((TensorFunctionNode) children.get(0)).function());
+ TensorType childType = childFunction.type(typeContext);
+ Rename rename = (Rename) tensorFunction;
+ List<String> from = new ArrayList<>();
+ List<String> to = new ArrayList<>();
+ for (TensorType.Dimension dimension : childType.dimensions()) {
+ int i = rename.fromDimensions().indexOf(dimension.name());
+ if (i < 0) {
+ throw new IllegalArgumentException("Rename does not contain dimension '" +
+ dimension + "' in child expression type: " + childType);
+ }
+ from.add(rename.fromDimensions().get(i));
+ to.add(rename.toDimensions().get(i));
+ }
+ return new TensorFunctionNode(new Rename(childFunction, from, to));
+ }
}
}
if (node instanceof ReferenceNode) {