From a6f80f9006e76b11cfc3d78643736b921760cb96 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 4 Jul 2019 18:00:33 -0700 Subject: Try renaming under all root operations (only) --- .../importer/DimensionRenamer.java | 131 +++++++++++++++++---- .../importer/operations/Argument.java | 7 ++ .../importer/operations/ConcatV2.java | 5 + .../importer/operations/Const.java | 5 + .../importer/operations/Constant.java | 8 ++ .../importer/operations/ExpandDims.java | 5 + .../importer/operations/Identity.java | 5 + .../importer/operations/IntermediateOperation.java | 4 +- .../importer/operations/Join.java | 5 + .../rankingexpression/importer/operations/Map.java | 5 + .../importer/operations/Mean.java | 5 + .../importer/operations/Merge.java | 5 + .../importer/operations/NoOp.java | 5 + .../operations/PlaceholderWithDefault.java | 5 + .../importer/operations/Rename.java | 15 +++ .../importer/operations/Reshape.java | 5 + .../importer/operations/Select.java | 5 + .../importer/operations/Shape.java | 5 + .../importer/operations/Softmax.java | 5 + .../importer/operations/Squeeze.java | 5 + .../rankingexpression/importer/operations/Sum.java | 5 + .../importer/operations/Switch.java | 5 + 22 files changed, 225 insertions(+), 25 deletions(-) (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 22fabe3ada7..89ba36f0d39 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -2,17 +2,19 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; -import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.Rename; import com.yahoo.collections.ListMap; import java.util.ArrayList; +import java.util.Comparator; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; +import java.util.TreeMap; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -68,34 +70,55 @@ public class DimensionRenamer { } private Map solve(int maxIterations) { - int renamesTried = 0; - while (renamesTried++ <= dimensions.size()) { - Map solution = solveWithOrWithoutSoftConstraints(maxIterations); + Map solution = solveWithOrWithoutSoftConstraints(maxIterations); + if (solution != null) return solution; + + for (RenameTarget renameTarget : prioritizedRenameTargets()) { + System.out.println("Trying rename " + renameTarget); + insertRenameOperation(renameTarget); + solution = solveWithOrWithoutSoftConstraints(maxIterations); if (solution != null) return solution; - if ( ! insertRenameOperation()) return null; + rollbackRenameOperation(renameTarget); } throw new IllegalArgumentException("Could not find a dimension naming solution " + "given constraints\n" + constraintsToString(constraints)); } - private boolean insertRenameOperation() { - IntermediateOperation operation = graph.operations().get("dense_out/MatMul"); - if (operation instanceof MatMul) { - IntermediateOperation arg0 = operation.inputs().get(0); - List inputs = new ArrayList<>(operation.inputs()); - inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0)); - IntermediateOperation newOperation = operation.withInputs(inputs); - graph.put("dense_out/MatMul", newOperation); - - for (Arc key : new HashSet<>(constraints.keySet())) { - if (key.operation == operation) - constraints.removeAll(key); - } - addDimension("renamed_0"); - newOperation.addDimensionNameConstraints(this); - return true; + /** Inserts a rename operation if possible. Returns whether an operation was inserted. */ + private boolean insertRenameOperation(RenameTarget target) { + Rename rename = new Rename(target.operation.modelName(), + "Dot_ExpandDims_1", "renamed_0", + target.input()); + + List newInputs = new ArrayList<>(target.operation.inputs()); + newInputs.set(target.inputNumber, rename); + IntermediateOperation newOperation = target.operation.withInputs(newInputs); + if (target.rootKey == null) + throw new IllegalStateException("Renaming non-roots is not implemented"); + graph.put(target.rootKey, newOperation); + + removeConstraintsOf(target.operation); + rename.addDimensionNameConstraints(this); + newOperation.addDimensionNameConstraints(this); + return true; + } + + /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */ + private void rollbackRenameOperation(RenameTarget target) { + IntermediateOperation newOperation = graph.operations().get(target.rootKey); + Rename rename = (Rename)newOperation.inputs().get(target.inputNumber); + graph.put(target.rootKey, target.operation); + + removeConstraintsOf(rename); + removeConstraintsOf(newOperation); + target.operation.addDimensionNameConstraints(this); + } + + private void removeConstraintsOf(IntermediateOperation operation) { + for (Arc key : new HashSet<>(constraints.keySet())) { + if (key.operation == operation) + constraints.removeAll(key); } - return false; } private Map solveWithOrWithoutSoftConstraints(int maxIterations) { @@ -124,6 +147,30 @@ public class DimensionRenamer { return removed; } + private List prioritizedRenameTargets() { + Map constraintsPerOperation = new HashMap<>(); + + for (var constraint : constraints.entrySet()) { + constraintsPerOperation.compute(constraint.getKey().operation, + (operation, count) -> count == null ? 1 : ++count); + } + List prioritizedOperations = + constraintsPerOperation.entrySet().stream() + .sorted(Comparator.comparingInt(entry -> - entry.getValue())) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()); + + List targets = new ArrayList<>(); + for (IntermediateOperation operation : prioritizedOperations) { + for (int i = 0; i < operation.inputs().size(); i++) { + RenameTarget target = new RenameTarget(operation, i, graph); + if (target.rootKey != null) // Inserting renames under non-roots is not implemented + targets.add(new RenameTarget(operation, i, graph)); + } + } + return targets; + } + /** * Retrieve resulting name of a dimension after solving for constraints, or empty if no * solution is found yet, or this dimension was not added before finding a solution. @@ -285,4 +332,44 @@ public class DimensionRenamer { } + /** + * An operation and an input number which we may want to insert a rename operation at. + * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...). + */ + private static class RenameTarget { + + final IntermediateOperation operation; + final int inputNumber; + + /** + * Returns the key of this operation in the root operations of the graph, + * or null if it is not a root operation + */ + final String rootKey; + + public RenameTarget(IntermediateOperation operation, int inputNumber, IntermediateGraph graph) { + this.operation = operation; + this.inputNumber = inputNumber; + this.rootKey = findRootKey(operation, graph); + } + + public IntermediateOperation input() { + return operation.inputs().get(inputNumber); + } + + private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) { + for (var entry : graph.operations().entrySet()) { + if (entry.getValue() == operation) + return entry.getKey(); + } + return null; + } + + @Override + public String toString() { + return operation + ", input " + inputNumber; + } + + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index 8fe70f7eefb..b03e6889ab5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -52,6 +52,13 @@ public class Argument extends IntermediateOperation { return false; } + @Override + public Argument withInputs(List inputs) { + if ( ! inputs.isEmpty()) + throw new IllegalArgumentException("Argument cannot take inputs"); + return new Argument(modelName(), name(), type); + } + @Override public String toString() { return "\t" + lazyGetType() + ":\tArgument(" + standardNamingType + ")"; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index c211b434176..b82a4458873 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -99,4 +99,9 @@ public class ConcatV2 extends IntermediateOperation { concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); } + @Override + public ConcatV2 withInputs(List inputs) { + return new ConcatV2(modelName(), name(), inputs); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index c48a99c2716..f9b5d770aa0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -85,6 +85,11 @@ public class Const extends IntermediateOperation { return value.get(); } + @Override + public Const withInputs(List inputs) { + return new Const(modelName(), name(), inputs, attributeMap, type); + } + @Override public String toString() { return "\t" + lazyGetType() + ":\tConst(" + type + ")"; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index 8c6e69584e0..42cdfe27e24 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.Collections; +import java.util.List; import java.util.Optional; public class Constant extends IntermediateOperation { @@ -56,4 +57,11 @@ public class Constant extends IntermediateOperation { return true; } + @Override + public Constant withInputs(List inputs) { + if ( ! inputs.isEmpty()) + throw new IllegalArgumentException("Constant cannot take inputs"); + return new Constant(modelName(), name(), type); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 91783aed7a5..569bd51cc0c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -101,6 +101,11 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = renamedDimensions; } + @Override + public ExpandDims withInputs(List inputs) { + return new ExpandDims(modelName(), name(), inputs); + } + @Override public String toString() { return "ExpandDims(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + expandDimensions + ")"; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index c2787aa14d4..d2fc08fc877 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -32,4 +32,9 @@ public class Identity extends IntermediateOperation { return inputs.get(0).function().orElse(null); } + @Override + public Identity withInputs(List inputs) { + return new Identity(modelName(), name(), inputs); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index cc9985af6d4..1bc424d1641 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -192,9 +192,7 @@ public abstract class IntermediateOperation { .collect(Collectors.toList())); } - public IntermediateOperation withInputs(List inputs) { - throw new UnsupportedOperationException(); - } + public abstract IntermediateOperation withInputs(List inputs); public String toFullString() { return toString(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 2c5c38b76cc..2c62fcf62a5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -111,6 +111,11 @@ public class Join extends IntermediateOperation { return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); } + @Override + public Join withInputs(List inputs) { + return new Join(modelName(), name(), inputs, operator); + } + @Override public String toString() { return "Join(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + operator + ")"; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index 52aec71fa3f..cbdb15cd364 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -34,6 +34,11 @@ public class Map extends IntermediateOperation { return new com.yahoo.tensor.functions.Map(input.get(), operator); } + @Override + public Map withInputs(List inputs) { + return new Map(modelName(), name(), inputs, operator); + } + @Override public String toString() { return "Map(" + asString(inputs().get(0).type()) + ", " + operator + ")"; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index d8e9950c61f..691e9966bb0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -91,6 +91,11 @@ public class Mean extends IntermediateOperation { reduceDimensions = renamedDimensions; } + @Override + public Mean withInputs(List inputs) { + return new Mean(modelName(), name(), inputs, attributeMap); + } + private boolean shouldKeepDimensions() { Optional keepDims = attributeMap.get("keep_dims"); return keepDims.isPresent() && keepDims.get().asBoolean(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java index ce0c58971d0..1bf92848d10 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java @@ -32,4 +32,9 @@ public class Merge extends IntermediateOperation { return null; } + @Override + public Merge withInputs(List inputs) { + return new Merge(modelName(), name(), inputs); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java index 4c5ce33b1b5..5ecaaa10c57 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java @@ -23,4 +23,9 @@ public class NoOp extends IntermediateOperation { return null; } + @Override + public NoOp withInputs(List inputs) { + return new NoOp(modelName(), name(), inputs); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java index e5e5c29f8f1..b74e5176862 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java @@ -45,4 +45,9 @@ public class PlaceholderWithDefault extends IntermediateOperation { return true; // not true if we add to function } + @Override + public PlaceholderWithDefault withInputs(List inputs) { + return new PlaceholderWithDefault(modelName(), name(), inputs); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java index 264ee6b9dff..2dad0fd641e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -1,6 +1,7 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.operations; +import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; @@ -46,4 +47,18 @@ public class Rename extends IntermediateOperation { return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to); } + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + renamer.addDimension(to); + } + + @Override + public Rename withInputs(List inputs) { + if (inputs.size() != 1) + throw new IllegalArgumentException("Rename require 1 input, not " + inputs.size()); + return new Rename(modelName(), from, to, inputs.get(0)); + } + } + + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 4a0fe236c9f..96a52e347d1 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -74,6 +74,11 @@ public class Reshape extends IntermediateOperation { } } + @Override + public Reshape withInputs(List inputs) { + return new Reshape(modelName(), name(), inputs); + } + public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index 2484310e829..b02609ca1d9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -85,4 +85,9 @@ public class Select extends IntermediateOperation { renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this); } + @Override + public Select withInputs(List inputs) { + return new Select(modelName(), name(), inputs); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index 79f3012c327..5aaf379df62 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -37,6 +37,11 @@ public class Shape extends IntermediateOperation { return true; } + @Override + public Shape withInputs(List inputs) { + return new Shape(modelName(), name(), inputs); + } + private void createConstantValue() { if (!allInputTypesPresent(1)) { return; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index cdacbe1656a..b95462baea5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -37,4 +37,9 @@ public class Softmax extends IntermediateOperation { return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension); } + @Override + public Softmax withInputs(List inputs) { + return new Softmax(modelName(), name(), inputs); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 52d40144f61..db6548ce4ff 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -70,6 +70,11 @@ public class Squeeze extends IntermediateOperation { squeezeDimensions = renamedDimensions; } + @Override + public Squeeze withInputs(List inputs) { + return new Squeeze(modelName(), name(), inputs, attributeMap); + } + private OrderedTensorType reducedType(OrderedTensorType inputType) { OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index 046ab2a1646..7de6c5b0fb8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -107,6 +107,11 @@ public class Sum extends IntermediateOperation { return builder.build(); } + @Override + public Sum withInputs(List inputs) { + return new Sum(modelName(), name(), inputs, attributeMap); + } + @Override public String toString() { return "Sum(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + reduceDimensions + ")"; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java index 39702690bfa..77ec720e645 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java @@ -42,6 +42,11 @@ public class Switch extends IntermediateOperation { return predicate == port ? inputs().get(0).function().get() : null; } + @Override + public Switch withInputs(List inputs) { + return new Switch(modelName(), name(), inputs, port); + } + } -- cgit v1.2.3