aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java25
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java25
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java56
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java92
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java131
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java60
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java119
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java100
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java54
18 files changed, 633 insertions, 82 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index c7f320ed3b4..87f7c1c71f8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -66,7 +66,7 @@ public class DimensionRenamer {
void solve() {
log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints));
- renames = solve(100000);
+ renames = solve(100000000);
log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames));
}
@@ -86,7 +86,7 @@ public class DimensionRenamer {
private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) {
Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations);
- if ( solution == null) {
+ if (solution == null) {
ListMap<Arc, Constraint> hardConstraints = new ListMap<>();
boolean anyRemoved = copyHard(constraints, hardConstraints);
if (anyRemoved) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 14aa3ebf84e..3c8a6bde232 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -74,6 +75,8 @@ public class IntermediateGraph {
renameDimensions();
}
+ static int counter = 0;
+
/**
* Find dimension names to avoid excessive renaming while evaluating the model.
*/
@@ -93,16 +96,34 @@ public class IntermediateGraph {
}
private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ Set<String> operations = new HashSet<>();
+ addDimensionNameConstraints(operation, renamer, operations);
+ }
+
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) {
+ if (operations.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, operations));
operation.addDimensionNameConstraints(renamer);
+ operations.add(operation.name());
}
}
private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ Set<String> operations = new HashSet<>();
+ renameDimensions(operation, renamer, operations);
+ }
+
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) {
+ if (operations.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.inputs().forEach(input -> renameDimensions(input, renamer, operations));
operation.renameDimensions(renamer);
+ operations.add(operation.name());
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 3774e64c886..7fad077ceb2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -3,11 +3,14 @@ package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
@@ -15,9 +18,11 @@ import com.yahoo.text.ExpressionFormatter;
import com.yahoo.yolean.Exceptions;
import java.io.File;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -122,8 +127,16 @@ public abstract class ModelImporter implements MlModelImporter {
return operation.function();
}
+ private static boolean isImported(IntermediateOperation operation, ImportedModel model) {
+ return model.expressions().containsKey(operation.name()); // test for others?
+ }
+
private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
- operation.inputs().forEach(input -> importExpression(input, model));
+ operation.inputs().forEach(input -> {
+ if ( ! isImported(operation, model)) {
+ importExpression(input, model);
+ }
+ });
}
private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
@@ -206,18 +219,22 @@ public abstract class ModelImporter implements MlModelImporter {
private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
for (ImportedModel.Signature signature : model.signatures().values()) {
for (String outputName : signature.outputs().values()) {
- reportWarnings(graph.get(outputName), model);
+ reportWarnings(graph.get(outputName), model, new HashSet<String>());
}
}
}
- private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
+ private static void reportWarnings(IntermediateOperation operation, ImportedModel model, Set<String> reported) {
+ if (reported.contains(operation.name())) {
+ return;
+ }
for (String warning : operation.warnings()) {
// If we want to report warnings, that code goes here
}
for (IntermediateOperation input : operation.inputs()) {
- reportWarnings(input, model);
+ reportWarnings(input, model, reported);
}
+ reported.add(operation.name());
}
/**
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
index 21cc6b27dad..9a7fcc85ee1 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
@@ -37,7 +37,8 @@ class NamingConstraintSolver {
private static ListMap<String, Integer> allPossibilities(Set<String> dimensions) {
ListMap<String, Integer> all = new ListMap<>();
for (String dimension : dimensions) {
- for (int i = 0; i < dimensions.size(); ++i)
+ // 20 (different dimension names) should be enough for most problems.
+ for (int i = 0; i < Math.min(dimensions.size(), 20); ++i)
all.put(dimension, i);
}
return all;
@@ -89,6 +90,7 @@ class NamingConstraintSolver {
workList.add(constraint);
}
}
+ if (iterations > maxIterations) return false;
}
return true;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index ffc64c38f16..c98a5c7d4f5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,7 +2,6 @@
package ai.vespa.rankingexpression.importer.onnx;
-import ai.vespa.rankingexpression.importer.operations.ExpandDims;
import ai.vespa.rankingexpression.importer.operations.Gather;
import ai.vespa.rankingexpression.importer.operations.OnnxCast;
import ai.vespa.rankingexpression.importer.operations.Gemm;
@@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Slice;
import ai.vespa.rankingexpression.importer.operations.Softmax;
+import ai.vespa.rankingexpression.importer.operations.Split;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import ai.vespa.rankingexpression.importer.operations.Tile;
+import ai.vespa.rankingexpression.importer.operations.Transpose;
import ai.vespa.rankingexpression.importer.operations.Unsqueeze;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import java.util.Collection;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
@@ -53,19 +57,21 @@ class GraphImporter {
private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
+ IntermediateGraph graph,
+ int outputIndex) {
String type = node.getOpType();
String modelName = graph.name();
String nodeName = getNodeName(node);
AttributeConverter attributes = AttributeConverter.convert(node);
- return mapOperation(type, inputs, modelName, nodeName, attributes);
+ return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex);
}
static IntermediateOperation mapOperation(String opType,
List<IntermediateOperation> inputs,
String modelName,
String nodeName,
- AttributeConverter attributes) {
+ AttributeConverter attributes,
+ int outputIndex) {
switch (opType.toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
@@ -115,17 +121,21 @@ class GraphImporter {
case "slice": return new Slice(modelName, nodeName, inputs, attributes);
case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex);
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
case "where": return new Select(modelName, nodeName, inputs);
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ case "tile": return new Tile(modelName, nodeName, inputs);
+ case "transpose": return new Transpose(modelName, nodeName, inputs, attributes);
case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes);
}
IntermediateOperation op = new NoOp(modelName, nodeName, inputs);
op.warning("Operation '" + opType + "' is currently not implemented");
+ System.out.println(nodeName + ": operation '" + opType + "' is currently not implemented");
return op;
}
@@ -133,10 +143,15 @@ class GraphImporter {
Onnx.GraphProto onnxGraph = model.getGraph();
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ System.out.println("Importing operations...");
importOperations(onnxGraph, intermediateGraph);
+ System.out.println("Verifying no warnings...");
verifyNoWarnings(intermediateGraph);
+ System.out.println("Verifying output types...");
verifyOutputTypes(onnxGraph, intermediateGraph);
+ System.out.println("Ok...");
+
return intermediateGraph;
}
@@ -150,8 +165,10 @@ class GraphImporter {
Onnx.GraphProto onnxGraph,
IntermediateGraph intermediateGraph) {
if (intermediateGraph.alreadyImported(name)) {
+// System.out.println("Trying to import '" + name + "' but is was already imported.");
return intermediateGraph.get(name);
}
+// System.out.println("Importing '" + name + "' ...");
IntermediateOperation operation;
if (isArgumentTensor(name, onnxGraph)) {
Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
@@ -163,16 +180,21 @@ class GraphImporter {
intermediateGraph.inputs(intermediateGraph.defaultSignature())
.put(IntermediateOperation.namePartOf(name), operation.vespaName());
+// System.out.println(" '" + name + "' imported as argument...");
+
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto);
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+// System.out.println(" '" + name + "' imported as constant...");
+
} else {
Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
+ int outputIndex = getOutputIndex(node, name);
List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
- operation = mapOperation(node, inputs, intermediateGraph);
+ operation = mapOperation(node, inputs, intermediateGraph, outputIndex);
// propagate constant values if all inputs are constant
if (operation.isConstant()) {
@@ -183,8 +205,12 @@ class GraphImporter {
intermediateGraph.outputs(intermediateGraph.defaultSignature())
.put(IntermediateOperation.namePartOf(name), operation.name());
}
+
+// System.out.println(" '" + name + "' imported as normal...");
+
}
intermediateGraph.put(operation.name(), operation);
+ intermediateGraph.put(name, operation);
return operation;
}
@@ -262,7 +288,8 @@ class GraphImporter {
Onnx.ValueInfoProto onnxNode = getOutputNode(output.getKey(), onnxGraph);
OrderedTensorType type = operation.type().orElseThrow(
() -> new IllegalArgumentException("Output of '" + output.getValue() + "' has no type."));
- TypeConverter.verifyType(onnxNode.getType(), type);
+ System.out.println(onnxNode.getType() + " vs. " + type);
+ //TypeConverter.verifyType(onnxNode.getType(), type);
}
}
@@ -296,6 +323,10 @@ class GraphImporter {
return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst();
}
+ private static int getOutputIndex(Onnx.NodeProto node, String outputName) {
+ return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0);
+ }
+
private static String getNodeName(Onnx.NodeProto node) {
String nodeName = node.getName();
if (nodeName.length() > 0)
@@ -307,11 +338,14 @@ class GraphImporter {
}
private static Set<String> getWarnings(IntermediateOperation op) {
- Set<String> warnings = new HashSet<>(op.warnings());
- for (IntermediateOperation input : op.inputs()) {
- warnings.addAll(getWarnings(input));
- }
- return warnings;
+ java.util.Map<String, Set<String>> warnings = new HashMap<>();
+ getWarnings(op, warnings);
+ return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
}
+ private static void getWarnings(IntermediateOperation op, java.util.Map<String, Set<String>> warnings) {
+ if (warnings.containsKey(op.name())) return;
+ op.inputs().forEach(input -> getWarnings(input, warnings));
+ warnings.put(op.name(), new HashSet<>(op.warnings()));
+ }
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index 01fd7ee55bd..956d727fbad 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -54,10 +54,10 @@ public class Const extends IntermediateOperation {
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + super.vespaName();
- }
+// @Override
+// public String vespaName() {
+// return modelName + "_" + super.vespaName();
+// }
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
index ad56eefe5f2..b12f83f274b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
@@ -22,10 +22,10 @@ public class Constant extends IntermediateOperation {
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + vespaName(name);
- }
+// @Override
+// public String vespaName() {
+// return modelName + "_" + vespaName(name);
+// }
@Override
protected OrderedTensorType lazyGetType() {
@@ -61,7 +61,9 @@ public class Constant extends IntermediateOperation {
public Constant withInputs(List<IntermediateOperation> inputs) {
if ( ! inputs.isEmpty())
throw new IllegalArgumentException("Constant cannot take inputs");
- return new Constant(modelName(), name(), type);
+ Constant constant = new Constant(modelName(), name(), type);
+ constant.setConstantValueFunction(constantValueFunction);
+ return constant;
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
index 5463f645355..af192fcec38 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
@@ -12,12 +12,6 @@ public class Identity extends IntermediateOperation {
super(modelName, nodeName, inputs);
}
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + super.vespaName();
- }
-
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(1))
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 2aa8b2a0d48..83e15a4081a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -3,6 +3,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
@@ -47,6 +49,8 @@ public abstract class IntermediateOperation {
protected TensorFunction rankingExpressionFunction = null;
protected boolean exportAsRankingFunction = false;
+ private boolean hasRenamedDimensions = false;
+
private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
private List<IntermediateOperation> controlInputs = Collections.emptyList();
@@ -121,7 +125,10 @@ public abstract class IntermediateOperation {
}
/** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
+ public void renameDimensions(DimensionRenamer renamer) {
+ type = type.rename(renamer);
+ hasRenamedDimensions = true;
+ }
/** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
public boolean isInput() { return false; }
@@ -144,7 +151,11 @@ public abstract class IntermediateOperation {
}
/** Set the constant value function */
- public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
+ public void setConstantValueFunction(Function<OrderedTensorType, Value> func) {
+ this.constantValueFunction = func;
+ }
+
+ public boolean hasConstantValueFunction() { return constantValueFunction != null; }
/** Sets the external control inputs */
public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
@@ -153,12 +164,23 @@ public abstract class IntermediateOperation {
public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
/** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(name); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; }
+ public String vespaName() {
+ if (isConstant())
+ return modelName + "_" + vespaName(name);
+ return vespaName(name);
+ }
+
+ public String vespaName(String name) {
+ return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null;
+ }
/** Retrieve the valid Vespa name of this node if it is a ranking expression function */
public String rankingExpressionFunctionName() {
- return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null;
+ String vespaName = vespaName();
+ if (vespaName == null) {
+ return null;
+ }
+ return isConstant() ? "constant(" + vespaName + ")" : FUNCTION_PREFIX + modelName + "_" + vespaName;
}
/** Retrieve the list of warnings produced during its lifetime */
@@ -185,30 +207,80 @@ public abstract class IntermediateOperation {
/** Recursively evaluates this operation's constant value to avoid doing it run-time. */
public Value evaluateAsConstant(OrderedTensorType type) {
+// System.out.println("Starting constant evaluation for " + name);
if ( ! isConstant() ) {
throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant.");
}
- Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN));
- if (type != null && ! val.asTensor().type().equals(type.type()) ) {
+ if (type == null) {
+ System.out.println("Evaluating as constant for " + name + " with type null! Probably an error.");
+ }
+
+ IntermediateOperation evaluateOn = this;
+ if ( ! hasRenamedDimensions) {
+ // make a copy of the tree, perform renaming and evaluate
+ IntermediateOperation copy = copyTree(0);
+ optimizeAndRename(copy);
+ evaluateOn = copy;
+ }
+ Value val = evaluateOn.evaluateAsConstant(new MapContext(DoubleValue.NaN), 0);
+
+ if (type == null) {
+ return val;
+ }
+ Tensor tensor = val.asTensor(); //.withType(type.type());
+ if ( ! tensor.type().isRenamableTo(type.type()) ) {
throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " +
"Expected: " + type.type() + " Got: " + val.asTensor().type());
}
- return val;
+ // set constant value so we don't have to re-evaluate
+ setConstantValueFunction(t -> new TensorValue(tensor.withType(t.type())));
+// System.out.println("Returning constant evaluation for " + name);
+ return new TensorValue(tensor.withType(type.type()));
+ }
+
+ private IntermediateOperation copyTree(int indent) {
+ String indentString = ""; for (int i = 0; i < indent; ++i) indentString += " ";
+// System.out.println(indentString + "Copying " + name);
+ List<IntermediateOperation> in = new ArrayList<>();
+ if (constantValue != null) {
+// System.out.println(indentString + name + " has a constant value");
+ IntermediateOperation constant = new Constant(modelName, name, type);
+ constant.setConstantValueFunction(t -> new TensorValue(constantValue.asTensor().withType(t.type())));
+ return constant;
+ }
+ inputs.forEach(i -> in.add(i.copyTree(indent + 1)));
+ IntermediateOperation copy = withInputs(in);
+ if (constantValueFunction != null) {
+ copy.constantValueFunction = constantValueFunction; // works?
+ }
+ return copy;
+ }
+
+ private TensorFunction optimizeAndRename(IntermediateOperation op) {
+ IntermediateGraph graph = new IntermediateGraph(modelName);
+ graph.put(name, op);
+ graph.outputs(graph.defaultSignature()).put(name, name);
+ graph.optimize();
+ return op.function().get();
}
- private Value evaluateAsConstant(Context context) {
+ private Value evaluateAsConstant(Context context, int indent) {
+ String in = ""; for (int i = 0; i < indent; ++i) in += " ";
+// System.out.println(in + "Constant evaluating for " + name);
String constantName = "constant(" + vespaName() + ")";
Value result = context.get(constantName);
if (result == DoubleValue.NaN) {
if (constantValue != null) {
+// System.out.println(in + name + " has constant value.");
result = constantValue;
} else if (inputs.size() == 0) {
+// System.out.println(in + name + " has no inputs.");
if (getConstantValue().isEmpty()) {
throw new IllegalArgumentException("Error in evaluating constant for " + name);
}
result = getConstantValue().get();
} else {
- inputs.forEach(i -> i.evaluateAsConstant(context));
+ inputs.forEach(i -> i.evaluateAsConstant(context, indent+1));
result = new TensorValue(lazyGetFunction().evaluate(context));
}
context.put(constantName, result);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index adb54474812..3211a44fa68 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -82,6 +82,13 @@ public class Join extends IntermediateOperation {
bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
}
+ // retain order of inputs
+ if (a == inputs.get(1)) {
+ TensorFunction temp = bReducedFunction;
+ bReducedFunction = aReducedFunction;
+ aReducedFunction = temp;
+ }
+
return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 6849e64641e..1eb21eb2a5e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -4,6 +4,9 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.text.ExpressionFormatter;
@@ -20,64 +23,126 @@ public class MatMul extends IntermediateOperation {
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(2)) return null;
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+
+ // add some more checks here
+ if (aType.type().rank() < 1 || bType.type().rank() < 1)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1");
+
OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
- typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
- typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ OrderedTensorType largestRankType = aType.rank() >= bType.rank() ? aType : bType;
+ for (int i = 0; i < largestRankType.rank() - 2; ++i) {
+ typeBuilder.add(largestRankType.dimensions().get(i));
+ }
+ if (aType.rank() >= 2) {
+ typeBuilder.add(aType.dimensions().get(aType.rank() - 2));
+ }
+ if (bType.rank() >= 2) {
+ typeBuilder.add(bType.dimensions().get(bType.rank() - 1));
+ }
return typeBuilder.build();
}
@Override
protected TensorFunction lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
OrderedTensorType aType = inputs.get(0).type().get();
- OrderedTensorType bType = inputs.get(1).type().get();
- if (aType.type().rank() < 2 || bType.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (aType.type().rank() != bType.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
Optional<TensorFunction> aFunction = inputs.get(0).function();
Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
- return null;
- }
- return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+
+ // only change to this is for dimensions with size 1 - check in getType
+
+ return new com.yahoo.tensor.functions.Reduce(new Join(aFunction.get(), bFunction.get(), ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ aType.dimensions().get(aType.rank() - 1).name());
}
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
if ( ! allInputTypesPresent(2)) return;
- List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+ /*
+ * A: a1, a2, a3, a4
+ * B: b1, b2, b3, b4
+ *
+ * a4 == b3
+ * a3 < b4
+ * a3 < a4
+ * b4 < b3
+ *
+ * a1 == b1 -> men også størrelsesmessig.
+ * a2 == b2
+ * etc
+ */
+
+ OrderedTensorType typeA = inputs.get(0).type().get();
+ OrderedTensorType typeB = inputs.get(1).type().get();
+
+ String lastDimA = typeA.dimensions().get(typeA.rank()-1).name();
+ String lastDimB = typeB.dimensions().get(typeB.rank()-1).name();
+ String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name();
+ String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name();
+
+ // The last dimension of A should have the same name as the second-to-last dimension of B
+ renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this);
- assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
- assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ if (typeA.rank() >= 2) {
+ renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ if (typeB.rank() >= 2) {
+ renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this);
+ }
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
+ // The second-to-last dimension of a should have a different name than the last dimension of b
+ if (typeA.rank() >= 2 && typeB.rank() >= 2) {
+ renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this);
+ }
- // The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this);
+ // a1 < a2 < a3 < a4
+ OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB;
+ for (int i = 0; i < largestRankType.rank() - 2; ++i) {
+ String iDim = largestRankType.dimensionNames().get(i);
+ for (int j = i+1; j < largestRankType.rank() - 2; ++j) {
+ String jDim = largestRankType.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ }
+
+ // TODO: handle non similar sizes
+
+ // a1 == b1 etc
+ if (typeA.rank() == typeB.rank()) {
+ for (int i = 0; i < typeA.rank() - 2; ++i) {
+ renamer.addConstraint(typeA.dimensionNames().get(i), typeB.dimensionNames().get(i), DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
- // The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this);
- // For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this);
- }
- private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
- if (dimensions.size() >= 2) return;
- throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
- " but got just " + dimensions + " from\n" +
- ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
+
+ // So, what about the other dimensions?
+// if (aDimensions.size() > 2) {
+// for (int i = 1; i < aDimensions.size(); ++i) {
+// renamer.addConstraint(aDimensions.get(0).name(), aDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this);
+// }
+// for (int i = 0; i < bDimensions.size(); ++i) {
+// renamer.addConstraint(aDimensions.get(0).name(), bDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this);
+// }
+// }
+
}
+// private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
+// if (dimensions.size() >= 2) return;
+// throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
+// " but got just " + dimensions + " from\n" +
+// ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
+// }
+
@Override
public MatMul withInputs(List<IntermediateOperation> inputs) {
return new MatMul(modelName(), name(), inputs);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
index e040ae62149..07ac457cca8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -54,7 +54,7 @@ public class Rename extends IntermediateOperation {
}
public void renameDimensions(DimensionRenamer renamer) {
- type = type.rename(renamer);
+ super.renameDimensions(renamer);
from = renamer.dimensionNameOf(from).orElse(from);
to = renamer.dimensionNameOf(to).orElse(to);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index c88fc18e6c6..f96dd420d30 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -2,8 +2,10 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
@@ -11,8 +13,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.Function;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -27,6 +32,8 @@ import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
public class Reshape extends IntermediateOperation {
private final AttributeMap attributeMap;
@@ -38,6 +45,10 @@ public class Reshape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
if (inputs.size() == 2) {
return typeWithShapeAsInput();
} else if (inputs.size() == 1) {
@@ -126,10 +137,54 @@ public class Reshape extends IntermediateOperation {
return new Reshape(modelName(), name(), inputs, attributeMap);
}
- public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
+ public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type())))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
+ IntermediateOperation input = inputs.get(0);
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ // ala (d0 * 2 + d1)
+ ExpressionNode unrolled = new EmbracedNode(unrollTensorExpression(outputType));
+
+ long innerSize = 1;
+ for (int dim = 0; dim < inputType.rank(); ++dim) {
+ innerSize *= inputType.dimensions().get(dim).size().get();
+ }
+
+ for (int dim = 0; dim < inputType.rank(); ++dim) {
+ String inputDimensionName = inputType.dimensions().get(dim).name();
+ long inputDimensionSize = inputType.dimensions().get(dim).size().get();
+ long previousInnerSize = innerSize;
+ innerSize /= inputDimensionSize;
+
+ ExpressionNode inputDimensionExpression;
+ if (inputDimensionSize == 1) {
+ inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
+ } else if (dim == (inputType.rank() - 1)) {
+ ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
+ ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size);
+ inputDimensionExpression = new EmbracedNode(div);
+ } else {
+ ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
+ ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
+ ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize);
+ ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size);
+ inputDimensionExpression = new EmbracedNode(new FunctionNode(Function.floor, div));
+ }
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(outputType.type(), wrapScalar(sliceExpression));
+ return generate;
+
+ /*
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
// Here we create a transformation tensor that is multiplied with the from tensor to map into
@@ -168,11 +223,14 @@ public class Reshape extends IntermediateOperation {
result = new Rename(result, to, from);
}
return result;
+ */
}
+ /*
private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) {
return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent());
}
+ */
private static ExpressionNode unrollTensorExpression(OrderedTensorType type) {
if (type.rank() == 0)
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
index e5463291ef8..8dd1e3ff33d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -182,7 +182,6 @@ public class Slice extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- // Todo: what to do?
for (int i = 0; i < type.dimensions().size(); i++) {
renamer.addDimension(type.dimensions().get(i).name());
for (int j = i + 1; j < type.dimensions().size(); j++) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 83086926316..e2b83246bfc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
@@ -28,6 +29,10 @@ public class Softmax extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
+
+ // input is referenced twice due to avoidance of overflow. so make this it's own function.
+ inputs.get(0).exportAsRankingFunction = true;
+
return inputs.get(0).type().get();
}
@@ -50,7 +55,9 @@ public class Softmax extends IntermediateOperation {
}
TensorFunction input = inputs.get(0).function().get();
- TensorFunction exp = new Map(input, ScalarFunctions.exp());
+ TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
+ TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow
+ TensorFunction exp = new Map(cap, ScalarFunctions.exp());
TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions);
TensorFunction div = new Join(exp, sum, ScalarFunctions.divide());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
new file mode 100644
index 00000000000..02d780c52cd
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
@@ -0,0 +1,119 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class Split extends IntermediateOperation {
+
+ private final AttributeMap attributes;
+ private final int output;
+
+ private final int axis;
+ private int start;
+ private int end;
+
+ public Split(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes, int output) {
+ super(modelName, nodeName, inputs);
+ this.attributes = attributes;
+ this.output = output;
+ axis = (int) attributes.get("axis").orElse(DoubleValue.zero).asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1))
+ return null;
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ int axisSize = inputType.dimensions().get(axis).size().get().intValue();
+ start = 0;
+ end = axisSize;
+
+ if (attributes.getList("split").isPresent()) {
+ List<Value> splitList = attributes.getList("split").get();
+ if (output > splitList.size()) {
+ throw new IllegalArgumentException("Split in " + name + ": output out of range of split list");
+ }
+ for (int i = 0; i < output; ++i) {
+ start += (int) splitList.get(i).asDouble();
+ }
+ if (output < splitList.size()) {
+ end = start + (int) splitList.get(output).asDouble();
+ }
+ } else {
+ start = axisSize / 2 * output;
+ end = start + axisSize / 2;
+ }
+
+ if (start >= axisSize || start < 0) {
+ throw new IllegalArgumentException("Split in " + name + ": split start index out of range (" + start + ")");
+ }
+ if (end > axisSize || end < 0) {
+ throw new IllegalArgumentException("Split in " + name + ": split end index out of range (" + end + ")");
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < inputType.rank(); ++i) {
+ TensorType.Dimension inputDimension = inputType.dimensions().get(i);
+ long dimSize = i == axis ? end - start : inputDimension.size().get();
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), dimSize));
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ OrderedTensorType inputType = input.type().get();
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ for (int i = 0; i < inputType.rank(); ++i) {
+ String inputDimensionName = inputType.dimensions().get(i).name();
+ ExpressionNode reference = new ReferenceNode(inputDimensionName);
+ ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset))));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ return generate;
+ }
+
+ @Override
+ public Split withInputs(List<IntermediateOperation> inputs) {
+ return new Split(modelName(), name(), inputs, attributes, output);
+ }
+
+ @Override
+ public String operationName() { return "Split"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
new file mode 100644
index 00000000000..8d3468f3d04
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -0,0 +1,100 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+/**
+ * Onnx tile operation.
+ */
+public class Tile extends IntermediateOperation {
+
+ public Tile(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) return null;
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ IntermediateOperation repeats = inputs.get(1);
+ if (repeats.getConstantValue().isEmpty())
+ throw new IllegalArgumentException("Tile " + name + ": repeats input must be a constant.");
+
+ Tensor shape = repeats.getConstantValue().get().asTensor();
+ if (shape.type().rank() != 1)
+ throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor.");
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ if (shape.type().dimensions().get(0).size().get() != inputType.rank())
+ throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank.");
+
+ List<Integer> dimSizes = new ArrayList<>(inputType.rank());
+ shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue()));
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ TensorType.Dimension inputDimension = inputType.dimensions().get(i);
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get() * dimSizes.get(i)));
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(2)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ OrderedTensorType inputType = input.type().get();
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ for (int axis = 0; axis < inputType.rank(); ++axis) {
+ String inputDimensionName = inputType.dimensions().get(axis).name();
+ long inputDimensionSize = inputType.dimensions().get(axis).size().get();
+
+ ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
+ ExpressionNode reference = new ReferenceNode(inputDimensionName);
+ ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size);
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod))));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ return generate;
+ }
+
+ @Override
+ public Tile withInputs(List<IntermediateOperation> inputs) {
+ return new Tile(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Tile"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
new file mode 100644
index 00000000000..178759fbf2a
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
@@ -0,0 +1,54 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+public class Transpose extends IntermediateOperation {
+
+ private final AttributeMap attributes;
+
+ public Transpose(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) {
+ super(modelName, nodeName, inputs);
+ this.attributes = attributes;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < inputType.rank(); ++i) {
+ int inputIndex = inputType.rank() - 1 - i;
+ if (attributes.getList("perm").isPresent()) {
+ inputIndex = (int) attributes.getList("perm").get().get(i).asDouble();
+ }
+ TensorType.Dimension inputDimension = inputType.dimensions().get(inputIndex);
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get()));
+ }
+ OrderedTensorType result = typeBuilder.build();
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1))
+ return null;
+ return inputs.get(0).function().orElse(null);
+ }
+
+ @Override
+ public Transpose withInputs(List<IntermediateOperation> inputs) {
+ return new Transpose(modelName(), name(), inputs, attributes);
+ }
+
+ @Override
+ public String operationName() { return "Transpose"; }
+
+}