aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 10:52:58 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 10:52:58 +0100
commitf3e934cdeae3fceb6bf952dde2f5b0b90b02bfa7 (patch)
treed160c6eaa8f2e38f4a5897b94cc8195327f2e020 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations
parent8b3c453b66f891a59ca80bfc47abe63be1b9bace (diff)
Actually insert rename into evaluation tree when needed in model import
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java36
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java11
2 files changed, 45 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 6d0cdfc5021..a3f5aa6c130 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -213,6 +213,42 @@ public abstract class IntermediateOperation {
return result;
}
+ /** Insert an operation between an input and this one */
+ public void insert(IntermediateOperation operationToInsert, int inputNumber) {
+ if ( operationToInsert.inputs.size() > 0 ) {
+ throw new IllegalArgumentException("Operation to insert to '" + name + "' has " +
+ "existing inputs which is not supported.");
+ }
+ IntermediateOperation previousInputOperation = inputs.get(inputNumber);
+ int outputNumber = findOutputNumber(previousInputOperation, this);
+ if (outputNumber == -1) {
+ throw new IllegalArgumentException("Input '" + previousInputOperation.name + "' to '" +
+ name + "' does not have '" + name + "' as output.");
+ }
+ previousInputOperation.outputs.set(outputNumber, operationToInsert);
+ operationToInsert.inputs.add(previousInputOperation);
+ operationToInsert.outputs.add(this);
+ inputs.set(inputNumber, operationToInsert);
+ }
+
+ /** Remove an operation between an input and this one */
+ public void uninsert(int inputNumber) {
+ IntermediateOperation operationToRemove = inputs.get(inputNumber);
+ IntermediateOperation newInputOperation = operationToRemove.inputs.get(0);
+ int outputNumber = findOutputNumber(newInputOperation, operationToRemove);
+ newInputOperation.outputs.set(outputNumber, this);
+ inputs.set(inputNumber, newInputOperation);
+ }
+
+ private int findOutputNumber(IntermediateOperation output, IntermediateOperation op) {
+ for (int i = 0; i < output.outputs.size(); ++i) {
+ if (output.outputs.get(i).equals(op)) {
+ return i;
+ }
+ }
+ return -1;
+ }
+
/**
* Returns the largest value type among the input value types.
* This should only be called after it has been verified that input types are available.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
index abc431233be..e040ae62149 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -6,6 +6,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import java.util.Collections;
import java.util.List;
/**
@@ -15,10 +16,10 @@ import java.util.List;
*/
public class Rename extends IntermediateOperation {
- private final String from, to;
+ private String from, to;
public Rename(String modelName, String from, String to, IntermediateOperation input) {
- super(modelName, "rename", List.of(input));
+ super(modelName, "rename", input != null ? List.of(input) : Collections.emptyList());
this.from = from;
this.to = to;
}
@@ -52,6 +53,12 @@ public class Rename extends IntermediateOperation {
renamer.addDimension(to);
}
+ public void renameDimensions(DimensionRenamer renamer) {
+ type = type.rename(renamer);
+ from = renamer.dimensionNameOf(from).orElse(from);
+ to = renamer.dimensionNameOf(to).orElse(to);
+ }
+
@Override
public Rename withInputs(List<IntermediateOperation> inputs) {
if (inputs.size() != 1)