diff options
author | Lester Solbakken <lesters@oath.com> | 2019-11-22 10:52:58 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2019-11-22 10:52:58 +0100 |
commit | f3e934cdeae3fceb6bf952dde2f5b0b90b02bfa7 (patch) | |
tree | d160c6eaa8f2e38f4a5897b94cc8195327f2e020 /model-integration | |
parent | 8b3c453b66f891a59ca80bfc47abe63be1b9bace (diff) |
Actually insert rename into evaluation tree when needed in model import
Diffstat (limited to 'model-integration')
3 files changed, 52 insertions, 34 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index e9be35b6f84..bcddfcbfc13 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -121,7 +121,7 @@ public class DimensionRenamer { List<IntermediateOperation> prioritizedOperations = constraintsPerOperation.entrySet().stream() .sorted(Comparator.comparingInt(entry -> - entry.getValue())) - .map(entry -> entry.getKey()) + .map(Map.Entry::getKey) .collect(Collectors.toList()); List<RenameTarget> targets = new ArrayList<>(); @@ -131,8 +131,7 @@ public class DimensionRenamer { if (inputType.isEmpty()) continue; for (String dimensionName : inputType.get().dimensionNames()) { RenameTarget target = new RenameTarget(operation, i, dimensionName, graph); - if (target.rootKey != null) // TODO: Inserting renames under non-roots is not implemented - targets.add(target); + targets.add(target); } } } @@ -313,17 +312,10 @@ public class DimensionRenamer { final String dimensionName; final IntermediateGraph graph; - /** - * Returns the key of this operation in the root operations of the graph, - * or null if it is not a root operation - */ - final String rootKey; - public RenameTarget(IntermediateOperation operation, int inputNumber, String dimensionName, IntermediateGraph graph) { this.operation = operation; this.inputNumber = inputNumber; this.dimensionName = dimensionName; - this.rootKey = findRootKey(operation, graph); this.graph = graph; } @@ -331,42 +323,25 @@ public class DimensionRenamer { return operation.inputs().get(inputNumber); } - private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) { - for (var entry : graph.operations().entrySet()) { - if (entry.getValue() == operation) - return entry.getKey(); - } - return null; - } - /** Inserts a rename operation if possible. Returns whether an operation was inserted. */ private boolean insertRename(DimensionRenamer renamer) { Rename rename = new Rename(operation.modelName(), dimensionName, renamer.dimensionPrefix + renamer.dimensions.size(), input()); - - List<IntermediateOperation> newInputs = new ArrayList<>(operation.inputs()); - newInputs.set(inputNumber, rename); - IntermediateOperation newOperation = operation.withInputs(newInputs); - if (rootKey == null) - throw new IllegalStateException("Renaming non-roots is not implemented"); - graph.put(rootKey, newOperation); - + operation.insert(rename, inputNumber); removeConstraintsOf(operation, renamer); rename.addDimensionNameConstraints(renamer); - newOperation.addDimensionNameConstraints(renamer); + operation.addDimensionNameConstraints(renamer); return true; } /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */ private void uninsertRename(DimensionRenamer renamer) { - IntermediateOperation newOperation = graph.operations().get(rootKey); - Rename rename = (Rename)newOperation.inputs().get(inputNumber); - graph.put(rootKey, operation); - + Rename rename = (Rename)operation.inputs().get(inputNumber); + operation.uninsert(inputNumber); removeConstraintsOf(rename, renamer); - removeConstraintsOf(newOperation, renamer); + removeConstraintsOf(operation, renamer); operation.addDimensionNameConstraints(renamer); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 6d0cdfc5021..a3f5aa6c130 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -213,6 +213,42 @@ public abstract class IntermediateOperation { return result; } + /** Insert an operation between an input and this one */ + public void insert(IntermediateOperation operationToInsert, int inputNumber) { + if ( operationToInsert.inputs.size() > 0 ) { + throw new IllegalArgumentException("Operation to insert to '" + name + "' has " + + "existing inputs which is not supported."); + } + IntermediateOperation previousInputOperation = inputs.get(inputNumber); + int outputNumber = findOutputNumber(previousInputOperation, this); + if (outputNumber == -1) { + throw new IllegalArgumentException("Input '" + previousInputOperation.name + "' to '" + + name + "' does not have '" + name + "' as output."); + } + previousInputOperation.outputs.set(outputNumber, operationToInsert); + operationToInsert.inputs.add(previousInputOperation); + operationToInsert.outputs.add(this); + inputs.set(inputNumber, operationToInsert); + } + + /** Remove an operation between an input and this one */ + public void uninsert(int inputNumber) { + IntermediateOperation operationToRemove = inputs.get(inputNumber); + IntermediateOperation newInputOperation = operationToRemove.inputs.get(0); + int outputNumber = findOutputNumber(newInputOperation, operationToRemove); + newInputOperation.outputs.set(outputNumber, this); + inputs.set(inputNumber, newInputOperation); + } + + private int findOutputNumber(IntermediateOperation output, IntermediateOperation op) { + for (int i = 0; i < output.outputs.size(); ++i) { + if (output.outputs.get(i).equals(op)) { + return i; + } + } + return -1; + } + /** * Returns the largest value type among the input value types. * This should only be called after it has been verified that input types are available. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java index abc431233be..e040ae62149 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -6,6 +6,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import java.util.Collections; import java.util.List; /** @@ -15,10 +16,10 @@ import java.util.List; */ public class Rename extends IntermediateOperation { - private final String from, to; + private String from, to; public Rename(String modelName, String from, String to, IntermediateOperation input) { - super(modelName, "rename", List.of(input)); + super(modelName, "rename", input != null ? List.of(input) : Collections.emptyList()); this.from = from; this.to = to; } @@ -52,6 +53,12 @@ public class Rename extends IntermediateOperation { renamer.addDimension(to); } + public void renameDimensions(DimensionRenamer renamer) { + type = type.rename(renamer); + from = renamer.dimensionNameOf(from).orElse(from); + to = renamer.dimensionNameOf(to).orElse(to); + } + @Override public Rename withInputs(List<IntermediateOperation> inputs) { if (inputs.size() != 1) |