summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-02 13:43:55 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-02 13:43:55 -0700
commitb2cd383345dd42311c7a5005a6232c941bc95dcf (patch)
tree9c41c4dd809e38f7ed14856f034b5538f8abe1d5
parent6c8e1b26bc33ba89f8fed9354fe2666dc796a485 (diff)
Output conflixcting constraints
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java111
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java10
7 files changed, 102 insertions, 40 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 9e9f66be700..282ab6df25b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -13,6 +13,8 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.function.Predicate;
+import java.util.stream.Collectors;
/**
* A constraint satisfier to find suitable dimension names to reduce the
@@ -47,11 +49,10 @@ public class DimensionRenamer {
/**
* Add a constraint between dimension names.
*/
- public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
+ public void addConstraint(String from, String to, Constraint constraint, IntermediateOperation operation) {
Arc arc = new Arc(from, to, operation);
- Arc opposite = arc.opposite();
- constraints.put(arc, pred);
- constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
+ constraints.put(arc, constraint);
+ constraints.put(arc.opposite(), constraint.opposite()); // make constraint graph symmetric
}
/**
@@ -85,8 +86,9 @@ public class DimensionRenamer {
for (String dimension : variables.keySet()) {
List<Integer> values = variables.get(dimension);
if (values.size() > 1) {
- if (!ac3()) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
+ if ( ! ac3()) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution" +
+ " given constraints\n" + constraintsToString());
}
values.sort(Integer::compare);
variables.put(dimension, Collections.singletonList(values.get(0)));
@@ -94,7 +96,8 @@ public class DimensionRenamer {
renames.put(dimension, variables.get(dimension).get(0));
if (iterations > maxIterations) {
throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
- maxIterations + " iterations");
+ maxIterations + " iterations for dimension '" + dimension + "'" +
+ "' given constraints\n" + constraintsToString());
}
}
@@ -155,20 +158,11 @@ public class DimensionRenamer {
return revised;
}
- public interface Constraint {
- boolean test(Integer x, Integer y);
- }
-
- public static boolean equals(Integer x, Integer y) {
- return Objects.equals(x, y);
- }
-
- public static boolean lesserThan(Integer x, Integer y) {
- return x < y;
- }
-
- public static boolean greaterThan(Integer x, Integer y) {
- return x > y;
+ private String constraintsToString() {
+ return constraints.entrySet().stream()
+ .filter(e -> ! e.getValue().isOpposite())
+ .map(e -> e.getKey().from + " " + e.getValue() + " " + e.getKey().to + " (origin: " + e.getKey().operation + ")")
+ .collect(Collectors.joining("\n"));
}
private static class Arc {
@@ -194,7 +188,7 @@ public class DimensionRenamer {
@Override
public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof Arc)) {
+ if (!(obj instanceof Arc)) {
return false;
}
Arc other = (Arc) obj;
@@ -203,8 +197,79 @@ public class DimensionRenamer {
@Override
public String toString() {
- return String.format("%s -> %s", from, to);
+ return from + " -> " + to;
+ }
+ }
+
+ public static abstract class Constraint {
+
+ private final boolean opposite;
+
+ protected Constraint(boolean opposite) {
+ this.opposite = opposite;
+ }
+
+ abstract boolean test(Integer x, Integer y);
+ abstract Constraint opposite();
+
+ /** Returns whether this is an opposite of another constraint */
+ boolean isOpposite() { return opposite; }
+
+ public static Constraint equal() { return new EqualConstraint(false); }
+ public static Constraint lessThan() { return new LessThanConstraint(false); }
+ public static Constraint greaterThan() { return new GreaterThanConstraint(false); }
+
+ }
+
+ private static class EqualConstraint extends Constraint {
+
+ private EqualConstraint(boolean opposite) {
+ super(opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return Objects.equals(x, y); }
+
+ @Override
+ public Constraint opposite() { return new EqualConstraint(true); }
+
+ @Override
+ public String toString() { return "=="; }
+
+ }
+
+ private static class LessThanConstraint extends Constraint {
+
+ private LessThanConstraint(boolean opposite) {
+ super(opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return x < y; }
+
+ @Override
+ public Constraint opposite() { return new GreaterThanConstraint(true); }
+
+ @Override
+ public String toString() { return "<"; }
+
+ }
+
+ private static class GreaterThanConstraint extends Constraint {
+
+ private GreaterThanConstraint(boolean opposite) {
+ super(opposite);
}
+
+ @Override
+ public boolean test(Integer x, Integer y) { return x > y; }
+
+ @Override
+ public Constraint opposite() { return new LessThanConstraint(true); }
+
+ @Override
+ public String toString() { return ">"; }
+
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index 7ae50a0549d..f5bae4c47d5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -89,7 +89,7 @@ public class ConcatV2 extends IntermediateOperation {
OrderedTensorType b = inputs.get(i).type().get();
String bDim = b.dimensions().get(concatDimensionIndex).name();
String aDim = a.dimensions().get(concatDimensionIndex).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(), this);
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 8f029fc9c4a..91783aed7a5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -52,7 +52,7 @@ public class ExpandDims extends IntermediateOperation {
typeBuilder.add(dimension);
dimensionIndex++;
}
- if (dimensionIndex == inputType.dimensions().size()) { // Insert last dimension
+ if (dimensionToInsert == inputType.dimensions().size()) { // Insert last dimension
addDimension(dimensionIndex, typeBuilder);
}
return typeBuilder.build();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index 5c7acc8a0ee..a04725f29fa 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -95,7 +95,7 @@ public class Join extends IntermediateOperation {
for (int i = 0; i < b.rank(); ++i) {
String bDim = b.dimensions().get(i).name();
String aDim = a.dimensions().get(i + sizeDifference).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(), this);
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index cf6cc722b9e..0a4821b8727 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -55,23 +55,20 @@ public class MatMul extends IntermediateOperation {
assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
- System.out.println("Dimensions in a: " + aDimensions);
- System.out.println("Dimensions in b: " + bDimensions);
-
String aDim0 = aDimensions.get(0).name();
String aDim1 = aDimensions.get(1).name();
String bDim0 = bDimensions.get(0).name();
String bDim1 = bDimensions.get(1).name();
// The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(), this);
// The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(), this);
// For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(), this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(), this);
}
private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index dc690329a8d..b7fd91df1e4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -81,8 +81,8 @@ public class Select extends IntermediateOperation {
String bDim1 = bDimensions.get(1).name();
// These tensors should have the same dimension names
- renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this);
- renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(), this);
+ renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(), this);
}
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
index cf8dd6e8e71..d030ece9aa3 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
@@ -18,17 +18,17 @@ public class DimensionRenamerTest {
renamer.addDimension("first_dimension_of_b");
// which dimension to join on matmul
- renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
+ renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(), null);
// other dimensions in matmul can't be equal
- renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(), null);
// for efficiency, put dimension to join on innermost
- renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
- renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(), null);
+ renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(), null);
// bias
- renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
+ renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(), null);
renamer.solve();