summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-22 14:27:58 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-22 14:27:58 +0100
commitb288e61f7af7331656a1850fbdc58cc95fd1bbad (patch)
tree9d41fa770d2890585a902f41a89c41040ed764be /model-integration
parent3c4020645b13be560c14e60969e50e3ad41e3d3c (diff)
Move all importing to model-integration
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java210
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java226
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java109
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java107
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java232
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java235
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java26
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java57
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java108
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java89
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java61
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java106
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java35
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java191
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java120
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java37
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java72
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java113
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java35
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java26
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java48
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java131
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java88
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java54
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java85
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java50
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java42
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java48
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java21
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java2
48 files changed, 2767 insertions, 60 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 28a00dcbdbc..da18d659060 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -20,24 +20,26 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>
+
<dependency>
<groupId>com.yahoo.vespa</groupId>
- <artifactId>component</artifactId>
+ <artifactId>annotations</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
- <artifactId>vespajlib</artifactId>
+ <artifactId>searchlib</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
<dependency>
<groupId>com.yahoo.vespa</groupId>
- <artifactId>searchlib</artifactId>
+ <artifactId>vespajlib</artifactId>
<version>${project.version}</version>
<scope>provided</scope>
</dependency>
+
<dependency>
<groupId>com.google.guava</groupId>
<artifactId>guava</artifactId>
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
new file mode 100644
index 00000000000..9e9f66be700
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -0,0 +1,210 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * A constraint satisfier to find suitable dimension names to reduce the
+ * amount of necessary renaming during evaluation of an imported model.
+ *
+ * @author lesters
+ */
+public class DimensionRenamer {
+
+ private final String dimensionPrefix;
+ private final Map<String, List<Integer>> variables = new HashMap<>();
+ private final Map<Arc, Constraint> constraints = new HashMap<>();
+ private final Map<String, Integer> renames = new HashMap<>();
+
+ private int iterations = 0;
+
+ public DimensionRenamer() {
+ this("d");
+ }
+
+ public DimensionRenamer(String dimensionPrefix) {
+ this.dimensionPrefix = dimensionPrefix;
+ }
+
+ /**
+ * Add a dimension name variable.
+ */
+ public void addDimension(String name) {
+ variables.computeIfAbsent(name, d -> new ArrayList<>());
+ }
+
+ /**
+ * Add a constraint between dimension names.
+ */
+ public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
+ Arc arc = new Arc(from, to, operation);
+ Arc opposite = arc.opposite();
+ constraints.put(arc, pred);
+ constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
+ }
+
+ /**
+ * Retrieve resulting name of dimension after solving for constraints.
+ */
+ public Optional<String> dimensionNameOf(String name) {
+ if (!renames.containsKey(name)) {
+ return Optional.empty();
+ }
+ return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
+ }
+
+ /**
+ * Perform iterative arc consistency until we have found a solution. After
+ * an initial iteration, the variables (dimensions) will have multiple
+ * valid values. Find a single valid assignment by iteratively locking one
+ * dimension after another, and running the arc consistency algorithm
+ * multiple times.
+ *
+ * This requires having constraints that result in an absolute ordering:
+ * equals, lesserThan and greaterThan do that, but adding notEquals does
+ * not typically result in a guaranteed ordering. If that is needed, the
+ * algorithm below needs to be adapted with a backtracking (tree) search
+ * to find solutions.
+ */
+ private void solve(int maxIterations) {
+ initialize();
+
+ // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
+
+ for (String dimension : variables.keySet()) {
+ List<Integer> values = variables.get(dimension);
+ if (values.size() > 1) {
+ if (!ac3()) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
+ }
+ values.sort(Integer::compare);
+ variables.put(dimension, Collections.singletonList(values.get(0)));
+ }
+ renames.put(dimension, variables.get(dimension).get(0));
+ if (iterations > maxIterations) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
+ maxIterations + " iterations");
+ }
+ }
+
+ // Todo: handle failure more gracefully:
+ // If a solution can't be found, look at the operation node in the arc
+ // with the most remaining constraints, and inject a rename operation.
+ // Then run this algorithm again.
+ }
+
+ void solve() {
+ solve(100000);
+ }
+
+ private void initialize() {
+ for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
+ List<Integer> values = variable.getValue();
+ for (int i = 0; i < variables.size(); ++i) {
+ values.add(i); // invariant: values are in increasing order
+ }
+ }
+ }
+
+ private boolean ac3() {
+ Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
+ while (!workList.isEmpty()) {
+ Arc arc = workList.pop();
+ iterations += 1;
+ if (revise(arc)) {
+ if (variables.get(arc.from).size() == 0) {
+ return false; // no solution found
+ }
+ for (Arc constraint : constraints.keySet()) {
+ if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
+ workList.add(constraint);
+ }
+ }
+ }
+ }
+ return true;
+ }
+
+ private boolean revise(Arc arc) {
+ boolean revised = false;
+ for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
+ Integer from = fromIterator.next();
+ boolean satisfied = false;
+ for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
+ Integer to = toIterator.next();
+ if (constraints.get(arc).test(from, to)) {
+ satisfied = true;
+ }
+ }
+ if (!satisfied) {
+ fromIterator.remove();
+ revised = true;
+ }
+ }
+ return revised;
+ }
+
+ public interface Constraint {
+ boolean test(Integer x, Integer y);
+ }
+
+ public static boolean equals(Integer x, Integer y) {
+ return Objects.equals(x, y);
+ }
+
+ public static boolean lesserThan(Integer x, Integer y) {
+ return x < y;
+ }
+
+ public static boolean greaterThan(Integer x, Integer y) {
+ return x > y;
+ }
+
+ private static class Arc {
+
+ private final String from;
+ private final String to;
+ private final IntermediateOperation operation;
+
+ Arc(String from, String to, IntermediateOperation operation) {
+ this.from = from;
+ this.to = to;
+ this.operation = operation;
+ }
+
+ Arc opposite() {
+ return new Arc(to, from, operation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(from, to);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof Arc)) {
+ return false;
+ }
+ Arc other = (Arc) obj;
+ return Objects.equals(from, other.from) && Objects.equals(to, other.to);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("%s -> %s", from, to);
+ }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
new file mode 100644
index 00000000000..2866a2c76b2
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -0,0 +1,226 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.collections.Pair;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.regex.Pattern;
+
+/**
+ * The result of importing an ML model into Vespa.
+ *
+ * @author bratseth
+ */
+public class ImportedModel {
+
+ private static final String defaultSignatureName = "default";
+
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+ private final String name;
+ private final String source;
+
+ private final Map<String, Signature> signatures = new HashMap<>();
+ private final Map<String, TensorType> inputs = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, RankingExpression> functions = new HashMap<>();
+
+ /**
+ * Creates a new imported model.
+ *
+ * @param name the name of this mode, containing only characters in [A-Za-z0-9_]
+ * @param source the source path (directory or file) of this model
+ */
+ public ImportedModel(String name, String source) {
+ if ( ! nameRegexp.matcher(name).matches())
+ throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'");
+ this.name = name;
+ this.source = source;
+ }
+
+ /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
+ public String name() { return name; }
+
+ /** Returns the source path (directory or file) of this model */
+ public String source() { return source; }
+
+ /** Returns an immutable map of the inputs of this */
+ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /**
+ * Returns an immutable map of the small constants of this.
+ * These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
+ */
+ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
+
+ /**
+ * Returns an immutable map of the large constants of this.
+ * These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
+ * For TensorFlow this corresponds to Variable files stored separately.
+ */
+ public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
+
+ /**
+ * Returns an immutable map of the expressions of this - corresponding to graph nodes
+ * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants).
+ * Note that only nodes recursively referenced by a placeholder/input are added.
+ */
+ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+
+ /**
+ * Returns an immutable map of the functions that are part of this model.
+ * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification.
+ */
+ public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); }
+
+ /** Returns an immutable map of the signatures of this */
+ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ public Signature signature(String name) {
+ return signatures.computeIfAbsent(name, Signature::new);
+ }
+
+ /** Convenience method for returning a default signature */
+ public Signature defaultSignature() { return signature(defaultSignatureName); }
+
+ public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
+ public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ public void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ public void function(String name, RankingExpression expression) { functions.put(name, expression); }
+
+ /**
+ * Returns all the output expressions of this indexed by name. The names consist of one or two parts
+ * separated by dot, where the first part is the signature name
+ * if signatures are used, or the expression name if signatures are not used and there are multiple
+ * expressions, and the second is the output name if signature names are used.
+ */
+ public List<Pair<String, ExpressionFunction>> outputExpressions() {
+ List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
+ for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
+ for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
+ expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
+ signatureEntry.getValue().outputExpression(outputEntry.getKey())
+ .withName(signatureEntry.getKey() + "." + outputEntry.getKey())));
+ if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
+ expressions.add(new Pair<>(signatureEntry.getKey(),
+ new ExpressionFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().values()),
+ expressions().get(signatureEntry.getKey()),
+ signatureEntry.getValue().inputMap(),
+ Optional.empty())));
+ }
+ if (signatures().isEmpty()) { // fallback for models without signatures
+ if (expressions().size() == 1) {
+ Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
+ expressions.add(new Pair<>(singleEntry.getKey(),
+ new ExpressionFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue(),
+ inputs,
+ Optional.empty())));
+ }
+ else {
+ for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
+ expressions.add(new Pair<>(expressionEntry.getKey(),
+ new ExpressionFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue(),
+ inputs,
+ Optional.empty())));
+ }
+ }
+ }
+ return expressions;
+ }
+
+ /**
+ * A signature is a set of named inputs and outputs, where the inputs maps to input
+ * ("placeholder") names+types, and outputs maps to expressions nodes.
+ * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
+ * concept of signatures. For now, we handle ONNX models as having a single signature.
+ */
+ public class Signature {
+
+ private final String name;
+ private final Map<String, String> inputs = new LinkedHashMap<>();
+ private final Map<String, String> outputs = new LinkedHashMap<>();
+ private final Map<String, String> skippedOutputs = new HashMap<>();
+ private final List<String> importWarnings = new ArrayList<>();
+
+ Signature(String name) {
+ this.name = name;
+ }
+
+ public String name() { return name; }
+
+ /** Returns the result this is part of */
+ ImportedModel owner() { return ImportedModel.this; }
+
+ /**
+ * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
+ * in this signature to input name in the owning model
+ */
+ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /** Returns the name and type of all inputs in this signature as an immutable map */
+ Map<String, TensorType> inputMap() {
+ ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
+ // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
+ // in the model, as these are the names which must actually be bound, if we are to avoid creating an
+ // "input mapping" to accomodate this complexity in
+ for (Map.Entry<String, String> inputEntry : inputs().entrySet())
+ inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
+ return inputs.build();
+ }
+
+ /** Returns the type of the input this input references */
+ public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expression names of this */
+ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+
+ /**
+ * Returns an immutable list of the outputs of this which could not be imported,
+ * with a string detailing the reason for each
+ */
+ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+
+ /**
+ * Returns an immutable list of possibly non-fatal warnings encountered during import.
+ */
+ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
+
+ /** Returns the expression this output references */
+ public ExpressionFunction outputExpression(String outputName) {
+ return new ExpressionFunction(outputName,
+ new ArrayList<>(inputs.values()),
+ owner().expressions().get(outputs.get(outputName)),
+ inputMap(),
+ Optional.empty());
+ }
+
+ @Override
+ public String toString() { return "signature '" + name + "'"; }
+
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
+
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
new file mode 100644
index 00000000000..1b7532631e1
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java
@@ -0,0 +1,109 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.path.Path;
+
+import java.io.File;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Optional;
+
+/**
+ * All models imported from the models/ directory in the application package.
+ * If this is empty it may be due to either not having any models in the application package,
+ * or this being created for a ZooKeeper application package, which does not have imported models.
+ *
+ * @author bratseth
+ */
+public class ImportedModels {
+
+ /** All imported models, indexed by their names */
+ private final ImmutableMap<String, ImportedModel> importedModels;
+
+ /** Create a null imported models */
+ public ImportedModels() {
+ importedModels = ImmutableMap.of();
+ }
+
+ public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
+ Map<String, ImportedModel> models = new HashMap<>();
+
+ // Find all subdirectories recursively which contains a model we can read
+ importRecursively(modelsDirectory, models, importers);
+ importedModels = ImmutableMap.copyOf(models);
+ }
+
+ private static void importRecursively(File dir,
+ Map<String, ImportedModel> models,
+ Collection<ModelImporter> importers) {
+ if ( ! dir.isDirectory()) return;
+
+ Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
+ Optional<ModelImporter> importer = findImporterOf(child, importers);
+ if (importer.isPresent()) {
+ String name = toName(child);
+ ImportedModel existing = models.get(name);
+ if (existing != null)
+ throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
+ " both resolve to the model name '" + name + "'");
+ models.put(name, importer.get().importModel(name, child));
+ }
+ else {
+ importRecursively(child, models, importers);
+ }
+ });
+ }
+
+ private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
+ return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
+ }
+
+ /**
+ * Returns the model at the given location in the application package.
+ *
+ * @param modelPath the path to this model (file or directory, depending on model type)
+ * under the application package, both from the root or relative to the
+ * models directory works
+ * @return the model at this path or null if none
+ */
+ public ImportedModel get(File modelPath) {
+ return importedModels.get(toName(modelPath));
+ }
+
+ public ImportedModel get(String modelName) {
+ return importedModels.get(modelName);
+ }
+
+ /** Returns an immutable collection of all the imported models */
+ public Collection<ImportedModel> all() {
+ return importedModels.values();
+ }
+
+ private static String toName(File modelFile) {
+ Path modelPath = Path.fromString(modelFile.toString());
+ if (modelFile.isFile())
+ modelPath = stripFileEnding(modelPath);
+ String localPath = concatenateAfterModelsDirectory(modelPath);
+ return localPath.replace('.', '_');
+ }
+
+ private static Path stripFileEnding(Path path) {
+ int dotIndex = path.last().lastIndexOf(".");
+ if (dotIndex <= 0) return path;
+ return path.withLast(path.last().substring(0, dotIndex));
+ }
+
+ private static String concatenateAfterModelsDirectory(Path path) {
+ boolean afterModels = false;
+ StringBuilder result = new StringBuilder();
+ for (String element : path.elements()) {
+ if (afterModels) result.append(element).append("_");
+ if (element.equals("models")) afterModels = true;
+ }
+ return result.substring(0, result.length()-1);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
new file mode 100644
index 00000000000..aec98d06874
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -0,0 +1,107 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.rankingexpression.importer;
+
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Set;
+
+/**
+ * Holds an intermediate representation of an imported model graph.
+ * After this intermediate representation is constructed, it is used to
+ * simplify and optimize the computational graph and then converted into the
+ * final ImportedModel that holds the Vespa ranking expressions for the model.
+ *
+ * @author lesters
+ */
+public class IntermediateGraph {
+
+ private final String modelName;
+ private final Map<String, IntermediateOperation> index = new HashMap<>();
+ private final Map<String, GraphSignature> signatures = new HashMap<>();
+
+ private static class GraphSignature {
+ final Map<String, String> inputs = new HashMap<>();
+ final Map<String, String> outputs = new HashMap<>();
+ }
+
+ public IntermediateGraph(String modelName) {
+ this.modelName = modelName;
+ }
+
+ public String name() {
+ return modelName;
+ }
+
+ public IntermediateOperation put(String key, IntermediateOperation operation) {
+ return index.put(key, operation);
+ }
+
+ public IntermediateOperation get(String key) {
+ return index.get(key);
+ }
+
+ public Set<String> signatures() {
+ return signatures.keySet();
+ }
+
+ public Map<String, String> inputs(String signature) {
+ return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
+ }
+
+ public Map<String, String> outputs(String signature) {
+ return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
+ }
+
+ public String defaultSignature() {
+ return "default";
+ }
+
+ public boolean alreadyImported(String key) {
+ return index.containsKey(key);
+ }
+
+ public Collection<IntermediateOperation> operations() {
+ return index.values();
+ }
+
+ void optimize() {
+ renameDimensions();
+ }
+
+ /**
+ * Find dimension names to avoid excessive renaming while evaluating the model.
+ */
+ private void renameDimensions() {
+ DimensionRenamer renamer = new DimensionRenamer();
+ for (String signature : signatures()) {
+ for (String output : outputs(signature).values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
+ }
+ renamer.solve();
+ for (String signature : signatures()) {
+ for (String output : outputs(signature).values()) {
+ renameDimensions(index.get(output), renamer);
+ }
+ }
+ }
+
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
+ }
+
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
new file mode 100644
index 00000000000..cb095e81147
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -0,0 +1,232 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import ai.vespa.rankingexpression.importer.operations.Constant;
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
+
+import java.io.File;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.logging.Logger;
+
+/**
+ * Base class for importing ML models (ONNX/TensorFlow etc.) as native Vespa
+ * ranking expressions. The general mechanism for import is for the
+ * specific ML platform import implementations to create an
+ * IntermediateGraph. This class offers common code to convert the
+ * IntermediateGraph to Vespa ranking expressions and functions.
+ *
+ * @author lesters
+ */
+public abstract class ModelImporter {
+
+ private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
+
+ /** Returns whether the file or directory at the given path is of the type which can be imported by this */
+ public abstract boolean canImport(String modelPath);
+
+ /** Imports the given model */
+ public abstract ImportedModel importModel(String modelName, String modelPath);
+
+ final ImportedModel importModel(String modelName, File modelPath) {
+ return importModel(modelName, modelPath.toString());
+ }
+
+ /**
+ * Takes an IntermediateGraph and converts it to a ImportedModel containing
+ * the actual Vespa ranking expressions.
+ */
+ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
+ ImportedModel model = new ImportedModel(graph.name(), modelSource);
+
+ graph.optimize();
+
+ importSignatures(graph, model);
+ importExpressions(graph, model);
+ reportWarnings(graph, model);
+ logVariableTypes(graph);
+
+ return model;
+ }
+
+ private static void importSignatures(IntermediateGraph graph, ImportedModel model) {
+ for (String signatureName : graph.signatures()) {
+ ImportedModel.Signature signature = model.signature(signatureName);
+ for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) {
+ signature.input(input.getKey(), input.getValue());
+ }
+ for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
+ signature.output(output.getKey(), output.getValue());
+ }
+ }
+ }
+
+ private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ if (outputName.equals(operation.name())) {
+ return true;
+ }
+ }
+ }
+ return false;
+ }
+
+ /**
+ * Convert intermediate representation to Vespa ranking expressions.
+ */
+ private static void importExpressions(IntermediateGraph graph, ImportedModel model) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(graph.get(outputName), model);
+ if (!function.isPresent()) {
+ signature.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
+ }
+ }
+ }
+
+ private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(operation, model);
+ }
+ importExpressionInputs(operation, model);
+ importRankingExpression(operation, model);
+ importArgumentExpression(operation, model);
+ importFunctionExpression(operation, model);
+
+ return operation.function();
+ }
+
+ private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
+ operation.inputs().forEach(input -> importExpression(input, model));
+ }
+
+ private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
+ }
+
+ Value value = operation.getConstantValue().orElseThrow(() ->
+ new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
+ "is constant but does not have a value."));
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+
+ Tensor tensor = value.asTensor();
+ if (tensor.type().rank() == 0) {
+ model.smallConstant(name, tensor);
+ } else {
+ model.largeConstant(name, tensor);
+ }
+ return operation.function();
+ }
+
+ private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.function().isPresent()) {
+ String name = operation.name();
+ if ( ! model.expressions().containsKey(name)) {
+ TensorFunction function = operation.function().get();
+
+ if (isSignatureOutput(model, operation)) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
+
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced from the output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Imported function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+ }
+
+ private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.isInput()) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
+ model.input(operation.vespaName(), standardNamingConvention.type());
+ }
+ }
+
+ private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) {
+ if (operation.rankingExpressionFunction().isPresent()) {
+ TensorFunction function = operation.rankingExpressionFunction().get();
+ try {
+ model.function(operation.rankingExpressionFunctionName(),
+ new RankingExpression(operation.rankingExpressionFunctionName(),
+ function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Model function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+
+ /**
+ * Add any import warnings to the signature in the ImportedModel.
+ */
+ private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
+ for (ImportedModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ reportWarnings(graph.get(outputName), model);
+ }
+ }
+ }
+
+ private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
+ for (String warning : operation.warnings()) {
+ model.defaultSignature().importWarning(warning);
+ }
+ for (IntermediateOperation input : operation.inputs()) {
+ reportWarnings(input, model);
+ }
+ }
+
+ /**
+ * Log all model Variables (i.e file constants) imported as part of this with their ordered type.
+ * This allows users to learn the exact types (including dimension order after renaming) of the Variables
+ * such that these can be converted and fed to a parent document independently of the rest of the model
+ * for fast model weight updates.
+ */
+ private static void logVariableTypes(IntermediateGraph graph) {
+ for (IntermediateOperation operation : graph.operations()) {
+ if ( ! (operation instanceof Constant)) continue;
+ if ( ! operation.type().isPresent()) continue; // will not happen
+ log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() +
+ " of type " + operation.type().get());
+ }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
new file mode 100644
index 00000000000..c4acfeb3235
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -0,0 +1,235 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.rankingexpression.importer;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorTypeParser;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * A Vespa tensor type is ordered by the lexicographical ordering of dimension
+ * names. Imported tensors have an explicit ordering of their dimensions.
+ * During import, we need to track the Vespa dimension that matches the
+ * corresponding imported dimension as the ordering can change after
+ * dimension renaming. That is the purpose of this class.
+ *
+ * @author lesters
+ */
+public class OrderedTensorType {
+
+ private final TensorType type;
+ private final List<TensorType.Dimension> dimensions;
+
+ private final long[] innerSizesOriginal;
+ private final long[] innerSizesVespa;
+ private final int[] dimensionMap;
+
+ private OrderedTensorType(List<TensorType.Dimension> dimensions) {
+ this.dimensions = Collections.unmodifiableList(dimensions);
+ this.type = new TensorType.Builder(dimensions).build();
+ this.innerSizesOriginal = new long[dimensions.size()];
+ this.innerSizesVespa = new long[dimensions.size()];
+ this.dimensionMap = createDimensionMap();
+ }
+
+ public TensorType type() { return this.type; }
+
+ public int rank() { return dimensions.size(); }
+
+ public List<TensorType.Dimension> dimensions() {
+ return dimensions;
+ }
+
+ public List<String> dimensionNames() {
+ return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
+ }
+
+ private int[] createDimensionMap() {
+ int numDimensions = dimensions.size();
+ if (numDimensions == 0) {
+ return null;
+ }
+ innerSizesOriginal[numDimensions - 1] = 1;
+ innerSizesVespa[numDimensions - 1] = 1;
+ for (int i = numDimensions - 1; --i >= 0; ) {
+ innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1];
+ innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
+ }
+ int[] mapping = new int[numDimensions];
+ for (int i = 0; i < numDimensions; ++i) {
+ TensorType.Dimension dim1 = dimensions().get(i);
+ for (int j = 0; j < numDimensions; ++j) {
+ TensorType.Dimension dim2 = type.dimensions().get(j);
+ if (dim1.equals(dim2)) {
+ mapping[i] = j;
+ break;
+ }
+ }
+ }
+ return mapping;
+ }
+
+ public int dimensionMap(int originalIndex) {
+ return dimensionMap[originalIndex];
+ }
+
+ /**
+ * When dimension ordering between Vespa and imported differs, i.e.
+ * after dimension renaming, use the dimension map to read in values
+ * so that they are correctly laid out in memory for Vespa.
+ * Used when importing tensors.
+ */
+ public int toDirectIndex(int index) {
+ if (dimensions.size() == 0) {
+ return 0;
+ }
+ if (dimensionMap == null) {
+ throw new IllegalArgumentException("Dimension map is not available");
+ }
+ int directIndex = 0;
+ long rest = index;
+ for (int i = 0; i < dimensions.size(); ++i) {
+ long address = rest / innerSizesOriginal[i];
+ directIndex += innerSizesVespa[dimensionMap[i]] * address;
+ rest %= innerSizesOriginal[i];
+ }
+ return directIndex;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof OrderedTensorType)) {
+ return false;
+ }
+ OrderedTensorType other = (OrderedTensorType) obj;
+ if (dimensions.size() != dimensions.size()) {
+ return false;
+ }
+ List<TensorType.Dimension> thisDimensions = this.dimensions();
+ List<TensorType.Dimension> otherDimensions = other.dimensions();
+ for (int i = 0; i < thisDimensions.size(); ++i) {
+ if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public OrderedTensorType rename(DimensionRenamer renamer) {
+ List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
+ for (TensorType.Dimension dimension : dimensions) {
+ String oldName = dimension.name();
+ Optional<String> newName = renamer.dimensionNameOf(oldName);
+ if (!newName.isPresent())
+ return this; // presumably, already renamed
+ TensorType.Dimension.Type dimensionType = dimension.type();
+ if (dimensionType == TensorType.Dimension.Type.indexedBound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
+ } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
+ } else if (dimensionType == TensorType.Dimension.Type.mapped) {
+ renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
+ }
+ }
+ return new OrderedTensorType(renamedDimensions);
+ }
+
+ public OrderedTensorType rename(String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < dimensions.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Optional<Long> dimSize = dimensions.get(i).size();
+ if (dimSize.isPresent() && dimSize.get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static OrderedTensorType standardType(OrderedTensorType type) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < type.dimensions().size(); ++ i) {
+ TensorType.Dimension dim = type.dimensions().get(i);
+ String dimensionName = "d" + i;
+ if (dim.size().isPresent() && dim.size().get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static Long tensorSize(TensorType type) {
+ Long size = 1L;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ size *= dimensionSize(dimension);
+ }
+ return size;
+ }
+
+ public static Long dimensionSize(TensorType.Dimension dim) {
+ return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
+ }
+
+ /**
+ * Returns a string representation of this: A standard tensor type string where dimensions
+ * are listed in the order of this rather than in the natural order of their names.
+ */
+ @Override
+ public String toString() {
+ return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
+ }
+
+ /**
+ * Creates an instance from the string representation of this: A standard tensor type string
+ * where dimensions are listed in the order of this rather than the natural order of their names.
+ */
+ public static OrderedTensorType fromSpec(String typeSpec) {
+ return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ }
+
+ public static OrderedTensorType fromDimensionList(List<Long> dims) {
+ return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (int i = 0; i < dims.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Long dimSize = dims.get(i);
+ if (dimSize >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static class Builder {
+
+ private final List<TensorType.Dimension> dimensions;
+
+ public Builder() {
+ this.dimensions = new ArrayList<>();
+ }
+
+ public Builder add(TensorType.Dimension vespaDimension) {
+ this.dimensions.add(vespaDimension);
+ return this;
+ }
+
+ public OrderedTensorType build() {
+ return new OrderedTensorType(dimensions);
+ }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index caf66baef66..dd2add973e4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -3,19 +3,19 @@
package ai.vespa.rankingexpression.importer.onnx;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Shape;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.operations.Argument;
+import ai.vespa.rankingexpression.importer.operations.ConcatV2;
+import ai.vespa.rankingexpression.importer.operations.Constant;
+import ai.vespa.rankingexpression.importer.operations.Identity;
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.operations.Join;
+import ai.vespa.rankingexpression.importer.operations.Map;
+import ai.vespa.rankingexpression.importer.operations.MatMul;
+import ai.vespa.rankingexpression.importer.operations.NoOp;
+import ai.vespa.rankingexpression.importer.operations.Reshape;
+import ai.vespa.rankingexpression.importer.operations.Shape;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
index cb591475d40..0a8a797a847 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
@@ -2,9 +2,9 @@
package ai.vespa.rankingexpression.importer.onnx;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter;
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
+import ai.vespa.rankingexpression.importer.ModelImporter;
import onnx.Onnx;
import java.io.File;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index a267411e8a9..f3d87d89c27 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -3,7 +3,7 @@
package ai.vespa.rankingexpression.importer.onnx;
import com.google.protobuf.ByteString;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import onnx.Onnx;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 3da7477fcef..f251a14213b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.onnx;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import onnx.Onnx;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
new file mode 100644
index 00000000000..d6ea00ca453
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
@@ -0,0 +1,57 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Collections;
+import java.util.List;
+
+public class Argument extends IntermediateOperation {
+
+ private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
+
+ public Argument(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
+ this.type = type.rename(vespaName() + "_");
+ standardNamingType = OrderedTensorType.standardType(type);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return type;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
+ if (!standardNamingType.equals(type)) {
+ List<String> renameFrom = standardNamingType.dimensionNames();
+ List<String> renameTo = type.dimensionNames();
+ output = new Rename(output, renameFrom, renameTo);
+ }
+ return output;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isInput() {
+ return true;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return false;
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
new file mode 100644
index 00000000000..a21fc5ff2f7
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -0,0 +1,108 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+
+public class ConcatV2 extends IntermediateOperation {
+
+ private String concatDimensionName;
+
+ public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
+ return null;
+ }
+
+ IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
+ if (!concatDimOp.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
+ "concat dimension must be a constant.");
+ }
+ Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
+ if (concatDimTensor.type().rank() != 0) {
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
+ "concat dimension must be a scalar.");
+ }
+
+ OrderedTensorType aType = inputs.get(0).type().get();
+
+ int concatDim = (int)concatDimTensor.asDouble();
+ long concatDimSize = aType.dimensions().get(concatDim).size().orElse(-1L);
+
+ for (int i = 1; i < inputs.size() - 1; ++i) {
+ OrderedTensorType bType = inputs.get(i).type().get();
+ if (bType.rank() != aType.rank()) {
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
+ "inputs must have save rank.");
+ }
+ for (int j = 0; j < aType.rank(); ++j) {
+ long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
+ long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
+ if (j == concatDim) {
+ concatDimSize += dimSizeB;
+ } else if (dimSizeA != dimSizeB) {
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
+ "input dimension " + j + " differs in input tensors.");
+ }
+ }
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ int dimensionIndex = 0;
+ for (TensorType.Dimension dimension : aType.dimensions()) {
+ if (dimensionIndex == concatDim) {
+ concatDimensionName = dimension.name();
+ typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize));
+ } else {
+ typeBuilder.add(dimension);
+ }
+ dimensionIndex++;
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
+ return null;
+ }
+ TensorFunction result = inputs.get(0).function().get();
+ for (int i = 1; i < inputs.size() - 1; ++i) {
+ TensorFunction b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName);
+ }
+ return result;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
+ return;
+ }
+ OrderedTensorType a = inputs.get(0).type().get();
+ for (int i = 1; i < inputs.size() - 1; ++i) {
+ OrderedTensorType b = inputs.get(i).type().get();
+ String bDim = b.dimensions().get(i).name();
+ String aDim = a.dimensions().get(i).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
new file mode 100644
index 00000000000..41d421b1f5a
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -0,0 +1,89 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Const extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+
+ public Const(String modelName,
+ String nodeName,
+ List<IntermediateOperation> inputs,
+ AttributeMap attributeMap,
+ OrderedTensorType type) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.type = type.rename(vespaName() + "_");
+ setConstantValue(value());
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return type;
+ }
+
+ @Override
+ public Optional<TensorFunction> function() {
+ if (function == null) {
+ function = lazyGetFunction();
+ }
+ return Optional.ofNullable(function);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ ExpressionNode expressionNode;
+ if (type.type().rank() == 0 && getConstantValue().isPresent()) {
+ expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue());
+ } else {
+ expressionNode = new ReferenceNode(Reference.simple("constant", vespaName()));
+ }
+ return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
+ }
+
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + super.vespaName();
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ setConstantValue(value());
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+ private Value value() {
+ Optional<Value> value = attributeMap.get("value", type);
+ if ( ! value.isPresent()) {
+ throw new IllegalArgumentException("Node '" + name + "' of type " +
+ "const has missing or non-recognized 'value' attribute");
+ }
+ return value.get();
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
new file mode 100644
index 00000000000..a1cc83296b0
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
@@ -0,0 +1,61 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Collections;
+import java.util.Optional;
+
+public class Constant extends IntermediateOperation {
+
+ private final String modelName;
+
+ public Constant(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
+ this.modelName = modelName;
+ this.type = type.rename(vespaName() + "_");
+ }
+
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + vespaName(name);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return type;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ /**
+ * Constant values are sent in via the constantValueFunction, as the
+ * dimension names and thus the data layout depends on the dimension
+ * renaming which happens after the conversion to intermediate graph.
+ */
+ @Override
+ public Optional<Value> getConstantValue() {
+ return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type));
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
new file mode 100644
index 00000000000..8ae6d81b8d4
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -0,0 +1,106 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+public class ExpandDims extends IntermediateOperation {
+
+ private List<String> expandDimensions;
+
+ public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+
+ IntermediateOperation axisOperation = inputs().get(1);
+ if (!axisOperation.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("ExpandDims in " + name + ": " +
+ "axis must be a constant.");
+ }
+ Tensor axis = axisOperation.getConstantValue().get().asTensor();
+ if (axis.type().rank() != 0) {
+ throw new IllegalArgumentException("ExpandDims in " + name + ": " +
+ "axis argument must be a scalar.");
+ }
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ int dimensionToInsert = (int)axis.asDouble();
+ if (dimensionToInsert < 0) {
+ dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ expandDimensions = new ArrayList<>();
+ int dimensionIndex = 0;
+ for (TensorType.Dimension dimension : inputType.dimensions()) {
+ if (dimensionIndex == dimensionToInsert) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ expandDimensions.add(name);
+ typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
+ }
+ typeBuilder.add(dimension);
+ dimensionIndex++;
+ }
+
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (String name : expandDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(expandDimensions.size());
+ for (String name : expandDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ expandDimensions = renamedDimensions;
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
new file mode 100644
index 00000000000..c2787aa14d4
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
@@ -0,0 +1,35 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+public class Identity extends IntermediateOperation {
+
+ public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+ return modelName + "_" + super.vespaName();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1))
+ return null;
+ return inputs.get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1))
+ return null;
+ return inputs.get(0).function().orElse(null);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
new file mode 100644
index 00000000000..60fba264635
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -0,0 +1,191 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * Wraps an imported operation node and produces the respective Vespa tensor
+ * operation. During import, a graph of these operations are constructed. Then,
+ * the types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper Vespa
+ * expressions can be extracted.
+ *
+ * @author lesters
+ */
+public abstract class IntermediateOperation {
+
+ public final static String FUNCTION_PREFIX = "imported_ml_function_";
+
+ protected final String name;
+ protected final String modelName;
+ protected final List<IntermediateOperation> inputs;
+ protected final List<IntermediateOperation> outputs = new ArrayList<>();
+
+ protected OrderedTensorType type;
+ protected TensorFunction function;
+ protected TensorFunction rankingExpressionFunction = null;
+
+ private final List<String> importWarnings = new ArrayList<>();
+ private Value constantValue = null;
+ private List<IntermediateOperation> controlInputs = Collections.emptyList();
+
+ protected Function<OrderedTensorType, Value> constantValueFunction = null;
+
+ IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
+ this.name = name;
+ this.modelName = modelName;
+ this.inputs = Collections.unmodifiableList(inputs);
+ this.inputs.forEach(i -> i.outputs.add(this));
+ }
+
+ protected abstract OrderedTensorType lazyGetType();
+ protected abstract TensorFunction lazyGetFunction();
+
+ /** Returns the Vespa tensor type of this operation if it exists */
+ public Optional<OrderedTensorType> type() {
+ if (type == null) {
+ type = lazyGetType();
+ }
+ return Optional.ofNullable(type);
+ }
+
+ /** Returns the Vespa tensor function implementing all operations from this node with inputs */
+ public Optional<TensorFunction> function() {
+ if (function == null) {
+ if (isConstant()) {
+ ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
+ function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
+ } else if (outputs.size() > 1) {
+ rankingExpressionFunction = lazyGetFunction();
+ function = new VariableTensor(rankingExpressionFunctionName(), type.type());
+ } else {
+ function = lazyGetFunction();
+ }
+ }
+ return Optional.ofNullable(function);
+ }
+
+ /** Returns original name of this operation node */
+ public String name() { return name; }
+
+ /** Return unmodifiable list of inputs */
+ public List<IntermediateOperation> inputs() { return inputs; }
+
+ /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a function. */
+ public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); }
+
+ /** Returns a function that should be added as a ranking expression function */
+ public Optional<TensorFunction> rankingExpressionFunction() {
+ return Optional.ofNullable(rankingExpressionFunction);
+ }
+
+ /** Add dimension name constraints for this operation */
+ public void addDimensionNameConstraints(DimensionRenamer renamer) { }
+
+ /** Performs dimension rename for this operation */
+ public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
+
+ /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
+ public boolean isInput() { return false; }
+
+ /** Return true if this node is constant */
+ public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); }
+
+ /** Sets the constant value */
+ public void setConstantValue(Value value) { constantValue = value; }
+
+ /** Gets the constant value if it exists */
+ public Optional<Value> getConstantValue() {
+ if (constantValue != null) {
+ return Optional.of(constantValue);
+ }
+ if (constantValueFunction != null) {
+ return Optional.of(constantValueFunction.apply(type));
+ }
+ return Optional.empty();
+ }
+
+ /** Set the constant value function */
+ public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
+
+ /** Sets the external control inputs */
+ public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
+
+ /** Retrieve the control inputs for this operation */
+ public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
+
+ /** Retrieve the valid Vespa name of this node */
+ public String vespaName() { return vespaName(name); }
+ public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
+
+ /** Retrieve the valid Vespa name of this node if it is a ranking expression function */
+ public String rankingExpressionFunctionName() {
+ return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null;
+ }
+
+ /** Retrieve the list of warnings produced during its lifetime */
+ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
+
+ /** Set an input warning */
+ public void warning(String warning) { importWarnings.add(warning); }
+
+ boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) {
+ if (inputs.size() != expected) {
+ throw new IllegalArgumentException("Expected " + expected + " inputs " +
+ "for '" + name + "', got " + inputs.size());
+ }
+ return inputs.stream().map(func).allMatch(Optional::isPresent);
+ }
+
+ boolean allInputTypesPresent(int expected) {
+ return verifyInputs(expected, IntermediateOperation::type);
+ }
+
+ boolean allInputFunctionsPresent(int expected) {
+ return verifyInputs(expected, IntermediateOperation::function);
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ public static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output index part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ public static int indexPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
+ /**
+ * An interface mapping operation attributes to Vespa Values.
+ * Adapter for differences in different model types.
+ */
+ public interface AttributeMap {
+ Optional<Value> get(String key);
+ Optional<Value> get(String key, OrderedTensorType type);
+ Optional<List<Value>> getList(String key);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
new file mode 100644
index 00000000000..fed95e13bb7
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -0,0 +1,120 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+
+public class Join extends IntermediateOperation {
+
+ private final DoubleBinaryOperator operator;
+
+ public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
+ super(modelName, nodeName, inputs);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < a.rank(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ long size = aDim.size().orElse(-1L);
+
+ if (i - sizeDifference >= 0) {
+ TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
+ size = Math.max(size, bDim.size().orElse(-1L));
+ }
+
+ if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name(), size));
+ } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name()));
+ } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
+ builder.add(TensorType.Dimension.mapped(aDim.name()));
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+
+ IntermediateOperation a = largestInput();
+ IntermediateOperation b = smallestInput();
+
+ List<String> aDimensionsToReduce = new ArrayList<>();
+ List<String> bDimensionsToReduce = new ArrayList<>();
+ int sizeDifference = a.type().get().rank() - b.type().get().rank();
+ for (int i = 0; i < b.type().get().rank(); ++i) {
+ TensorType.Dimension bDim = b.type().get().dimensions().get(i);
+ TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
+ long bSize = bDim.size().orElse(-1L);
+ long aSize = aDim.size().orElse(-1L);
+ if (bSize == 1L && aSize != 1L) {
+ bDimensionsToReduce.add(bDim.name());
+ }
+ if (aSize == 1L && bSize != 1L) {
+ aDimensionsToReduce.add(bDim.name());
+ }
+ }
+
+ TensorFunction aReducedFunction = a.function().get();
+ if (aDimensionsToReduce.size() > 0) {
+ aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
+ }
+ TensorFunction bReducedFunction = b.function().get();
+ if (bDimensionsToReduce.size() > 0) {
+ bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
+ }
+
+ return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < b.rank(); ++i) {
+ String bDim = b.dimensions().get(i).name();
+ String aDim = a.dimensions().get(i + sizeDifference).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ }
+ }
+
+ private IntermediateOperation largestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+ private IntermediateOperation smallestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
new file mode 100644
index 00000000000..e0842d820f9
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
@@ -0,0 +1,37 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.DoubleUnaryOperator;
+
+public class Map extends IntermediateOperation {
+
+ private final DoubleUnaryOperator operator;
+
+ public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) {
+ super(modelName, nodeName, inputs);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ return inputs.get(0).type().get();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ Optional<TensorFunction> input = inputs.get(0).function();
+ return new com.yahoo.tensor.functions.Map(input.get(), operator);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
new file mode 100644
index 00000000000..1dbfd6e40dc
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -0,0 +1,72 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+
+public class MatMul extends IntermediateOperation {
+
+ public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
+ typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+ if (aType.type().rank() < 2 || bType.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (aType.type().rank() != bType.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (!aFunction.isPresent() || !bFunction.isPresent()) {
+ return null;
+ }
+ return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+
+ String aDim0 = aDimensions.get(0).name();
+ String aDim1 = aDimensions.get(1).name();
+ String bDim0 = bDimensions.get(0).name();
+ String bDim1 = bDimensions.get(1).name();
+
+ // The second dimension of a should have the same name as the first dimension of b
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
+
+ // The first dimension of a should have a different name than the second dimension of b
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
+
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
+ }
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
new file mode 100644
index 00000000000..4be220db9d5
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
@@ -0,0 +1,113 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+public class Mean extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private List<String> reduceDimensions;
+
+ public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ IntermediateOperation reductionIndices = inputs.get(1);
+ if (!reductionIndices.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Mean in " + name + ": " +
+ "reduction indices must be a constant.");
+ }
+ Tensor indices = reductionIndices.getConstantValue().get().asTensor();
+ reduceDimensions = new ArrayList<>();
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) {
+ Tensor.Cell cell = cellIterator.next();
+ int dimensionIndex = cell.getValue().intValue();
+ if (dimensionIndex < 0) {
+ dimensionIndex = inputType.dimensions().size() - dimensionIndex;
+ }
+ reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
+ }
+ return reducedType(inputType, shouldKeepDimensions());
+ }
+
+ // optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
+ if (shouldKeepDimensions()) {
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (String name : reduceDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ }
+ return output;
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size());
+ for (String name : reduceDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ reduceDimensions = renamedDimensions;
+ }
+
+ private boolean shouldKeepDimensions() {
+ Optional<Value> keepDims = attributeMap.get("keep_dims");
+ return keepDims.isPresent() && keepDims.get().asBoolean();
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if (!reduceDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ } else if (keepDimensions) {
+ builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
new file mode 100644
index 00000000000..ce0c58971d0
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java
@@ -0,0 +1,35 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+public class Merge extends IntermediateOperation {
+
+ public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ for (IntermediateOperation operation : inputs) {
+ if (operation.type().isPresent()) {
+ return operation.type().get();
+ }
+ }
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ for (IntermediateOperation operation : inputs) {
+ if (operation.function().isPresent()) {
+ return operation.function().get();
+ }
+ }
+ return null;
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
new file mode 100644
index 00000000000..4c5ce33b1b5
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java
@@ -0,0 +1,26 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends IntermediateOperation {
+
+ public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
new file mode 100644
index 00000000000..e5e5c29f8f1
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java
@@ -0,0 +1,48 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+
+public class PlaceholderWithDefault extends IntermediateOperation {
+
+ public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ return inputs().get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ // This should be a call to the function we add below, but for now
+ // we treat this as as identity function and just pass the constant.
+ return inputs.get(0).function().orElse(null);
+ }
+
+ @Override
+ public Optional<TensorFunction> rankingExpressionFunction() {
+ // For now, it is much more efficient to assume we always will return
+ // the default value, as we can prune away large parts of the expression
+ // tree by having it calculated as a constant. If a case arises where
+ // it is important to support this, implement this.
+ return Optional.empty();
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true; // not true if we add to function
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
new file mode 100644
index 00000000000..18f3cc1cc39
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -0,0 +1,131 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class Reshape extends IntermediateOperation {
+
+ public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ IntermediateOperation newShape = inputs.get(1);
+ if (!newShape.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Reshape in " + name + ": " +
+ "shape input must be a constant.");
+ }
+ Tensor shape = newShape.getConstantValue().get().asTensor();
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
+ int dimensionIndex = 0;
+ for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
+ Tensor.Cell cell = cellIterator.next();
+ int size = cell.getValue().intValue();
+ if (size < 0) {
+ size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
+ OrderedTensorType.tensorSize(inputType.type()).intValue();
+ }
+ outputTypeBuilder.add(TensorType.Dimension.indexed(
+ String.format("%s_%d", vespaName(), dimensionIndex), size));
+ dimensionIndex++;
+ }
+ return outputTypeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ return reshape(inputFunction, inputType.type(), type.type());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
+ if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
+ throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
+ }
+
+ // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
+ // then use the dimension order of the new shape to roll back into a tensor.
+ // Here we create a transformation tensor that is multiplied with the from tensor to map into
+ // the new shape. We have to introduce temporary dimension names and rename back if dimension names
+ // in the new and old tensor type overlap.
+
+ ExpressionNode unrollFrom = unrollTensorExpression(inputType);
+ ExpressionNode unrollTo = unrollTensorExpression(outputType);
+ ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
+
+ TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
+ Generate transformTensor = new Generate(transformationType,
+ new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
+
+ TensorFunction outputFunction = new Reduce(
+ new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+
+ return outputFunction;
+ }
+
+ private static ExpressionNode unrollTensorExpression(TensorType type) {
+ if (type.rank() == 0) {
+ return new ConstantNode(DoubleValue.zero);
+ }
+ List<ExpressionNode> children = new ArrayList<>();
+ List<ArithmeticOperator> operators = new ArrayList<>();
+ int size = 1;
+ for (int i = type.dimensions().size() - 1; i >= 0; --i) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ children.add(0, new ReferenceNode(dimension.name()));
+ if (size > 1) {
+ operators.add(0, ArithmeticOperator.MULTIPLY);
+ children.add(0, new ConstantNode(new DoubleValue(size)));
+ }
+ size *= OrderedTensorType.dimensionSize(dimension);
+ if (i > 0) {
+ operators.add(0, ArithmeticOperator.PLUS);
+ }
+ }
+ return new ArithmeticNode(children, operators);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
new file mode 100644
index 00000000000..dc690329a8d
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -0,0 +1,88 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+
+import static ai.vespa.rankingexpression.importer.OrderedTensorType.dimensionSize;
+import static ai.vespa.rankingexpression.importer.OrderedTensorType.tensorSize;
+
+public class Select extends IntermediateOperation {
+
+ public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(3)) {
+ return null;
+ }
+ OrderedTensorType a = inputs.get(1).type().get();
+ OrderedTensorType b = inputs.get(2).type().get();
+ if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) {
+ throw new IllegalArgumentException("'Select': input tensors must have the same shape");
+ }
+ return a;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(3)) {
+ return null;
+ }
+ IntermediateOperation conditionOperation = inputs().get(0);
+ TensorFunction a = inputs().get(1).function().get();
+ TensorFunction b = inputs().get(2).function().get();
+
+ // Shortcut: if we know during import which tensor to select, do that directly here.
+ if (conditionOperation.getConstantValue().isPresent()) {
+ Tensor condition = conditionOperation.getConstantValue().get().asTensor();
+ if (condition.type().rank() == 0) {
+ return ((int) condition.asDouble() == 0) ? b : a;
+ }
+ if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
+ return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
+ }
+ }
+
+ // The task is to select cells from 'x' or 'y' based on 'condition'.
+ // If 'condition' is 0 (false), select from 'y', if 1 (true) select
+ // from 'x'. We do this by individually joining 'x' and 'y' with
+ // 'condition', and then joining the resulting two tensors.
+
+ TensorFunction conditionFunction = conditionOperation.function().get();
+ TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
+ TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {
+ @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
+ @Override public String toString() { return "f(a,b)(a * (1-b))"; }
+ });
+ return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(3)) {
+ return;
+ }
+ List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions();
+
+ String aDim0 = aDimensions.get(0).name();
+ String aDim1 = aDimensions.get(1).name();
+ String bDim0 = bDimensions.get(0).name();
+ String bDim1 = bDimensions.get(1).name();
+
+ // These tensors should have the same dimension names
+ renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this);
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
new file mode 100644
index 00000000000..361729a8c14
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
@@ -0,0 +1,54 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+public class Shape extends IntermediateOperation {
+
+ public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ createConstantValue();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ return new OrderedTensorType.Builder()
+ .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
+ .build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+ private void createConstantValue() {
+ if (!allInputTypesPresent(1)) {
+ return;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type());
+ List<TensorType.Dimension> inputDimensions = inputType.dimensions();
+ for (int i = 0; i < inputDimensions.size(); i++) {
+ builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L));
+ }
+ this.setConstantValue(new TensorValue(builder.build()));
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
new file mode 100644
index 00000000000..2eeefcbe8a2
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -0,0 +1,85 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+public class Squeeze extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private List<String> squeezeDimensions;
+
+ public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ squeezeDimensions = new ArrayList<>();
+
+ Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims");
+ if ( ! squeezeDimsAttr.isPresent()) {
+ squeezeDimensions = inputType.type().dimensions().stream().
+ filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
+ map(TensorType.Dimension::name).
+ collect(Collectors.toList());
+ } else {
+ squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue).
+ map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).
+ map(i -> inputType.type().dimensions().get(i)).
+ filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
+ map(TensorType.Dimension::name).
+ collect(Collectors.toList());
+ }
+
+ return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size());
+ for (String name : squeezeDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ squeezeDimensions = renamedDimensions;
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if ( ! squeezeDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
new file mode 100644
index 00000000000..131af8de065
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java
@@ -0,0 +1,50 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Switch extends IntermediateOperation {
+
+ private final int port;
+
+ public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) {
+ super(modelName, nodeName, inputs);
+ this.port = port;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ Optional<OrderedTensorType> predicate = inputs.get(1).type();
+ if (predicate.get().type().rank() != 0) {
+ throw new IllegalArgumentException("Switch in " + name + ": " +
+ "predicate must be a scalar");
+ }
+ return inputs.get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ IntermediateOperation predicateOperation = inputs().get(1);
+ if (!predicateOperation.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Switch in " + name + ": " +
+ "predicate must be a constant");
+ }
+ if (port < 0 || port > 1) {
+ throw new IllegalArgumentException("Switch in " + name + ": " +
+ "choice should be boolean");
+ }
+
+ double predicate = predicateOperation.getConstantValue().get().asDouble();
+ return predicate == port ? inputs().get(0).function().get() : null;
+ }
+
+}
+
+
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
new file mode 100644
index 00000000000..4473f306dcd
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java
@@ -0,0 +1,11 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+/**
+ * Model integration.
+ *
+ * CAUTION!: Config models depends on this API. It cannot be changed without ensuring compatibility with
+ * old config models.
+ */
+@ExportPackage
+package ai.vespa.rankingexpression.importer;
+
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
index 978fc3ecf60..ecb67f93d69 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
@@ -5,8 +5,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index e264b0daf6e..cb838cd67b1 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -3,27 +3,27 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Argument;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.ConcatV2;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Const;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.ExpandDims;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Identity;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Join;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Map;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.MatMul;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Mean;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Merge;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.NoOp;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.PlaceholderWithDefault;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Reshape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Select;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Shape;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Squeeze;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Switch;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.operations.Argument;
+import ai.vespa.rankingexpression.importer.operations.ConcatV2;
+import ai.vespa.rankingexpression.importer.operations.Const;
+import ai.vespa.rankingexpression.importer.operations.Constant;
+import ai.vespa.rankingexpression.importer.operations.ExpandDims;
+import ai.vespa.rankingexpression.importer.operations.Identity;
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import ai.vespa.rankingexpression.importer.operations.Join;
+import ai.vespa.rankingexpression.importer.operations.Map;
+import ai.vespa.rankingexpression.importer.operations.MatMul;
+import ai.vespa.rankingexpression.importer.operations.Mean;
+import ai.vespa.rankingexpression.importer.operations.Merge;
+import ai.vespa.rankingexpression.importer.operations.NoOp;
+import ai.vespa.rankingexpression.importer.operations.PlaceholderWithDefault;
+import ai.vespa.rankingexpression.importer.operations.Reshape;
+import ai.vespa.rankingexpression.importer.operations.Select;
+import ai.vespa.rankingexpression.importer.operations.Shape;
+import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import ai.vespa.rankingexpression.importer.operations.Switch;
import com.yahoo.tensor.functions.ScalarFunctions;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index 80c23f2af69..6c92ffa6055 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -1,7 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index a8b453b7a1c..2a406f92756 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -1,9 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.IntermediateGraph;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter;
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
+import ai.vespa.rankingexpression.importer.ModelImporter;
import org.tensorflow.SavedModelBundle;
import java.io.File;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index d2430d34711..63a605ce97a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
index c6bc889053a..31cb60b5509 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
@@ -1,7 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.tensorflow;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.serialization.JsonFormat;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
index 35895995d3b..ac462cc39eb 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java
@@ -2,8 +2,8 @@
package ai.vespa.rankingexpression.importer.xgboost;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter;
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.ModelImporter;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import java.io.File;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
new file mode 100644
index 00000000000..cf8dd6e8e71
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
@@ -0,0 +1,48 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertTrue;
+
+public class DimensionRenamerTest {
+
+ @Test
+ public void testMnistRenaming() {
+ DimensionRenamer renamer = new DimensionRenamer();
+
+ renamer.addDimension("first_dimension_of_x");
+ renamer.addDimension("second_dimension_of_x");
+ renamer.addDimension("first_dimension_of_w");
+ renamer.addDimension("second_dimension_of_w");
+ renamer.addDimension("first_dimension_of_b");
+
+ // which dimension to join on matmul
+ renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
+
+ // other dimensions in matmul can't be equal
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
+
+ // for efficiency, put dimension to join on innermost
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
+ renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
+
+ // bias
+ renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
+
+ renamer.solve();
+
+ String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
+ String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
+ String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
+ String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
+ String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();
+
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
+ assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
+ assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
+ assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
new file mode 100644
index 00000000000..afe699d6e05
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
@@ -0,0 +1,21 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class OrderedTensorTypeTestCase {
+
+ @Test
+ public void testToFromSpec() {
+ String spec = "tensor(b[],c{},a[3])";
+ OrderedTensorType type = OrderedTensorType.fromSpec(spec);
+ assertEquals(spec, type.toString());
+ assertEquals("tensor(a[3],b[],c{})", type.type().toString());
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index d86e7d6dd8e..d3996da9b58 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -6,7 +6,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
index d112a3fa9f2..1a072f54c89 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
index fa89e060006..37104ab43db 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java
@@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
index b3559a0a5f6..5e20be051ea 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.tensor.TensorType;
import org.junit.Assert;
import org.junit.Test;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
index 7e717c204f8..28b91b3797a 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import org.junit.Assert;
import org.junit.Test;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
index f98b37b7e55..6215997d8f9 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java
@@ -2,7 +2,7 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Assert;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
index faa2c7acc18..c3b82cccb46 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java
@@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
index 30b50c025d0..965d5eb8577 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java
@@ -1,7 +1,7 @@
package ai.vespa.rankingexpression.importer.xgboost;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import ai.vespa.rankingexpression.importer.ImportedModel;
import org.junit.Test;
import static org.junit.Assert.assertEquals;