summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2019-07-08 14:47:54 +0200
committerGitHub <noreply@github.com>2019-07-08 14:47:54 +0200
commitd0b6a8a2fe100ade8d3aac5689bead29118480ad (patch)
tree466faa85c0f1a6d8311d62983e23c00126240b85
parent4e8a65ed3701c814459b5ce58291d9764446d873 (diff)
parent76e924dcde2613c7956a50c29dbcc082e2b3b59c (diff)
Merge pull request #9944 from vespa-engine/bratseth/output-immediate-graph
Bratseth/output immediate graph
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java421
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java32
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java126
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java13
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java24
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java23
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java25
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java29
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java38
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java29
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java67
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java8
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java12
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java28
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java2
-rw-r--r--model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt1318
-rw-r--r--vespajlib/abi-spec.json14
-rw-r--r--vespajlib/src/main/java/com/yahoo/collections/ListMap.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java180
-rw-r--r--vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java190
37 files changed, 2579 insertions, 178 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java
index 61ce9d98e69..b2f5d104890 100644
--- a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java
+++ b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java
@@ -73,14 +73,13 @@ public class BlendingSearcher extends Searcher {
}
/**
- * Produce a single blended result list from a group of hitgroups.
+ * Produce a single blended hit list from a group of hitgroups.
*
- * It is assumed that the results are ordered in hitgroups. If not, the blend will not be performed
+ * This assumes that all hits are organized into hitgroups. If not, blending will not be performed.
*/
protected Result blendResults(Result result, Query q, int offset, int hits, Execution execution) {
//Assert that there are more than one hitgroup and that there are only hitgroups on the lowest level
-
boolean foundNonGroup = false;
Iterator<Hit> hitIterator = result.hits().iterator();
List<HitGroup> groups = new ArrayList<>();
@@ -89,14 +88,14 @@ public class BlendingSearcher extends Searcher {
if (hit instanceof HitGroup) {
groups.add((HitGroup)hit);
hitIterator.remove();
- } else if(!hit.isMeta()) {
+ } else if ( ! hit.isMeta()) {
foundNonGroup = true;
}
}
- if(foundNonGroup) {
+ if( foundNonGroup) {
result.hits().addError(ErrorMessage.createUnspecifiedError("Blendingsearcher could not blend - there are toplevel hits" +
- " that are not hitgroups"));
+ " that are not hitgroups"));
return result;
}
if (groups.size() == 0) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 9e9f66be700..0f563a75b11 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -2,179 +2,179 @@
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.operations.Rename;
+import com.yahoo.collections.ListMap;
-import java.util.ArrayDeque;
import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Deque;
+import java.util.Comparator;
import java.util.HashMap;
-import java.util.Iterator;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.Set;
+import java.util.logging.Level;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
- * A constraint satisfier to find suitable dimension names to reduce the
+ * A constraint solver which finds suitable dimension names to reduce the
* amount of necessary renaming during evaluation of an imported model.
*
* @author lesters
+ * @author bratseth
*/
public class DimensionRenamer {
+ private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName());
+
private final String dimensionPrefix;
- private final Map<String, List<Integer>> variables = new HashMap<>();
- private final Map<Arc, Constraint> constraints = new HashMap<>();
- private final Map<String, Integer> renames = new HashMap<>();
- private int iterations = 0;
+ /** The graph we are renaming the dimensions of */
+ private final IntermediateGraph graph;
+
+ /** The set of dimensions to find a solution for */
+ private final Set<String> dimensions = new HashSet<>();
+
+ /** The constraints on the dimension name assignment */
+ private final ListMap<Arc, Constraint> constraints = new ListMap<>();
+
+ /** The solution to this, or null if no solution is found yet */
+ private Map<String, Integer> renames = null;
- public DimensionRenamer() {
- this("d");
+ public DimensionRenamer(IntermediateGraph graph) {
+ this(graph, "d");
}
- public DimensionRenamer(String dimensionPrefix) {
+ public DimensionRenamer(IntermediateGraph graph, String dimensionPrefix) {
+ this.graph = graph;
this.dimensionPrefix = dimensionPrefix;
}
- /**
- * Add a dimension name variable.
- */
- public void addDimension(String name) {
- variables.computeIfAbsent(name, d -> new ArrayList<>());
- }
+ /** Add a dimension to the set of dimensions to be renamed */
+ public void addDimension(String name) { dimensions.add(name); }
+
+ /** Add a constraint between dimension names */
+ public void addConstraint(String from, String to, Constraint constraint, IntermediateOperation operation) {
+ if (constraint instanceof EqualConstraint && from.equals(to)) return;
- /**
- * Add a constraint between dimension names.
- */
- public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
Arc arc = new Arc(from, to, operation);
- Arc opposite = arc.opposite();
- constraints.put(arc, pred);
- constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
+ constraints.put(arc, constraint);
+ constraints.put(arc.opposite(), constraint.opposite()); // make constraint graph symmetric
}
- /**
- * Retrieve resulting name of dimension after solving for constraints.
- */
- public Optional<String> dimensionNameOf(String name) {
- if (!renames.containsKey(name)) {
- return Optional.empty();
- }
- return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
+ void solve() {
+ log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints));
+ renames = solve(100000);
+ log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames));
}
- /**
- * Perform iterative arc consistency until we have found a solution. After
- * an initial iteration, the variables (dimensions) will have multiple
- * valid values. Find a single valid assignment by iteratively locking one
- * dimension after another, and running the arc consistency algorithm
- * multiple times.
- *
- * This requires having constraints that result in an absolute ordering:
- * equals, lesserThan and greaterThan do that, but adding notEquals does
- * not typically result in a guaranteed ordering. If that is needed, the
- * algorithm below needs to be adapted with a backtracking (tree) search
- * to find solutions.
- */
- private void solve(int maxIterations) {
- initialize();
-
- // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
+ private Map<String, Integer> solve(int maxIterations) {
+ Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations);
+ if (solution != null) return solution;
- for (String dimension : variables.keySet()) {
- List<Integer> values = variables.get(dimension);
- if (values.size() > 1) {
- if (!ac3()) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
- }
- values.sort(Integer::compare);
- variables.put(dimension, Collections.singletonList(values.get(0)));
- }
- renames.put(dimension, variables.get(dimension).get(0));
- if (iterations > maxIterations) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
- maxIterations + " iterations");
- }
+ for (RenameTarget target : prioritizedRenameTargets()) {
+ System.out.println("Trying rename " + target);
+ target.insertRename(this);
+ solution = solveWithOrWithoutSoftConstraints(maxIterations);
+ if (solution != null) return solution;
+ target.uninsertRename(this);
}
-
- // Todo: handle failure more gracefully:
- // If a solution can't be found, look at the operation node in the arc
- // with the most remaining constraints, and inject a rename operation.
- // Then run this algorithm again.
+ throw new IllegalArgumentException("Could not find a dimension naming solution " +
+ "given constraints\n" + constraintsToString(constraints));
}
- void solve() {
- solve(100000);
+ private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) {
+ Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations);
+ if ( solution == null) {
+ ListMap<Arc, Constraint> hardConstraints = new ListMap<>();
+ boolean anyRemoved = copyHard(constraints, hardConstraints);
+ if (anyRemoved)
+ solution = NamingConstraintSolver.solve(dimensions, hardConstraints, maxIterations);
+ }
+ return solution;
}
- private void initialize() {
- for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
- List<Integer> values = variable.getValue();
- for (int i = 0; i < variables.size(); ++i) {
- values.add(i); // invariant: values are in increasing order
+ /** Removes soft constraints and returns whether something was removed */
+ private boolean copyHard(ListMap<Arc, Constraint> source, ListMap<Arc, Constraint> target) {
+ boolean removed = false;
+ for (var entry : source.entrySet()) {
+ Arc arc = entry.getKey();
+ for (Constraint constraint : entry.getValue()) {
+ if ( ! constraint.isSoft())
+ target.put(arc, constraint);
+ else
+ removed = true;
}
}
+ return removed;
}
- private boolean ac3() {
- Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
- while (!workList.isEmpty()) {
- Arc arc = workList.pop();
- iterations += 1;
- if (revise(arc)) {
- if (variables.get(arc.from).size() == 0) {
- return false; // no solution found
- }
- for (Arc constraint : constraints.keySet()) {
- if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
- workList.add(constraint);
- }
- }
- }
+ private List<RenameTarget> prioritizedRenameTargets() {
+ Map<IntermediateOperation, Integer> constraintsPerOperation = new HashMap<>();
+
+ for (var constraint : constraints.entrySet()) {
+ constraintsPerOperation.compute(constraint.getKey().operation,
+ (operation, count) -> count == null ? 1 : ++count);
}
- return true;
- }
+ List<IntermediateOperation> prioritizedOperations =
+ constraintsPerOperation.entrySet().stream()
+ .sorted(Comparator.comparingInt(entry -> - entry.getValue()))
+ .map(entry -> entry.getKey())
+ .collect(Collectors.toList());
- private boolean revise(Arc arc) {
- boolean revised = false;
- for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
- Integer from = fromIterator.next();
- boolean satisfied = false;
- for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
- Integer to = toIterator.next();
- if (constraints.get(arc).test(from, to)) {
- satisfied = true;
+ List<RenameTarget> targets = new ArrayList<>();
+ for (IntermediateOperation operation : prioritizedOperations) {
+ for (int i = 0; i < operation.inputs().size(); i++) {
+ Optional<OrderedTensorType> inputType = operation.inputs().get(i).type();
+ if (inputType.isEmpty()) continue;
+ for (String dimensionName : inputType.get().dimensionNames()) {
+ RenameTarget target = new RenameTarget(operation, i, dimensionName, graph);
+ if (target.rootKey != null) // TODO: Inserting renames under non-roots is not implemented
+ targets.add(target);
}
}
- if (!satisfied) {
- fromIterator.remove();
- revised = true;
- }
}
- return revised;
- }
-
- public interface Constraint {
- boolean test(Integer x, Integer y);
+ return targets;
}
- public static boolean equals(Integer x, Integer y) {
- return Objects.equals(x, y);
+ /**
+ * Retrieve resulting name of a dimension after solving for constraints, or empty if no
+ * solution is found yet, or this dimension was not added before finding a solution.
+ */
+ public Optional<String> dimensionNameOf(String name) {
+ if ( renames == null || ! renames.containsKey(name))
+ return Optional.empty();
+ return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
}
- public static boolean lesserThan(Integer x, Integer y) {
- return x < y;
+ private static String renamesToString(Map<String, Integer> renames) {
+ return renames.entrySet().stream()
+ .map(e -> " " + e.getKey() + " -> " + e.getValue())
+ .collect(Collectors.joining("\n"));
}
- public static boolean greaterThan(Integer x, Integer y) {
- return x > y;
+ private static String constraintsToString(ListMap<Arc, Constraint> constraints) {
+ StringBuilder b = new StringBuilder();
+ for (var entry : constraints.entrySet()) {
+ Arc arc = entry.getKey();
+ for (Constraint constraint : entry.getValue()) {
+ if (constraint.isOpposite()) continue; // noise
+ b.append(" ");
+ if (constraint.isSoft())
+ b.append("(soft) ");
+ b.append(arc.from).append(" ").append(constraint).append(" ").append(arc.to);
+ b.append(" (origin: ").append(arc.operation).append(")\n");
+ }
+ }
+ return b.toString();
}
- private static class Arc {
+ static class Arc {
- private final String from;
- private final String to;
+ final String from;
+ final String to;
private final IntermediateOperation operation;
Arc(String from, String to, IntermediateOperation operation) {
@@ -194,7 +194,7 @@ public class DimensionRenamer {
@Override
public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof Arc)) {
+ if (!(obj instanceof Arc)) {
return false;
}
Arc other = (Arc) obj;
@@ -203,8 +203,185 @@ public class DimensionRenamer {
@Override
public String toString() {
- return String.format("%s -> %s", from, to);
+ return from + " -> " + to;
+ }
+ }
+
+ public static abstract class Constraint {
+
+ private final boolean soft, opposite;
+
+ protected Constraint(boolean soft, boolean opposite) {
+ this.soft = soft;
+ this.opposite = opposite;
+ }
+
+ abstract boolean test(Integer x, Integer y);
+ abstract Constraint opposite();
+
+ /** Returns whether this constraint can be violated if that is necessary to achieve a solution */
+ boolean isSoft() { return soft; }
+
+ /** Returns whether this is an opposite of another constraint */
+ boolean isOpposite() { return opposite; }
+
+ public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); }
+ public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); }
+ public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); }
+ public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); }
+
+ }
+
+ private static class EqualConstraint extends Constraint {
+
+ private EqualConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return Objects.equals(x, y); }
+
+ @Override
+ public Constraint opposite() { return new EqualConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return "=="; }
+
+ }
+
+ private static class NotEqualConstraint extends Constraint {
+
+ private NotEqualConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return ! Objects.equals(x, y); }
+
+ @Override
+ public Constraint opposite() { return new NotEqualConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return "!="; }
+
+ }
+
+ private static class LessThanConstraint extends Constraint {
+
+ private LessThanConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return x < y; }
+
+ @Override
+ public Constraint opposite() { return new GreaterThanConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return "<"; }
+
+ }
+
+ private static class GreaterThanConstraint extends Constraint {
+
+ private GreaterThanConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return x > y; }
+
+ @Override
+ public Constraint opposite() { return new LessThanConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return ">"; }
+
+ }
+
+ /**
+ * An operation and an input number which we may want to insert a rename operation at.
+ * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...).
+ *
+ * This class is (and must remain) immutable.
+ */
+ private static class RenameTarget {
+
+ final IntermediateOperation operation;
+ final int inputNumber;
+ final String dimensionName;
+ final IntermediateGraph graph;
+
+ /**
+ * Returns the key of this operation in the root operations of the graph,
+ * or null if it is not a root operation
+ */
+ final String rootKey;
+
+ public RenameTarget(IntermediateOperation operation, int inputNumber, String dimensionName, IntermediateGraph graph) {
+ this.operation = operation;
+ this.inputNumber = inputNumber;
+ this.dimensionName = dimensionName;
+ this.rootKey = findRootKey(operation, graph);
+ this.graph = graph;
+ }
+
+ public IntermediateOperation input() {
+ return operation.inputs().get(inputNumber);
+ }
+
+ private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) {
+ for (var entry : graph.operations().entrySet()) {
+ if (entry.getValue() == operation)
+ return entry.getKey();
+ }
+ return null;
+ }
+
+ /** Inserts a rename operation if possible. Returns whether an operation was inserted. */
+ private boolean insertRename(DimensionRenamer renamer) {
+ Rename rename = new Rename(operation.modelName(),
+ dimensionName,
+ renamer.dimensionPrefix + renamer.dimensions.size(),
+ input());
+
+ List<IntermediateOperation> newInputs = new ArrayList<>(operation.inputs());
+ newInputs.set(inputNumber, rename);
+ IntermediateOperation newOperation = operation.withInputs(newInputs);
+ if (rootKey == null)
+ throw new IllegalStateException("Renaming non-roots is not implemented");
+ graph.put(rootKey, newOperation);
+
+ removeConstraintsOf(operation, renamer);
+ rename.addDimensionNameConstraints(renamer);
+ newOperation.addDimensionNameConstraints(renamer);
+ return true;
+ }
+
+ /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */
+ private void uninsertRename(DimensionRenamer renamer) {
+ IntermediateOperation newOperation = graph.operations().get(rootKey);
+ Rename rename = (Rename)newOperation.inputs().get(inputNumber);
+ graph.put(rootKey, operation);
+
+ removeConstraintsOf(rename, renamer);
+ removeConstraintsOf(newOperation, renamer);
+ operation.addDimensionNameConstraints(renamer);
+ }
+
+ private void removeConstraintsOf(IntermediateOperation operation, DimensionRenamer renamer) {
+ for (Arc key : new HashSet<>(renamer.constraints.keySet())) {
+ if (key.operation == operation)
+ renamer.constraints.removeAll(key);
+ }
+ }
+
+ @Override
+ public String toString() {
+ return operation + ", input " + inputNumber;
}
+
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 0c570261ae7..a9be1bbd40e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -262,9 +262,12 @@ public class ImportedModel implements ImportedMlModel {
/** Returns the expression this output references as an imported function */
public ImportedMlFunction outputFunction(String outputName, String functionName) {
+ RankingExpression outputExpression = owner().expressions().get(outputs.get(outputName));
+ if (outputExpression == null)
+ throw new IllegalArgumentException("Missing output '" + outputName + "' in " + this);
return new ImportedMlFunction(functionName,
new ArrayList<>(inputs.values()),
- owner().expressions().get(outputs.get(outputName)).getRoot().toString(),
+ outputExpression.getRoot().toString(),
asStrings(inputMap()),
Optional.empty());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index aec98d06874..6c583d960bd 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -3,9 +3,11 @@
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -20,7 +22,7 @@ import java.util.Set;
public class IntermediateGraph {
private final String modelName;
- private final Map<String, IntermediateOperation> index = new HashMap<>();
+ private final Map<String, IntermediateOperation> operations = new HashMap<>();
private final Map<String, GraphSignature> signatures = new HashMap<>();
private static class GraphSignature {
@@ -37,11 +39,11 @@ public class IntermediateGraph {
}
public IntermediateOperation put(String key, IntermediateOperation operation) {
- return index.put(key, operation);
+ return operations.put(key, operation);
}
public IntermediateOperation get(String key) {
- return index.get(key);
+ return operations.get(key);
}
public Set<String> signatures() {
@@ -61,11 +63,11 @@ public class IntermediateGraph {
}
public boolean alreadyImported(String key) {
- return index.containsKey(key);
+ return operations.containsKey(key);
}
- public Collection<IntermediateOperation> operations() {
- return index.values();
+ public Map<String, IntermediateOperation> operations() {
+ return operations;
}
void optimize() {
@@ -76,16 +78,16 @@ public class IntermediateGraph {
* Find dimension names to avoid excessive renaming while evaluating the model.
*/
private void renameDimensions() {
- DimensionRenamer renamer = new DimensionRenamer();
+ DimensionRenamer renamer = new DimensionRenamer(this);
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- addDimensionNameConstraints(index.get(output), renamer);
+ addDimensionNameConstraints(operations.get(output), renamer);
}
}
renamer.solve();
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- renameDimensions(index.get(output), renamer);
+ renameDimensions(operations.get(output), renamer);
}
}
}
@@ -104,4 +106,16 @@ public class IntermediateGraph {
}
}
+ @Override
+ public String toString() {
+ return "intermediate graph for '" + modelName + "'";
+ }
+
+ public String toFullString() {
+ StringBuilder b = new StringBuilder();
+ for (var input : operations.entrySet())
+ b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n");
+ return b.toString();
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 99bfa08db43..b587a9200ec 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -11,12 +11,14 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.text.ExpressionFormatter;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Level;
import java.util.logging.Logger;
/**
@@ -50,6 +52,8 @@ public abstract class ModelImporter implements MlModelImporter {
*/
protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
ImportedModel model = new ImportedModel(graph.name(), modelSource);
+ log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" +
+ ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString()));
graph.optimize();
@@ -223,7 +227,7 @@ public abstract class ModelImporter implements MlModelImporter {
* for fast model weight updates.
*/
private static void logVariableTypes(IntermediateGraph graph) {
- for (IntermediateOperation operation : graph.operations()) {
+ for (IntermediateOperation operation : graph.operations().values()) {
if ( ! (operation instanceof Constant)) continue;
if ( ! operation.type().isPresent()) continue; // will not happen
log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() +
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
new file mode 100644
index 00000000000..21cc6b27dad
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
@@ -0,0 +1,126 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import com.yahoo.collections.ListMap;
+
+import java.util.ArrayDeque;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Solves a dimension naming constraint problem.
+ *
+ * @author lesters
+ * @author bratseth
+ */
+class NamingConstraintSolver {
+
+ private final ListMap<String, Integer> possibleAssignments;
+ private final ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints;
+
+ private int iterations = 0;
+ private final int maxIterations;
+
+ private NamingConstraintSolver(Set<String> dimensions,
+ ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints,
+ int maxIterations) {
+ this.possibleAssignments = allPossibilities(dimensions);
+ this.constraints = constraints;
+ this.maxIterations = maxIterations;
+ }
+
+ /** Returns a list containing a list of all assignment possibilities for each of the given dimensions */
+ private static ListMap<String, Integer> allPossibilities(Set<String> dimensions) {
+ ListMap<String, Integer> all = new ListMap<>();
+ for (String dimension : dimensions) {
+ for (int i = 0; i < dimensions.size(); ++i)
+ all.put(dimension, i);
+ }
+ return all;
+ }
+
+ /**
+ * Try the solve the constraint problem given in the arguments, and put the result in renames.
+ *
+ * This is done by performing iterative arc consistency until we have found a solution.
+ * After an initial iteration, the dimensions will have multiple
+ * valid values. Find a single valid assignment by iteratively locking one
+ * dimension after another, and running the arc consistency algorithm
+ * multiple times.
+ *
+ * This requires having constraints that result in an absolute ordering:
+ * equal, lessThan and greaterThan do that, but not necessarily notEqual
+ * If that is needed, the algorithm needs to be adapted with a backtracking
+ * (tree) search
+ *
+ * @return the solution in the form of the renames to perform
+ */
+ private Map<String, Integer> trySolve() {
+ // TODO: Evaluate possible improved efficiency by using a heuristic such as min-conflicts
+
+ Map<String, Integer> solution = new HashMap<>();
+ for (String dimension : possibleAssignments.keySet()) {
+ List<Integer> values = possibleAssignments.get(dimension);
+ if (values.size() > 1) {
+ if ( ! ac3()) return null;
+ values.sort(Integer::compare);
+ possibleAssignments.replace(dimension, values.get(0));
+ }
+ solution.put(dimension, possibleAssignments.get(dimension).get(0)); // Pick the first available solution
+ if (iterations > maxIterations) return null;
+ }
+ return solution;
+ }
+
+ private boolean ac3() {
+ Deque<DimensionRenamer.Arc> workList = new ArrayDeque<>(constraints.keySet());
+ while ( ! workList.isEmpty()) {
+ DimensionRenamer.Arc arc = workList.pop();
+ iterations++;
+ if (revise(arc)) {
+ if (possibleAssignments.get(arc.from).isEmpty()) return false;
+
+ for (DimensionRenamer.Arc constraint : constraints.keySet()) {
+ if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from))
+ workList.add(constraint);
+ }
+ }
+ }
+ return true;
+ }
+
+ private boolean revise(DimensionRenamer.Arc arc) {
+ boolean revised = false;
+ for (Iterator<Integer> fromIterator = possibleAssignments.get(arc.from).iterator(); fromIterator.hasNext(); ) {
+ Integer from = fromIterator.next();
+ boolean satisfied = false;
+ for (Iterator<Integer> toIterator = possibleAssignments.get(arc.to).iterator(); toIterator.hasNext(); ) {
+ Integer to = toIterator.next();
+ if (constraints.get(arc).stream().allMatch(constraint -> constraint.test(from, to)))
+ satisfied = true;
+ }
+ if ( ! satisfied) {
+ fromIterator.remove();
+ revised = true;
+ }
+ }
+ return revised;
+ }
+
+ /**
+ * Attempts to solve the given naming problem. The input maps are never modified.
+ *
+ * @return the solution as a map from existing names to name ids represented as integers, or NULL
+ * if no solution could be found
+ */
+ public static Map<String, Integer> solve(Set<String> dimensions,
+ ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints,
+ int maxIterations) {
+ return new NamingConstraintSolver(dimensions, constraints, maxIterations).trySolve();
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
index 9115dc99b82..1cb8f3a2951 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -7,8 +7,11 @@ import com.yahoo.tensor.TensorTypeParser;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.stream.Collectors;
/**
@@ -131,11 +134,17 @@ public class OrderedTensorType {
public OrderedTensorType rename(DimensionRenamer renamer) {
List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
+ Map<String, String> new2Old = new HashMap<>(); // Just to create meaningful error messages
for (TensorType.Dimension dimension : dimensions) {
String oldName = dimension.name();
Optional<String> newName = renamer.dimensionNameOf(oldName);
- if (!newName.isPresent())
- return this; // presumably, already renamed
+ if ( newName.isEmpty()) return this; // presumably already renamed
+
+ if (new2Old.containsKey(newName.get()))
+ throw new IllegalArgumentException("Can not rename '" + oldName + "' to '" + newName.get() + "' in " + this +
+ " as '" + new2Old.get(newName.get()) + "' should also be renamed to it");
+ new2Old.put(newName.get(), oldName);
+
TensorType.Dimension.Type dimensionType = dimension.type();
if (dimensionType == TensorType.Dimension.Type.indexedBound) {
renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
index d6ea00ca453..9f62a27a3b9 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
@@ -29,7 +29,7 @@ public class Argument extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
- if (!standardNamingType.equals(type)) {
+ if ( ! standardNamingType.equals(type)) {
List<String> renameFrom = standardNamingType.dimensionNames();
List<String> renameTo = type.dimensionNames();
output = new Rename(output, renameFrom, renameTo);
@@ -39,9 +39,7 @@ public class Argument extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
+ addConstraintsFrom(type, renamer);
}
@Override
@@ -54,4 +52,22 @@ public class Argument extends IntermediateOperation {
return false;
}
+ @Override
+ public Argument withInputs(List<IntermediateOperation> inputs) {
+ if ( ! inputs.isEmpty())
+ throw new IllegalArgumentException("Argument cannot take inputs");
+ return new Argument(modelName(), name(), type);
+ }
+
+ @Override
+ public String operationName() { return "Argument"; }
+
+ @Override
+ public String toString() { return "Argument(" + standardNamingType + ")"; }
+
+ @Override
+ public String toFullString() {
+ return "\t" + lazyGetType() + ":\tArgument(" + standardNamingType + ")";
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index 7ae50a0549d..7787caa83ce 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
import java.util.Optional;
+import java.util.stream.Collectors;
public class ConcatV2 extends IntermediateOperation {
@@ -89,7 +90,7 @@ public class ConcatV2 extends IntermediateOperation {
OrderedTensorType b = inputs.get(i).type().get();
String bDim = b.dimensions().get(concatDimensionIndex).name();
String aDim = a.dimensions().get(concatDimensionIndex).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
}
}
@@ -99,4 +100,12 @@ public class ConcatV2 extends IntermediateOperation {
concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
}
+ @Override
+ public ConcatV2 withInputs(List<IntermediateOperation> inputs) {
+ return new ConcatV2(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "ConcatV2"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index 41d421b1f5a..d13c1ad5f3c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -62,9 +62,7 @@ public class Const extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
+ addConstraintsFrom(type, renamer);
}
@Override
@@ -86,4 +84,23 @@ public class Const extends IntermediateOperation {
}
return value.get();
}
+
+ @Override
+ public Const withInputs(List<IntermediateOperation> inputs) {
+ return new Const(modelName(), name(), inputs, attributeMap, type);
+ }
+
+ @Override
+ public String operationName() { return "Const"; }
+
+ @Override
+ public String toString() {
+ return "Const(" + type + ")";
+ }
+
+ @Override
+ public String toFullString() {
+ return "\t" + lazyGetType() + ":\tConst(" + type + ")";
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
index a1cc83296b0..1eaaf705220 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Collections;
+import java.util.List;
import java.util.Optional;
public class Constant extends IntermediateOperation {
@@ -48,9 +49,7 @@ public class Constant extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
+ addConstraintsFrom(type, renamer);
}
@Override
@@ -58,4 +57,24 @@ public class Constant extends IntermediateOperation {
return true;
}
+ @Override
+ public Constant withInputs(List<IntermediateOperation> inputs) {
+ if ( ! inputs.isEmpty())
+ throw new IllegalArgumentException("Constant cannot take inputs");
+ return new Constant(modelName(), name(), type);
+ }
+
+ @Override
+ public String operationName() { return "Constant"; }
+
+ @Override
+ public String toString() {
+ return "Constant(" + type + ")";
+ }
+
+ @Override
+ public String toFullString() {
+ return "\t" + lazyGetType() + ":\tConstant(" + type + ")";
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index c64b9ded601..e6cc96d48ad 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -30,7 +30,7 @@ public class ExpandDims extends IntermediateOperation {
if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation axisOperation = inputs().get(1);
- if (!axisOperation.getConstantValue().isPresent()) {
+ if ( ! axisOperation.getConstantValue().isPresent()) {
throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
@@ -47,18 +47,23 @@ public class ExpandDims extends IntermediateOperation {
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
- if (dimensionIndex == dimensionToInsert) {
- String name = String.format("%s_%d", vespaName(), dimensionIndex);
- expandDimensions.add(name);
- typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
- }
+ if (dimensionIndex == dimensionToInsert)
+ addDimension(dimensionIndex, typeBuilder);
typeBuilder.add(dimension);
dimensionIndex++;
}
-
+ if (dimensionToInsert == inputType.dimensions().size()) { // Insert last dimension
+ addDimension(dimensionIndex, typeBuilder);
+ }
return typeBuilder.build();
}
+ private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ expandDimensions.add(name);
+ typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
+ }
+
@Override
protected TensorFunction lazyGetFunction() {
if ( ! allInputFunctionsPresent(2)) return null;
@@ -88,7 +93,7 @@ public class ExpandDims extends IntermediateOperation {
List<String> renamedDimensions = new ArrayList<>(expandDimensions.size());
for (String name : expandDimensions) {
Optional<String> newName = renamer.dimensionNameOf(name);
- if (!newName.isPresent()) {
+ if ( ! newName.isPresent()) {
return; // presumably, already renamed
}
renamedDimensions.add(newName.get());
@@ -96,4 +101,12 @@ public class ExpandDims extends IntermediateOperation {
expandDimensions = renamedDimensions;
}
+ @Override
+ public ExpandDims withInputs(List<IntermediateOperation> inputs) {
+ return new ExpandDims(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "ExpandDims"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
index c2787aa14d4..5463f645355 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
@@ -32,4 +32,12 @@ public class Identity extends IntermediateOperation {
return inputs.get(0).function().orElse(null);
}
+ @Override
+ public Identity withInputs(List<IntermediateOperation> inputs) {
+ return new Identity(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Identity"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 0ee54f839bc..c3980b8fe93 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -3,6 +3,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -58,6 +59,8 @@ public abstract class IntermediateOperation {
protected abstract OrderedTensorType lazyGetType();
protected abstract TensorFunction lazyGetFunction();
+ public String modelName() { return modelName; }
+
/** Returns the Vespa tensor type of this operation if it exists */
public Optional<OrderedTensorType> type() {
if (type == null) {
@@ -99,6 +102,20 @@ public abstract class IntermediateOperation {
/** Add dimension name constraints for this operation */
public void addDimensionNameConstraints(DimensionRenamer renamer) { }
+ /** Conveinence method to adds dimensions and constraints of the given tensor type */
+ protected void addConstraintsFrom(OrderedTensorType type, DimensionRenamer renamer) {
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ renamer.addDimension(type.dimensions().get(i).name());
+
+ // Each dimension is distinct:
+ for (int j = i + 1; j < type.dimensions().size(); j++) {
+ renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
+ DimensionRenamer.Constraint.notEqual(false),
+ this);
+ }
+ }
+ }
+
/** Performs dimension rename for this operation */
public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
@@ -175,6 +192,12 @@ public abstract class IntermediateOperation {
.collect(Collectors.toList()));
}
+ public abstract IntermediateOperation withInputs(List<IntermediateOperation> inputs);
+
+ String asString(Optional<OrderedTensorType> type) {
+ return type.map(t -> t.toString()).orElse("(unknown)");
+ }
+
/**
* A method signature input and output has the form name:index.
* This returns the name part without the index.
@@ -203,4 +226,19 @@ public abstract class IntermediateOperation {
Optional<List<Value>> getList(String key);
}
+ public abstract String operationName();
+
+ @Override
+ public String toString() {
+ return operationName() +
+ inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) +
+ ")";
+ }
+
+ public String toFullString() {
+ return "\t" + lazyGetType() + ":\t" + operationName() +
+ inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) +
+ ")";
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index c2d75153586..adb54474812 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -95,7 +95,7 @@ public class Join extends IntermediateOperation {
for (int i = 0; i < b.rank(); ++i) {
String bDim = b.dimensions().get(i).name();
String aDim = a.dimensions().get(i + sizeDifference).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
}
}
@@ -111,4 +111,12 @@ public class Join extends IntermediateOperation {
return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
}
+ @Override
+ public Join withInputs(List<IntermediateOperation> inputs) {
+ return new Join(modelName(), name(), inputs, operator);
+ }
+
+ @Override
+ public String operationName() { return "Join"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
index e0842d820f9..ea39e289c48 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
@@ -34,4 +34,12 @@ public class Map extends IntermediateOperation {
return new com.yahoo.tensor.functions.Map(input.get(), operator);
}
+ @Override
+ public Map withInputs(List<IntermediateOperation> inputs) {
+ return new Map(modelName(), name(), inputs, operator);
+ }
+
+ @Override
+ public String operationName() { return "Map"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 9a76662529d..434261c6077 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.text.ExpressionFormatter;
import java.util.List;
import java.util.Optional;
@@ -51,20 +52,40 @@ public class MatMul extends IntermediateOperation {
List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+ assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
+ assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+
String aDim0 = aDimensions.get(0).name();
String aDim1 = aDimensions.get(1).name();
String bDim0 = bDimensions.get(0).name();
String bDim1 = bDimensions.get(1).name();
// The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this);
// The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this);
// For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this);
+ }
+
+ private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
+ if (dimensions.size() >= 2) return;
+
+
+ throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
+ " but got just " + dimensions + " from\n" +
+ ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
}
+ @Override
+ public MatMul withInputs(List<IntermediateOperation> inputs) {
+ return new MatMul(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "MatMul"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
index d8e9950c61f..215edf88c4f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
@@ -91,6 +91,11 @@ public class Mean extends IntermediateOperation {
reduceDimensions = renamedDimensions;
}
+ @Override
+ public Mean withInputs(List<IntermediateOperation> inputs) {
+ return new Mean(modelName(), name(), inputs, attributeMap);
+ }
+
private boolean shouldKeepDimensions() {
Optional<Value> keepDims = attributeMap.get("keep_dims");
return keepDims.isPresent() && keepDims.get().asBoolean();
@@ -108,4 +113,7 @@ public class Mean extends IntermediateOperation {
return builder.build();
}
+ @Override
+ public String operationName() { return "Mean"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
index ce0c58971d0..671cfe852a7 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
@@ -32,4 +32,12 @@ public class Merge extends IntermediateOperation {
return null;
}
+ @Override
+ public Merge withInputs(List<IntermediateOperation> inputs) {
+ return new Merge(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Merge"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
index 4c5ce33b1b5..35d89cf6ab6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
@@ -23,4 +23,12 @@ public class NoOp extends IntermediateOperation {
return null;
}
+ @Override
+ public NoOp withInputs(List<IntermediateOperation> inputs) {
+ return new NoOp(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "NoOp"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
index e5e5c29f8f1..177ef8d5e17 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
@@ -45,4 +45,12 @@ public class PlaceholderWithDefault extends IntermediateOperation {
return true; // not true if we add to function
}
+ @Override
+ public PlaceholderWithDefault withInputs(List<IntermediateOperation> inputs) {
+ return new PlaceholderWithDefault(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "PlaceholdeWithDefault"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
new file mode 100644
index 00000000000..abc431233be
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -0,0 +1,67 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+/**
+ * Renames a tensor dimension to relax dimension constraints
+ *
+ * @author bratseth
+ */
+public class Rename extends IntermediateOperation {
+
+ private final String from, to;
+
+ public Rename(String modelName, String from, String to, IntermediateOperation input) {
+ super(modelName, "rename", List.of(input));
+ this.from = from;
+ this.to = to;
+ }
+
+ @Override
+ boolean allInputFunctionsPresent(int expected) {
+ return super.allInputFunctionsPresent(expected);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().orElse(null);
+ if (inputType == null) return null;
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(inputType.type().valueType());
+ for (TensorType.Dimension dimension : inputType.dimensions())
+ builder.add(dimension.withName(dimension.name().equals(from) ? to : dimension.name()));
+ return builder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(1)) return null;
+ return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ renamer.addDimension(to);
+ }
+
+ @Override
+ public Rename withInputs(List<IntermediateOperation> inputs) {
+ if (inputs.size() != 1)
+ throw new IllegalArgumentException("Rename require 1 input, not " + inputs.size());
+ return new Rename(modelName(), from, to, inputs.get(0));
+ }
+
+ @Override
+ public String operationName() { return "Rename"; }
+
+}
+
+
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 4a0fe236c9f..a210ed13f5d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -74,6 +74,11 @@ public class Reshape extends IntermediateOperation {
}
}
+ @Override
+ public Reshape withInputs(List<IntermediateOperation> inputs) {
+ return new Reshape(modelName(), name(), inputs);
+ }
+
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType)))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
@@ -119,4 +124,7 @@ public class Reshape extends IntermediateOperation {
return new ArithmeticNode(children, operators);
}
+ @Override
+ public String operationName() { return "Reshape"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index dc690329a8d..35a1b6e2b0e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -81,8 +81,16 @@ public class Select extends IntermediateOperation {
String bDim1 = bDimensions.get(1).name();
// These tensors should have the same dimension names
- renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this);
- renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this);
+ renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this);
}
+ @Override
+ public Select withInputs(List<IntermediateOperation> inputs) {
+ return new Select(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Select"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
index 79f3012c327..57175092b5c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
@@ -37,6 +37,11 @@ public class Shape extends IntermediateOperation {
return true;
}
+ @Override
+ public Shape withInputs(List<IntermediateOperation> inputs) {
+ return new Shape(modelName(), name(), inputs);
+ }
+
private void createConstantValue() {
if (!allInputTypesPresent(1)) {
return;
@@ -50,4 +55,7 @@ public class Shape extends IntermediateOperation {
this.setConstantValue(new TensorValue(builder.build()));
}
+ @Override
+ public String operationName() { return "Shape"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index cdacbe1656a..032ffb88a46 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -37,4 +37,12 @@ public class Softmax extends IntermediateOperation {
return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension);
}
+ @Override
+ public Softmax withInputs(List<IntermediateOperation> inputs) {
+ return new Softmax(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "SoftMax"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index 52d40144f61..56d9b542093 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -70,6 +70,11 @@ public class Squeeze extends IntermediateOperation {
squeezeDimensions = renamedDimensions;
}
+ @Override
+ public Squeeze withInputs(List<IntermediateOperation> inputs) {
+ return new Squeeze(modelName(), name(), inputs, attributeMap);
+ }
+
private OrderedTensorType reducedType(OrderedTensorType inputType) {
OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
@@ -80,4 +85,7 @@ public class Squeeze extends IntermediateOperation {
return builder.build();
}
+ @Override
+ public String operationName() { return "Squeeze"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
index 46b95233d11..c8cd235f50e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
@@ -107,4 +107,12 @@ public class Sum extends IntermediateOperation {
return builder.build();
}
+ @Override
+ public Sum withInputs(List<IntermediateOperation> inputs) {
+ return new Sum(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "Sum"; }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
index 39702690bfa..4beafc68909 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
@@ -42,6 +42,14 @@ public class Switch extends IntermediateOperation {
return predicate == port ? inputs().get(0).function().get() : null;
}
+ @Override
+ public Switch withInputs(List<IntermediateOperation> inputs) {
+ return new Switch(modelName(), name(), inputs, port);
+ }
+
+ @Override
+ public String operationName() { return "Switch"; }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
index cf8dd6e8e71..793258868ee 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
@@ -9,7 +9,7 @@ public class DimensionRenamerTest {
@Test
public void testMnistRenaming() {
- DimensionRenamer renamer = new DimensionRenamer();
+ DimensionRenamer renamer = new DimensionRenamer(new IntermediateGraph("test"));
renamer.addDimension("first_dimension_of_x");
renamer.addDimension("second_dimension_of_x");
@@ -18,17 +18,17 @@ public class DimensionRenamerTest {
renamer.addDimension("first_dimension_of_b");
// which dimension to join on matmul
- renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
+ renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(false), null);
// other dimensions in matmul can't be equal
- renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(false), null);
// for efficiency, put dimension to join on innermost
- renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
- renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(true), null);
+ renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(true), null);
// bias
- renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
+ renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(false), null);
renamer.solve();
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java
new file mode 100644
index 00000000000..6500a380190
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java
@@ -0,0 +1,28 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.tensorflow;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author bratseth
+ */
+public class Issue9662TestCase {
+
+ @Test
+ public void testImporting() {
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/9662");
+ ImportedModel.Signature signature = model.get().signature("serving_default");
+ Assert.assertEquals("Should have no skipped outputs",
+ 0, model.get().signature("serving_default").skippedOutputs().size());
+
+ ImportedMlFunction output = signature.outputFunction("output", "output");
+ assertNotNull(output);
+ model.assertEqualResultSum("input_embedding_user_guid", "dense_out/MatMul", 0.00001);
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index 9d2f8cf0692..75fa2ed7933 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -49,7 +49,7 @@ public class TestableTensorFlowModel {
public ImportedModel get() { return model; }
- /** Compare that summing the tensors produce the same result to within some tolerance delta */
+ /** Compare that computing the expressions produce the same result to within some tolerance delta */
public void assertEqualResultSum(String inputName, String operationName, double delta) {
Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName);
Context context = contextFrom(model);
diff --git a/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt
new file mode 100644
index 00000000000..83c601edfc0
--- /dev/null
+++ b/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt
@@ -0,0 +1,1318 @@
+saved_model_schema_version: 1
+meta_graphs {
+ meta_info_def {
+ stripped_op_list {
+ op {
+ name: "Add"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ type: DT_STRING
+ }
+ }
+ }
+ }
+ op {
+ name: "BiasAdd"
+ input_arg {
+ name: "value"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "bias"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "data_format"
+ type: "string"
+ default_value {
+ s: "NHWC"
+ }
+ allowed_values {
+ list {
+ s: "NHWC"
+ s: "NCHW"
+ }
+ }
+ }
+ }
+ op {
+ name: "Const"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "value"
+ type: "tensor"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ }
+ op {
+ name: "ExpandDims"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "dim"
+ type_attr: "Tdim"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ }
+ attr {
+ name: "Tdim"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ op {
+ name: "MatMul"
+ input_arg {
+ name: "a"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "b"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "product"
+ type_attr: "T"
+ }
+ attr {
+ name: "transpose_a"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "transpose_b"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Maximum"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "Mul"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ is_commutative: true
+ }
+ op {
+ name: "Placeholder"
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ default_value {
+ shape {
+ unknown_rank: true
+ }
+ }
+ }
+ }
+ op {
+ name: "PlaceholderWithDefault"
+ input_arg {
+ name: "input"
+ type_attr: "dtype"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "dtype"
+ }
+ attr {
+ name: "dtype"
+ type: "type"
+ }
+ attr {
+ name: "shape"
+ type: "shape"
+ }
+ }
+ op {
+ name: "Rsqrt"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Sigmoid"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Square"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Sub"
+ input_arg {
+ name: "x"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "y"
+ type_attr: "T"
+ }
+ output_arg {
+ name: "z"
+ type_attr: "T"
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_BFLOAT16
+ type: DT_HALF
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_UINT8
+ type: DT_INT8
+ type: DT_UINT16
+ type: DT_INT16
+ type: DT_INT32
+ type: DT_INT64
+ type: DT_COMPLEX64
+ type: DT_COMPLEX128
+ }
+ }
+ }
+ }
+ op {
+ name: "Sum"
+ input_arg {
+ name: "input"
+ type_attr: "T"
+ }
+ input_arg {
+ name: "reduction_indices"
+ type_attr: "Tidx"
+ }
+ output_arg {
+ name: "output"
+ type_attr: "T"
+ }
+ attr {
+ name: "keep_dims"
+ type: "bool"
+ default_value {
+ b: false
+ }
+ }
+ attr {
+ name: "T"
+ type: "type"
+ allowed_values {
+ list {
+ type: DT_FLOAT
+ type: DT_DOUBLE
+ type: DT_INT32
+ type: DT_UINT8
+ type: DT_INT16
+ type: DT_INT8
+ type: DT_COMPLEX64
+ type: DT_INT64
+ type: DT_QINT8
+ type: DT_QUINT8
+ type: DT_QINT32
+ type: DT_BFLOAT16
+ type: DT_UINT16
+ type: DT_COMPLEX128
+ type: DT_HALF
+ type: DT_UINT32
+ type: DT_UINT64
+ }
+ }
+ }
+ attr {
+ name: "Tidx"
+ type: "type"
+ default_value {
+ type: DT_INT32
+ }
+ allowed_values {
+ list {
+ type: DT_INT32
+ type: DT_INT64
+ }
+ }
+ }
+ }
+ }
+ tags: "serve"
+ tensorflow_version: "1.13.1"
+ tensorflow_git_version: "b\'v1.13.1-0-g6612da8951\'"
+ }
+ graph_def {
+ node {
+ name: "keras_learning_phase/input"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_BOOL
+ tensor_shape {
+ }
+ bool_val: false
+ }
+ }
+ }
+ }
+ node {
+ name: "keras_learning_phase"
+ op: "PlaceholderWithDefault"
+ input: "keras_learning_phase/input"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_BOOL
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/l2_normalize/Maximum/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 9.999999960041972e-13
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/l2_normalize/Sum/reduction_indices"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: 1
+ }
+ }
+ }
+ }
+ node {
+ name: "dense_out/kernel"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ dim {
+ size: 1
+ }
+ }
+ float_val: 0.1835838258266449
+ }
+ }
+ }
+ }
+ node {
+ name: "input_embedding_user_guid"
+ op: "Placeholder"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 32
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 32
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/l2_normalize_1/Square"
+ op: "Square"
+ input: "input_embedding_user_guid"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 32
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/l2_normalize/Sum"
+ op: "Sum"
+ input: "Dot/l2_normalize_1/Square"
+ input: "Dot/l2_normalize/Sum/reduction_indices"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: true
+ }
+ }
+ }
+ node {
+ name: "Dot/l2_normalize/Maximum"
+ op: "Maximum"
+ input: "Dot/l2_normalize/Sum"
+ input: "Dot/l2_normalize/Maximum/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/l2_normalize_1/Rsqrt"
+ op: "Rsqrt"
+ input: "Dot/l2_normalize/Maximum"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/l2_normalize_1"
+ op: "Mul"
+ input: "input_embedding_user_guid"
+ input: "Dot/l2_normalize_1/Rsqrt"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 32
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/Mul"
+ op: "Mul"
+ input: "Dot/l2_normalize_1"
+ input: "Dot/l2_normalize_1"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 32
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/Sum"
+ op: "Sum"
+ input: "Dot/Mul"
+ input: "Dot/l2_normalize/Sum/reduction_indices"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/moving_variance"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 1.0
+ }
+ }
+ }
+ }
+ node {
+ name: "Dot/ExpandDims"
+ op: "ExpandDims"
+ input: "Dot/Sum"
+ input: "Dot/l2_normalize/Sum/reduction_indices"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tdim"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/moving_mean"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 1
+ }
+ }
+ float_val: 0.0
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/batchnorm/add/y"
+ op: "Const"
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ }
+ }
+ }
+ }
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ }
+ float_val: 0.0010000000474974513
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/batchnorm/add"
+ op: "Add"
+ input: "batch_normalization_v1/moving_variance"
+ input: "batch_normalization_v1/batchnorm/add/y"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/batchnorm/Rsqrt"
+ op: "Rsqrt"
+ input: "batch_normalization_v1/batchnorm/add"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/batchnorm/mul"
+ op: "Mul"
+ input: "batch_normalization_v1/batchnorm/Rsqrt"
+ input: "batch_normalization_v1/moving_variance"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/batchnorm/mul_1"
+ op: "Mul"
+ input: "Dot/ExpandDims"
+ input: "batch_normalization_v1/batchnorm/mul"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/batchnorm/mul_2"
+ op: "Mul"
+ input: "batch_normalization_v1/moving_mean"
+ input: "batch_normalization_v1/batchnorm/mul"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/batchnorm/sub"
+ op: "Sub"
+ input: "batch_normalization_v1/moving_mean"
+ input: "batch_normalization_v1/batchnorm/mul_2"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "batch_normalization_v1/batchnorm/add_1"
+ op: "Add"
+ input: "batch_normalization_v1/batchnorm/mul_1"
+ input: "batch_normalization_v1/batchnorm/sub"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ node {
+ name: "dense_out/MatMul"
+ op: "MatMul"
+ input: "batch_normalization_v1/batchnorm/add_1"
+ input: "dense_out/kernel"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "transpose_a"
+ value {
+ b: false
+ }
+ }
+ attr {
+ key: "transpose_b"
+ value {
+ b: false
+ }
+ }
+ }
+ node {
+ name: "dense_out/BiasAdd"
+ op: "BiasAdd"
+ input: "dense_out/MatMul"
+ input: "batch_normalization_v1/moving_mean"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ attr {
+ key: "data_format"
+ value {
+ s: "NHWC"
+ }
+ }
+ }
+ node {
+ name: "dense_out/Sigmoid"
+ op: "Sigmoid"
+ input: "dense_out/BiasAdd"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "_output_shapes"
+ value {
+ list {
+ shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ }
+ }
+ versions {
+ }
+ }
+ signature_def {
+ key: "serving_default"
+ value {
+ inputs {
+ key: "input_embedding_user_guid"
+ value {
+ name: "input_embedding_user_guid:0"
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 32
+ }
+ }
+ }
+ }
+ outputs {
+ key: "output"
+ value {
+ name: "dense_out/Sigmoid:0"
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: -1
+ }
+ dim {
+ size: 1
+ }
+ }
+ }
+ }
+ method_name: "tensorflow/serving/predict"
+ }
+ }
+}
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index a16127931e9..6f37b9edea4 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2576,6 +2576,20 @@
],
"fields": []
},
+ "com.yahoo.text.ExpressionFormatter": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public java.lang.String format(java.lang.String)",
+ "public static java.lang.String on(java.lang.String)",
+ "public static com.yahoo.text.ExpressionFormatter withLineLength(int)",
+ "public static com.yahoo.text.ExpressionFormatter inTwoColumnMode(int, int)"
+ ],
+ "fields": []
+ },
"com.yahoo.text.ForwardWriter": {
"superClass": "com.yahoo.text.GenericWriter",
"interfaces": [],
diff --git a/vespajlib/src/main/java/com/yahoo/collections/ListMap.java b/vespajlib/src/main/java/com/yahoo/collections/ListMap.java
index e851362a99d..479850beb1a 100644
--- a/vespajlib/src/main/java/com/yahoo/collections/ListMap.java
+++ b/vespajlib/src/main/java/com/yahoo/collections/ListMap.java
@@ -23,6 +23,12 @@ public class ListMap<K, V> {
this(HashMap.class);
}
+ /** Copy constructor. This will not be frozen even if the argument map is */
+ public ListMap(ListMap<K, V> original) {
+ map = new HashMap<>();
+ original.map.forEach((k, v) -> this.map.put(k, new ArrayList<>(v)));
+ }
+
@SuppressWarnings("unchecked")
public ListMap(@SuppressWarnings("rawtypes") Class<? extends Map> implementation) {
try {
@@ -45,6 +51,27 @@ public class ListMap<K, V> {
list.add(value);
}
+ /** Put a key without adding a new value, such that there is an empty list of values if no values are already added */
+ public void put(K key) {
+ List<V> list = map.get(key);
+ if (list == null) {
+ list = new ArrayList<>();
+ map.put(key, list);
+ }
+ }
+
+ /** Put this map in the state where it has just the given value of the given key */
+ public void replace(K key, V value) {
+ List<V> list = map.get(key);
+ if (list == null) {
+ put(key);
+ }
+ else {
+ list.clear();
+ list.add(value);
+ }
+ }
+
public void removeAll(K key) {
map.remove(key);
}
@@ -73,13 +100,13 @@ public class ListMap<K, V> {
/**
* Returns the List containing the elements with this key, or an empty list
- * if there are no elements for this key. The list returned is unmodifiable.
+ * if there are no elements for this key.
+ * The returned list can be modified to add and remove values if the value exists.
*/
public List<V> get(K key) {
List<V> list = map.get(key);
- if (list == null)
- return ImmutableList.of();;
- return ImmutableList.copyOf(list);
+ if (list == null) return ImmutableList.of();;
+ return list;
}
/** The same as get */
diff --git a/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java
new file mode 100644
index 00000000000..280b75f9cbb
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java
@@ -0,0 +1,180 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.text;
+
+/**
+ * Formats any parenthesis expression.
+ * In addition to the obvious this can also operate in "two column mode",
+ * wherein each chunk that will be formatted on a separate line may optionally
+ * contain a prefix marked by a start and end tab sign which will be printed in a left column of the given fixed size.
+ * The prefix itself is not formatted but will be cut if too long.
+ *
+ * @author bratseth
+ */
+public class ExpressionFormatter {
+
+ private static final int indentUnit = 2;
+
+ /** The size of the first column, or 0 if none */
+ private final int firstColumnLength;
+
+ /**
+ * The desired size of the second column (or the entire line if no first column),
+ * or 0 to split into multiple lines as much as possible.
+ * Setting this collects larger chunks to one line across markup
+ * but will not split too long lines that have no markup.
+ */
+ private final int secondColumnLength;
+
+ private ExpressionFormatter(int firstColumnLength, int secondColumnLength) {
+ this.firstColumnLength = firstColumnLength;
+ this.secondColumnLength = secondColumnLength;
+ }
+
+ public String format(String parenthesisExpression) {
+ StringBuilder b = new StringBuilder();
+ format(parenthesisExpression, 0, b);
+ while (b.length() > 0 && Character.isWhitespace(b.charAt(b.length() - 1)))
+ b.setLength(b.length() - 1);
+ return b.toString();
+ }
+
+ private void format(String expression, int indent, StringBuilder b) {
+ if (expression.isEmpty()) return;
+ expression = appendFirstColumn(expression, b);
+
+ Markup next = Markup.next(expression);
+
+ appendIndent( ! next.isClose() || next.position() > 0 ? indent : indent - 2, b);
+
+ int endOfBalancedChunk = endOfBalancedChunk(expression, Math.max(0, secondColumnLength - indent));
+ if (next.isEmpty()) {
+ b.append(expression);
+ }
+ else if (endOfBalancedChunk > 0) {
+ b.append(expression, 0, endOfBalancedChunk + 1).append("\n");
+ format(expression.substring(endOfBalancedChunk + 1), indent, b);
+ }
+ else if (next.isComma()) {
+ b.append(expression, 0, next.position() + 1).append("\n");
+ format(expression.substring(next.position() + 1), indent, b);
+ }
+ else {
+ if ( next.isClose() && next.position() > 0) { // content before end parenthesis: content, newline, then end parenthesis
+ b.append(expression, 0, next.position()).append("\n");
+ appendFirstColumn(")", b);
+ appendIndent(indent - 2, b);
+ b.append(")\n");
+ }
+ else {
+ b.append(expression, 0, next.position() + 1).append("\n");
+ }
+ format(expression.substring(next.position() + 1), indent + (next.isOpen() ? indentUnit : -indentUnit), b);
+ }
+ }
+
+ /** Returns the position of the end of a balanced chunk of at most the given size, or 0 if there is no such chunk */
+ private int endOfBalancedChunk(String expression, int maxSize) {
+ int chunkSize = 0;
+ int i = 0;
+ int nesting = 0;
+ while (i < maxSize && i < expression.length()) {
+ if (expression.charAt(i) == '\t') return chunkSize;
+ if (expression.charAt(i) == '(') nesting++;
+ if (expression.charAt(i) == ')') nesting--;
+ if (nesting < 0) return chunkSize;
+ if (nesting == 0 && ( expression.charAt(i)==')' || expression.charAt(i)==','))
+ chunkSize = i;
+ i++;
+ }
+ return chunkSize;
+ }
+
+ private String appendFirstColumn(String expression, StringBuilder b) {
+ if (firstColumnLength == 0) return expression;
+
+ while (expression.charAt(0) == ' ')
+ expression = expression.substring(1);
+
+ if (expression.charAt(0) == '\t') {
+ int tab2 = expression.indexOf('\t', 1);
+ if (tab2 >= 0) {
+ String firstColumn = expression.substring(1, tab2);
+ b.append(asSize(firstColumnLength, firstColumn)).append(" ");
+ return expression.substring(tab2 + 1);
+ }
+ }
+ appendIndent(firstColumnLength + 1, b);
+ return expression;
+ }
+
+ private void appendIndent(int indent, StringBuilder b) {
+ b.append(" ".repeat(Math.max(0, indent)));
+ }
+
+ private String asSize(int size, String s) {
+ if (s.length() > size)
+ return s.substring(0, size);
+ else
+ return s + " ".repeat(size - s.length());
+ }
+
+ /** Convenience method creating a formatter and using it to format the given expression */
+ public static String on(String parenthesisExpression) {
+ return new ExpressionFormatter(0, 80).format(parenthesisExpression);
+ }
+
+ public static ExpressionFormatter withLineLength(int maxLineLength) {
+ return new ExpressionFormatter(0, maxLineLength);
+ }
+
+ public static ExpressionFormatter inTwoColumnMode(int firstColumnSize, int secondColumnSize) {
+ return new ExpressionFormatter(firstColumnSize, secondColumnSize);
+ }
+
+ /** Contains the next position of each kind of markup, or Integer.MAX_VALUE if not present */
+ private static class Markup {
+
+ final int open, close, comma;
+
+ private Markup(int open, int close, int comma) {
+ this.open = open;
+ this.close = close;
+ this.comma = comma;
+ }
+
+ int position() {
+ return Math.min(Math.min(open, close), comma);
+ }
+
+ boolean isOpen() {
+ return open < close && open < comma;
+ }
+
+ boolean isClose() {
+ return close < open && close < comma;
+ }
+
+ boolean isComma() {
+ return comma < open && comma < close;
+ }
+
+ boolean isEmpty() {
+ return open == Integer.MAX_VALUE && close == Integer.MAX_VALUE && comma == Integer.MAX_VALUE;
+ }
+
+ static Markup next(String expression) {
+ int nextOpen = expression.indexOf('(');
+ int nextClose = expression.indexOf(')');
+ int nextComma = expression.indexOf(',');
+ if (nextOpen < 0)
+ nextOpen = Integer.MAX_VALUE;
+ if (nextClose < 0)
+ nextClose = Integer.MAX_VALUE;
+ if (nextComma < 0)
+ nextComma = Integer.MAX_VALUE;
+ return new Markup(nextOpen, nextClose, nextComma);
+ }
+
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java
new file mode 100644
index 00000000000..7251ccef521
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java
@@ -0,0 +1,190 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.text;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class ExpressionFormatterTest {
+
+ @Test
+ public void testBasic() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " )\n" +
+ " )\n" +
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz()))", 0);
+ }
+
+ @Test
+ public void testBasicDense() {
+ assertPrettyPrint("foo(bar(baz()))", "foo(bar(baz()))", 50);
+ }
+
+ @Test
+ public void testArgument() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " hello world\n" +
+ " )\n" +
+ " )\n" +
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz(hello world)))", 0);
+ }
+
+ @Test
+ public void testMultipleArguments() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " hello world,\n" +
+ " 37\n" +
+ " )\n" +
+ " )\n" +
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz(hello world,37)))", 0);
+ }
+
+ @Test
+ public void testMultipleArgumentsSemiDense() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(hi,37),\n" +
+ " baz(\n" +
+ " hello world,\n" +
+ " 37\n" +
+ " )\n" +
+ " )\n" +
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz(hi,37),baz(hello world,37)))", 15);
+ }
+
+ @Test
+ public void testUnmatchedStart() {
+ String expected =
+ "foo(\n" +
+ " (\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " )\n" +
+ " )\n" +
+ " )";
+ assertPrettyPrint(expected, "foo((bar(baz()))", 0);
+ }
+
+ @Test
+ public void testUnmatchedEnd() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " )\n" +
+ " )\n" +
+ ")\n" +
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz())))", 0);
+ }
+
+ @Test
+ public void testNoParenthesis() {
+ String expected =
+ "foo bar baz";
+ assertPrettyPrint(expected, "foo bar baz", 0);
+ }
+
+ @Test
+ public void testEmpty() {
+ String expected =
+ "";
+ assertPrettyPrint(expected, "", 0);
+ }
+
+ @Test
+ public void test2ColumnMode() {
+ String expected =
+ "1: foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ "2: hello world\n" +
+ " )\n" +
+ "t(o )\n" +
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0);
+ assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world)\tt(o)@olong:\t))"));
+ }
+
+ @Test
+ public void test2ColumnModeMultipleArguments() {
+ String expected =
+ "1: foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ "2: hello world,\n" +
+ "3: 37\n" +
+ " )\n" +
+ "t(o )\n" +
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0);
+ assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world,\t3:\t37)\tt(o)@olong:\t))"));
+ }
+
+ @Test
+ public void test2ColumnModeMultipleArgumentsSemiDense() {
+ String expected =
+ "1: foo(\n" +
+ " bar(\n" +
+ " baz(hi,37),\n" +
+ " boz(\n" +
+ "2: hello world,\n" +
+ "3: 5\n" +
+ " )\n" +
+ "t(o )\n" +
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 15);
+ assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(hi,37),boz(\t2:\thello world,\t3:\t5)\tt(o)@olong:\t))"));
+ }
+
+ @Test
+ public void test2ColumnModeMultipleArgumentsWithSpaces() {
+ String expected =
+ " foo(\n" +
+ "1: bar(\n" +
+ " baz(\n" +
+ "2: hello world,\n" +
+ "3: 37\n" +
+ " )\n" +
+ "t(o )\n" +
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0);
+ assertEquals(expected, pp.format("foo(\t1:\tbar(baz(\t2:\thello world, \t3:\t37)\tt(o)@olong:\t))"));
+ }
+
+ @Test
+ public void testTwoColumnLambdaFunction() {
+ String expected =
+ " join(\n" +
+ " a,\n" +
+ " join(\n" +
+ " b, c, f(a, b)(a * b)\n" +
+ " )\n" +
+ " , f(a, b)(a * b)\n" +
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(5, 25);
+ assertEquals(expected, pp.format("join(a, join(b, c, f(a, b)(a * b)), f(a, b)(a * b))"));
+ }
+
+ private void assertPrettyPrint(String expected, String expression, int lineLength) {
+ assertEquals(expected, ExpressionFormatter.withLineLength(lineLength).format(expression));
+ }
+
+}