summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-02 16:04:01 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-02 16:04:01 -0700
commit0cd53c0204a8caf9fba1847d7f422cc51248f615 (patch)
tree587931b21b27cd130d52690671c805cfbb79c6ec
parent1a60bfebe04b418b0bbde8ebd8b05904e4df1760 (diff)
Add notEqual constraints for dimensions of the same tensor argument
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java50
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java12
3 files changed, 53 insertions, 12 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index 9fe4560b7c6..d2d7367585c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -13,6 +13,8 @@ import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.logging.Level;
+import java.util.logging.Logger;
import java.util.stream.Collectors;
/**
@@ -23,6 +25,8 @@ import java.util.stream.Collectors;
*/
public class DimensionRenamer {
+ private static final Logger log = Logger.getLogger(DimensionRenamer.class.getName());
+
private final String dimensionPrefix;
private final ListMap<String, Integer> variables = new ListMap<>();
private final Map<Arc, Constraint> constraints = new HashMap<>();
@@ -72,7 +76,7 @@ public class DimensionRenamer {
* multiple times.
*
* This requires having constraints that result in an absolute ordering:
- * equals, lesserThan and greaterThan do that, but adding notEquals does
+ * equal, lessThan and greaterThan do that, but adding notEqual does
* not typically result in a guaranteed ordering. If that is needed, the
* algorithm below needs to be adapted with a backtracking (tree) search
* to find solutions.
@@ -88,14 +92,14 @@ public class DimensionRenamer {
if ( ! solved) {
renames.clear();
Map<Arc, Constraint> hardConstraints = new HashMap<>();
- constraints.entrySet().stream().filter(e -> ! e.getValue().isSoft())
- .forEach(e -> hardConstraints.put(e.getKey(), e.getValue()));
+ constraints.entrySet().stream().filter(e -> !e.getValue().isSoft())
+ .forEach(e -> hardConstraints.put(e.getKey(), e.getValue()));
if (hardConstraints.size() < constraints.size())
solved = trySolve(variables, hardConstraints, maxIterations, renames);
- }
- if ( ! solved) {
- throw new IllegalArgumentException("Could not find a dimension naming solution" +
- " given constraints\n" + constraintsToString());
+ if ( ! solved) {
+ throw new IllegalArgumentException("Could not find a dimension naming solution " +
+ " given constraints\n" + constraintsToString(hardConstraints));
+ }
}
// Todo: handle failure more gracefully:
@@ -128,7 +132,17 @@ public class DimensionRenamer {
}
void solve() {
+ log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints));
+ System.out.println("Rename problem:\n" + constraintsToString(constraints));
renames = solve(100000);
+ log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames));
+ System.out.println("Rename solution:\n" + renamesToString(renames));
+ }
+
+ private static String renamesToString(Map<String, Integer> renames) {
+ return renames.entrySet().stream()
+ .map(e -> " " + e.getKey() + " -> " + e.getValue())
+ .collect(Collectors.joining("\n"));
}
private static void initialize(ListMap<String, Integer> variables) {
@@ -182,10 +196,10 @@ public class DimensionRenamer {
return revised;
}
- private String constraintsToString() {
+ private static String constraintsToString(Map<Arc, Constraint> constraints) {
return constraints.entrySet().stream()
.filter(e -> ! e.getValue().isOpposite())
- .map(e -> e.getKey().from + " " + e.getValue() + " " + e.getKey().to + " (origin: " + e.getKey().operation + ")")
+ .map(e -> " " + e.getKey().from + " " + e.getValue() + " " + e.getKey().to + " (origin: " + e.getKey().operation + ")")
.collect(Collectors.joining("\n"));
}
@@ -244,6 +258,7 @@ public class DimensionRenamer {
boolean isOpposite() { return opposite; }
public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); }
+ public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); }
public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); }
public static Constraint greaterThan(boolean soft) { return new GreaterThanConstraint(soft, false); }
@@ -266,6 +281,23 @@ public class DimensionRenamer {
}
+ private static class NotEqualConstraint extends Constraint {
+
+ private NotEqualConstraint(boolean soft, boolean opposite) {
+ super(soft, opposite);
+ }
+
+ @Override
+ public boolean test(Integer x, Integer y) { return ! Objects.equals(x, y); }
+
+ @Override
+ public Constraint opposite() { return new NotEqualConstraint(isSoft(), true); }
+
+ @Override
+ public String toString() { return "!="; }
+
+ }
+
private static class LessThanConstraint extends Constraint {
private LessThanConstraint(boolean soft, boolean opposite) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
index 9115dc99b82..8ea9c9a258d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -130,12 +130,15 @@ public class OrderedTensorType {
}
public OrderedTensorType rename(DimensionRenamer renamer) {
+ System.out.println("Renaming " + this);
List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
for (TensorType.Dimension dimension : dimensions) {
String oldName = dimension.name();
Optional<String> newName = renamer.dimensionNameOf(oldName);
if (!newName.isPresent())
return this; // presumably, already renamed
+ if ( ! oldName.equals(newName.get()))
+ System.out.println(" Renaming " + oldName + " to " + newName.get());
TensorType.Dimension.Type dimensionType = dimension.type();
if (dimensionType == TensorType.Dimension.Type.indexedBound) {
renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
index d5671889e01..2d746bf338c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
@@ -29,7 +29,7 @@ public class Argument extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
- if (!standardNamingType.equals(type)) {
+ if ( ! standardNamingType.equals(type)) {
List<String> renameFrom = standardNamingType.dimensionNames();
List<String> renameTo = type.dimensionNames();
output = new Rename(output, renameFrom, renameTo);
@@ -39,8 +39,14 @@ public class Argument extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ renamer.addDimension(type.dimensions().get(i).name());
+
+ // Each dimension is distinct:
+ for (int j = i + 1; j < type.dimensions().size(); j++)
+ renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
+ DimensionRenamer.Constraint.notEqual(false),
+ this);
}
}