diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-02 13:43:55 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-02 13:43:55 -0700 |
commit | b2cd383345dd42311c7a5005a6232c941bc95dcf (patch) | |
tree | 9c41c4dd809e38f7ed14856f034b5538f8abe1d5 | |
parent | 6c8e1b26bc33ba89f8fed9354fe2666dc796a485 (diff) |
Output conflixcting constraints
7 files changed, 102 insertions, 40 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 9e9f66be700..282ab6df25b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -13,6 +13,8 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.function.Predicate; +import java.util.stream.Collectors; /** * A constraint satisfier to find suitable dimension names to reduce the @@ -47,11 +49,10 @@ public class DimensionRenamer { /** * Add a constraint between dimension names. */ - public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { + public void addConstraint(String from, String to, Constraint constraint, IntermediateOperation operation) { Arc arc = new Arc(from, to, operation); - Arc opposite = arc.opposite(); - constraints.put(arc, pred); - constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + constraints.put(arc, constraint); + constraints.put(arc.opposite(), constraint.opposite()); // make constraint graph symmetric } /** @@ -85,8 +86,9 @@ public class DimensionRenamer { for (String dimension : variables.keySet()) { List<Integer> values = variables.get(dimension); if (values.size() > 1) { - if (!ac3()) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution."); + if ( ! ac3()) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution" + + " given constraints\n" + constraintsToString()); } values.sort(Integer::compare); variables.put(dimension, Collections.singletonList(values.get(0))); @@ -94,7 +96,8 @@ public class DimensionRenamer { renames.put(dimension, variables.get(dimension).get(0)); if (iterations > maxIterations) { throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + - maxIterations + " iterations"); + maxIterations + " iterations for dimension '" + dimension + "'" + + "' given constraints\n" + constraintsToString()); } } @@ -155,20 +158,11 @@ public class DimensionRenamer { return revised; } - public interface Constraint { - boolean test(Integer x, Integer y); - } - - public static boolean equals(Integer x, Integer y) { - return Objects.equals(x, y); - } - - public static boolean lesserThan(Integer x, Integer y) { - return x < y; - } - - public static boolean greaterThan(Integer x, Integer y) { - return x > y; + private String constraintsToString() { + return constraints.entrySet().stream() + .filter(e -> ! e.getValue().isOpposite()) + .map(e -> e.getKey().from + " " + e.getValue() + " " + e.getKey().to + " (origin: " + e.getKey().operation + ")") + .collect(Collectors.joining("\n")); } private static class Arc { @@ -194,7 +188,7 @@ public class DimensionRenamer { @Override public boolean equals(Object obj) { - if (obj == null || !(obj instanceof Arc)) { + if (!(obj instanceof Arc)) { return false; } Arc other = (Arc) obj; @@ -203,8 +197,79 @@ public class DimensionRenamer { @Override public String toString() { - return String.format("%s -> %s", from, to); + return from + " -> " + to; + } + } + + public static abstract class Constraint { + + private final boolean opposite; + + protected Constraint(boolean opposite) { + this.opposite = opposite; + } + + abstract boolean test(Integer x, Integer y); + abstract Constraint opposite(); + + /** Returns whether this is an opposite of another constraint */ + boolean isOpposite() { return opposite; } + + public static Constraint equal() { return new EqualConstraint(false); } + public static Constraint lessThan() { return new LessThanConstraint(false); } + public static Constraint greaterThan() { return new GreaterThanConstraint(false); } + + } + + private static class EqualConstraint extends Constraint { + + private EqualConstraint(boolean opposite) { + super(opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return Objects.equals(x, y); } + + @Override + public Constraint opposite() { return new EqualConstraint(true); } + + @Override + public String toString() { return "=="; } + + } + + private static class LessThanConstraint extends Constraint { + + private LessThanConstraint(boolean opposite) { + super(opposite); + } + + @Override + public boolean test(Integer x, Integer y) { return x < y; } + + @Override + public Constraint opposite() { return new GreaterThanConstraint(true); } + + @Override + public String toString() { return "<"; } + + } + + private static class GreaterThanConstraint extends Constraint { + + private GreaterThanConstraint(boolean opposite) { + super(opposite); } + + @Override + public boolean test(Integer x, Integer y) { return x > y; } + + @Override + public Constraint opposite() { return new LessThanConstraint(true); } + + @Override + public String toString() { return ">"; } + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 7ae50a0549d..f5bae4c47d5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -89,7 +89,7 @@ public class ConcatV2 extends IntermediateOperation { OrderedTensorType b = inputs.get(i).type().get(); String bDim = b.dimensions().get(concatDimensionIndex).name(); String aDim = a.dimensions().get(concatDimensionIndex).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(), this); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 8f029fc9c4a..91783aed7a5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -52,7 +52,7 @@ public class ExpandDims extends IntermediateOperation { typeBuilder.add(dimension); dimensionIndex++; } - if (dimensionIndex == inputType.dimensions().size()) { // Insert last dimension + if (dimensionToInsert == inputType.dimensions().size()) { // Insert last dimension addDimension(dimensionIndex, typeBuilder); } return typeBuilder.build(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 5c7acc8a0ee..a04725f29fa 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -95,7 +95,7 @@ public class Join extends IntermediateOperation { for (int i = 0; i < b.rank(); ++i) { String bDim = b.dimensions().get(i).name(); String aDim = a.dimensions().get(i + sizeDifference).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(), this); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index cf6cc722b9e..0a4821b8727 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -55,23 +55,20 @@ public class MatMul extends IntermediateOperation { assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); - System.out.println("Dimensions in a: " + aDimensions); - System.out.println("Dimensions in b: " + bDimensions); - String aDim0 = aDimensions.get(0).name(); String aDim1 = aDimensions.get(1).name(); String bDim0 = bDimensions.get(0).name(); String bDim1 = bDimensions.get(1).name(); // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); + renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(), this); // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); + renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(), this); // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); + renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(), this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(), this); } private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index dc690329a8d..b7fd91df1e4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -81,8 +81,8 @@ public class Select extends IntermediateOperation { String bDim1 = bDimensions.get(1).name(); // These tensors should have the same dimension names - renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this); - renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this); + renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(), this); + renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(), this); } } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java index cf8dd6e8e71..d030ece9aa3 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java @@ -18,17 +18,17 @@ public class DimensionRenamerTest { renamer.addDimension("first_dimension_of_b"); // which dimension to join on matmul - renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(), null); // other dimensions in matmul can't be equal - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(), null); // for efficiency, put dimension to join on innermost - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); - renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(), null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(), null); // bias - renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(), null); renamer.solve(); |