diff options
5 files changed, 22 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 877b1ac72a9..4acb47df179 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -409,6 +409,24 @@ public class ConvertedModel { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } + // Modify any renames in expression to disregard batch dimension + else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { + TensorFunction childFunction = (((TensorFunctionNode) children.get(0)).function()); + TensorType childType = childFunction.type(typeContext); + Rename rename = (Rename) tensorFunction; + List<String> from = new ArrayList<>(); + List<String> to = new ArrayList<>(); + for (TensorType.Dimension dimension : childType.dimensions()) { + int i = rename.fromDimensions().indexOf(dimension.name()); + if (i < 0) { + throw new IllegalArgumentException("Rename does not contain dimension '" + + dimension + "' in child expression type: " + childType); + } + from.add(rename.fromDimensions().get(i)); + to.add(rename.toDimensions().get(i)); + } + return new TensorFunctionNode(new Rename(childFunction, from, to)); + } } } if (node instanceof ReferenceNode) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index 9f62a27a3b9..dad4508bc61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -67,7 +67,7 @@ public class Argument extends IntermediateOperation { @Override public String toFullString() { - return "\t" + lazyGetType() + ":\tArgument(" + standardNamingType + ")"; + return "\t" + type + ":\tArgument(" + standardNamingType + ")"; } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 3ad5cb1d19f..fc895b07d53 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -102,7 +102,7 @@ public class Const extends IntermediateOperation { @Override public String toFullString() { - return "\t" + lazyGetType() + ":\tConst(" + type + ")"; + return "\t" + type + ":\tConst(" + getConstantValue().get() + ")"; } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index 1eaaf705220..ad56eefe5f2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -74,7 +74,7 @@ public class Constant extends IntermediateOperation { @Override public String toFullString() { - return "\t" + lazyGetType() + ":\tConstant(" + type + ")"; + return "\t" + type + ":\tConstant(" + type + ")"; } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 9c9fed89585..26b376cce1c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -226,7 +226,7 @@ public abstract class IntermediateOperation { } public String toFullString() { - return "\t" + lazyGetType() + ":\t" + operationName() + "(" + + return "\t" + type + ":\t" + operationName() + "(" + inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + ")"; } |