diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-02 16:04:01 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-02 16:04:01 -0700 |
commit | 0cd53c0204a8caf9fba1847d7f422cc51248f615 (patch) | |
tree | 587931b21b27cd130d52690671c805cfbb79c6ec /model-integration | |
parent | 1a60bfebe04b418b0bbde8ebd8b05904e4df1760 (diff) |
Add notEqual constraints for dimensions of the same tensor argument
Diffstat (limited to 'model-integration')
3 files changed, 53 insertions, 12 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 9fe4560b7c6..d2d7367585c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -13,6 +13,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; import java.util.stream.Collectors; /** @@ -23,6 +25,8 @@ import java.util.stream.Collectors; */ public class DimensionRenamer { + private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName()); + private final String dimensionPrefix; private final ListMap<String, Integer> variables = new ListMap<>(); private final Map<Arc, Constraint> constraints = new HashMap<>(); @@ -72,7 +76,7 @@ public class DimensionRenamer { * multiple times. * * This requires having constraints that result in an absolute ordering: - * equals, lesserThan and greaterThan do that, but adding notEquals does + * equal, lessThan and greaterThan do that, but adding notEqual does * not typically result in a guaranteed ordering. If that is needed, the * algorithm below needs to be adapted with a backtracking (tree) search * to find solutions. @@ -88,14 +92,14 @@ public class DimensionRenamer { if ( ! solved) { renames.clear(); Map<Arc, Constraint> hardConstraints = new HashMap<>(); - constraints.entrySet().stream().filter(e -> ! e.getValue().isSoft()) - .forEach(e -> hardConstraints.put(e.getKey(), e.getValue())); + constraints.entrySet().stream().filter(e -> !e.getValue().isSoft()) + .forEach(e -> hardConstraints.put(e.getKey(), e.getValue())); if (hardConstraints.size() < constraints.size()) solved = trySolve(variables, hardConstraints, maxIterations, renames); - } - if ( ! solved) { - throw new IllegalArgumentException("Could not find a dimension naming solution" + - " given constraints\n" + constraintsToString()); + if ( ! solved) { + throw new IllegalArgumentException("Could not find a dimension naming solution " + + " given constraints\n" + constraintsToString(hardConstraints)); + } } // Todo: handle failure more gracefully: @@ -128,7 +132,17 @@ public class DimensionRenamer { } void solve() { + log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints)); + System.out.println("Rename problem:\n" + constraintsToString(constraints)); renames = solve(100000); + log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames)); + System.out.println("Rename solution:\n" + renamesToString(renames)); + } + + private static String renamesToString(Map<String, Integer> renames) { + return renames.entrySet().stream() + .map(e -> " " + e.getKey() + " -> " + e.getValue()) + .collect(Collectors.joining("\n")); } private static void initialize(ListMap<String, Integer> variables) { @@ -182,10 +196,10 @@ public class DimensionRenamer { return revised; } - private String constraintsToString() { + private static String constraintsToString(Map<Arc, Constraint> constraints) { return constraints.entrySet().stream() .filter(e -> ! e.getValue().isOpposite()) - .map(e -> e.getKey().from + " " + e.getValue() + " " + e.getKey().to + " (origin: " + e.getKey().operation + ")") + .map(e -> " " + e.getKey().from + " " + e.getValue() + " " + e.getKey().to + " (origin: " + e.getKey().operation + ")") .collect(Collectors.joining("\n")); } @@ -244,6 +258,7 @@ public class DimensionRenamer { boolean isOpposite() { return opposite; } public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); } + public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); } public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); } public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); } @@ -266,6 +281,23 @@ public class DimensionRenamer { } + private static class NotEqualConstraint extends Constraint { + + private NotEqualConstraint(boolean soft, boolean opposite) { + super(soft, opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return ! Objects.equals(x, y); } + + @Override + public Constraint opposite() { return new NotEqualConstraint(isSoft(), true); } + + @Override + public String toString() { return "!="; } + + } + private static class LessThanConstraint extends Constraint { private LessThanConstraint(boolean soft, boolean opposite) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index 9115dc99b82..8ea9c9a258d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -130,12 +130,15 @@ public class OrderedTensorType { } public OrderedTensorType rename(DimensionRenamer renamer) { + System.out.println("Renaming " + this); List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); for (TensorType.Dimension dimension : dimensions) { String oldName = dimension.name(); Optional<String> newName = renamer.dimensionNameOf(oldName); if (!newName.isPresent()) return this; // presumably, already renamed + if ( ! oldName.equals(newName.get())) + System.out.println(" Renaming " + oldName + " to " + newName.get()); TensorType.Dimension.Type dimensionType = dimension.type(); if (dimensionType == TensorType.Dimension.Type.indexedBound) { renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index d5671889e01..2d746bf338c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -29,7 +29,7 @@ public class Argument extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); - if (!standardNamingType.equals(type)) { + if ( ! standardNamingType.equals(type)) { List<String> renameFrom = standardNamingType.dimensionNames(); List<String> renameTo = type.dimensionNames(); output = new Rename(output, renameFrom, renameTo); @@ -39,8 +39,14 @@ public class Argument extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + + // Each dimension is distinct: + for (int j = i + 1; j < type.dimensions().size(); j++) + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.notEqual(false), + this); } } |