diff options
author | Jon Bratseth <bratseth@oath.com> | 2019-07-08 14:47:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-07-08 14:47:54 +0200 |
commit | d0b6a8a2fe100ade8d3aac5689bead29118480ad (patch) | |
tree | 466faa85c0f1a6d8311d62983e23c00126240b85 | |
parent | 4e8a65ed3701c814459b5ce58291d9764446d873 (diff) | |
parent | 76e924dcde2613c7956a50c29dbcc082e2b3b59c (diff) |
Merge pull request #9944 from vespa-engine/bratseth/output-immediate-graph
Bratseth/output immediate graph
37 files changed, 2579 insertions, 178 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java index 61ce9d98e69..b2f5d104890 100644 --- a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java @@ -73,14 +73,13 @@ public class BlendingSearcher extends Searcher { } /** - * Produce a single blended result list from a group of hitgroups. + * Produce a single blended hit list from a group of hitgroups. * - * It is assumed that the results are ordered in hitgroups. If not, the blend will not be performed + * This assumes that all hits are organized into hitgroups. If not, blending will not be performed. */ protected Result blendResults(Result result, Query q, int offset, int hits, Execution execution) { //Assert that there are more than one hitgroup and that there are only hitgroups on the lowest level - boolean foundNonGroup = false; Iterator<Hit> hitIterator = result.hits().iterator(); List<HitGroup> groups = new ArrayList<>(); @@ -89,14 +88,14 @@ public class BlendingSearcher extends Searcher { if (hit instanceof HitGroup) { groups.add((HitGroup)hit); hitIterator.remove(); - } else if(!hit.isMeta()) { + } else if ( ! hit.isMeta()) { foundNonGroup = true; } } - if(foundNonGroup) { + if( foundNonGroup) { result.hits().addError(ErrorMessage.createUnspecifiedError("Blendingsearcher could not blend - there are toplevel hits" + - " that are not hitgroups")); + " that are not hitgroups")); return result; } if (groups.size() == 0) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 9e9f66be700..0f563a75b11 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -2,179 +2,179 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.Rename; +import com.yahoo.collections.ListMap; -import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Collections; -import java.util.Deque; +import java.util.Comparator; import java.util.HashMap; -import java.util.Iterator; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; /** - * A constraint satisfier to find suitable dimension names to reduce the + * A constraint solver which finds suitable dimension names to reduce the * amount of necessary renaming during evaluation of an imported model. * * @author lesters + * @author bratseth */ public class DimensionRenamer { + private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName()); + private final String dimensionPrefix; - private final Map<String, List<Integer>> variables = new HashMap<>(); - private final Map<Arc, Constraint> constraints = new HashMap<>(); - private final Map<String, Integer> renames = new HashMap<>(); - private int iterations = 0; + /** The graph we are renaming the dimensions of */ + private final IntermediateGraph graph; + + /** The set of dimensions to find a solution for */ + private final Set<String> dimensions = new HashSet<>(); + + /** The constraints on the dimension name assignment */ + private final ListMap<Arc, Constraint> constraints = new ListMap<>(); + + /** The solution to this, or null if no solution is found yet */ + private Map<String, Integer> renames = null; - public DimensionRenamer() { - this("d"); + public DimensionRenamer(IntermediateGraph graph) { + this(graph, "d"); } - public DimensionRenamer(String dimensionPrefix) { + public DimensionRenamer(IntermediateGraph graph, String dimensionPrefix) { + this.graph = graph; this.dimensionPrefix = dimensionPrefix; } - /** - * Add a dimension name variable. - */ - public void addDimension(String name) { - variables.computeIfAbsent(name, d -> new ArrayList<>()); - } + /** Add a dimension to the set of dimensions to be renamed */ + public void addDimension(String name) { dimensions.add(name); } + + /** Add a constraint between dimension names */ + public void addConstraint(String from, String to, Constraint constraint, IntermediateOperation operation) { + if (constraint instanceof EqualConstraint && from.equals(to)) return; - /** - * Add a constraint between dimension names. - */ - public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { Arc arc = new Arc(from, to, operation); - Arc opposite = arc.opposite(); - constraints.put(arc, pred); - constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + constraints.put(arc, constraint); + constraints.put(arc.opposite(), constraint.opposite()); // make constraint graph symmetric } - /** - * Retrieve resulting name of dimension after solving for constraints. - */ - public Optional<String> dimensionNameOf(String name) { - if (!renames.containsKey(name)) { - return Optional.empty(); - } - return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); + void solve() { + log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints)); + renames = solve(100000); + log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames)); } - /** - * Perform iterative arc consistency until we have found a solution. After - * an initial iteration, the variables (dimensions) will have multiple - * valid values. Find a single valid assignment by iteratively locking one - * dimension after another, and running the arc consistency algorithm - * multiple times. - * - * This requires having constraints that result in an absolute ordering: - * equals, lesserThan and greaterThan do that, but adding notEquals does - * not typically result in a guaranteed ordering. If that is needed, the - * algorithm below needs to be adapted with a backtracking (tree) search - * to find solutions. - */ - private void solve(int maxIterations) { - initialize(); - - // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + private Map<String, Integer> solve(int maxIterations) { + Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations); + if (solution != null) return solution; - for (String dimension : variables.keySet()) { - List<Integer> values = variables.get(dimension); - if (values.size() > 1) { - if (!ac3()) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution."); - } - values.sort(Integer::compare); - variables.put(dimension, Collections.singletonList(values.get(0))); - } - renames.put(dimension, variables.get(dimension).get(0)); - if (iterations > maxIterations) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + - maxIterations + " iterations"); - } + for (RenameTarget target : prioritizedRenameTargets()) { + System.out.println("Trying rename " + target); + target.insertRename(this); + solution = solveWithOrWithoutSoftConstraints(maxIterations); + if (solution != null) return solution; + target.uninsertRename(this); } - - // Todo: handle failure more gracefully: - // If a solution can't be found, look at the operation node in the arc - // with the most remaining constraints, and inject a rename operation. - // Then run this algorithm again. + throw new IllegalArgumentException("Could not find a dimension naming solution " + + "given constraints\n" + constraintsToString(constraints)); } - void solve() { - solve(100000); + private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) { + Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations); + if ( solution == null) { + ListMap<Arc, Constraint> hardConstraints = new ListMap<>(); + boolean anyRemoved = copyHard(constraints, hardConstraints); + if (anyRemoved) + solution = NamingConstraintSolver.solve(dimensions, hardConstraints, maxIterations); + } + return solution; } - private void initialize() { - for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { - List<Integer> values = variable.getValue(); - for (int i = 0; i < variables.size(); ++i) { - values.add(i); // invariant: values are in increasing order + /** Removes soft constraints and returns whether something was removed */ + private boolean copyHard(ListMap<Arc, Constraint> source, ListMap<Arc, Constraint> target) { + boolean removed = false; + for (var entry : source.entrySet()) { + Arc arc = entry.getKey(); + for (Constraint constraint : entry.getValue()) { + if ( ! constraint.isSoft()) + target.put(arc, constraint); + else + removed = true; } } + return removed; } - private boolean ac3() { - Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); - while (!workList.isEmpty()) { - Arc arc = workList.pop(); - iterations += 1; - if (revise(arc)) { - if (variables.get(arc.from).size() == 0) { - return false; // no solution found - } - for (Arc constraint : constraints.keySet()) { - if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { - workList.add(constraint); - } - } - } + private List<RenameTarget> prioritizedRenameTargets() { + Map<IntermediateOperation, Integer> constraintsPerOperation = new HashMap<>(); + + for (var constraint : constraints.entrySet()) { + constraintsPerOperation.compute(constraint.getKey().operation, + (operation, count) -> count == null ? 1 : ++count); } - return true; - } + List<IntermediateOperation> prioritizedOperations = + constraintsPerOperation.entrySet().stream() + .sorted(Comparator.comparingInt(entry -> - entry.getValue())) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()); - private boolean revise(Arc arc) { - boolean revised = false; - for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { - Integer from = fromIterator.next(); - boolean satisfied = false; - for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { - Integer to = toIterator.next(); - if (constraints.get(arc).test(from, to)) { - satisfied = true; + List<RenameTarget> targets = new ArrayList<>(); + for (IntermediateOperation operation : prioritizedOperations) { + for (int i = 0; i < operation.inputs().size(); i++) { + Optional<OrderedTensorType> inputType = operation.inputs().get(i).type(); + if (inputType.isEmpty()) continue; + for (String dimensionName : inputType.get().dimensionNames()) { + RenameTarget target = new RenameTarget(operation, i, dimensionName, graph); + if (target.rootKey != null) // TODO: Inserting renames under non-roots is not implemented + targets.add(target); } } - if (!satisfied) { - fromIterator.remove(); - revised = true; - } } - return revised; - } - - public interface Constraint { - boolean test(Integer x, Integer y); + return targets; } - public static boolean equals(Integer x, Integer y) { - return Objects.equals(x, y); + /** + * Retrieve resulting name of a dimension after solving for constraints, or empty if no + * solution is found yet, or this dimension was not added before finding a solution. + */ + public Optional<String> dimensionNameOf(String name) { + if ( renames == null || ! renames.containsKey(name)) + return Optional.empty(); + return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); } - public static boolean lesserThan(Integer x, Integer y) { - return x < y; + private static String renamesToString(Map<String, Integer> renames) { + return renames.entrySet().stream() + .map(e -> " " + e.getKey() + " -> " + e.getValue()) + .collect(Collectors.joining("\n")); } - public static boolean greaterThan(Integer x, Integer y) { - return x > y; + private static String constraintsToString(ListMap<Arc, Constraint> constraints) { + StringBuilder b = new StringBuilder(); + for (var entry : constraints.entrySet()) { + Arc arc = entry.getKey(); + for (Constraint constraint : entry.getValue()) { + if (constraint.isOpposite()) continue; // noise + b.append(" "); + if (constraint.isSoft()) + b.append("(soft) "); + b.append(arc.from).append(" ").append(constraint).append(" ").append(arc.to); + b.append(" (origin: ").append(arc.operation).append(")\n"); + } + } + return b.toString(); } - private static class Arc { + static class Arc { - private final String from; - private final String to; + final String from; + final String to; private final IntermediateOperation operation; Arc(String from, String to, IntermediateOperation operation) { @@ -194,7 +194,7 @@ public class DimensionRenamer { @Override public boolean equals(Object obj) { - if (obj == null || !(obj instanceof Arc)) { + if (!(obj instanceof Arc)) { return false; } Arc other = (Arc) obj; @@ -203,8 +203,185 @@ public class DimensionRenamer { @Override public String toString() { - return String.format("%s -> %s", from, to); + return from + " -> " + to; + } + } + + public static abstract class Constraint { + + private final boolean soft, opposite; + + protected Constraint(boolean soft, boolean opposite) { + this.soft = soft; + this.opposite = opposite; + } + + abstract boolean test(Integer x, Integer y); + abstract Constraint opposite(); + + /** Returns whether this constraint can be violated if that is necessary to achieve a solution */ + boolean isSoft() { return soft; } + + /** Returns whether this is an opposite of another constraint */ + boolean isOpposite() { return opposite; } + + public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); } + public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); } + public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); } + public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); } + + } + + private static class EqualConstraint extends Constraint { + + private EqualConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return Objects.equals(x, y); } + + @Override + public Constraint opposite() { return new EqualConstraint(isSoft(), true); } + + @Override + public String toString() { return "=="; } + + } + + private static class NotEqualConstraint extends Constraint { + + private NotEqualConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return ! Objects.equals(x, y); } + + @Override + public Constraint opposite() { return new NotEqualConstraint(isSoft(), true); } + + @Override + public String toString() { return "!="; } + + } + + private static class LessThanConstraint extends Constraint { + + private LessThanConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return x < y; } + + @Override + public Constraint opposite() { return new GreaterThanConstraint(isSoft(), true); } + + @Override + public String toString() { return "<"; } + + } + + private static class GreaterThanConstraint extends Constraint { + + private GreaterThanConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return x > y; } + + @Override + public Constraint opposite() { return new LessThanConstraint(isSoft(), true); } + + @Override + public String toString() { return ">"; } + + } + + /** + * An operation and an input number which we may want to insert a rename operation at. + * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...). + * + * This class is (and must remain) immutable. + */ + private static class RenameTarget { + + final IntermediateOperation operation; + final int inputNumber; + final String dimensionName; + final IntermediateGraph graph; + + /** + * Returns the key of this operation in the root operations of the graph, + * or null if it is not a root operation + */ + final String rootKey; + + public RenameTarget(IntermediateOperation operation, int inputNumber, String dimensionName, IntermediateGraph graph) { + this.operation = operation; + this.inputNumber = inputNumber; + this.dimensionName = dimensionName; + this.rootKey = findRootKey(operation, graph); + this.graph = graph; + } + + public IntermediateOperation input() { + return operation.inputs().get(inputNumber); + } + + private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) { + for (var entry : graph.operations().entrySet()) { + if (entry.getValue() == operation) + return entry.getKey(); + } + return null; + } + + /** Inserts a rename operation if possible. Returns whether an operation was inserted. */ + private boolean insertRename(DimensionRenamer renamer) { + Rename rename = new Rename(operation.modelName(), + dimensionName, + renamer.dimensionPrefix + renamer.dimensions.size(), + input()); + + List<IntermediateOperation> newInputs = new ArrayList<>(operation.inputs()); + newInputs.set(inputNumber, rename); + IntermediateOperation newOperation = operation.withInputs(newInputs); + if (rootKey == null) + throw new IllegalStateException("Renaming non-roots is not implemented"); + graph.put(rootKey, newOperation); + + removeConstraintsOf(operation, renamer); + rename.addDimensionNameConstraints(renamer); + newOperation.addDimensionNameConstraints(renamer); + return true; + } + + /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */ + private void uninsertRename(DimensionRenamer renamer) { + IntermediateOperation newOperation = graph.operations().get(rootKey); + Rename rename = (Rename)newOperation.inputs().get(inputNumber); + graph.put(rootKey, operation); + + removeConstraintsOf(rename, renamer); + removeConstraintsOf(newOperation, renamer); + operation.addDimensionNameConstraints(renamer); + } + + private void removeConstraintsOf(IntermediateOperation operation, DimensionRenamer renamer) { + for (Arc key : new HashSet<>(renamer.constraints.keySet())) { + if (key.operation == operation) + renamer.constraints.removeAll(key); + } + } + + @Override + public String toString() { + return operation + ", input " + inputNumber; } + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 0c570261ae7..a9be1bbd40e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -262,9 +262,12 @@ public class ImportedModel implements ImportedMlModel { /** Returns the expression this output references as an imported function */ public ImportedMlFunction outputFunction(String outputName, String functionName) { + RankingExpression outputExpression = owner().expressions().get(outputs.get(outputName)); + if (outputExpression == null) + throw new IllegalArgumentException("Missing output '" + outputName + "' in " + this); return new ImportedMlFunction(functionName, new ArrayList<>(inputs.values()), - owner().expressions().get(outputs.get(outputName)).getRoot().toString(), + outputExpression.getRoot().toString(), asStrings(inputMap()), Optional.empty()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index aec98d06874..6c583d960bd 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -3,9 +3,11 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.MatMul; import java.util.Collection; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Set; @@ -20,7 +22,7 @@ import java.util.Set; public class IntermediateGraph { private final String modelName; - private final Map<String, IntermediateOperation> index = new HashMap<>(); + private final Map<String, IntermediateOperation> operations = new HashMap<>(); private final Map<String, GraphSignature> signatures = new HashMap<>(); private static class GraphSignature { @@ -37,11 +39,11 @@ public class IntermediateGraph { } public IntermediateOperation put(String key, IntermediateOperation operation) { - return index.put(key, operation); + return operations.put(key, operation); } public IntermediateOperation get(String key) { - return index.get(key); + return operations.get(key); } public Set<String> signatures() { @@ -61,11 +63,11 @@ public class IntermediateGraph { } public boolean alreadyImported(String key) { - return index.containsKey(key); + return operations.containsKey(key); } - public Collection<IntermediateOperation> operations() { - return index.values(); + public Map<String, IntermediateOperation> operations() { + return operations; } void optimize() { @@ -76,16 +78,16 @@ public class IntermediateGraph { * Find dimension names to avoid excessive renaming while evaluating the model. */ private void renameDimensions() { - DimensionRenamer renamer = new DimensionRenamer(); + DimensionRenamer renamer = new DimensionRenamer(this); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - addDimensionNameConstraints(index.get(output), renamer); + addDimensionNameConstraints(operations.get(output), renamer); } } renamer.solve(); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - renameDimensions(index.get(output), renamer); + renameDimensions(operations.get(output), renamer); } } } @@ -104,4 +106,16 @@ public class IntermediateGraph { } } + @Override + public String toString() { + return "intermediate graph for '" + modelName + "'"; + } + + public String toFullString() { + StringBuilder b = new StringBuilder(); + for (var input : operations.entrySet()) + b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n"); + return b.toString(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 99bfa08db43..b587a9200ec 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -11,12 +11,14 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ExpressionFormatter; import com.yahoo.yolean.Exceptions; import java.io.File; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Level; import java.util.logging.Logger; /** @@ -50,6 +52,8 @@ public abstract class ModelImporter implements MlModelImporter { */ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { ImportedModel model = new ImportedModel(graph.name(), modelSource); + log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" + + ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString())); graph.optimize(); @@ -223,7 +227,7 @@ public abstract class ModelImporter implements MlModelImporter { * for fast model weight updates. */ private static void logVariableTypes(IntermediateGraph graph) { - for (IntermediateOperation operation : graph.operations()) { + for (IntermediateOperation operation : graph.operations().values()) { if ( ! (operation instanceof Constant)) continue; if ( ! operation.type().isPresent()) continue; // will not happen log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java new file mode 100644 index 00000000000..21cc6b27dad --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java @@ -0,0 +1,126 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import com.yahoo.collections.ListMap; + +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * Solves a dimension naming constraint problem. + * + * @author lesters + * @author bratseth + */ +class NamingConstraintSolver { + + private final ListMap<String, Integer> possibleAssignments; + private final ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints; + + private int iterations = 0; + private final int maxIterations; + + private NamingConstraintSolver(Set<String> dimensions, + ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints, + int maxIterations) { + this.possibleAssignments = allPossibilities(dimensions); + this.constraints = constraints; + this.maxIterations = maxIterations; + } + + /** Returns a list containing a list of all assignment possibilities for each of the given dimensions */ + private static ListMap<String, Integer> allPossibilities(Set<String> dimensions) { + ListMap<String, Integer> all = new ListMap<>(); + for (String dimension : dimensions) { + for (int i = 0; i < dimensions.size(); ++i) + all.put(dimension, i); + } + return all; + } + + /** + * Try the solve the constraint problem given in the arguments, and put the result in renames. + * + * This is done by performing iterative arc consistency until we have found a solution. + * After an initial iteration, the dimensions will have multiple + * valid values. Find a single valid assignment by iteratively locking one + * dimension after another, and running the arc consistency algorithm + * multiple times. + * + * This requires having constraints that result in an absolute ordering: + * equal, lessThan and greaterThan do that, but not necessarily notEqual + * If that is needed, the algorithm needs to be adapted with a backtracking + * (tree) search + * + * @return the solution in the form of the renames to perform + */ + private Map<String, Integer> trySolve() { + // TODO: Evaluate possible improved efficiency by using a heuristic such as min-conflicts + + Map<String, Integer> solution = new HashMap<>(); + for (String dimension : possibleAssignments.keySet()) { + List<Integer> values = possibleAssignments.get(dimension); + if (values.size() > 1) { + if ( ! ac3()) return null; + values.sort(Integer::compare); + possibleAssignments.replace(dimension, values.get(0)); + } + solution.put(dimension, possibleAssignments.get(dimension).get(0)); // Pick the first available solution + if (iterations > maxIterations) return null; + } + return solution; + } + + private boolean ac3() { + Deque<DimensionRenamer.Arc> workList = new ArrayDeque<>(constraints.keySet()); + while ( ! workList.isEmpty()) { + DimensionRenamer.Arc arc = workList.pop(); + iterations++; + if (revise(arc)) { + if (possibleAssignments.get(arc.from).isEmpty()) return false; + + for (DimensionRenamer.Arc constraint : constraints.keySet()) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) + workList.add(constraint); + } + } + } + return true; + } + + private boolean revise(DimensionRenamer.Arc arc) { + boolean revised = false; + for (Iterator<Integer> fromIterator = possibleAssignments.get(arc.from).iterator(); fromIterator.hasNext(); ) { + Integer from = fromIterator.next(); + boolean satisfied = false; + for (Iterator<Integer> toIterator = possibleAssignments.get(arc.to).iterator(); toIterator.hasNext(); ) { + Integer to = toIterator.next(); + if (constraints.get(arc).stream().allMatch(constraint -> constraint.test(from, to))) + satisfied = true; + } + if ( ! satisfied) { + fromIterator.remove(); + revised = true; + } + } + return revised; + } + + /** + * Attempts to solve the given naming problem. The input maps are never modified. + * + * @return the solution as a map from existing names to name ids represented as integers, or NULL + * if no solution could be found + */ + public static Map<String, Integer> solve(Set<String> dimensions, + ListMap<DimensionRenamer.Arc, DimensionRenamer.Constraint> constraints, + int maxIterations) { + return new NamingConstraintSolver(dimensions, constraints, maxIterations).trySolve(); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index 9115dc99b82..1cb8f3a2951 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -7,8 +7,11 @@ import com.yahoo.tensor.TensorTypeParser; import java.util.ArrayList; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.stream.Collectors; /** @@ -131,11 +134,17 @@ public class OrderedTensorType { public OrderedTensorType rename(DimensionRenamer renamer) { List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); + Map<String, String> new2Old = new HashMap<>(); // Just to create meaningful error messages for (TensorType.Dimension dimension : dimensions) { String oldName = dimension.name(); Optional<String> newName = renamer.dimensionNameOf(oldName); - if (!newName.isPresent()) - return this; // presumably, already renamed + if ( newName.isEmpty()) return this; // presumably already renamed + + if (new2Old.containsKey(newName.get())) + throw new IllegalArgumentException("Can not rename '" + oldName + "' to '" + newName.get() + "' in " + this + + " as '" + new2Old.get(newName.get()) + "' should also be renamed to it"); + new2Old.put(newName.get(), oldName); + TensorType.Dimension.Type dimensionType = dimension.type(); if (dimensionType == TensorType.Dimension.Type.indexedBound) { renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index d6ea00ca453..9f62a27a3b9 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -29,7 +29,7 @@ public class Argument extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); - if (!standardNamingType.equals(type)) { + if ( ! standardNamingType.equals(type)) { List<String> renameFrom = standardNamingType.dimensionNames(); List<String> renameTo = type.dimensionNames(); output = new Rename(output, renameFrom, renameTo); @@ -39,9 +39,7 @@ public class Argument extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } + addConstraintsFrom(type, renamer); } @Override @@ -54,4 +52,22 @@ public class Argument extends IntermediateOperation { return false; } + @Override + public Argument withInputs(List<IntermediateOperation> inputs) { + if ( ! inputs.isEmpty()) + throw new IllegalArgumentException("Argument cannot take inputs"); + return new Argument(modelName(), name(), type); + } + + @Override + public String operationName() { return "Argument"; } + + @Override + public String toString() { return "Argument(" + standardNamingType + ")"; } + + @Override + public String toFullString() { + return "\t" + lazyGetType() + ":\tArgument(" + standardNamingType + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 7ae50a0549d..7787caa83ce 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.functions.TensorFunction; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; public class ConcatV2 extends IntermediateOperation { @@ -89,7 +90,7 @@ public class ConcatV2 extends IntermediateOperation { OrderedTensorType b = inputs.get(i).type().get(); String bDim = b.dimensions().get(concatDimensionIndex).name(); String aDim = a.dimensions().get(concatDimensionIndex).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); } } @@ -99,4 +100,12 @@ public class ConcatV2 extends IntermediateOperation { concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); } + @Override + public ConcatV2 withInputs(List<IntermediateOperation> inputs) { + return new ConcatV2(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "ConcatV2"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 41d421b1f5a..d13c1ad5f3c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -62,9 +62,7 @@ public class Const extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } + addConstraintsFrom(type, renamer); } @Override @@ -86,4 +84,23 @@ public class Const extends IntermediateOperation { } return value.get(); } + + @Override + public Const withInputs(List<IntermediateOperation> inputs) { + return new Const(modelName(), name(), inputs, attributeMap, type); + } + + @Override + public String operationName() { return "Const"; } + + @Override + public String toString() { + return "Const(" + type + ")"; + } + + @Override + public String toFullString() { + return "\t" + lazyGetType() + ":\tConst(" + type + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index a1cc83296b0..1eaaf705220 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.Collections; +import java.util.List; import java.util.Optional; public class Constant extends IntermediateOperation { @@ -48,9 +49,7 @@ public class Constant extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } + addConstraintsFrom(type, renamer); } @Override @@ -58,4 +57,24 @@ public class Constant extends IntermediateOperation { return true; } + @Override + public Constant withInputs(List<IntermediateOperation> inputs) { + if ( ! inputs.isEmpty()) + throw new IllegalArgumentException("Constant cannot take inputs"); + return new Constant(modelName(), name(), type); + } + + @Override + public String operationName() { return "Constant"; } + + @Override + public String toString() { + return "Constant(" + type + ")"; + } + + @Override + public String toFullString() { + return "\t" + lazyGetType() + ":\tConstant(" + type + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index c64b9ded601..e6cc96d48ad 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -30,7 +30,7 @@ public class ExpandDims extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; IntermediateOperation axisOperation = inputs().get(1); - if (!axisOperation.getConstantValue().isPresent()) { + if ( ! axisOperation.getConstantValue().isPresent()) { throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); @@ -47,18 +47,23 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { - if (dimensionIndex == dimensionToInsert) { - String name = String.format("%s_%d", vespaName(), dimensionIndex); - expandDimensions.add(name); - typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); - } + if (dimensionIndex == dimensionToInsert) + addDimension(dimensionIndex, typeBuilder); typeBuilder.add(dimension); dimensionIndex++; } - + if (dimensionToInsert == inputType.dimensions().size()) { // Insert last dimension + addDimension(dimensionIndex, typeBuilder); + } return typeBuilder.build(); } + private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + @Override protected TensorFunction lazyGetFunction() { if ( ! allInputFunctionsPresent(2)) return null; @@ -88,7 +93,7 @@ public class ExpandDims extends IntermediateOperation { List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); for (String name : expandDimensions) { Optional<String> newName = renamer.dimensionNameOf(name); - if (!newName.isPresent()) { + if ( ! newName.isPresent()) { return; // presumably, already renamed } renamedDimensions.add(newName.get()); @@ -96,4 +101,12 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = renamedDimensions; } + @Override + public ExpandDims withInputs(List<IntermediateOperation> inputs) { + return new ExpandDims(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "ExpandDims"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index c2787aa14d4..5463f645355 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -32,4 +32,12 @@ public class Identity extends IntermediateOperation { return inputs.get(0).function().orElse(null); } + @Override + public Identity withInputs(List<IntermediateOperation> inputs) { + return new Identity(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Identity"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 0ee54f839bc..c3980b8fe93 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -58,6 +59,8 @@ public abstract class IntermediateOperation { protected abstract OrderedTensorType lazyGetType(); protected abstract TensorFunction lazyGetFunction(); + public String modelName() { return modelName; } + /** Returns the Vespa tensor type of this operation if it exists */ public Optional<OrderedTensorType> type() { if (type == null) { @@ -99,6 +102,20 @@ public abstract class IntermediateOperation { /** Add dimension name constraints for this operation */ public void addDimensionNameConstraints(DimensionRenamer renamer) { } + /** Conveinence method to adds dimensions and constraints of the given tensor type */ + protected void addConstraintsFrom(OrderedTensorType type, DimensionRenamer renamer) { + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + + // Each dimension is distinct: + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.notEqual(false), + this); + } + } + } + /** Performs dimension rename for this operation */ public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } @@ -175,6 +192,12 @@ public abstract class IntermediateOperation { .collect(Collectors.toList())); } + public abstract IntermediateOperation withInputs(List<IntermediateOperation> inputs); + + String asString(Optional<OrderedTensorType> type) { + return type.map(t -> t.toString()).orElse("(unknown)"); + } + /** * A method signature input and output has the form name:index. * This returns the name part without the index. @@ -203,4 +226,19 @@ public abstract class IntermediateOperation { Optional<List<Value>> getList(String key); } + public abstract String operationName(); + + @Override + public String toString() { + return operationName() + + inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) + + ")"; + } + + public String toFullString() { + return "\t" + lazyGetType() + ":\t" + operationName() + + inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index c2d75153586..adb54474812 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -95,7 +95,7 @@ public class Join extends IntermediateOperation { for (int i = 0; i < b.rank(); ++i) { String bDim = b.dimensions().get(i).name(); String aDim = a.dimensions().get(i + sizeDifference).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); } } @@ -111,4 +111,12 @@ public class Join extends IntermediateOperation { return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); } + @Override + public Join withInputs(List<IntermediateOperation> inputs) { + return new Join(modelName(), name(), inputs, operator); + } + + @Override + public String operationName() { return "Join"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index e0842d820f9..ea39e289c48 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -34,4 +34,12 @@ public class Map extends IntermediateOperation { return new com.yahoo.tensor.functions.Map(input.get(), operator); } + @Override + public Map withInputs(List<IntermediateOperation> inputs) { + return new Map(modelName(), name(), inputs, operator); + } + + @Override + public String operationName() { return "Map"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 9a76662529d..434261c6077 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ExpressionFormatter; import java.util.List; import java.util.Optional; @@ -51,20 +52,40 @@ public class MatMul extends IntermediateOperation { List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); + assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + String aDim0 = aDimensions.get(0).name(); String aDim1 = aDimensions.get(1).name(); String bDim0 = bDimensions.get(0).name(); String bDim1 = bDimensions.get(1).name(); // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); + renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); + renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); + renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); + } + + private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { + if (dimensions.size() >= 2) return; + + + throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + + " but got just " + dimensions + " from\n" + + ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); } + @Override + public MatMul withInputs(List<IntermediateOperation> inputs) { + return new MatMul(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "MatMul"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index d8e9950c61f..215edf88c4f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -91,6 +91,11 @@ public class Mean extends IntermediateOperation { reduceDimensions = renamedDimensions; } + @Override + public Mean withInputs(List<IntermediateOperation> inputs) { + return new Mean(modelName(), name(), inputs, attributeMap); + } + private boolean shouldKeepDimensions() { Optional<Value> keepDims = attributeMap.get("keep_dims"); return keepDims.isPresent() && keepDims.get().asBoolean(); @@ -108,4 +113,7 @@ public class Mean extends IntermediateOperation { return builder.build(); } + @Override + public String operationName() { return "Mean"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java index ce0c58971d0..671cfe852a7 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java @@ -32,4 +32,12 @@ public class Merge extends IntermediateOperation { return null; } + @Override + public Merge withInputs(List<IntermediateOperation> inputs) { + return new Merge(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Merge"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java index 4c5ce33b1b5..35d89cf6ab6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java @@ -23,4 +23,12 @@ public class NoOp extends IntermediateOperation { return null; } + @Override + public NoOp withInputs(List<IntermediateOperation> inputs) { + return new NoOp(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "NoOp"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java index e5e5c29f8f1..177ef8d5e17 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java @@ -45,4 +45,12 @@ public class PlaceholderWithDefault extends IntermediateOperation { return true; // not true if we add to function } + @Override + public PlaceholderWithDefault withInputs(List<IntermediateOperation> inputs) { + return new PlaceholderWithDefault(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "PlaceholdeWithDefault"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java new file mode 100644 index 00000000000..abc431233be --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -0,0 +1,67 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +/** + * Renames a tensor dimension to relax dimension constraints + * + * @author bratseth + */ +public class Rename extends IntermediateOperation { + + private final String from, to; + + public Rename(String modelName, String from, String to, IntermediateOperation input) { + super(modelName, "rename", List.of(input)); + this.from = from; + this.to = to; + } + + @Override + boolean allInputFunctionsPresent(int expected) { + return super.allInputFunctionsPresent(expected); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().orElse(null); + if (inputType == null) return null; + + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(inputType.type().valueType()); + for (TensorType.Dimension dimension : inputType.dimensions()) + builder.add(dimension.withName(dimension.name().equals(from) ? to : dimension.name())); + return builder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(1)) return null; + return new com.yahoo.tensor.functions.Rename(inputs.get(0).function().orElse(null), from, to); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + renamer.addDimension(to); + } + + @Override + public Rename withInputs(List<IntermediateOperation> inputs) { + if (inputs.size() != 1) + throw new IllegalArgumentException("Rename require 1 input, not " + inputs.size()); + return new Rename(modelName(), from, to, inputs.get(0)); + } + + @Override + public String operationName() { return "Rename"; } + +} + + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 4a0fe236c9f..a210ed13f5d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -74,6 +74,11 @@ public class Reshape extends IntermediateOperation { } } + @Override + public Reshape withInputs(List<IntermediateOperation> inputs) { + return new Reshape(modelName(), name(), inputs); + } + public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); @@ -119,4 +124,7 @@ public class Reshape extends IntermediateOperation { return new ArithmeticNode(children, operators); } + @Override + public String operationName() { return "Reshape"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index dc690329a8d..35a1b6e2b0e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -81,8 +81,16 @@ public class Select extends IntermediateOperation { String bDim1 = bDimensions.get(1).name(); // These tensors should have the same dimension names - renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this); - renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this); + renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this); + renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this); } + @Override + public Select withInputs(List<IntermediateOperation> inputs) { + return new Select(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Select"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index 79f3012c327..57175092b5c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -37,6 +37,11 @@ public class Shape extends IntermediateOperation { return true; } + @Override + public Shape withInputs(List<IntermediateOperation> inputs) { + return new Shape(modelName(), name(), inputs); + } + private void createConstantValue() { if (!allInputTypesPresent(1)) { return; @@ -50,4 +55,7 @@ public class Shape extends IntermediateOperation { this.setConstantValue(new TensorValue(builder.build())); } + @Override + public String operationName() { return "Shape"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index cdacbe1656a..032ffb88a46 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -37,4 +37,12 @@ public class Softmax extends IntermediateOperation { return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension); } + @Override + public Softmax withInputs(List<IntermediateOperation> inputs) { + return new Softmax(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "SoftMax"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 52d40144f61..56d9b542093 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -70,6 +70,11 @@ public class Squeeze extends IntermediateOperation { squeezeDimensions = renamedDimensions; } + @Override + public Squeeze withInputs(List<IntermediateOperation> inputs) { + return new Squeeze(modelName(), name(), inputs, attributeMap); + } + private OrderedTensorType reducedType(OrderedTensorType inputType) { OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { @@ -80,4 +85,7 @@ public class Squeeze extends IntermediateOperation { return builder.build(); } + @Override + public String operationName() { return "Squeeze"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index 46b95233d11..c8cd235f50e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -107,4 +107,12 @@ public class Sum extends IntermediateOperation { return builder.build(); } + @Override + public Sum withInputs(List<IntermediateOperation> inputs) { + return new Sum(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Sum"; } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java index 39702690bfa..4beafc68909 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java @@ -42,6 +42,14 @@ public class Switch extends IntermediateOperation { return predicate == port ? inputs().get(0).function().get() : null; } + @Override + public Switch withInputs(List<IntermediateOperation> inputs) { + return new Switch(modelName(), name(), inputs, port); + } + + @Override + public String operationName() { return "Switch"; } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java index cf8dd6e8e71..793258868ee 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java @@ -9,7 +9,7 @@ public class DimensionRenamerTest { @Test public void testMnistRenaming() { - DimensionRenamer renamer = new DimensionRenamer(); + DimensionRenamer renamer = new DimensionRenamer(new IntermediateGraph("test")); renamer.addDimension("first_dimension_of_x"); renamer.addDimension("second_dimension_of_x"); @@ -18,17 +18,17 @@ public class DimensionRenamerTest { renamer.addDimension("first_dimension_of_b"); // which dimension to join on matmul - renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(false), null); // other dimensions in matmul can't be equal - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(false), null); // for efficiency, put dimension to join on innermost - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); - renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(true), null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(true), null); // bias - renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(false), null); renamer.solve(); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java new file mode 100644 index 00000000000..6500a380190 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java @@ -0,0 +1,28 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.tensorflow; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class Issue9662TestCase { + + @Test + public void testImporting() { + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/9662"); + ImportedModel.Signature signature = model.get().signature("serving_default"); + Assert.assertEquals("Should have no skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + ImportedMlFunction output = signature.outputFunction("output", "output"); + assertNotNull(output); + model.assertEqualResultSum("input_embedding_user_guid", "dense_out/MatMul", 0.00001); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index 9d2f8cf0692..75fa2ed7933 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -49,7 +49,7 @@ public class TestableTensorFlowModel { public ImportedModel get() { return model; } - /** Compare that summing the tensors produce the same result to within some tolerance delta */ + /** Compare that computing the expressions produce the same result to within some tolerance delta */ public void assertEqualResultSum(String inputName, String operationName, double delta) { Tensor tfResult = tensorFlowExecute(tensorFlowModel, inputName, operationName); Context context = contextFrom(model); diff --git a/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt b/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt new file mode 100644 index 00000000000..83c601edfc0 --- /dev/null +++ b/model-integration/src/test/models/tensorflow/9662/saved_model.pbtxt @@ -0,0 +1,1318 @@ +saved_model_schema_version: 1 +meta_graphs { + meta_info_def { + stripped_op_list { + op { + name: "Add" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + type: DT_STRING + } + } + } + } + op { + name: "BiasAdd" + input_arg { + name: "value" + type_attr: "T" + } + input_arg { + name: "bias" + type_attr: "T" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "data_format" + type: "string" + default_value { + s: "NHWC" + } + allowed_values { + list { + s: "NHWC" + s: "NCHW" + } + } + } + } + op { + name: "Const" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "value" + type: "tensor" + } + attr { + name: "dtype" + type: "type" + } + } + op { + name: "ExpandDims" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "dim" + type_attr: "Tdim" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "T" + type: "type" + } + attr { + name: "Tdim" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + op { + name: "MatMul" + input_arg { + name: "a" + type_attr: "T" + } + input_arg { + name: "b" + type_attr: "T" + } + output_arg { + name: "product" + type_attr: "T" + } + attr { + name: "transpose_a" + type: "bool" + default_value { + b: false + } + } + attr { + name: "transpose_b" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Maximum" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + } + } + } + is_commutative: true + } + op { + name: "Mul" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + is_commutative: true + } + op { + name: "Placeholder" + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + default_value { + shape { + unknown_rank: true + } + } + } + } + op { + name: "PlaceholderWithDefault" + input_arg { + name: "input" + type_attr: "dtype" + } + output_arg { + name: "output" + type_attr: "dtype" + } + attr { + name: "dtype" + type: "type" + } + attr { + name: "shape" + type: "shape" + } + } + op { + name: "Rsqrt" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sigmoid" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Square" + input_arg { + name: "x" + type_attr: "T" + } + output_arg { + name: "y" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sub" + input_arg { + name: "x" + type_attr: "T" + } + input_arg { + name: "y" + type_attr: "T" + } + output_arg { + name: "z" + type_attr: "T" + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_BFLOAT16 + type: DT_HALF + type: DT_FLOAT + type: DT_DOUBLE + type: DT_UINT8 + type: DT_INT8 + type: DT_UINT16 + type: DT_INT16 + type: DT_INT32 + type: DT_INT64 + type: DT_COMPLEX64 + type: DT_COMPLEX128 + } + } + } + } + op { + name: "Sum" + input_arg { + name: "input" + type_attr: "T" + } + input_arg { + name: "reduction_indices" + type_attr: "Tidx" + } + output_arg { + name: "output" + type_attr: "T" + } + attr { + name: "keep_dims" + type: "bool" + default_value { + b: false + } + } + attr { + name: "T" + type: "type" + allowed_values { + list { + type: DT_FLOAT + type: DT_DOUBLE + type: DT_INT32 + type: DT_UINT8 + type: DT_INT16 + type: DT_INT8 + type: DT_COMPLEX64 + type: DT_INT64 + type: DT_QINT8 + type: DT_QUINT8 + type: DT_QINT32 + type: DT_BFLOAT16 + type: DT_UINT16 + type: DT_COMPLEX128 + type: DT_HALF + type: DT_UINT32 + type: DT_UINT64 + } + } + } + attr { + name: "Tidx" + type: "type" + default_value { + type: DT_INT32 + } + allowed_values { + list { + type: DT_INT32 + type: DT_INT64 + } + } + } + } + } + tags: "serve" + tensorflow_version: "1.13.1" + tensorflow_git_version: "b\'v1.13.1-0-g6612da8951\'" + } + graph_def { + node { + name: "keras_learning_phase/input" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_BOOL + tensor_shape { + } + bool_val: false + } + } + } + } + node { + name: "keras_learning_phase" + op: "PlaceholderWithDefault" + input: "keras_learning_phase/input" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_BOOL + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Dot/l2_normalize/Maximum/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 9.999999960041972e-13 + } + } + } + } + node { + name: "Dot/l2_normalize/Sum/reduction_indices" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: 1 + } + } + } + } + node { + name: "dense_out/kernel" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + dim { + size: 1 + } + } + float_val: 0.1835838258266449 + } + } + } + } + node { + name: "input_embedding_user_guid" + op: "Placeholder" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + node { + name: "Dot/l2_normalize_1/Square" + op: "Square" + input: "input_embedding_user_guid" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + } + node { + name: "Dot/l2_normalize/Sum" + op: "Sum" + input: "Dot/l2_normalize_1/Square" + input: "Dot/l2_normalize/Sum/reduction_indices" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: true + } + } + } + node { + name: "Dot/l2_normalize/Maximum" + op: "Maximum" + input: "Dot/l2_normalize/Sum" + input: "Dot/l2_normalize/Maximum/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "Dot/l2_normalize_1/Rsqrt" + op: "Rsqrt" + input: "Dot/l2_normalize/Maximum" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "Dot/l2_normalize_1" + op: "Mul" + input: "input_embedding_user_guid" + input: "Dot/l2_normalize_1/Rsqrt" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + } + node { + name: "Dot/Mul" + op: "Mul" + input: "Dot/l2_normalize_1" + input: "Dot/l2_normalize_1" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + } + } + node { + name: "Dot/Sum" + op: "Sum" + input: "Dot/Mul" + input: "Dot/l2_normalize/Sum/reduction_indices" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + } + } + } + } + attr { + key: "keep_dims" + value { + b: false + } + } + } + node { + name: "batch_normalization_v1/moving_variance" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 1.0 + } + } + } + } + node { + name: "Dot/ExpandDims" + op: "ExpandDims" + input: "Dot/Sum" + input: "Dot/l2_normalize/Sum/reduction_indices" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tdim" + value { + type: DT_INT32 + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/moving_mean" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + dim { + size: 1 + } + } + float_val: 0.0 + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/add/y" + op: "Const" + attr { + key: "_output_shapes" + value { + list { + shape { + } + } + } + } + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "value" + value { + tensor { + dtype: DT_FLOAT + tensor_shape { + } + float_val: 0.0010000000474974513 + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/add" + op: "Add" + input: "batch_normalization_v1/moving_variance" + input: "batch_normalization_v1/batchnorm/add/y" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/Rsqrt" + op: "Rsqrt" + input: "batch_normalization_v1/batchnorm/add" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/mul" + op: "Mul" + input: "batch_normalization_v1/batchnorm/Rsqrt" + input: "batch_normalization_v1/moving_variance" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/mul_1" + op: "Mul" + input: "Dot/ExpandDims" + input: "batch_normalization_v1/batchnorm/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/mul_2" + op: "Mul" + input: "batch_normalization_v1/moving_mean" + input: "batch_normalization_v1/batchnorm/mul" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/sub" + op: "Sub" + input: "batch_normalization_v1/moving_mean" + input: "batch_normalization_v1/batchnorm/mul_2" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: 1 + } + } + } + } + } + } + node { + name: "batch_normalization_v1/batchnorm/add_1" + op: "Add" + input: "batch_normalization_v1/batchnorm/mul_1" + input: "batch_normalization_v1/batchnorm/sub" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + node { + name: "dense_out/MatMul" + op: "MatMul" + input: "batch_normalization_v1/batchnorm/add_1" + input: "dense_out/kernel" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "transpose_a" + value { + b: false + } + } + attr { + key: "transpose_b" + value { + b: false + } + } + } + node { + name: "dense_out/BiasAdd" + op: "BiasAdd" + input: "dense_out/MatMul" + input: "batch_normalization_v1/moving_mean" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + attr { + key: "data_format" + value { + s: "NHWC" + } + } + } + node { + name: "dense_out/Sigmoid" + op: "Sigmoid" + input: "dense_out/BiasAdd" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "_output_shapes" + value { + list { + shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + } + } + versions { + } + } + signature_def { + key: "serving_default" + value { + inputs { + key: "input_embedding_user_guid" + value { + name: "input_embedding_user_guid:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 32 + } + } + } + } + outputs { + key: "output" + value { + name: "dense_out/Sigmoid:0" + dtype: DT_FLOAT + tensor_shape { + dim { + size: -1 + } + dim { + size: 1 + } + } + } + } + method_name: "tensorflow/serving/predict" + } + } +} diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index a16127931e9..6f37b9edea4 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2576,6 +2576,20 @@ ], "fields": [] }, + "com.yahoo.text.ExpressionFormatter": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public java.lang.String format(java.lang.String)", + "public static java.lang.String on(java.lang.String)", + "public static com.yahoo.text.ExpressionFormatter withLineLength(int)", + "public static com.yahoo.text.ExpressionFormatter inTwoColumnMode(int, int)" + ], + "fields": [] + }, "com.yahoo.text.ForwardWriter": { "superClass": "com.yahoo.text.GenericWriter", "interfaces": [], diff --git a/vespajlib/src/main/java/com/yahoo/collections/ListMap.java b/vespajlib/src/main/java/com/yahoo/collections/ListMap.java index e851362a99d..479850beb1a 100644 --- a/vespajlib/src/main/java/com/yahoo/collections/ListMap.java +++ b/vespajlib/src/main/java/com/yahoo/collections/ListMap.java @@ -23,6 +23,12 @@ public class ListMap<K, V> { this(HashMap.class); } + /** Copy constructor. This will not be frozen even if the argument map is */ + public ListMap(ListMap<K, V> original) { + map = new HashMap<>(); + original.map.forEach((k, v) -> this.map.put(k, new ArrayList<>(v))); + } + @SuppressWarnings("unchecked") public ListMap(@SuppressWarnings("rawtypes") Class<? extends Map> implementation) { try { @@ -45,6 +51,27 @@ public class ListMap<K, V> { list.add(value); } + /** Put a key without adding a new value, such that there is an empty list of values if no values are already added */ + public void put(K key) { + List<V> list = map.get(key); + if (list == null) { + list = new ArrayList<>(); + map.put(key, list); + } + } + + /** Put this map in the state where it has just the given value of the given key */ + public void replace(K key, V value) { + List<V> list = map.get(key); + if (list == null) { + put(key); + } + else { + list.clear(); + list.add(value); + } + } + public void removeAll(K key) { map.remove(key); } @@ -73,13 +100,13 @@ public class ListMap<K, V> { /** * Returns the List containing the elements with this key, or an empty list - * if there are no elements for this key. The list returned is unmodifiable. + * if there are no elements for this key. + * The returned list can be modified to add and remove values if the value exists. */ public List<V> get(K key) { List<V> list = map.get(key); - if (list == null) - return ImmutableList.of();; - return ImmutableList.copyOf(list); + if (list == null) return ImmutableList.of();; + return list; } /** The same as get */ diff --git a/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java new file mode 100644 index 00000000000..280b75f9cbb --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java @@ -0,0 +1,180 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.text; + +/** + * Formats any parenthesis expression. + * In addition to the obvious this can also operate in "two column mode", + * wherein each chunk that will be formatted on a separate line may optionally + * contain a prefix marked by a start and end tab sign which will be printed in a left column of the given fixed size. + * The prefix itself is not formatted but will be cut if too long. + * + * @author bratseth + */ +public class ExpressionFormatter { + + private static final int indentUnit = 2; + + /** The size of the first column, or 0 if none */ + private final int firstColumnLength; + + /** + * The desired size of the second column (or the entire line if no first column), + * or 0 to split into multiple lines as much as possible. + * Setting this collects larger chunks to one line across markup + * but will not split too long lines that have no markup. + */ + private final int secondColumnLength; + + private ExpressionFormatter(int firstColumnLength, int secondColumnLength) { + this.firstColumnLength = firstColumnLength; + this.secondColumnLength = secondColumnLength; + } + + public String format(String parenthesisExpression) { + StringBuilder b = new StringBuilder(); + format(parenthesisExpression, 0, b); + while (b.length() > 0 && Character.isWhitespace(b.charAt(b.length() - 1))) + b.setLength(b.length() - 1); + return b.toString(); + } + + private void format(String expression, int indent, StringBuilder b) { + if (expression.isEmpty()) return; + expression = appendFirstColumn(expression, b); + + Markup next = Markup.next(expression); + + appendIndent( ! next.isClose() || next.position() > 0 ? indent : indent - 2, b); + + int endOfBalancedChunk = endOfBalancedChunk(expression, Math.max(0, secondColumnLength - indent)); + if (next.isEmpty()) { + b.append(expression); + } + else if (endOfBalancedChunk > 0) { + b.append(expression, 0, endOfBalancedChunk + 1).append("\n"); + format(expression.substring(endOfBalancedChunk + 1), indent, b); + } + else if (next.isComma()) { + b.append(expression, 0, next.position() + 1).append("\n"); + format(expression.substring(next.position() + 1), indent, b); + } + else { + if ( next.isClose() && next.position() > 0) { // content before end parenthesis: content, newline, then end parenthesis + b.append(expression, 0, next.position()).append("\n"); + appendFirstColumn(")", b); + appendIndent(indent - 2, b); + b.append(")\n"); + } + else { + b.append(expression, 0, next.position() + 1).append("\n"); + } + format(expression.substring(next.position() + 1), indent + (next.isOpen() ? indentUnit : -indentUnit), b); + } + } + + /** Returns the position of the end of a balanced chunk of at most the given size, or 0 if there is no such chunk */ + private int endOfBalancedChunk(String expression, int maxSize) { + int chunkSize = 0; + int i = 0; + int nesting = 0; + while (i < maxSize && i < expression.length()) { + if (expression.charAt(i) == '\t') return chunkSize; + if (expression.charAt(i) == '(') nesting++; + if (expression.charAt(i) == ')') nesting--; + if (nesting < 0) return chunkSize; + if (nesting == 0 && ( expression.charAt(i)==')' || expression.charAt(i)==',')) + chunkSize = i; + i++; + } + return chunkSize; + } + + private String appendFirstColumn(String expression, StringBuilder b) { + if (firstColumnLength == 0) return expression; + + while (expression.charAt(0) == ' ') + expression = expression.substring(1); + + if (expression.charAt(0) == '\t') { + int tab2 = expression.indexOf('\t', 1); + if (tab2 >= 0) { + String firstColumn = expression.substring(1, tab2); + b.append(asSize(firstColumnLength, firstColumn)).append(" "); + return expression.substring(tab2 + 1); + } + } + appendIndent(firstColumnLength + 1, b); + return expression; + } + + private void appendIndent(int indent, StringBuilder b) { + b.append(" ".repeat(Math.max(0, indent))); + } + + private String asSize(int size, String s) { + if (s.length() > size) + return s.substring(0, size); + else + return s + " ".repeat(size - s.length()); + } + + /** Convenience method creating a formatter and using it to format the given expression */ + public static String on(String parenthesisExpression) { + return new ExpressionFormatter(0, 80).format(parenthesisExpression); + } + + public static ExpressionFormatter withLineLength(int maxLineLength) { + return new ExpressionFormatter(0, maxLineLength); + } + + public static ExpressionFormatter inTwoColumnMode(int firstColumnSize, int secondColumnSize) { + return new ExpressionFormatter(firstColumnSize, secondColumnSize); + } + + /** Contains the next position of each kind of markup, or Integer.MAX_VALUE if not present */ + private static class Markup { + + final int open, close, comma; + + private Markup(int open, int close, int comma) { + this.open = open; + this.close = close; + this.comma = comma; + } + + int position() { + return Math.min(Math.min(open, close), comma); + } + + boolean isOpen() { + return open < close && open < comma; + } + + boolean isClose() { + return close < open && close < comma; + } + + boolean isComma() { + return comma < open && comma < close; + } + + boolean isEmpty() { + return open == Integer.MAX_VALUE && close == Integer.MAX_VALUE && comma == Integer.MAX_VALUE; + } + + static Markup next(String expression) { + int nextOpen = expression.indexOf('('); + int nextClose = expression.indexOf(')'); + int nextComma = expression.indexOf(','); + if (nextOpen < 0) + nextOpen = Integer.MAX_VALUE; + if (nextClose < 0) + nextClose = Integer.MAX_VALUE; + if (nextComma < 0) + nextComma = Integer.MAX_VALUE; + return new Markup(nextOpen, nextClose, nextComma); + } + + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java new file mode 100644 index 00000000000..7251ccef521 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java @@ -0,0 +1,190 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.text; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class ExpressionFormatterTest { + + @Test + public void testBasic() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz()))", 0); + } + + @Test + public void testBasicDense() { + assertPrettyPrint("foo(bar(baz()))", "foo(bar(baz()))", 50); + } + + @Test + public void testArgument() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " hello world\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hello world)))", 0); + } + + @Test + public void testMultipleArguments() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " hello world,\n" + + " 37\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hello world,37)))", 0); + } + + @Test + public void testMultipleArgumentsSemiDense() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(hi,37),\n" + + " baz(\n" + + " hello world,\n" + + " 37\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hi,37),baz(hello world,37)))", 15); + } + + @Test + public void testUnmatchedStart() { + String expected = + "foo(\n" + + " (\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + " )"; + assertPrettyPrint(expected, "foo((bar(baz()))", 0); + } + + @Test + public void testUnmatchedEnd() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + ")\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz())))", 0); + } + + @Test + public void testNoParenthesis() { + String expected = + "foo bar baz"; + assertPrettyPrint(expected, "foo bar baz", 0); + } + + @Test + public void testEmpty() { + String expected = + ""; + assertPrettyPrint(expected, "", 0); + } + + @Test + public void test2ColumnMode() { + String expected = + "1: foo(\n" + + " bar(\n" + + " baz(\n" + + "2: hello world\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); + assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world)\tt(o)@olong:\t))")); + } + + @Test + public void test2ColumnModeMultipleArguments() { + String expected = + "1: foo(\n" + + " bar(\n" + + " baz(\n" + + "2: hello world,\n" + + "3: 37\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); + assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world,\t3:\t37)\tt(o)@olong:\t))")); + } + + @Test + public void test2ColumnModeMultipleArgumentsSemiDense() { + String expected = + "1: foo(\n" + + " bar(\n" + + " baz(hi,37),\n" + + " boz(\n" + + "2: hello world,\n" + + "3: 5\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 15); + assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(hi,37),boz(\t2:\thello world,\t3:\t5)\tt(o)@olong:\t))")); + } + + @Test + public void test2ColumnModeMultipleArgumentsWithSpaces() { + String expected = + " foo(\n" + + "1: bar(\n" + + " baz(\n" + + "2: hello world,\n" + + "3: 37\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); + assertEquals(expected, pp.format("foo(\t1:\tbar(baz(\t2:\thello world, \t3:\t37)\tt(o)@olong:\t))")); + } + + @Test + public void testTwoColumnLambdaFunction() { + String expected = + " join(\n" + + " a,\n" + + " join(\n" + + " b, c, f(a, b)(a * b)\n" + + " )\n" + + " , f(a, b)(a * b)\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(5, 25); + assertEquals(expected, pp.format("join(a, join(b, c, f(a, b)(a * b)), f(a, b)(a * b))")); + } + + private void assertPrettyPrint(String expected, String expression, int lineLength) { + assertEquals(expected, ExpressionFormatter.withLineLength(lineLength).format(expression)); + } + +} |