diff options
author | Lester Solbakken <lesters@oath.com> | 2019-09-24 12:41:07 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-09-24 12:41:07 +0200 |
commit | e1723e3e7c40997ecc099ab0ccfd4f6b8ba3e221 (patch) | |
tree | 833c17c59ad261b7c7b842dad92965d642a3c05e /config-model/src/main/java/com/yahoo/vespa/model/ml | |
parent | d082531b8c6244de5bc99ed887f706be3a1084df (diff) |
Avoid renaming batch reduced dimensions
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 18 |
1 files changed, 18 insertions, 0 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 877b1ac72a9..4acb47df179 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -409,6 +409,24 @@ public class ConvertedModel { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } + // Modify any renames in expression to disregard batch dimension + else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { + TensorFunction childFunction = (((TensorFunctionNode) children.get(0)).function()); + TensorType childType = childFunction.type(typeContext); + Rename rename = (Rename) tensorFunction; + List<String> from = new ArrayList<>(); + List<String> to = new ArrayList<>(); + for (TensorType.Dimension dimension : childType.dimensions()) { + int i = rename.fromDimensions().indexOf(dimension.name()); + if (i < 0) { + throw new IllegalArgumentException("Rename does not contain dimension '" + + dimension + "' in child expression type: " + childType); + } + from.add(rename.fromDimensions().get(i)); + to.add(rename.toDimensions().get(i)); + } + return new TensorFunctionNode(new Rename(childFunction, from, to)); + } } } if (node instanceof ReferenceNode) { |