diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-03 09:47:49 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-03 09:47:49 -0700 |
commit | 0ce6fa7cbdf71fd39cb5bb18accfa84a20e7e120 (patch) | |
tree | 0afa3697a10fad159883a402a2f367ee8175c027 | |
parent | fdc6a48fe913de5f5a84c6eb42123543c1d2ee46 (diff) |
Inject rename to relax constraint proof of concept
10 files changed, 122 insertions, 26 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 0579af13154..9821870e38b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -2,12 +2,17 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.MatMul; +import ai.vespa.rankingexpression.importer.operations.Rename; import com.yahoo.collections.ListMap; import com.yahoo.lang.MutableInteger; +import com.yahoo.text.ExpressionFormatter; import java.util.ArrayDeque; +import java.util.ArrayList; import java.util.Deque; import java.util.HashMap; +import java.util.HashSet; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -28,17 +33,22 @@ public class DimensionRenamer { private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName()); private final String dimensionPrefix; + + /** The graph we are renaming the dimensions of */ + private final IntermediateGraph graph; + private final ListMap<String, Integer> variables = new ListMap<>(); private final ListMap<Arc, Constraint> constraints = new ListMap<>(); /** The solution to this, or null if no solution is found (yet) */ private Map<String, Integer> renames = null; - public DimensionRenamer() { - this("d"); + public DimensionRenamer(IntermediateGraph graph) { + this(graph, "d"); } - public DimensionRenamer(String dimensionPrefix) { + public DimensionRenamer(IntermediateGraph graph, String dimensionPrefix) { + this.graph = graph; this.dimensionPrefix = dimensionPrefix; } @@ -84,12 +94,32 @@ public class DimensionRenamer { * @return the solution in the form of the renames to perform */ private Map<String, Integer> solve(int maxIterations) { - variables.freeze(); + // variables.freeze(); Map<String, Integer> renames = new HashMap<>(); // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts boolean solved = trySolve(variables, constraints, maxIterations, renames); if ( ! solved) { + IntermediateOperation operation = graph.operations().get("dense_out/MatMul"); + if (operation != null && operation instanceof MatMul) { + IntermediateOperation arg0 = operation.inputs().get(0); + List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs()); + inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0)); + IntermediateOperation newOperation = operation.withInputs(inputs); + graph.put("dense_out/MatMul", newOperation); + + for (Arc key : new HashSet<>(constraints.keySet())) { + if (key.operation == operation) + constraints.removeAll(key); + } + addDimension("renamed_0"); + newOperation.addDimensionNameConstraints(this); + + renames.clear(); + solved = trySolve(variables, constraints, maxIterations, renames); + } + } + if ( ! solved) { renames.clear(); ListMap<Arc, Constraint> hardConstraints = new ListMap<>(); boolean anyRemoved = copyHard(constraints, hardConstraints); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 0c570261ae7..a9be1bbd40e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -262,9 +262,12 @@ public class ImportedModel implements ImportedMlModel { /** Returns the expression this output references as an imported function */ public ImportedMlFunction outputFunction(String outputName, String functionName) { + RankingExpression outputExpression = owner().expressions().get(outputs.get(outputName)); + if (outputExpression == null) + throw new IllegalArgumentException("Missing output '" + outputName + "' in " + this); return new ImportedMlFunction(functionName, new ArrayList<>(inputs.values()), - owner().expressions().get(outputs.get(outputName)).getRoot().toString(), + outputExpression.getRoot().toString(), asStrings(inputMap()), Optional.empty()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 54d4bd3cb0a..6c583d960bd 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -3,9 +3,11 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.MatMul; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -20,7 +22,7 @@ import java.util.Set; public class IntermediateGraph { private final String modelName; - private final Map<String, IntermediateOperation> index = new HashMap<>(); + private final Map<String, IntermediateOperation> operations = new HashMap<>(); private final Map<String, GraphSignature> signatures = new HashMap<>(); private static class GraphSignature { @@ -37,11 +39,11 @@ public class IntermediateGraph { } public IntermediateOperation put(String key, IntermediateOperation operation) { - return index.put(key, operation); + return operations.put(key, operation); } public IntermediateOperation get(String key) { - return index.get(key); + return operations.get(key); } public Set<String> signatures() { @@ -61,11 +63,11 @@ public class IntermediateGraph { } public boolean alreadyImported(String key) { - return index.containsKey(key); + return operations.containsKey(key); } - public Collection<IntermediateOperation> operations() { - return index.values(); + public Map<String, IntermediateOperation> operations() { + return operations; } void optimize() { @@ -76,16 +78,16 @@ public class IntermediateGraph { * Find dimension names to avoid excessive renaming while evaluating the model. */ private void renameDimensions() { - DimensionRenamer renamer = new DimensionRenamer(); + DimensionRenamer renamer = new DimensionRenamer(this); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - addDimensionNameConstraints(index.get(output), renamer); + addDimensionNameConstraints(operations.get(output), renamer); } } renamer.solve(); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - renameDimensions(index.get(output), renamer); + renameDimensions(operations.get(output), renamer); } } } @@ -111,7 +113,7 @@ public class IntermediateGraph { public String toFullString() { StringBuilder b = new StringBuilder(); - for (var input : index.entrySet()) + for (var input : operations.entrySet()) b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n"); return b.toString(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 19c2026d457..b587a9200ec 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -227,7 +227,7 @@ public abstract class ModelImporter implements MlModelImporter { * for fast model weight updates. */ private static void logVariableTypes(IntermediateGraph graph) { - for (IntermediateOperation operation : graph.operations()) { + for (IntermediateOperation operation : graph.operations().values()) { if ( ! (operation instanceof Constant)) continue; if ( ! operation.type().isPresent()) continue; // will not happen log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 878fa1ca1b1..cc9985af6d4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -58,6 +59,8 @@ public abstract class IntermediateOperation { protected abstract OrderedTensorType lazyGetType(); protected abstract TensorFunction lazyGetFunction(); + public String modelName() { return modelName; } + /** Returns the Vespa tensor type of this operation if it exists */ public Optional<OrderedTensorType> type() { if (type == null) { @@ -189,6 +192,16 @@ public abstract class IntermediateOperation { .collect(Collectors.toList())); } + public IntermediateOperation withInputs(List<IntermediateOperation> inputs) { + throw new UnsupportedOperationException(); + } + + public String toFullString() { return toString(); } + + String asString(Optional<OrderedTensorType> type) { + return type.map(t -> t.toString()).orElse("(unknown)"); + } + /** * A method signature input and output has the form name:index. * This returns the name part without the index. @@ -217,10 +230,4 @@ public abstract class IntermediateOperation { Optional<List<Value>> getList(String key); } - public String toFullString() { return toString(); } - - String asString(Optional<OrderedTensorType> type) { - return type.map(t -> t.toString()).orElse("(unknown)"); - } - } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 6c6b51a27a5..9158eeea02b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -81,6 +81,11 @@ public class MatMul extends IntermediateOperation { } @Override + public MatMul withInputs(List<IntermediateOperation> inputs) { + return new MatMul(modelName(), name(), inputs); + } + + @Override public String toFullString() { return "\t" + lazyGetType() + ":\tMatMul(" + inputs().get(0).toFullString() + ", " + inputs().get(1).toFullString() + ")"; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java new file mode 100644 index 00000000000..264ee6b9dff --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -0,0 +1,49 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +/** + * Renames a tensor dimension to relax dimension constraints + * + * @author bratseth + */ +public class Rename extends IntermediateOperation { + + private final String from, to; + + public Rename(String modelName, String from, String to, IntermediateOperation input) { + super(modelName, "rename", List.of(input)); + this.from = from; + this.to = to; + } + + @Override + boolean allInputFunctionsPresent(int expected) { + return super.allInputFunctionsPresent(expected); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().orElse(null); + if (inputType == null) return null; + + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(inputType.type().valueType()); + for (TensorType.Dimension dimension : inputType.dimensions()) + builder.add(dimension.withName(dimension.name().equals(from) ? to : dimension.name())); + return builder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) return null; + return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java index ee282c7d988..793258868ee 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java @@ -9,7 +9,7 @@ public class DimensionRenamerTest { @Test public void testMnistRenaming() { - DimensionRenamer renamer = new DimensionRenamer(); + DimensionRenamer renamer = new DimensionRenamer(new IntermediateGraph("test")); renamer.addDimension("first_dimension_of_x"); renamer.addDimension("second_dimension_of_x"); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java index be0ab4b894a..6500a380190 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java @@ -20,9 +20,9 @@ public class Issue9662TestCase { Assert.assertEquals("Should have no skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ImportedMlFunction output = signature.outputFunction("y", "y"); + ImportedMlFunction output = signature.outputFunction("output", "output"); assertNotNull(output); - model.assertEqualResultSum("input", "dnn/outputs/add", 0.00001); + model.assertEqualResultSum("input_embedding_user_guid", "dense_out/MatMul", 0.00001); } } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index 9d2f8cf0692..75fa2ed7933 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -49,7 +49,7 @@ public class TestableTensorFlowModel { public ImportedModel get() { return model; } - /** Compare that summing the tensors produce the same result to within some tolerance delta */ + /** Compare that computing the expressions produce the same result to within some tolerance delta */ public void assertEqualResultSum(String inputName, String operationName, double delta) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); Context context = contextFrom(model); |