summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-03 09:47:49 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-03 09:47:49 -0700
commit0ce6fa7cbdf71fd39cb5bb18accfa84a20e7e120 (patch)
tree0afa3697a10fad159883a402a2f367ee8175c027
parentfdc6a48fe913de5f5a84c6eb42123543c1d2ee46 (diff)
Inject rename to relax constraint proof of concept
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java38
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java22
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java19
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java49
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java2
10 files changed, 122 insertions, 26 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 0579af13154..9821870e38b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -2,12 +2,17 @@
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.operations.MatMul;
+import ai.vespa.rankingexpression.importer.operations.Rename;
import com.yahoo.collections.ListMap;
import com.yahoo.lang.MutableInteger;
+import com.yahoo.text.ExpressionFormatter;
import java.util.ArrayDeque;
+import java.util.ArrayList;
import java.util.Deque;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
@@ -28,17 +33,22 @@ public class DimensionRenamer {
private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName());
private final String dimensionPrefix;
+
+ /** The graph we are renaming the dimensions of */
+ private final IntermediateGraph graph;
+
private final ListMap<String, Integer> variables = new ListMap<>();
private final ListMap<Arc, Constraint> constraints = new ListMap<>();
/** The solution to this, or null if no solution is found (yet) */
private Map<String, Integer> renames = null;
- public DimensionRenamer() {
- this("d");
+ public DimensionRenamer(IntermediateGraph graph) {
+ this(graph, "d");
}
- public DimensionRenamer(String dimensionPrefix) {
+ public DimensionRenamer(IntermediateGraph graph, String dimensionPrefix) {
+ this.graph = graph;
this.dimensionPrefix = dimensionPrefix;
}
@@ -84,12 +94,32 @@ public class DimensionRenamer {
* @return the solution in the form of the renames to perform
*/
private Map<String, Integer> solve(int maxIterations) {
- variables.freeze();
+ // variables.freeze();
Map<String, Integer> renames = new HashMap<>();
// Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
boolean solved = trySolve(variables, constraints, maxIterations, renames);
if ( ! solved) {
+ IntermediateOperation operation = graph.operations().get("dense_out/MatMul");
+ if (operation != null && operation instanceof MatMul) {
+ IntermediateOperation arg0 = operation.inputs().get(0);
+ List<IntermediateOperation> inputs = new ArrayList<>(operation.inputs());
+ inputs.set(0, new Rename(arg0.modelName(), "Dot_ExpandDims_1", "renamed_0", arg0));
+ IntermediateOperation newOperation = operation.withInputs(inputs);
+ graph.put("dense_out/MatMul", newOperation);
+
+ for (Arc key : new HashSet<>(constraints.keySet())) {
+ if (key.operation == operation)
+ constraints.removeAll(key);
+ }
+ addDimension("renamed_0");
+ newOperation.addDimensionNameConstraints(this);
+
+ renames.clear();
+ solved = trySolve(variables, constraints, maxIterations, renames);
+ }
+ }
+ if ( ! solved) {
renames.clear();
ListMap<Arc, Constraint> hardConstraints = new ListMap<>();
boolean anyRemoved = copyHard(constraints, hardConstraints);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 0c570261ae7..a9be1bbd40e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -262,9 +262,12 @@ public class ImportedModel implements ImportedMlModel {
/** Returns the expression this output references as an imported function */
public ImportedMlFunction outputFunction(String outputName, String functionName) {
+ RankingExpression outputExpression = owner().expressions().get(outputs.get(outputName));
+ if (outputExpression == null)
+ throw new IllegalArgumentException("Missing output '" + outputName + "' in " + this);
return new ImportedMlFunction(functionName,
new ArrayList<>(inputs.values()),
- owner().expressions().get(outputs.get(outputName)).getRoot().toString(),
+ outputExpression.getRoot().toString(),
asStrings(inputMap()),
Optional.empty());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 54d4bd3cb0a..6c583d960bd 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -3,9 +3,11 @@
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -20,7 +22,7 @@ import java.util.Set;
public class IntermediateGraph {
private final String modelName;
- private final Map<String, IntermediateOperation> index = new HashMap<>();
+ private final Map<String, IntermediateOperation> operations = new HashMap<>();
private final Map<String, GraphSignature> signatures = new HashMap<>();
private static class GraphSignature {
@@ -37,11 +39,11 @@ public class IntermediateGraph {
}
public IntermediateOperation put(String key, IntermediateOperation operation) {
- return index.put(key, operation);
+ return operations.put(key, operation);
}
public IntermediateOperation get(String key) {
- return index.get(key);
+ return operations.get(key);
}
public Set<String> signatures() {
@@ -61,11 +63,11 @@ public class IntermediateGraph {
}
public boolean alreadyImported(String key) {
- return index.containsKey(key);
+ return operations.containsKey(key);
}
- public Collection<IntermediateOperation> operations() {
- return index.values();
+ public Map<String, IntermediateOperation> operations() {
+ return operations;
}
void optimize() {
@@ -76,16 +78,16 @@ public class IntermediateGraph {
* Find dimension names to avoid excessive renaming while evaluating the model.
*/
private void renameDimensions() {
- DimensionRenamer renamer = new DimensionRenamer();
+ DimensionRenamer renamer = new DimensionRenamer(this);
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- addDimensionNameConstraints(index.get(output), renamer);
+ addDimensionNameConstraints(operations.get(output), renamer);
}
}
renamer.solve();
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- renameDimensions(index.get(output), renamer);
+ renameDimensions(operations.get(output), renamer);
}
}
}
@@ -111,7 +113,7 @@ public class IntermediateGraph {
public String toFullString() {
StringBuilder b = new StringBuilder();
- for (var input : index.entrySet())
+ for (var input : operations.entrySet())
b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n");
return b.toString();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 19c2026d457..b587a9200ec 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -227,7 +227,7 @@ public abstract class ModelImporter implements MlModelImporter {
* for fast model weight updates.
*/
private static void logVariableTypes(IntermediateGraph graph) {
- for (IntermediateOperation operation : graph.operations()) {
+ for (IntermediateOperation operation : graph.operations().values()) {
if ( ! (operation instanceof Constant)) continue;
if ( ! operation.type().isPresent()) continue; // will not happen
log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() +
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 878fa1ca1b1..cc9985af6d4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -3,6 +3,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -58,6 +59,8 @@ public abstract class IntermediateOperation {
protected abstract OrderedTensorType lazyGetType();
protected abstract TensorFunction lazyGetFunction();
+ public String modelName() { return modelName; }
+
/** Returns the Vespa tensor type of this operation if it exists */
public Optional<OrderedTensorType> type() {
if (type == null) {
@@ -189,6 +192,16 @@ public abstract class IntermediateOperation {
.collect(Collectors.toList()));
}
+ public IntermediateOperation withInputs(List<IntermediateOperation> inputs) {
+ throw new UnsupportedOperationException();
+ }
+
+ public String toFullString() { return toString(); }
+
+ String asString(Optional<OrderedTensorType> type) {
+ return type.map(t -> t.toString()).orElse("(unknown)");
+ }
+
/**
* A method signature input and output has the form name:index.
* This returns the name part without the index.
@@ -217,10 +230,4 @@ public abstract class IntermediateOperation {
Optional<List<Value>> getList(String key);
}
- public String toFullString() { return toString(); }
-
- String asString(Optional<OrderedTensorType> type) {
- return type.map(t -> t.toString()).orElse("(unknown)");
- }
-
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 6c6b51a27a5..9158eeea02b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -81,6 +81,11 @@ public class MatMul extends IntermediateOperation {
}
@Override
+ public MatMul withInputs(List<IntermediateOperation> inputs) {
+ return new MatMul(modelName(), name(), inputs);
+ }
+
+ @Override
public String toFullString() {
return "\t" + lazyGetType() + ":\tMatMul(" + inputs().get(0).toFullString() + ", " + inputs().get(1).toFullString() + ")";
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
new file mode 100644
index 00000000000..264ee6b9dff
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -0,0 +1,49 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+/**
+ * Renames a tensor dimension to relax dimension constraints
+ *
+ * @author bratseth
+ */
+public class Rename extends IntermediateOperation {
+
+ private final String from, to;
+
+ public Rename(String modelName, String from, String to, IntermediateOperation input) {
+ super(modelName, "rename", List.of(input));
+ this.from = from;
+ this.to = to;
+ }
+
+ @Override
+ boolean allInputFunctionsPresent(int expected) {
+ return super.allInputFunctionsPresent(expected);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().orElse(null);
+ if (inputType == null) return null;
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(inputType.type().valueType());
+ for (TensorType.Dimension dimension : inputType.dimensions())
+ builder.add(dimension.withName(dimension.name().equals(from) ? to : dimension.name()));
+ return builder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(1)) return null;
+ return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to);
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
index ee282c7d988..793258868ee 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
@@ -9,7 +9,7 @@ public class DimensionRenamerTest {
@Test
public void testMnistRenaming() {
- DimensionRenamer renamer = new DimensionRenamer();
+ DimensionRenamer renamer = new DimensionRenamer(new IntermediateGraph("test"));
renamer.addDimension("first_dimension_of_x");
renamer.addDimension("second_dimension_of_x");
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java
index be0ab4b894a..6500a380190 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java
@@ -20,9 +20,9 @@ public class Issue9662TestCase {
Assert.assertEquals("Should have no skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ImportedMlFunction output = signature.outputFunction("y", "y");
+ ImportedMlFunction output = signature.outputFunction("output", "output");
assertNotNull(output);
- model.assertEqualResultSum("input", "dnn/outputs/add", 0.00001);
+ model.assertEqualResultSum("input_embedding_user_guid", "dense_out/MatMul", 0.00001);
}
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index 9d2f8cf0692..75fa2ed7933 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -49,7 +49,7 @@ public class TestableTensorFlowModel {
public ImportedModel get() { return model; }
- /** Compare that summing the tensors produce the same result to within some tolerance delta */
+ /** Compare that computing the expressions produce the same result to within some tolerance delta */
public void assertEqualResultSum(String inputName, String operationName, double delta) {
Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
Context context = contextFrom(model);