summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 10:52:58 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 10:52:58 +0100
commitf3e934cdeae3fceb6bf952dde2f5b0b90b02bfa7 (patch)
treed160c6eaa8f2e38f4a5897b94cc8195327f2e020 /model-integration
parent8b3c453b66f891a59ca80bfc47abe63be1b9bace (diff)
Actually insert rename into evaluation tree when needed in model import
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java39
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java36
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java11
3 files changed, 52 insertions, 34 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index e9be35b6f84..bcddfcbfc13 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -121,7 +121,7 @@ public class DimensionRenamer {
List<IntermediateOperation> prioritizedOperations =
constraintsPerOperation.entrySet().stream()
.sorted(Comparator.comparingInt(entry -> - entry.getValue()))
- .map(entry -> entry.getKey())
+ .map(Map.Entry::getKey)
.collect(Collectors.toList());
List<RenameTarget> targets = new ArrayList<>();
@@ -131,8 +131,7 @@ public class DimensionRenamer {
if (inputType.isEmpty()) continue;
for (String dimensionName : inputType.get().dimensionNames()) {
RenameTarget target = new RenameTarget(operation, i, dimensionName, graph);
- if (target.rootKey != null) // TODO: Inserting renames under non-roots is not implemented
- targets.add(target);
+ targets.add(target);
}
}
}
@@ -313,17 +312,10 @@ public class DimensionRenamer {
final String dimensionName;
final IntermediateGraph graph;
- /**
- * Returns the key of this operation in the root operations of the graph,
- * or null if it is not a root operation
- */
- final String rootKey;
-
public RenameTarget(IntermediateOperation operation, int inputNumber, String dimensionName, IntermediateGraph graph) {
this.operation = operation;
this.inputNumber = inputNumber;
this.dimensionName = dimensionName;
- this.rootKey = findRootKey(operation, graph);
this.graph = graph;
}
@@ -331,42 +323,25 @@ public class DimensionRenamer {
return operation.inputs().get(inputNumber);
}
- private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) {
- for (var entry : graph.operations().entrySet()) {
- if (entry.getValue() == operation)
- return entry.getKey();
- }
- return null;
- }
-
/** Inserts a rename operation if possible. Returns whether an operation was inserted. */
private boolean insertRename(DimensionRenamer renamer) {
Rename rename = new Rename(operation.modelName(),
dimensionName,
renamer.dimensionPrefix + renamer.dimensions.size(),
input());
-
- List<IntermediateOperation> newInputs = new ArrayList<>(operation.inputs());
- newInputs.set(inputNumber, rename);
- IntermediateOperation newOperation = operation.withInputs(newInputs);
- if (rootKey == null)
- throw new IllegalStateException("Renaming non-roots is not implemented");
- graph.put(rootKey, newOperation);
-
+ operation.insert(rename, inputNumber);
removeConstraintsOf(operation, renamer);
rename.addDimensionNameConstraints(renamer);
- newOperation.addDimensionNameConstraints(renamer);
+ operation.addDimensionNameConstraints(renamer);
return true;
}
/** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */
private void uninsertRename(DimensionRenamer renamer) {
- IntermediateOperation newOperation = graph.operations().get(rootKey);
- Rename rename = (Rename)newOperation.inputs().get(inputNumber);
- graph.put(rootKey, operation);
-
+ Rename rename = (Rename)operation.inputs().get(inputNumber);
+ operation.uninsert(inputNumber);
removeConstraintsOf(rename, renamer);
- removeConstraintsOf(newOperation, renamer);
+ removeConstraintsOf(operation, renamer);
operation.addDimensionNameConstraints(renamer);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 6d0cdfc5021..a3f5aa6c130 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -213,6 +213,42 @@ public abstract class IntermediateOperation {
return result;
}
+ /** Insert an operation between an input and this one */
+ public void insert(IntermediateOperation operationToInsert, int inputNumber) {
+ if ( operationToInsert.inputs.size() > 0 ) {
+ throw new IllegalArgumentException("Operation to insert to '" + name + "' has " +
+ "existing inputs which is not supported.");
+ }
+ IntermediateOperation previousInputOperation = inputs.get(inputNumber);
+ int outputNumber = findOutputNumber(previousInputOperation, this);
+ if (outputNumber == -1) {
+ throw new IllegalArgumentException("Input '" + previousInputOperation.name + "' to '" +
+ name + "' does not have '" + name + "' as output.");
+ }
+ previousInputOperation.outputs.set(outputNumber, operationToInsert);
+ operationToInsert.inputs.add(previousInputOperation);
+ operationToInsert.outputs.add(this);
+ inputs.set(inputNumber, operationToInsert);
+ }
+
+ /** Remove an operation between an input and this one */
+ public void uninsert(int inputNumber) {
+ IntermediateOperation operationToRemove = inputs.get(inputNumber);
+ IntermediateOperation newInputOperation = operationToRemove.inputs.get(0);
+ int outputNumber = findOutputNumber(newInputOperation, operationToRemove);
+ newInputOperation.outputs.set(outputNumber, this);
+ inputs.set(inputNumber, newInputOperation);
+ }
+
+ private int findOutputNumber(IntermediateOperation output, IntermediateOperation op) {
+ for (int i = 0; i < output.outputs.size(); ++i) {
+ if (output.outputs.get(i).equals(op)) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
/**
* Returns the largest value type among the input value types.
* This should only be called after it has been verified that input types are available.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
index abc431233be..e040ae62149 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -6,6 +6,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import java.util.Collections;
import java.util.List;
/**
@@ -15,10 +16,10 @@ import java.util.List;
*/
public class Rename extends IntermediateOperation {
- private final String from, to;
+ private String from, to;
public Rename(String modelName, String from, String to, IntermediateOperation input) {
- super(modelName, "rename", List.of(input));
+ super(modelName, "rename", input != null ? List.of(input) : Collections.emptyList());
this.from = from;
this.to = to;
}
@@ -52,6 +53,12 @@ public class Rename extends IntermediateOperation {
renamer.addDimension(to);
}
+ public void renameDimensions(DimensionRenamer renamer) {
+ type = type.rename(renamer);
+ from = renamer.dimensionNameOf(from).orElse(from);
+ to = renamer.dimensionNameOf(to).orElse(to);
+ }
+
@Override
public Rename withInputs(List<IntermediateOperation> inputs) {
if (inputs.size() != 1)