summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-04 18:00:33 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-04 18:00:33 -0700
commita6f80f9006e76b11cfc3d78643736b921760cb96 (patch)
tree35b9b9a4b8b70633eab7b67f795111fb0d5bead3 /model-integration
parent5492ad488db1c460a08fa92890205c37c1456db6 (diff)
Try renaming under all root operations (only)
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java131
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java5
22 files changed, 225 insertions, 25 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 22fabe3ada7..89ba36f0d39 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -2,17 +2,19 @@
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
-import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.Rename;
import com.yahoo.collections.ListMap;
import java.util.ArrayList;
+import java.util.Comparator;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
+import java.util.TreeMap;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.stream.Collectors;
@@ -68,34 +70,55 @@ public class DimensionRenamer {
}
private Map<String, Integer> solve(int maxIterations) {
- int renamesTried = 0;
- while (renamesTried++ <= dimensions.size()) {
- Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations);
+ Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations);
+ if (solution != null) return solution;
+
+ for (RenameTarget renameTarget : prioritizedRenameTargets()) {
+ System.out.println("Trying rename " + renameTarget);
+ insertRenameOperation(renameTarget);
+ solution = solveWithOrWithoutSoftConstraints(maxIterations);
if (solution != null) return solution;
- if ( ! insertRenameOperation()) return null;
+ rollbackRenameOperation(renameTarget);
}
throw new IllegalArgumentException("Could not find a dimension naming solution " +
"given constraints\n" + constraintsToString(constraints));
}
- private boolean insertRenameOperation() {
- IntermediateOperation operation = graph.operations().get("dense_out/MatMul");
- if (operation instanceof MatMul) {
- IntermediateOperation arg0 = operation.inputs().get(0);
- List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs());
- inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0));
- IntermediateOperation newOperation = operation.withInputs(inputs);
- graph.put("dense_out/MatMul", newOperation);
-
- for (Arc key : new HashSet<>(constraints.keySet())) {
- if (key.operation == operation)
- constraints.removeAll(key);
- }
- addDimension("renamed_0");
- newOperation.addDimensionNameConstraints(this);
- return true;
+ /** Inserts a rename operation if possible. Returns whether an operation was inserted. */
+ private boolean insertRenameOperation(RenameTarget target) {
+ Rename rename = new Rename(target.operation.modelName(),
+ "Dot_ExpandDims_1", "renamed_0",
+ target.input());
+
+ List<IntermediateOperation> newInputs = new ArrayList<>(target.operation.inputs());
+ newInputs.set(target.inputNumber, rename);
+ IntermediateOperation newOperation = target.operation.withInputs(newInputs);
+ if (target.rootKey == null)
+ throw new IllegalStateException("Renaming non-roots is not implemented");
+ graph.put(target.rootKey, newOperation);
+
+ removeConstraintsOf(target.operation);
+ rename.addDimensionNameConstraints(this);
+ newOperation.addDimensionNameConstraints(this);
+ return true;
+ }
+
+ /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */
+ private void rollbackRenameOperation(RenameTarget target) {
+ IntermediateOperation newOperation = graph.operations().get(target.rootKey);
+ Rename rename = (Rename)newOperation.inputs().get(target.inputNumber);
+ graph.put(target.rootKey, target.operation);
+
+ removeConstraintsOf(rename);
+ removeConstraintsOf(newOperation);
+ target.operation.addDimensionNameConstraints(this);
+ }
+
+ private void removeConstraintsOf(IntermediateOperation operation) {
+ for (Arc key : new HashSet<>(constraints.keySet())) {
+ if (key.operation == operation)
+ constraints.removeAll(key);
}
- return false;
}
private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) {
@@ -124,6 +147,30 @@ public class DimensionRenamer {
return removed;
}
+ private List<RenameTarget> prioritizedRenameTargets() {
+ Map<IntermediateOperation, Integer> constraintsPerOperation = new HashMap<>();
+
+ for (var constraint : constraints.entrySet()) {
+ constraintsPerOperation.compute(constraint.getKey().operation,
+ (operation, count) -> count == null ? 1 : ++count);
+ }
+ List<IntermediateOperation> prioritizedOperations =
+ constraintsPerOperation.entrySet().stream()
+ .sorted(Comparator.comparingInt(entry -> - entry.getValue()))
+ .map(entry -> entry.getKey())
+ .collect(Collectors.toList());
+
+ List<RenameTarget> targets = new ArrayList<>();
+ for (IntermediateOperation operation : prioritizedOperations) {
+ for (int i = 0; i < operation.inputs().size(); i++) {
+ RenameTarget target = new RenameTarget(operation, i, graph);
+ if (target.rootKey != null) // Inserting renames under non-roots is not implemented
+ targets.add(new RenameTarget(operation, i, graph));
+ }
+ }
+ return targets;
+ }
+
/**
* Retrieve resulting name of a dimension after solving for constraints, or empty if no
* solution is found yet, or this dimension was not added before finding a solution.
@@ -285,4 +332,44 @@ public class DimensionRenamer {
}
+ /**
+ * An operation and an input number which we may want to insert a rename operation at.
+ * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...).
+ */
+ private static class RenameTarget {
+
+ final IntermediateOperation operation;
+ final int inputNumber;
+
+ /**
+ * Returns the key of this operation in the root operations of the graph,
+ * or null if it is not a root operation
+ */
+ final String rootKey;
+
+ public RenameTarget(IntermediateOperation operation, int inputNumber, IntermediateGraph graph) {
+ this.operation = operation;
+ this.inputNumber = inputNumber;
+ this.rootKey = findRootKey(operation, graph);
+ }
+
+ public IntermediateOperation input() {
+ return operation.inputs().get(inputNumber);
+ }
+
+ private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) {
+ for (var entry : graph.operations().entrySet()) {
+ if (entry.getValue() == operation)
+ return entry.getKey();
+ }
+ return null;
+ }
+
+ @Override
+ public String toString() {
+ return operation + ", input " + inputNumber;
+ }
+
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
index 8fe70f7eefb..b03e6889ab5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
@@ -53,6 +53,13 @@ public class Argument extends IntermediateOperation {
}
@Override
+ public Argument withInputs(List<IntermediateOperation> inputs) {
+ if ( ! inputs.isEmpty())
+ throw new IllegalArgumentException("Argument cannot take inputs");
+ return new Argument(modelName(), name(), type);
+ }
+
+ @Override
public String toString() {
return "\t" + lazyGetType() + ":\tArgument(" + standardNamingType + ")";
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index c211b434176..b82a4458873 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -99,4 +99,9 @@ public class ConcatV2 extends IntermediateOperation {
concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
}
+ @Override
+ public ConcatV2 withInputs(List<IntermediateOperation> inputs) {
+ return new ConcatV2(modelName(), name(), inputs);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index c48a99c2716..f9b5d770aa0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -86,6 +86,11 @@ public class Const extends IntermediateOperation {
}
@Override
+ public Const withInputs(List<IntermediateOperation> inputs) {
+ return new Const(modelName(), name(), inputs, attributeMap, type);
+ }
+
+ @Override
public String toString() {
return "\t" + lazyGetType() + ":\tConst(" + type + ")";
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
index 8c6e69584e0..42cdfe27e24 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Collections;
+import java.util.List;
import java.util.Optional;
public class Constant extends IntermediateOperation {
@@ -56,4 +57,11 @@ public class Constant extends IntermediateOperation {
return true;
}
+ @Override
+ public Constant withInputs(List<IntermediateOperation> inputs) {
+ if ( ! inputs.isEmpty())
+ throw new IllegalArgumentException("Constant cannot take inputs");
+ return new Constant(modelName(), name(), type);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 91783aed7a5..569bd51cc0c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -102,6 +102,11 @@ public class ExpandDims extends IntermediateOperation {
}
@Override
+ public ExpandDims withInputs(List<IntermediateOperation> inputs) {
+ return new ExpandDims(modelName(), name(), inputs);
+ }
+
+ @Override
public String toString() {
return "ExpandDims(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + expandDimensions + ")";
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
index c2787aa14d4..d2fc08fc877 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
@@ -32,4 +32,9 @@ public class Identity extends IntermediateOperation {
return inputs.get(0).function().orElse(null);
}
+ @Override
+ public Identity withInputs(List<IntermediateOperation> inputs) {
+ return new Identity(modelName(), name(), inputs);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index cc9985af6d4..1bc424d1641 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -192,9 +192,7 @@ public abstract class IntermediateOperation {
.collect(Collectors.toList()));
}
- public IntermediateOperation withInputs(List<IntermediateOperation> inputs) {
- throw new UnsupportedOperationException();
- }
+ public abstract IntermediateOperation withInputs(List<IntermediateOperation> inputs);
public String toFullString() { return toString(); }
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index 2c5c38b76cc..2c62fcf62a5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -112,6 +112,11 @@ public class Join extends IntermediateOperation {
}
@Override
+ public Join withInputs(List<IntermediateOperation> inputs) {
+ return new Join(modelName(), name(), inputs, operator);
+ }
+
+ @Override
public String toString() {
return "Join(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + operator + ")";
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
index 52aec71fa3f..cbdb15cd364 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
@@ -35,6 +35,11 @@ public class Map extends IntermediateOperation {
}
@Override
+ public Map withInputs(List<IntermediateOperation> inputs) {
+ return new Map(modelName(), name(), inputs, operator);
+ }
+
+ @Override
public String toString() {
return "Map(" + asString(inputs().get(0).type()) + ", " + operator + ")";
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
index d8e9950c61f..691e9966bb0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
@@ -91,6 +91,11 @@ public class Mean extends IntermediateOperation {
reduceDimensions = renamedDimensions;
}
+ @Override
+ public Mean withInputs(List<IntermediateOperation> inputs) {
+ return new Mean(modelName(), name(), inputs, attributeMap);
+ }
+
private boolean shouldKeepDimensions() {
Optional<Value> keepDims = attributeMap.get("keep_dims");
return keepDims.isPresent() && keepDims.get().asBoolean();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
index ce0c58971d0..1bf92848d10 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
@@ -32,4 +32,9 @@ public class Merge extends IntermediateOperation {
return null;
}
+ @Override
+ public Merge withInputs(List<IntermediateOperation> inputs) {
+ return new Merge(modelName(), name(), inputs);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
index 4c5ce33b1b5..5ecaaa10c57 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
@@ -23,4 +23,9 @@ public class NoOp extends IntermediateOperation {
return null;
}
+ @Override
+ public NoOp withInputs(List<IntermediateOperation> inputs) {
+ return new NoOp(modelName(), name(), inputs);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
index e5e5c29f8f1..b74e5176862 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
@@ -45,4 +45,9 @@ public class PlaceholderWithDefault extends IntermediateOperation {
return true; // not true if we add to function
}
+ @Override
+ public PlaceholderWithDefault withInputs(List<IntermediateOperation> inputs) {
+ return new PlaceholderWithDefault(modelName(), name(), inputs);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
index 264ee6b9dff..2dad0fd641e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -1,6 +1,7 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.operations;
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
@@ -46,4 +47,18 @@ public class Rename extends IntermediateOperation {
return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to);
}
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ renamer.addDimension(to);
+ }
+
+ @Override
+ public Rename withInputs(List<IntermediateOperation> inputs) {
+ if (inputs.size() != 1)
+ throw new IllegalArgumentException("Rename require 1 input, not " + inputs.size());
+ return new Rename(modelName(), from, to, inputs.get(0));
+ }
+
}
+
+
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 4a0fe236c9f..96a52e347d1 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -74,6 +74,11 @@ public class Reshape extends IntermediateOperation {
}
}
+ @Override
+ public Reshape withInputs(List<IntermediateOperation> inputs) {
+ return new Reshape(modelName(), name(), inputs);
+ }
+
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType)))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index 2484310e829..b02609ca1d9 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -85,4 +85,9 @@ public class Select extends IntermediateOperation {
renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this);
}
+ @Override
+ public Select withInputs(List<IntermediateOperation> inputs) {
+ return new Select(modelName(), name(), inputs);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
index 79f3012c327..5aaf379df62 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
@@ -37,6 +37,11 @@ public class Shape extends IntermediateOperation {
return true;
}
+ @Override
+ public Shape withInputs(List<IntermediateOperation> inputs) {
+ return new Shape(modelName(), name(), inputs);
+ }
+
private void createConstantValue() {
if (!allInputTypesPresent(1)) {
return;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index cdacbe1656a..b95462baea5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -37,4 +37,9 @@ public class Softmax extends IntermediateOperation {
return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension);
}
+ @Override
+ public Softmax withInputs(List<IntermediateOperation> inputs) {
+ return new Softmax(modelName(), name(), inputs);
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index 52d40144f61..db6548ce4ff 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -70,6 +70,11 @@ public class Squeeze extends IntermediateOperation {
squeezeDimensions = renamedDimensions;
}
+ @Override
+ public Squeeze withInputs(List<IntermediateOperation> inputs) {
+ return new Squeeze(modelName(), name(), inputs, attributeMap);
+ }
+
private OrderedTensorType reducedType(OrderedTensorType inputType) {
OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
index 046ab2a1646..7de6c5b0fb8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
@@ -108,6 +108,11 @@ public class Sum extends IntermediateOperation {
}
@Override
+ public Sum withInputs(List<IntermediateOperation> inputs) {
+ return new Sum(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
public String toString() {
return "Sum(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + reduceDimensions + ")";
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
index 39702690bfa..77ec720e645 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
@@ -42,6 +42,11 @@ public class Switch extends IntermediateOperation {
return predicate == port ? inputs().get(0).function().get() : null;
}
+ @Override
+ public Switch withInputs(List<IntermediateOperation> inputs) {
+ return new Switch(modelName(), name(), inputs, port);
+ }
+
}