aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-04-21 15:26:58 +0200
committerLester Solbakken <lesters@oath.com>2020-04-21 15:26:58 +0200
commitaad5c7184f37e1441c928efa77b434620742ff88 (patch)
tree34a92e7f954aa92e21d48816335771ff607fe404
parent6f5ca49e45cdc8262fcf360b1c731a393385ffa8 (diff)
Update model-integration for supporting BERT-type models
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java21
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java40
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java14
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java73
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java159
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java90
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java119
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java100
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java53
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java132
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java14
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java2
-rw-r--r--model-integration/src/test/models/onnx/simple/concat.onnxbin0 -> 135 bytes
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/concat.py25
-rw-r--r--model-integration/src/test/models/onnx/simple/gather.onnxbin150 -> 150 bytes
-rw-r--r--model-integration/src/test/models/onnx/simple/simple.onnx4
27 files changed, 752 insertions, 159 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index c7f320ed3b4..11abd3d24d8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -66,7 +66,7 @@ public class DimensionRenamer {
void solve() {
log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints));
- renames = solve(100000);
+ renames = solve(1000000);
log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames));
}
@@ -86,7 +86,7 @@ public class DimensionRenamer {
private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) {
Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations);
- if ( solution == null) {
+ if (solution == null) {
ListMap<Arc, Constraint> hardConstraints = new ListMap<>();
boolean anyRemoved = copyHard(constraints, hardConstraints);
if (anyRemoved) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 14aa3ebf84e..ea981603481 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -81,28 +82,36 @@ public class IntermediateGraph {
DimensionRenamer renamer = new DimensionRenamer(this);
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- addDimensionNameConstraints(operations.get(output), renamer);
+ addDimensionNameConstraints(operations.get(output), renamer, new HashSet<>());
}
}
renamer.solve();
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
- renameDimensions(operations.get(output), renamer);
+ renameDimensions(operations.get(output), renamer, new HashSet<>());
}
}
}
- private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> processed) {
+ if (processed.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, processed));
operation.addDimensionNameConstraints(renamer);
+ processed.add(operation.name());
}
}
- private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> processed) {
+ if (processed.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.inputs().forEach(input -> renameDimensions(input, renamer, processed));
operation.renameDimensions(renamer);
+ processed.add(operation.name());
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 3774e64c886..a9d71b7d9d5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -15,9 +15,11 @@ import com.yahoo.text.ExpressionFormatter;
import com.yahoo.yolean.Exceptions;
import java.io.File;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -108,6 +110,9 @@ public abstract class ModelImporter implements MlModelImporter {
}
private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
+ if (model.expressions().containsKey(operation.name())) {
+ return operation.function();
+ }
if (operation.type().isEmpty()) {
return Optional.empty();
}
@@ -206,18 +211,22 @@ public abstract class ModelImporter implements MlModelImporter {
private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
for (ImportedModel.Signature signature : model.signatures().values()) {
for (String outputName : signature.outputs().values()) {
- reportWarnings(graph.get(outputName), model);
+ reportWarnings(graph.get(outputName), model, new HashSet<>());
}
}
}
- private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
+ private static void reportWarnings(IntermediateOperation operation, ImportedModel model, Set<String> processed) {
+ if (processed.contains(operation.name())) {
+ return;
+ }
for (String warning : operation.warnings()) {
// If we want to report warnings, that code goes here
}
for (IntermediateOperation input : operation.inputs()) {
- reportWarnings(input, model);
+ reportWarnings(input, model, processed);
}
+ processed.add(operation.name());
}
/**
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
index 21cc6b27dad..9a7fcc85ee1 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
@@ -37,7 +37,8 @@ class NamingConstraintSolver {
private static ListMap<String, Integer> allPossibilities(Set<String> dimensions) {
ListMap<String, Integer> all = new ListMap<>();
for (String dimension : dimensions) {
- for (int i = 0; i < dimensions.size(); ++i)
+ // 20 (different dimension names) should be enough for most problems.
+ for (int i = 0; i < Math.min(dimensions.size(), 20); ++i)
all.put(dimension, i);
}
return all;
@@ -89,6 +90,7 @@ class NamingConstraintSolver {
workList.add(constraint);
}
}
+ if (iterations > maxIterations) return false;
}
return true;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index ffc64c38f16..d14ad033a69 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,7 +2,6 @@
package ai.vespa.rankingexpression.importer.onnx;
-import ai.vespa.rankingexpression.importer.operations.ExpandDims;
import ai.vespa.rankingexpression.importer.operations.Gather;
import ai.vespa.rankingexpression.importer.operations.OnnxCast;
import ai.vespa.rankingexpression.importer.operations.Gemm;
@@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Slice;
import ai.vespa.rankingexpression.importer.operations.Softmax;
+import ai.vespa.rankingexpression.importer.operations.Split;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import ai.vespa.rankingexpression.importer.operations.Tile;
+import ai.vespa.rankingexpression.importer.operations.Transpose;
import ai.vespa.rankingexpression.importer.operations.Unsqueeze;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import java.util.Collection;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
@@ -53,19 +57,21 @@ class GraphImporter {
private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
+ IntermediateGraph graph,
+ int outputIndex) {
String type = node.getOpType();
String modelName = graph.name();
String nodeName = getNodeName(node);
AttributeConverter attributes = AttributeConverter.convert(node);
- return mapOperation(type, inputs, modelName, nodeName, attributes);
+ return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex);
}
static IntermediateOperation mapOperation(String opType,
List<IntermediateOperation> inputs,
String modelName,
String nodeName,
- AttributeConverter attributes) {
+ AttributeConverter attributes,
+ int outputIndex) {
switch (opType.toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
@@ -115,12 +121,15 @@ class GraphImporter {
case "slice": return new Slice(modelName, nodeName, inputs, attributes);
case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex);
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
case "where": return new Select(modelName, nodeName, inputs);
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ case "tile": return new Tile(modelName, nodeName, inputs);
+ case "transpose": return new Transpose(modelName, nodeName, inputs, attributes);
case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes);
}
@@ -168,11 +177,11 @@ class GraphImporter {
OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto);
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
-
} else {
Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
+ int outputIndex = getOutputIndex(node, name);
List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
- operation = mapOperation(node, inputs, intermediateGraph);
+ operation = mapOperation(node, inputs, intermediateGraph, outputIndex);
// propagate constant values if all inputs are constant
if (operation.isConstant()) {
@@ -185,7 +194,7 @@ class GraphImporter {
}
}
intermediateGraph.put(operation.name(), operation);
-
+ intermediateGraph.put(name, operation);
return operation;
}
@@ -296,6 +305,10 @@ class GraphImporter {
return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst();
}
+ private static int getOutputIndex(Onnx.NodeProto node, String outputName) {
+ return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0);
+ }
+
private static String getNodeName(Onnx.NodeProto node) {
String nodeName = node.getName();
if (nodeName.length() > 0)
@@ -307,11 +320,14 @@ class GraphImporter {
}
private static Set<String> getWarnings(IntermediateOperation op) {
- Set<String> warnings = new HashSet<>(op.warnings());
- for (IntermediateOperation input : op.inputs()) {
- warnings.addAll(getWarnings(input));
- }
- return warnings;
+ java.util.Map<String, Set<String>> warnings = new HashMap<>();
+ getWarnings(op, warnings);
+ return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
}
+ private static void getWarnings(IntermediateOperation op, java.util.Map<String, Set<String>> warnings) {
+ if (warnings.containsKey(op.name())) return;
+ op.inputs().forEach(input -> getWarnings(input, warnings));
+ warnings.put(op.name(), new HashSet<>(op.warnings()));
+ }
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 7c8038cea66..9354a346aaf 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -22,8 +22,11 @@ class TypeConverter {
for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) {
int vespaIndex = type.dimensionMap(onnxIndex);
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
- TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
long onnxDimensionSize = onnxDimension.getDimValue() == 0 ? 1 : onnxDimension.getDimValue();
+ if (onnxDimensionSize == -1) {
+ continue; // disregard batch dimensions
+ }
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
if (onnxDimensionSize != vespaDimension.size().orElse(-1L)) {
throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions");
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index 01fd7ee55bd..bc883076b33 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -53,12 +53,6 @@ public class Const extends IntermediateOperation {
return new TensorFunctionNode.ExpressionTensorFunction(expressionNode);
}
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + super.vespaName();
- }
-
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
addConstraintsFrom(type, renamer);
@@ -77,7 +71,7 @@ public class Const extends IntermediateOperation {
private Value value() {
Optional<Value> value = attributeMap.get("value", type);
- if ( ! value.isPresent()) {
+ if (value.isEmpty()) {
throw new IllegalArgumentException("Node '" + name + "' of type " +
"const has missing or non-recognized 'value' attribute");
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
index ad56eefe5f2..8f6e2335b10 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
@@ -4,7 +4,6 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.Collections;
@@ -13,20 +12,11 @@ import java.util.Optional;
public class Constant extends IntermediateOperation {
- private final String modelName;
-
public Constant(String modelName, String nodeName, OrderedTensorType type) {
super(modelName, nodeName, Collections.emptyList());
- this.modelName = modelName;
this.type = type.rename(vespaName() + "_");
}
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + vespaName(name);
- }
-
@Override
protected OrderedTensorType lazyGetType() {
return type;
@@ -61,7 +51,9 @@ public class Constant extends IntermediateOperation {
public Constant withInputs(List<IntermediateOperation> inputs) {
if ( ! inputs.isEmpty())
throw new IllegalArgumentException("Constant cannot take inputs");
- return new Constant(modelName(), name(), type);
+ Constant constant = new Constant(modelName(), name(), type);
+ constant.setConstantValueFunction(constantValueFunction);
+ return constant;
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
index 5463f645355..af192fcec38 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
@@ -12,12 +12,6 @@ public class Identity extends IntermediateOperation {
super(modelName, nodeName, inputs);
}
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + super.vespaName();
- }
-
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(1))
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 2aa8b2a0d48..af134fac6cf 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -3,6 +3,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
@@ -47,6 +49,8 @@ public abstract class IntermediateOperation {
protected TensorFunction rankingExpressionFunction = null;
protected boolean exportAsRankingFunction = false;
+ private boolean hasRenamedDimensions = false;
+
private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
private List<IntermediateOperation> controlInputs = Collections.emptyList();
@@ -121,7 +125,10 @@ public abstract class IntermediateOperation {
}
/** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
+ public void renameDimensions(DimensionRenamer renamer) {
+ type = type.rename(renamer);
+ hasRenamedDimensions = true;
+ }
/** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
public boolean isInput() { return false; }
@@ -153,12 +160,23 @@ public abstract class IntermediateOperation {
public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
/** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(name); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; }
+ public String vespaName() {
+ if (isConstant())
+ return modelName + "_" + vespaName(name);
+ return vespaName(name);
+ }
+
+ public String vespaName(String name) {
+ return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null;
+ }
/** Retrieve the valid Vespa name of this node if it is a ranking expression function */
public String rankingExpressionFunctionName() {
- return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null;
+ String vespaName = vespaName();
+ if (vespaName == null) {
+ return null;
+ }
+ return isConstant() ? "constant(" + vespaName + ")" : FUNCTION_PREFIX + modelName + "_" + vespaName;
}
/** Retrieve the list of warnings produced during its lifetime */
@@ -188,12 +206,51 @@ public abstract class IntermediateOperation {
if ( ! isConstant() ) {
throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant.");
}
- Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN));
- if (type != null && ! val.asTensor().type().equals(type.type()) ) {
+ Value val = evaluableCopy().evaluateAsConstant(new MapContext(DoubleValue.NaN));
+ if (type == null) {
+ return val;
+ }
+ Tensor tensor = val.asTensor();
+ checkIfRenameableTo(tensor, type);
+ setConstantValueFunction(t -> new TensorValue(tensor.withType(t.type()))); // so we don't have to re-evaluate
+ return new TensorValue(tensor.withType(type.type()));
+ }
+
+ private void checkIfRenameableTo(Tensor tensor, OrderedTensorType type) {
+ if ( ! tensor.type().isRenamableTo(type.type()) ) {
throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " +
- "Expected: " + type.type() + " Got: " + val.asTensor().type());
+ "Expected: " + type.type() + " Got: " + tensor.type());
+ }
+ }
+
+ private IntermediateOperation evaluableCopy() {
+ if (hasRenamedDimensions) {
+ return this;
+ }
+ IntermediateOperation copy = copyTree();
+
+ // Must have performed dimension renaming to properly evaluate as constant
+ IntermediateGraph graph = new IntermediateGraph(modelName);
+ graph.put(name, copy);
+ graph.outputs(graph.defaultSignature()).put(name, name);
+ graph.optimize();
+
+ return copy;
+ }
+
+ private IntermediateOperation copyTree() {
+ List<IntermediateOperation> in = new ArrayList<>();
+ if (constantValue != null) {
+ IntermediateOperation constant = new Constant(modelName, name, type);
+ constant.setConstantValueFunction(t -> new TensorValue(constantValue.asTensor().withType(t.type())));
+ return constant;
+ }
+ inputs.forEach(i -> in.add(i.copyTree()));
+ IntermediateOperation copy = withInputs(in);
+ if (constantValueFunction != null) {
+ copy.constantValueFunction = constantValueFunction;
}
- return val;
+ return copy;
}
private Value evaluateAsConstant(Context context) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index adb54474812..3211a44fa68 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -82,6 +82,13 @@ public class Join extends IntermediateOperation {
bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
}
+ // retain order of inputs
+ if (a == inputs.get(1)) {
+ TensorFunction temp = bReducedFunction;
+ bReducedFunction = aReducedFunction;
+ aReducedFunction = temp;
+ }
+
return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 6849e64641e..2b0af93fd8e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -3,13 +3,23 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.Slice;
import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.text.ExpressionFormatter;
+import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
public class MatMul extends IntermediateOperation {
public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
@@ -20,62 +30,139 @@ public class MatMul extends IntermediateOperation {
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(2)) return null;
+ OrderedTensorType typeA = inputs.get(0).type().get();
+ OrderedTensorType typeB = inputs.get(1).type().get();
+
+ if (typeA.type().rank() < 1 || typeB.type().rank() < 1)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1");
+
OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
- typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
- typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB;
+ OrderedTensorType smallestRankType = typeA.rank() >= typeB.rank() ? typeB : typeA;
+ for (int i = 0; i < largestRankType.rank() - 2; ++i) {
+ TensorType.Dimension dim = largestRankType.dimensions().get(i);
+ // broadcasting
+ int j = smallestRankType.rank() - largestRankType.rank() + i;
+ if (j >= 0 && smallestRankType.dimensions().get(j).size().get() > dim.size().get()) {
+ dim = smallestRankType.dimensions().get(j);
+ }
+ typeBuilder.add(dim);
+ }
+ if (typeA.rank() >= 2) {
+ typeBuilder.add(typeA.dimensions().get(typeA.rank() - 2));
+ }
+ if (typeB.rank() >= 2) {
+ typeBuilder.add(typeB.dimensions().get(typeB.rank() - 1));
+ }
return typeBuilder.build();
}
@Override
protected TensorFunction lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
+
+ OrderedTensorType typeA = inputs.get(0).type().get();
+ OrderedTensorType typeB = inputs.get(1).type().get();
+
+ TensorFunction functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB);
+ TensorFunction functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA);
+
+ return new com.yahoo.tensor.functions.Reduce(
+ new Join(functionA, functionB, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ typeA.dimensions().get(typeA.rank() - 1).name());
+ }
- OrderedTensorType aType = inputs.get(0).type().get();
- OrderedTensorType bType = inputs.get(1).type().get();
- if (aType.type().rank() < 2 || bType.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (aType.type().rank() != bType.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
- return null;
+ private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) {
+ List<Slice.DimensionValue> slices = new ArrayList<>();
+ for (int i = 0; i < typeA.rank() - 2; ++i) {
+ long dimSizeA = typeA.dimensions().get(i).size().get();
+ String dimNameA = typeA.dimensionNames().get(i);
+ int j = typeB.rank() - typeA.rank() + i;
+ if (j >= 0) {
+ long dimSizeB = typeB.dimensions().get(j).size().get();
+ if (dimSizeB > dimSizeA && dimSizeA == 1) {
+ ExpressionNode dimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
+ slices.add(new Slice.DimensionValue(Optional.of(dimNameA), wrapScalar(dimensionExpression)));
+ }
+ }
}
- return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+ return slices.size() == 0 ? tensorFunction : new Slice(tensorFunction, slices);
}
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
if ( ! allInputTypesPresent(2)) return;
- List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
-
- assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
- assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+ OrderedTensorType typeA = inputs.get(0).type().get();
+ OrderedTensorType typeB = inputs.get(1).type().get();
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
+ String lastDimA = typeA.dimensions().get(typeA.rank()-1).name();
+ String lastDimB = typeB.dimensions().get(typeB.rank()-1).name();
+ String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name();
+ String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name();
- // The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this);
+ // The last dimension of A should have the same name as the second-to-last dimension of B
+ renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this);
- // The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this);
+ // The second-to-last dimension of a should have a different name than the last dimension of b
+ if (typeA.rank() >= 2 && typeB.rank() >= 2) {
+ renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this);
+ }
// For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this);
- }
+ if (typeA.rank() >= 2) {
+ renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ if (typeB.rank() >= 2) {
+ renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this);
+ }
+
+ // Handle different cases when at least one of the tensors have rank > 2
+ for (int i = 0; i < typeA.rank() - 2; ++i) {
+ String iDim = typeA.dimensionNames().get(i);
+
+ // a1 < a2 < a3 < a4
+ for (int j = i+1; j < typeA.rank(); ++j) {
+ String jDim = typeA.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ // not equal to last 2 dimensions in B
+ for (int j = typeB.rank()-2; j < typeB.rank(); ++j) {
+ if (j < 0) continue;
+ String jDim = typeB.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this);
+ }
+ // equal to matching dimension in tensor B
+ int j = typeB.rank() - typeA.rank() + i;
+ if (j >= 0) {
+ String jDim = typeB.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
- private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
- if (dimensions.size() >= 2) return;
- throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
- " but got just " + dimensions + " from\n" +
- ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
+ for (int i = 0; i < typeB.rank() - 2; ++i) {
+ String iDim = typeB.dimensionNames().get(i);
+
+ // b1 < b2 < b3 < b4
+ for (int j = i+1; j < typeB.rank(); ++j) {
+ String jDim = typeB.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ // not equal to last 2 dimensions in A
+ for (int j = typeA.rank()-2; j < typeA.rank(); ++j) {
+ if (j < 0) continue;
+ String jDim = typeA.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this);
+ }
+ // equal to matching dimension in tensor A
+ int j = typeA.rank() - typeB.rank() + i;
+ if (j >= 0) {
+ String jDim = typeA.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
index e040ae62149..07ac457cca8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -54,7 +54,7 @@ public class Rename extends IntermediateOperation {
}
public void renameDimensions(DimensionRenamer renamer) {
- type = type.rename(renamer);
+ super.renameDimensions(renamer);
from = renamer.dimensionNameOf(from).orElse(from);
to = renamer.dimensionNameOf(to).orElse(to);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index c88fc18e6c6..d8e806ae7e4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -2,30 +2,29 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.Function;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
-import java.util.stream.Collectors;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
public class Reshape extends IntermediateOperation {
@@ -38,6 +37,10 @@ public class Reshape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
if (inputs.size() == 2) {
return typeWithShapeAsInput();
} else if (inputs.size() == 1) {
@@ -126,52 +129,53 @@ public class Reshape extends IntermediateOperation {
return new Reshape(modelName(), name(), inputs, attributeMap);
}
- public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
+ public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type())))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
+ IntermediateOperation input = inputs.get(0);
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
- // Here we create a transformation tensor that is multiplied with the from tensor to map into
- // the new shape. We have to introduce temporary dimension names and rename back if dimension names
- // in the new and old tensor type overlap.
-
- // Todo: change this to use tensor generate when available
-
- List<String> from = new ArrayList<>();
- List<String> to = new ArrayList<>();
- boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType);
- if (dimensionNamesOverlap) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType());
- for (int i = 0; i < outputType.rank(); ++i) {
- TensorType.Dimension dim = outputType.dimensions().get(i);
- from.add(dim.name());
- to.add("temp_" + dim.name());
- builder.add(dim.withName("temp_" + dim.name()));
- }
- outputType = builder.build();
- }
-
- ExpressionNode unrollFrom = unrollTensorExpression(inputType);
- ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo));
- TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build();
- Generate transformTensor = new Generate(transformationType,
- new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
+ ExpressionNode unrolled = new EmbracedNode(unrollTensorExpression(outputType));
- TensorFunction result = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+ long innerSize = 1;
+ for (int dim = 0; dim < inputType.rank(); ++dim) {
+ innerSize *= inputType.dimensions().get(dim).size().get();
+ }
- if (dimensionNamesOverlap) {
- result = new Rename(result, to, from);
+ for (int dim = 0; dim < inputType.rank(); ++dim) {
+ String inputDimensionName = inputType.dimensions().get(dim).name();
+ long inputDimensionSize = inputType.dimensions().get(dim).size().get();
+ long previousInnerSize = innerSize;
+ innerSize /= inputDimensionSize;
+
+ ExpressionNode inputDimensionExpression;
+ if (inputDimensionSize == 1) {
+ inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
+ } else if (dim == (inputType.rank() - 1)) {
+ ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
+ ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size);
+ inputDimensionExpression = new EmbracedNode(div);
+ } else {
+ ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
+ ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
+ ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize);
+ ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size);
+ inputDimensionExpression = new EmbracedNode(new FunctionNode(Function.floor, div));
+ }
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
}
- return result;
- }
- private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) {
- return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent());
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ return Generate.bound(outputType.type(), wrapScalar(sliceExpression));
}
private static ExpressionNode unrollTensorExpression(OrderedTensorType type) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
index e5463291ef8..372d70a9fda 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -176,13 +176,11 @@ public class Slice extends IntermediateOperation {
com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
- TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
- return generate;
+ return Generate.bound(type.type(), wrapScalar(sliceExpression));
}
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- // Todo: what to do?
for (int i = 0; i < type.dimensions().size(); i++) {
renamer.addDimension(type.dimensions().get(i).name());
for (int j = i + 1; j < type.dimensions().size(); j++) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 83086926316..d03827f4c72 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -28,6 +28,10 @@ public class Softmax extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
+
+ // input is referenced twice due to overflow avoidance, so make this it's own function.
+ inputs.get(0).exportAsRankingFunction = true;
+
return inputs.get(0).type().get();
}
@@ -50,7 +54,9 @@ public class Softmax extends IntermediateOperation {
}
TensorFunction input = inputs.get(0).function().get();
- TensorFunction exp = new Map(input, ScalarFunctions.exp());
+ TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
+ TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow
+ TensorFunction exp = new Map(cap, ScalarFunctions.exp());
TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions);
TensorFunction div = new Join(exp, sum, ScalarFunctions.divide());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
new file mode 100644
index 00000000000..3241732671d
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
@@ -0,0 +1,119 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class Split extends IntermediateOperation {
+
+ private final AttributeMap attributes;
+ private final int output;
+
+ private final int axis;
+ private int start;
+ private int end;
+
+ public Split(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes, int output) {
+ super(modelName, nodeName, inputs);
+ this.attributes = attributes;
+ this.output = output;
+ axis = (int) attributes.get("axis").orElse(DoubleValue.zero).asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1))
+ return null;
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ int axisSize = inputType.dimensions().get(axis).size().get().intValue();
+ start = 0;
+ end = axisSize;
+
+ if (attributes.getList("split").isPresent()) {
+ List<Value> splitList = attributes.getList("split").get();
+ if (output > splitList.size()) {
+ throw new IllegalArgumentException("Split in " + name + ": output out of range of split list");
+ }
+ for (int i = 0; i < output; ++i) {
+ start += (int) splitList.get(i).asDouble();
+ }
+ if (output < splitList.size()) {
+ end = start + (int) splitList.get(output).asDouble();
+ }
+ } else {
+ start = axisSize / 2 * output;
+ end = start + axisSize / 2;
+ }
+
+ if (start >= axisSize || start < 0) {
+ throw new IllegalArgumentException("Split in " + name + ": split start index out of range (" + start + ")");
+ }
+ if (end > axisSize || end < 0) {
+ throw new IllegalArgumentException("Split in " + name + ": split end index out of range (" + end + ")");
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < inputType.rank(); ++i) {
+ TensorType.Dimension inputDimension = inputType.dimensions().get(i);
+ long dimSize = i == axis ? end - start : inputDimension.size().get();
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), dimSize));
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ OrderedTensorType inputType = input.type().get();
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ for (int i = 0; i < inputType.rank(); ++i) {
+ String inputDimensionName = inputType.dimensions().get(i).name();
+ ExpressionNode reference = new ReferenceNode(inputDimensionName);
+ ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset))));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ return generate;
+ }
+
+ @Override
+ public Split withInputs(List<IntermediateOperation> inputs) {
+ return new Split(modelName(), name(), inputs, attributes, output);
+ }
+
+ @Override
+ public String operationName() { return "Split"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index a9e3fc6a43a..7ab01bf65c1 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -80,7 +80,7 @@ public class Squeeze extends IntermediateOperation {
private OrderedTensorType reducedType(OrderedTensorType inputType) {
OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
- for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ for (TensorType.Dimension dimension: inputType.dimensions()) {
if ( ! squeezeDimensions.contains(dimension.name())) {
builder.add(dimension);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
new file mode 100644
index 00000000000..a3c755bf1c0
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -0,0 +1,100 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+/**
+ * Onnx tile operation.
+ */
+public class Tile extends IntermediateOperation {
+
+ public Tile(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) return null;
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ IntermediateOperation repeats = inputs.get(1);
+ if (repeats.getConstantValue().isEmpty())
+ throw new IllegalArgumentException("Tile " + name + ": repeats input must be a constant.");
+
+ Tensor shape = repeats.getConstantValue().get().asTensor();
+ if (shape.type().rank() != 1)
+ throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor.");
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ if (shape.type().dimensions().get(0).size().get() != inputType.rank())
+ throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank.");
+
+ List<Integer> dimSizes = new ArrayList<>(inputType.rank());
+ shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue()));
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ TensorType.Dimension inputDimension = inputType.dimensions().get(i);
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get() * dimSizes.get(i)));
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(2)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ OrderedTensorType inputType = input.type().get();
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ for (int axis = 0; axis < inputType.rank(); ++axis) {
+ String inputDimensionName = inputType.dimensions().get(axis).name();
+ long inputDimensionSize = inputType.dimensions().get(axis).size().get();
+
+ ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
+ ExpressionNode reference = new ReferenceNode(inputDimensionName);
+ ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size);
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod))));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ return generate;
+ }
+
+ @Override
+ public Tile withInputs(List<IntermediateOperation> inputs) {
+ return new Tile(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Tile"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
new file mode 100644
index 00000000000..5e7bc1a1f36
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
@@ -0,0 +1,53 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+public class Transpose extends IntermediateOperation {
+
+ private final AttributeMap attributes;
+
+ public Transpose(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) {
+ super(modelName, nodeName, inputs);
+ this.attributes = attributes;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < inputType.rank(); ++i) {
+ int inputIndex = inputType.rank() - 1 - i;
+ if (attributes.getList("perm").isPresent()) {
+ inputIndex = (int) attributes.getList("perm").get().get(i).asDouble();
+ }
+ TensorType.Dimension inputDimension = inputType.dimensions().get(inputIndex);
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get()));
+ }
+ OrderedTensorType result = typeBuilder.build();
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1))
+ return null;
+ return inputs.get(0).function().orElse(null);
+ }
+
+ @Override
+ public Transpose withInputs(List<IntermediateOperation> inputs) {
+ return new Transpose(modelName(), name(), inputs, attributes);
+ }
+
+ @Override
+ public String operationName() { return "Transpose"; }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 94c5577357b..d5dff7fb1b7 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -107,6 +107,18 @@ public class OnnxOperationsTestCase {
assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y));
assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y));
assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y));
+
+ // broadcasting - opposite order
+ x = evaluate("random(d0[4]) + 1");
+ y = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ assertEval("add", x, y, evaluate("rename(x, d0, d2) + y", x, y));
+ assertEval("sub", x, y, evaluate("rename(x, d0, d2) - y", x, y));
+ assertEval("mul", x, y, evaluate("rename(x, d0, d2) * y", x, y));
+ assertEval("div", x, y, evaluate("rename(x, d0, d2) / y", x, y));
+ assertEval("greater", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a > b))", x, y));
+ assertEval("less", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a < b))", x, y));
+ assertEval("equal", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a == b))", x, y));
+ assertEval("pow", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(pow(a,b)))", x, y));
}
@Test
@@ -185,9 +197,49 @@ public class OnnxOperationsTestCase {
@Test
public void testMatMul1() throws ParseException {
- Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]");
- Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]");
- assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]"));
+ Tensor a = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ Tensor b = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("91"));
+
+ a = evaluate("tensor(d0[3]):[1,2,3]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2]):[22, 28]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3]):[1,2,3]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2]):[14, 32]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[2],d1[3],d2[2]):[1,2,3,4,5,6,7,8,9,10,11,12]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2]):[22,28,49,64,58,64,139,154]"));
+
+ a = evaluate("tensor(d0[2],d1[2],d2[3]):[1,2,3,4,5,6,7,8,9,10,11,12]");
+ b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2]):[22,28,49,64,76,100,103,136]"));
+
+ a = evaluate("tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+ b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]"));
+
+ a = evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[2],d1[2],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2],d3[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]"));
}
@Test
@@ -217,6 +269,10 @@ public class OnnxOperationsTestCase {
y = evaluate("tensor(d0[4]):[3,2,-1,1]");
assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]"));
+
+ x = evaluate("tensor(d0[1],d1[2],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12]");
+ y = evaluate("tensor(d0[2]):[2,6]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[6]):[1,2,3,4,5,6,7,8,9,10,11,12]"));
}
@Test
@@ -435,6 +491,48 @@ public class OnnxOperationsTestCase {
}
+ @Test
+ public void testTranspose1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[3]):[[1,2,3],[4,5,6]]");
+ assertEval("transpose", x, evaluate("tensor(d0[3],d1[2]):[[1,4],[2,5],[3,6]]"));
+ }
+
+ @Test
+ public void testTile6() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]");
+ Tensor y = evaluate("tensor(d0[2]):[1,2]");
+ assertEval("tile", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]"));
+
+ x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]");
+ y = evaluate("tensor(d0[2]):[3,1]");
+ assertEval("tile", x, y, evaluate("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]"));
+
+ x = evaluate("tensor(d0[1],d1[1],d2[1]):[1]");
+ y = evaluate("tensor(d0[3]):[1,6,1]");
+ assertEval("tile", x, y, evaluate("tensor(d0[1],d1[6],d2[1]):[1,1,1,1,1,1]"));
+ }
+
+ @Test
+ public void testSplit2() throws ParseException {
+ Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ assertEval("split", x, evaluate("tensor(d0[3]):[1,2,3]"), 0);
+ assertEval("split", x, evaluate("tensor(d0[3]):[4,5,6]"), 1);
+ assertEval("split", x, evaluate("tensor(d0[2]):[1,2]"), createAttribute("split", new int[] {2}), 0);
+ assertEval("split", x, evaluate("tensor(d0[4]):[3,4,5,6]"), createAttribute("split", new int[] {2}), 1);
+ assertEval("split", x, evaluate("tensor(d0[3]):[3,4,5]"), createAttribute("split", new int[] {2,3}), 1);
+ assertEval("split", x, evaluate("tensor(d0[1]):[6]"), createAttribute("split", new int[] {2,3}), 2);
+
+ x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"));
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), 0);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), 1);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), createAttribute("split", new int[] {1}), 0);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), createAttribute("split", new int[] {1}), 1);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[1,4]"), createAttribute("axis", 1), 0);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[2,5]"), createAttribute("axis", 1), 1);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2);
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -461,41 +559,49 @@ public class OnnxOperationsTestCase {
}
private void assertEval(String opName, Tensor x, Tensor expected) {
- assertEval(opName, x, null, null, null, null, expected, null);
+ assertEval(opName, x, null, null, null, null, expected, null, 0);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected, int output) {
+ assertEval(opName, x, null, null, null, null, expected, null, output);
}
private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, null, null, null, null, expected, attr);
+ assertEval(opName, x, null, null, null, null, expected, attr, 0);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr, int output) {
+ assertEval(opName, x, null, null, null, null, expected, attr, output);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, y, null, null, null, expected, attr);
+ assertEval(opName, x, y, null, null, null, expected, attr, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) {
- assertEval(opName, x, y, null, null, null, expected, null);
+ assertEval(opName, x, y, null, null, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) {
- assertEval(opName, x, y, z, null, null, expected, null);
+ assertEval(opName, x, y, z, null, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, y, z, null, null, expected, attr);
+ assertEval(opName, x, y, z, null, null, expected, attr, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) {
- assertEval(opName, x, y, z, q, null, expected, null);
+ assertEval(opName, x, y, z, q, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) {
- assertEval(opName, x, y, z, q, r, expected, null);
+ assertEval(opName, x, y, z, q, r, expected, null, 0);
}
- private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) {
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr, int output) {
Context context = new MapContext(DoubleValue.NaN);
List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r);
- IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build());
+ IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build(), output);
optimizeAndRename(opName, op);
Tensor result = evaluate(op);
assertEquals(expected, result);
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index 9631bddd93d..04db902073b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -11,7 +11,6 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -48,6 +47,19 @@ public class SimpleImportTestCase {
assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"));
}
+ @Test
+ public void testConcat() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx");
+
+ MapContext context = new MapContext();
+ context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]")));
+ context.put("j", new TensorValue(Tensor.from("tensor(d0[1]):[2]")));
+ context.put("k", new TensorValue(Tensor.from("tensor(d0[1]):[3]")));
+
+ Tensor result = model.expressions().get("y").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor(d0[3]):[1, 2, 3]"));
+ }
+
private void evaluateFunction(Context context, ImportedModel model, String functionName) {
if (!context.names().contains(functionName)) {
RankingExpression e = RankingExpression.from(model.functions().get(functionName));
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
index b9d767774be..25f8acf1f6d 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java
@@ -34,7 +34,7 @@ public class DropoutImportTestCase {
ImportedMlFunction function = signature.outputFunction("y", "y");
assertNotNull(function);
- assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(reduce(constant(test_outputs_Const), sum, d1), imported_ml_function_test_outputs_BiasAdd, f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
function.expression());
model.assertEqualResult("X", "outputs/Maximum");
assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
diff --git a/model-integration/src/test/models/onnx/simple/concat.onnx b/model-integration/src/test/models/onnx/simple/concat.onnx
new file mode 100644
index 00000000000..945bc3c9445
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/concat.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/simple/concat.py b/model-integration/src/test/models/onnx/simple/concat.py
new file mode 100755
index 00000000000..186002c2abb
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/concat.py
@@ -0,0 +1,25 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+import onnx
+from onnx import helper, TensorProto
+
+i_type = helper.make_tensor_value_info('i', TensorProto.FLOAT, [1])
+j_type = helper.make_tensor_value_info('j', TensorProto.FLOAT, [1])
+k_type = helper.make_tensor_value_info('k', TensorProto.FLOAT, [1])
+
+output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3])
+
+node = onnx.helper.make_node(
+ 'Concat',
+ inputs=['i', 'j', 'k'],
+ outputs=['y'],
+ axis=0,
+)
+graph_def = onnx.helper.make_graph(
+ nodes = [node],
+ name = 'concat_test',
+ inputs = [i_type, j_type, k_type],
+ outputs = [output_type]
+)
+model_def = helper.make_model(graph_def, producer_name='concat.py')
+onnx.save(model_def, 'concat.onnx')
diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx
index 62451ad953d..0647d86ed0f 100644
--- a/model-integration/src/test/models/onnx/simple/gather.onnx
+++ b/model-integration/src/test/models/onnx/simple/gather.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx
index 1c746c90efa..41b458451d0 100644
--- a/model-integration/src/test/models/onnx/simple/simple.onnx
+++ b/model-integration/src/test/models/onnx/simple/simple.onnx
@@ -1,4 +1,4 @@
- simple.py:ã
+ simple.py:ã
0
query_tensor
attribute_tensormatmul"MatMul
@@ -20,4 +20,4 @@
output


-B
+B \ No newline at end of file