diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java | 421 |
1 files changed, 299 insertions, 122 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 9e9f66be700..0f563a75b11 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -2,179 +2,179 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.Rename; +import com.yahoo.collections.ListMap; -import java.util.ArrayDeque; import java.util.ArrayList; -import java.util.Collections; -import java.util.Deque; +import java.util.Comparator; import java.util.HashMap; -import java.util.Iterator; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; /** - * A constraint satisfier to find suitable dimension names to reduce the + * A constraint solver which finds suitable dimension names to reduce the * amount of necessary renaming during evaluation of an imported model. * * @author lesters + * @author bratseth */ public class DimensionRenamer { + private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName()); + private final String dimensionPrefix; - private final Map<String, List<Integer>> variables = new HashMap<>(); - private final Map<Arc, Constraint> constraints = new HashMap<>(); - private final Map<String, Integer> renames = new HashMap<>(); - private int iterations = 0; + /** The graph we are renaming the dimensions of */ + private final IntermediateGraph graph; + + /** The set of dimensions to find a solution for */ + private final Set<String> dimensions = new HashSet<>(); + + /** The constraints on the dimension name assignment */ + private final ListMap<Arc, Constraint> constraints = new ListMap<>(); + + /** The solution to this, or null if no solution is found yet */ + private Map<String, Integer> renames = null; - public DimensionRenamer() { - this("d"); + public DimensionRenamer(IntermediateGraph graph) { + this(graph, "d"); } - public DimensionRenamer(String dimensionPrefix) { + public DimensionRenamer(IntermediateGraph graph, String dimensionPrefix) { + this.graph = graph; this.dimensionPrefix = dimensionPrefix; } - /** - * Add a dimension name variable. - */ - public void addDimension(String name) { - variables.computeIfAbsent(name, d -> new ArrayList<>()); - } + /** Add a dimension to the set of dimensions to be renamed */ + public void addDimension(String name) { dimensions.add(name); } + + /** Add a constraint between dimension names */ + public void addConstraint(String from, String to, Constraint constraint, IntermediateOperation operation) { + if (constraint instanceof EqualConstraint && from.equals(to)) return; - /** - * Add a constraint between dimension names. - */ - public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { Arc arc = new Arc(from, to, operation); - Arc opposite = arc.opposite(); - constraints.put(arc, pred); - constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + constraints.put(arc, constraint); + constraints.put(arc.opposite(), constraint.opposite()); // make constraint graph symmetric } - /** - * Retrieve resulting name of dimension after solving for constraints. - */ - public Optional<String> dimensionNameOf(String name) { - if (!renames.containsKey(name)) { - return Optional.empty(); - } - return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); + void solve() { + log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints)); + renames = solve(100000); + log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames)); } - /** - * Perform iterative arc consistency until we have found a solution. After - * an initial iteration, the variables (dimensions) will have multiple - * valid values. Find a single valid assignment by iteratively locking one - * dimension after another, and running the arc consistency algorithm - * multiple times. - * - * This requires having constraints that result in an absolute ordering: - * equals, lesserThan and greaterThan do that, but adding notEquals does - * not typically result in a guaranteed ordering. If that is needed, the - * algorithm below needs to be adapted with a backtracking (tree) search - * to find solutions. - */ - private void solve(int maxIterations) { - initialize(); - - // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + private Map<String, Integer> solve(int maxIterations) { + Map<String, Integer> solution = solveWithOrWithoutSoftConstraints(maxIterations); + if (solution != null) return solution; - for (String dimension : variables.keySet()) { - List<Integer> values = variables.get(dimension); - if (values.size() > 1) { - if (!ac3()) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution."); - } - values.sort(Integer::compare); - variables.put(dimension, Collections.singletonList(values.get(0))); - } - renames.put(dimension, variables.get(dimension).get(0)); - if (iterations > maxIterations) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + - maxIterations + " iterations"); - } + for (RenameTarget target : prioritizedRenameTargets()) { + System.out.println("Trying rename " + target); + target.insertRename(this); + solution = solveWithOrWithoutSoftConstraints(maxIterations); + if (solution != null) return solution; + target.uninsertRename(this); } - - // Todo: handle failure more gracefully: - // If a solution can't be found, look at the operation node in the arc - // with the most remaining constraints, and inject a rename operation. - // Then run this algorithm again. + throw new IllegalArgumentException("Could not find a dimension naming solution " + + "given constraints\n" + constraintsToString(constraints)); } - void solve() { - solve(100000); + private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) { + Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations); + if ( solution == null) { + ListMap<Arc, Constraint> hardConstraints = new ListMap<>(); + boolean anyRemoved = copyHard(constraints, hardConstraints); + if (anyRemoved) + solution = NamingConstraintSolver.solve(dimensions, hardConstraints, maxIterations); + } + return solution; } - private void initialize() { - for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { - List<Integer> values = variable.getValue(); - for (int i = 0; i < variables.size(); ++i) { - values.add(i); // invariant: values are in increasing order + /** Removes soft constraints and returns whether something was removed */ + private boolean copyHard(ListMap<Arc, Constraint> source, ListMap<Arc, Constraint> target) { + boolean removed = false; + for (var entry : source.entrySet()) { + Arc arc = entry.getKey(); + for (Constraint constraint : entry.getValue()) { + if ( ! constraint.isSoft()) + target.put(arc, constraint); + else + removed = true; } } + return removed; } - private boolean ac3() { - Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); - while (!workList.isEmpty()) { - Arc arc = workList.pop(); - iterations += 1; - if (revise(arc)) { - if (variables.get(arc.from).size() == 0) { - return false; // no solution found - } - for (Arc constraint : constraints.keySet()) { - if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { - workList.add(constraint); - } - } - } + private List<RenameTarget> prioritizedRenameTargets() { + Map<IntermediateOperation, Integer> constraintsPerOperation = new HashMap<>(); + + for (var constraint : constraints.entrySet()) { + constraintsPerOperation.compute(constraint.getKey().operation, + (operation, count) -> count == null ? 1 : ++count); } - return true; - } + List<IntermediateOperation> prioritizedOperations = + constraintsPerOperation.entrySet().stream() + .sorted(Comparator.comparingInt(entry -> - entry.getValue())) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()); - private boolean revise(Arc arc) { - boolean revised = false; - for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { - Integer from = fromIterator.next(); - boolean satisfied = false; - for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { - Integer to = toIterator.next(); - if (constraints.get(arc).test(from, to)) { - satisfied = true; + List<RenameTarget> targets = new ArrayList<>(); + for (IntermediateOperation operation : prioritizedOperations) { + for (int i = 0; i < operation.inputs().size(); i++) { + Optional<OrderedTensorType> inputType = operation.inputs().get(i).type(); + if (inputType.isEmpty()) continue; + for (String dimensionName : inputType.get().dimensionNames()) { + RenameTarget target = new RenameTarget(operation, i, dimensionName, graph); + if (target.rootKey != null) // TODO: Inserting renames under non-roots is not implemented + targets.add(target); } } - if (!satisfied) { - fromIterator.remove(); - revised = true; - } } - return revised; - } - - public interface Constraint { - boolean test(Integer x, Integer y); + return targets; } - public static boolean equals(Integer x, Integer y) { - return Objects.equals(x, y); + /** + * Retrieve resulting name of a dimension after solving for constraints, or empty if no + * solution is found yet, or this dimension was not added before finding a solution. + */ + public Optional<String> dimensionNameOf(String name) { + if ( renames == null || ! renames.containsKey(name)) + return Optional.empty(); + return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); } - public static boolean lesserThan(Integer x, Integer y) { - return x < y; + private static String renamesToString(Map<String, Integer> renames) { + return renames.entrySet().stream() + .map(e -> " " + e.getKey() + " -> " + e.getValue()) + .collect(Collectors.joining("\n")); } - public static boolean greaterThan(Integer x, Integer y) { - return x > y; + private static String constraintsToString(ListMap<Arc, Constraint> constraints) { + StringBuilder b = new StringBuilder(); + for (var entry : constraints.entrySet()) { + Arc arc = entry.getKey(); + for (Constraint constraint : entry.getValue()) { + if (constraint.isOpposite()) continue; // noise + b.append(" "); + if (constraint.isSoft()) + b.append("(soft) "); + b.append(arc.from).append(" ").append(constraint).append(" ").append(arc.to); + b.append(" (origin: ").append(arc.operation).append(")\n"); + } + } + return b.toString(); } - private static class Arc { + static class Arc { - private final String from; - private final String to; + final String from; + final String to; private final IntermediateOperation operation; Arc(String from, String to, IntermediateOperation operation) { @@ -194,7 +194,7 @@ public class DimensionRenamer { @Override public boolean equals(Object obj) { - if (obj == null || !(obj instanceof Arc)) { + if (!(obj instanceof Arc)) { return false; } Arc other = (Arc) obj; @@ -203,8 +203,185 @@ public class DimensionRenamer { @Override public String toString() { - return String.format("%s -> %s", from, to); + return from + " -> " + to; + } + } + + public static abstract class Constraint { + + private final boolean soft, opposite; + + protected Constraint(boolean soft, boolean opposite) { + this.soft = soft; + this.opposite = opposite; + } + + abstract boolean test(Integer x, Integer y); + abstract Constraint opposite(); + + /** Returns whether this constraint can be violated if that is necessary to achieve a solution */ + boolean isSoft() { return soft; } + + /** Returns whether this is an opposite of another constraint */ + boolean isOpposite() { return opposite; } + + public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); } + public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); } + public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); } + public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); } + + } + + private static class EqualConstraint extends Constraint { + + private EqualConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return Objects.equals(x, y); } + + @Override + public Constraint opposite() { return new EqualConstraint(isSoft(), true); } + + @Override + public String toString() { return "=="; } + + } + + private static class NotEqualConstraint extends Constraint { + + private NotEqualConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return ! Objects.equals(x, y); } + + @Override + public Constraint opposite() { return new NotEqualConstraint(isSoft(), true); } + + @Override + public String toString() { return "!="; } + + } + + private static class LessThanConstraint extends Constraint { + + private LessThanConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return x < y; } + + @Override + public Constraint opposite() { return new GreaterThanConstraint(isSoft(), true); } + + @Override + public String toString() { return "<"; } + + } + + private static class GreaterThanConstraint extends Constraint { + + private GreaterThanConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return x > y; } + + @Override + public Constraint opposite() { return new LessThanConstraint(isSoft(), true); } + + @Override + public String toString() { return ">"; } + + } + + /** + * An operation and an input number which we may want to insert a rename operation at. + * That is, we may want to change op(..., input, ...) to op(..., rename(input), ...). + * + * This class is (and must remain) immutable. + */ + private static class RenameTarget { + + final IntermediateOperation operation; + final int inputNumber; + final String dimensionName; + final IntermediateGraph graph; + + /** + * Returns the key of this operation in the root operations of the graph, + * or null if it is not a root operation + */ + final String rootKey; + + public RenameTarget(IntermediateOperation operation, int inputNumber, String dimensionName, IntermediateGraph graph) { + this.operation = operation; + this.inputNumber = inputNumber; + this.dimensionName = dimensionName; + this.rootKey = findRootKey(operation, graph); + this.graph = graph; + } + + public IntermediateOperation input() { + return operation.inputs().get(inputNumber); + } + + private static String findRootKey(IntermediateOperation operation, IntermediateGraph graph) { + for (var entry : graph.operations().entrySet()) { + if (entry.getValue() == operation) + return entry.getKey(); + } + return null; + } + + /** Inserts a rename operation if possible. Returns whether an operation was inserted. */ + private boolean insertRename(DimensionRenamer renamer) { + Rename rename = new Rename(operation.modelName(), + dimensionName, + renamer.dimensionPrefix + renamer.dimensions.size(), + input()); + + List<IntermediateOperation> newInputs = new ArrayList<>(operation.inputs()); + newInputs.set(inputNumber, rename); + IntermediateOperation newOperation = operation.withInputs(newInputs); + if (rootKey == null) + throw new IllegalStateException("Renaming non-roots is not implemented"); + graph.put(rootKey, newOperation); + + removeConstraintsOf(operation, renamer); + rename.addDimensionNameConstraints(renamer); + newOperation.addDimensionNameConstraints(renamer); + return true; + } + + /** Undo what insertRenameOperation has done: Set back the original operation and remove+add constraints */ + private void uninsertRename(DimensionRenamer renamer) { + IntermediateOperation newOperation = graph.operations().get(rootKey); + Rename rename = (Rename)newOperation.inputs().get(inputNumber); + graph.put(rootKey, operation); + + removeConstraintsOf(rename, renamer); + removeConstraintsOf(newOperation, renamer); + operation.addDimensionNameConstraints(renamer); + } + + private void removeConstraintsOf(IntermediateOperation operation, DimensionRenamer renamer) { + for (Arc key : new HashSet<>(renamer.constraints.keySet())) { + if (key.operation == operation) + renamer.constraints.removeAll(key); + } + } + + @Override + public String toString() { + return operation + ", input " + inputNumber; } + } } |