From 08c7c9919ee6f60047cd57b3afece54dfa7dda52 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 2 Jul 2019 14:09:19 -0700 Subject: Forfeit soft constraints when necessary --- .../importer/DimensionRenamer.java | 71 +++++++++++++--------- .../importer/operations/ConcatV2.java | 2 +- .../importer/operations/Join.java | 2 +- .../importer/operations/MatMul.java | 8 +-- .../importer/operations/Select.java | 4 +- .../importer/DimensionRenamerTest.java | 10 +-- 6 files changed, 56 insertions(+), 41 deletions(-) (limited to 'model-integration/src') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index 282ab6df25b..af6008b07e3 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -82,29 +82,40 @@ public class DimensionRenamer { initialize(); // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + boolean solved = trySolve(maxIterations); + if ( ! solved) { + List softConstraints = constraints.entrySet().stream() + .filter(e -> e.getValue().isSoft()) + .map(e -> e.getKey()) + .collect(Collectors.toList()); + if ( ! softConstraints.isEmpty()) { + softConstraints.forEach(softConstraint -> constraints.remove(softConstraint)); + trySolve(maxIterations); + } + } + if ( ! solved) { + throw new IllegalArgumentException("Could not find a dimension naming solution" + + " given constraints\n" + constraintsToString()); + } + + // Todo: handle failure more gracefully: + // If a solution can't be found, look at the operation node in the arc + // with the most remaining constraints, and inject a rename operation. + // Then run this algorithm again. + } + private boolean trySolve(int maxIterations) { for (String dimension : variables.keySet()) { List values = variables.get(dimension); if (values.size() > 1) { - if ( ! ac3()) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution" + - " given constraints\n" + constraintsToString()); - } + if ( ! ac3()) return false; values.sort(Integer::compare); variables.put(dimension, Collections.singletonList(values.get(0))); } renames.put(dimension, variables.get(dimension).get(0)); - if (iterations > maxIterations) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + - maxIterations + " iterations for dimension '" + dimension + "'" + - "' given constraints\n" + constraintsToString()); - } + if (iterations > maxIterations) return false; } - - // Todo: handle failure more gracefully: - // If a solution can't be found, look at the operation node in the arc - // with the most remaining constraints, and inject a rename operation. - // Then run this algorithm again. + return true; } void solve() { @@ -203,35 +214,39 @@ public class DimensionRenamer { public static abstract class Constraint { - private final boolean opposite; + private final boolean soft, opposite; - protected Constraint(boolean opposite) { + protected Constraint(boolean soft, boolean opposite) { + this.soft = soft; this.opposite = opposite; } abstract boolean test(Integer x, Integer y); abstract Constraint opposite(); + /** Returns whether this constraint can be violated if that is necessary to achieve a solution */ + boolean isSoft() { return soft; } + /** Returns whether this is an opposite of another constraint */ boolean isOpposite() { return opposite; } - public static Constraint equal() { return new EqualConstraint(false); } - public static Constraint lessThan() { return new LessThanConstraint(false); } - public static Constraint greaterThan() { return new GreaterThanConstraint(false); } + public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); } + public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); } + public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); } } private static class EqualConstraint extends Constraint { - private EqualConstraint(boolean opposite) { - super(opposite); + private EqualConstraint(boolean soft, boolean opposite) { + super(soft, opposite); } @Override public boolean test(Integer x, Integer y) { return Objects.equals(x, y); } @Override - public Constraint opposite() { return new EqualConstraint(true); } + public Constraint opposite() { return new EqualConstraint(isSoft(), true); } @Override public String toString() { return "=="; } @@ -240,15 +255,15 @@ public class DimensionRenamer { private static class LessThanConstraint extends Constraint { - private LessThanConstraint(boolean opposite) { - super(opposite); + private LessThanConstraint(boolean soft, boolean opposite) { + super(soft, opposite); } @Override public boolean test(Integer x, Integer y) { return x < y; } @Override - public Constraint opposite() { return new GreaterThanConstraint(true); } + public Constraint opposite() { return new GreaterThanConstraint(isSoft(), true); } @Override public String toString() { return "<"; } @@ -257,15 +272,15 @@ public class DimensionRenamer { private static class GreaterThanConstraint extends Constraint { - private GreaterThanConstraint(boolean opposite) { - super(opposite); + private GreaterThanConstraint(boolean soft, boolean opposite) { + super(soft, opposite); } @Override public boolean test(Integer x, Integer y) { return x > y; } @Override - public Constraint opposite() { return new LessThanConstraint(true); } + public Constraint opposite() { return new LessThanConstraint(isSoft(), true); } @Override public String toString() { return ">"; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index f5bae4c47d5..c211b434176 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -89,7 +89,7 @@ public class ConcatV2 extends IntermediateOperation { OrderedTensorType b = inputs.get(i).type().get(); String bDim = b.dimensions().get(concatDimensionIndex).name(); String aDim = a.dimensions().get(concatDimensionIndex).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(), this); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index a04725f29fa..2c5c38b76cc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -95,7 +95,7 @@ public class Join extends IntermediateOperation { for (int i = 0; i < b.rank(); ++i) { String bDim = b.dimensions().get(i).name(); String aDim = a.dimensions().get(i + sizeDifference).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(), this); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 0a4821b8727..6c6b51a27a5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -61,14 +61,14 @@ public class MatMul extends IntermediateOperation { String bDim1 = bDimensions.get(1).name(); // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(), this); + renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(), this); + renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(), this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(), this); + renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); } private void assertTwoDimensions(List dimensions, IntermediateOperation supplier, String inputDescription) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index b7fd91df1e4..2484310e829 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -81,8 +81,8 @@ public class Select extends IntermediateOperation { String bDim1 = bDimensions.get(1).name(); // These tensors should have the same dimension names - renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(), this); - renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(), this); + renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this); + renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this); } } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java index d030ece9aa3..ee282c7d988 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java @@ -18,17 +18,17 @@ public class DimensionRenamerTest { renamer.addDimension("first_dimension_of_b"); // which dimension to join on matmul - renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(), null); + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(false), null); // other dimensions in matmul can't be equal - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(), null); + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(false), null); // for efficiency, put dimension to join on innermost - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(), null); - renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(), null); + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(true), null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(true), null); // bias - renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(), null); + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(false), null); renamer.solve(); -- cgit v1.2.3