summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-02 14:09:19 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-02 14:09:19 -0700
commit08c7c9919ee6f60047cd57b3afece54dfa7dda52 (patch)
tree26871add6d51d1ff7c77736b0dfa948d48a4debb
parentb2cd383345dd42311c7a5005a6232c941bc95dcf (diff)
Forfeit soft constraints when necessary
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java71
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java10
6 files changed, 56 insertions, 41 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 282ab6df25b..af6008b07e3 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -82,29 +82,40 @@ public class DimensionRenamer {
initialize();
// Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
+ boolean solved = trySolve(maxIterations);
+ if ( ! solved) {
+ List<Arc> softConstraints = constraints.entrySet().stream()
+ .filter(e -> e.getValue().isSoft())
+ .map(e -> e.getKey())
+ .collect(Collectors.toList());
+ if ( ! softConstraints.isEmpty()) {
+ softConstraints.forEach(softConstraint -> constraints.remove(softConstraint));
+ trySolve(maxIterations);
+ }
+ }
+ if ( ! solved) {
+ throw new IllegalArgumentException("Could not find a dimension naming solution" +
+ " given constraints\n" + constraintsToString());
+ }
+
+ // Todo: handle failure more gracefully:
+ // If a solution can't be found, look at the operation node in the arc
+ // with the most remaining constraints, and inject a rename operation.
+ // Then run this algorithm again.
+ }
+ private boolean trySolve(int maxIterations) {
for (String dimension : variables.keySet()) {
List<Integer> values = variables.get(dimension);
if (values.size() > 1) {
- if ( ! ac3()) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution" +
- " given constraints\n" + constraintsToString());
- }
+ if ( ! ac3()) return false;
values.sort(Integer::compare);
variables.put(dimension, Collections.singletonList(values.get(0)));
}
renames.put(dimension, variables.get(dimension).get(0));
- if (iterations > maxIterations) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
- maxIterations + " iterations for dimension '" + dimension + "'" +
- "' given constraints\n" + constraintsToString());
- }
+ if (iterations > maxIterations) return false;
}
-
- // Todo: handle failure more gracefully:
- // If a solution can't be found, look at the operation node in the arc
- // with the most remaining constraints, and inject a rename operation.
- // Then run this algorithm again.
+ return true;
}
void solve() {
@@ -203,35 +214,39 @@ public class DimensionRenamer {
public static abstract class Constraint {
- private final boolean opposite;
+ private final boolean soft, opposite;
- protected Constraint(boolean opposite) {
+ protected Constraint(boolean soft, boolean opposite) {
+ this.soft = soft;
this.opposite = opposite;
}
abstract boolean test(Integer x, Integer y);
abstract Constraint opposite();
+ /** Returns whether this constraint can be violated if that is necessary to achieve a solution */
+ boolean isSoft() { return soft; }
+
/** Returns whether this is an opposite of another constraint */
boolean isOpposite() { return opposite; }
- public static Constraint equal() { return new EqualConstraint(false); }
- public static Constraint lessThan() { return new LessThanConstraint(false); }
- public static Constraint greaterThan() { return new GreaterThanConstraint(false); }
+ public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); }
+ public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); }
+ public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); }
}
private static class EqualConstraint extends Constraint {
- private EqualConstraint(boolean opposite) {
- super(opposite);
+ private EqualConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
}
@Override
public boolean test(Integer x, Integer y) { return Objects.equals(x, y); }
@Override
- public Constraint opposite() { return new EqualConstraint(true); }
+ public Constraint opposite() { return new EqualConstraint(isSoft(), true); }
@Override
public String toString() { return "=="; }
@@ -240,15 +255,15 @@ public class DimensionRenamer {
private static class LessThanConstraint extends Constraint {
- private LessThanConstraint(boolean opposite) {
- super(opposite);
+ private LessThanConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
}
@Override
public boolean test(Integer x, Integer y) { return x < y; }
@Override
- public Constraint opposite() { return new GreaterThanConstraint(true); }
+ public Constraint opposite() { return new GreaterThanConstraint(isSoft(), true); }
@Override
public String toString() { return "<"; }
@@ -257,15 +272,15 @@ public class DimensionRenamer {
private static class GreaterThanConstraint extends Constraint {
- private GreaterThanConstraint(boolean opposite) {
- super(opposite);
+ private GreaterThanConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
}
@Override
public boolean test(Integer x, Integer y) { return x > y; }
@Override
- public Constraint opposite() { return new LessThanConstraint(true); }
+ public Constraint opposite() { return new LessThanConstraint(isSoft(), true); }
@Override
public String toString() { return ">"; }
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index f5bae4c47d5..c211b434176 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -89,7 +89,7 @@ public class ConcatV2 extends IntermediateOperation {
OrderedTensorType b = inputs.get(i).type().get();
String bDim = b.dimensions().get(concatDimensionIndex).name();
String aDim = a.dimensions().get(concatDimensionIndex).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(), this);
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index a04725f29fa..2c5c38b76cc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -95,7 +95,7 @@ public class Join extends IntermediateOperation {
for (int i = 0; i < b.rank(); ++i) {
String bDim = b.dimensions().get(i).name();
String aDim = a.dimensions().get(i + sizeDifference).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(), this);
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 0a4821b8727..6c6b51a27a5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -61,14 +61,14 @@ public class MatMul extends IntermediateOperation {
String bDim1 = bDimensions.get(1).name();
// The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(), this);
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this);
// The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(), this);
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this);
// For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(), this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(), this);
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this);
}
private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index b7fd91df1e4..2484310e829 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -81,8 +81,8 @@ public class Select extends IntermediateOperation {
String bDim1 = bDimensions.get(1).name();
// These tensors should have the same dimension names
- renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(), this);
- renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(), this);
+ renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this);
+ renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this);
}
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
index d030ece9aa3..ee282c7d988 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
@@ -18,17 +18,17 @@ public class DimensionRenamerTest {
renamer.addDimension("first_dimension_of_b");
// which dimension to join on matmul
- renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(), null);
+ renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(false), null);
// other dimensions in matmul can't be equal
- renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(), null);
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(false), null);
// for efficiency, put dimension to join on innermost
- renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(), null);
- renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(), null);
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(true), null);
+ renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(true), null);
// bias
- renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(), null);
+ renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(false), null);
renamer.solve();