aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2020-02-10 12:52:50 +0100
committerGitHub <noreply@github.com>2020-02-10 12:52:50 +0100
commita0db4db00f1a426741b09b2bc77ed06a87d930b9 (patch)
tree39662cd1405408ca35cbc57defbe7ba4a10d5016
parent075ad25176380a89ce1bd80a86ec8626de903586 (diff)
parent238085125e0c14fc7e3251338530b7508994ea3a (diff)
Merge pull request #12131 from vespa-engine/lesters/add-onnx-operators
Add gather,slice,cast,unsqueeze onnx operations
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java22
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java170
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java82
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java203
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java109
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java165
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java56
-rw-r--r--model-integration/src/test/models/onnx/simple/gather.onnxbin0 -> 150 bytes
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/gather.py23
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java50
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java2
16 files changed, 886 insertions, 24 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index d22a8067bd4..c7f320ed3b4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -224,6 +224,11 @@ public class DimensionRenamer {
/** Returns whether this is an opposite of another constraint */
boolean isOpposite() { return opposite; }
+ public static Constraint equal() { return new EqualConstraint(false, false); }
+ public static Constraint notEqual() { return new NotEqualConstraint(false, false); }
+ public static Constraint lessThan() { return new LessThanConstraint(false, false); }
+ public static Constraint greaterThan() { return new GreaterThanConstraint(false, false); }
+
public static Constraint equal(boolean soft) { return new EqualConstraint(soft, false); }
public static Constraint notEqual(boolean soft) { return new NotEqualConstraint(soft, false); }
public static Constraint lessThan(boolean soft) { return new LessThanConstraint(soft, false); }
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java
index 8caa158e5be..b272d4c6750 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import onnx.Onnx;
@@ -37,6 +38,7 @@ class AttributeConverter implements IntermediateOperation.AttributeMap {
case INT: return Optional.of(DoubleValue.frozen(attr.getI()));
case FLOAT: return Optional.of(DoubleValue.frozen(attr.getF()));
case STRING: return Optional.of(StringValue.frozen(attr.getS().toString()));
+ case TENSOR: return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attr.getT(), TypeConverter.typeFrom(attr.getT()))));
default:
return Optional.empty();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index d42338deaf8..ffc64c38f16 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,13 +2,18 @@
package ai.vespa.rankingexpression.importer.onnx;
+import ai.vespa.rankingexpression.importer.operations.ExpandDims;
+import ai.vespa.rankingexpression.importer.operations.Gather;
+import ai.vespa.rankingexpression.importer.operations.OnnxCast;
import ai.vespa.rankingexpression.importer.operations.Gemm;
import ai.vespa.rankingexpression.importer.operations.ConcatReduce;
import ai.vespa.rankingexpression.importer.operations.OnnxConcat;
import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
+import ai.vespa.rankingexpression.importer.operations.Slice;
import ai.vespa.rankingexpression.importer.operations.Softmax;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import ai.vespa.rankingexpression.importer.operations.Unsqueeze;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
@@ -67,6 +72,7 @@ class GraphImporter {
case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
+ case "cast": return new OnnxCast(modelName, nodeName, inputs, attributes);
case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes);
case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
@@ -75,6 +81,7 @@ class GraphImporter {
case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "gather": return new Gather(modelName, nodeName, inputs, attributes);
case "gemm": return new Gemm(modelName, nodeName, inputs, attributes);
case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
case "identity": return new Identity(modelName, nodeName, inputs);
@@ -105,6 +112,7 @@ class GraphImporter {
case "shape": return new Shape(modelName, nodeName, inputs);
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
+ case "slice": return new Slice(modelName, nodeName, inputs, attributes);
case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
@@ -113,6 +121,7 @@ class GraphImporter {
case "where": return new Select(modelName, nodeName, inputs);
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes);
}
IntermediateOperation op = new NoOp(modelName, nodeName, inputs);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index 69d18d0ffcb..f8c7dc15857 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -37,12 +37,14 @@ class TensorConverter {
case BOOL: return new RawBoolValues(tensorProto);
case FLOAT: return new RawFloatValues(tensorProto);
case DOUBLE: return new RawDoubleValues(tensorProto);
+ case INT32: return new RawIntValues(tensorProto);
case INT64: return new RawLongValues(tensorProto);
}
} else {
switch (tensorProto.getDataType()) {
case FLOAT: return new FloatValues(tensorProto);
case DOUBLE: return new DoubleValues(tensorProto);
+ case INT32: return new IntValues(tensorProto);
case INT64: return new LongValues(tensorProto);
}
}
@@ -96,6 +98,17 @@ class TensorConverter {
@Override int size() { return size; }
}
+ private static class RawIntValues extends RawValues {
+ private final IntBuffer values;
+ private final int size;
+ RawIntValues(Onnx.TensorProto tensorProto) {
+ values = bytes(tensorProto).asIntBuffer();
+ size = values.remaining();
+ }
+ @Override double get(int i) { return values.get(i); }
+ @Override int size() { return size; }
+ }
+
private static class RawLongValues extends RawValues {
private final LongBuffer values;
private final int size;
@@ -125,6 +138,15 @@ class TensorConverter {
@Override int size() { return tensorProto.getDoubleDataCount(); }
}
+ private static class IntValues extends Values {
+ private final Onnx.TensorProto tensorProto;
+ IntValues(Onnx.TensorProto tensorProto) {
+ this.tensorProto = tensorProto;
+ }
+ @Override double get(int i) { return tensorProto.getInt32Data(i); }
+ @Override int size() { return tensorProto.getInt32DataCount(); }
+ }
+
private static class LongValues extends Values {
private final Onnx.TensorProto tensorProto;
LongValues(Onnx.TensorProto tensorProto) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 3487d889338..e02f29a63f9 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -40,7 +40,7 @@ public class ExpandDims extends IntermediateOperation {
OrderedTensorType inputType = inputs.get(0).type().get();
int dimensionToInsert = (int)axis.asDouble();
if (dimensionToInsert < 0) {
- dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
+ dimensionToInsert = inputType.dimensions().size() + dimensionToInsert;
}
OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
new file mode 100644
index 00000000000..2a34ae53d5e
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -0,0 +1,170 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Slice;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+/*
+ * Onnx gather is the same as Numpy take.
+ */
+public class Gather extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+
+ private int axis;
+
+ public Gather(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(2)) return null;
+
+ OrderedTensorType dataType = inputs.get(0).type().get();
+ OrderedTensorType indicesType = inputs.get(1).type().get();
+
+ axis = (int) attributeMap.get("axis").orElse(DoubleValue.zero).asDouble();
+ if (axis < 0)
+ axis = dataType.rank() + axis;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < axis; ++i) {
+ addDimension(i, dataType.dimensions().get(i).size().orElse(-1L), typeBuilder);
+ }
+ for (int i = 0; i < indicesType.rank(); ++i) {
+ addDimension(i + axis, indicesType.dimensions().get(i).size().orElse(-1L), typeBuilder);
+ }
+ for (int i = axis + 1; i < dataType.rank(); ++i) {
+ addDimension(i + indicesType.rank(), dataType.dimensions().get(i).size().orElse(-1L), typeBuilder);
+ }
+
+ inputs.get(0).exportAsRankingFunction = true;
+ inputs.get(1).exportAsRankingFunction = true;
+
+ return typeBuilder.build();
+ }
+
+ private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ typeBuilder.add(TensorType.Dimension.indexed(name, size));
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(2)) return null;
+
+ IntermediateOperation data = inputs.get(0);
+ IntermediateOperation indices = inputs.get(1);
+ OrderedTensorType dataType = data.type().get();
+ OrderedTensorType indicesType = indices.type().get();
+
+ String dataFunctionName = data.rankingExpressionFunctionName();
+ String indicesFunctionName = indices.rankingExpressionFunctionName();
+
+ List<Slice.DimensionValue<Reference>> dataSliceDimensions = new ArrayList<>();
+ for (int i = 0; i < axis; ++i) {
+ addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i);
+ }
+
+ List<Slice.DimensionValue<Reference>> indicesSliceDimensions = new ArrayList<>();
+ for (int i = 0; i < indicesType.rank(); ++i) {
+ addSliceDimension(indicesSliceDimensions, indicesType.dimensions().get(i).name(), axis + i);
+ }
+ ExpressionNode sliceExpression = createSliceExpression(indicesSliceDimensions, indicesFunctionName);
+ ExpressionNode indexExpression = createIndexExpression(dataType, sliceExpression);
+ addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
+
+ for (int i = axis + 1; i < dataType.rank(); ++i) {
+ addSliceDimension(dataSliceDimensions, dataType.dimensions().get(i).name(), i + indicesType.rank() - 1);
+ }
+
+ sliceExpression = createSliceExpression(dataSliceDimensions, dataFunctionName);
+ return Generate.bound(type.type(), wrapScalar(sliceExpression));
+ }
+
+ private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) {
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(referenceName));
+ Slice<Reference> sliceIndices = new Slice<>(inputIndices, dimensionValues);
+ return new TensorFunctionNode(sliceIndices);
+ }
+
+ /** to support negative indexing */
+ private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) {
+ ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
+ ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize));
+ ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize);
+ return mod;
+ }
+
+ private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, ExpressionNode expr) {
+ dimensionValues.add(new Slice.DimensionValue<>(Optional.of(dimensionName), wrapScalar(new EmbracedNode(expr))));
+ }
+
+ private void addSliceDimension(List<Slice.DimensionValue<Reference>> dimensionValues, String dimensionName, int dimensionIndex) {
+ String outputDimensionName = type.dimensions().get(dimensionIndex).name();
+ addSliceDimension(dimensionValues, dimensionName, new ReferenceNode(outputDimensionName));
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if ( ! allInputTypesPresent(2)) return;
+
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ renamer.addDimension(type.dimensions().get(i).name());
+ for (int j = i + 1; j < type.dimensions().size(); j++) {
+ renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
+ DimensionRenamer.Constraint.lessThan(), this);
+ }
+ }
+
+ OrderedTensorType dataType = inputs.get(0).type().get();
+ OrderedTensorType indicesType = inputs.get(1).type().get();
+
+ for (int i = 0; i < axis; ++i) {
+ renamer.addConstraint(type.dimensions().get(i).name(),
+ dataType.dimensions().get(i).name(),
+ DimensionRenamer.Constraint.equal(), this);
+ }
+ for (int i = 0; i < indicesType.rank(); ++i) {
+ renamer.addConstraint(type.dimensions().get(i + axis).name(),
+ indicesType.dimensions().get(i).name(),
+ DimensionRenamer.Constraint.equal(), this);
+ }
+ for (int i = axis + 1; i < dataType.rank(); ++i) {
+ renamer.addConstraint(type.dimensions().get(i + indicesType.rank() - 1).name(),
+ dataType.dimensions().get(i).name(),
+ DimensionRenamer.Constraint.equal(), this);
+ }
+
+ }
+
+ @Override
+ public Gather withInputs(List<IntermediateOperation> inputs) {
+ return new Gather(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "Gather"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 724b5c6b3ac..2aa8b2a0d48 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -45,6 +45,7 @@ public abstract class IntermediateOperation {
protected OrderedTensorType type;
protected TensorFunction function;
protected TensorFunction rankingExpressionFunction = null;
+ protected boolean exportAsRankingFunction = false;
private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
@@ -78,7 +79,7 @@ public abstract class IntermediateOperation {
if (isConstant()) {
ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
function = new TensorFunctionNode.ExpressionTensorFunction(constant);
- } else if (outputs.size() > 1) {
+ } else if (outputs.size() > 1 || exportAsRankingFunction) {
rankingExpressionFunction = lazyGetFunction();
function = new VariableTensor(rankingExpressionFunctionName(), type.type());
} else {
@@ -137,7 +138,7 @@ public abstract class IntermediateOperation {
return Optional.of(constantValue);
}
if (constantValueFunction != null) {
- return Optional.of(constantValueFunction.apply(type));
+ return Optional.of(constantValueFunction.apply(type().orElse(null)));
}
return Optional.empty();
}
@@ -188,7 +189,7 @@ public abstract class IntermediateOperation {
throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant.");
}
Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN));
- if ( ! val.asTensor().type().equals(type.type()) ) {
+ if (type != null && ! val.asTensor().type().equals(type.type()) ) {
throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " +
"Expected: " + type.type() + " Got: " + val.asTensor().type());
}
@@ -211,6 +212,9 @@ public abstract class IntermediateOperation {
result = new TensorValue(lazyGetFunction().evaluate(context));
}
context.put(constantName, result);
+ if (outputs.size() > 1 || exportAsRankingFunction) {
+ context.put(rankingExpressionFunctionName(), result);
+ }
}
return result;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java
new file mode 100644
index 00000000000..d15ac1b69f7
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxCast.java
@@ -0,0 +1,82 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx.TensorProto.DataType;
+
+import java.util.List;
+import java.util.function.DoubleUnaryOperator;
+
+public class OnnxCast extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private final DataType toType;
+
+ public OnnxCast(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ if (attributeMap.get("to").isEmpty()) {
+ throw new IllegalArgumentException("OnnxCast in " + name + ": Required attribute 'to' is missing.");
+ }
+ toType = DataType.forNumber((int) attributeMap.get("to").get().asDouble());
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1))
+ return null;
+ return inputs.get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(1))
+ return null;
+ TensorFunction input = inputs.get(0).function().get();
+ switch (toType) {
+ case BOOL:
+ return new com.yahoo.tensor.functions.Map(input, new AsBool());
+ case INT8:
+ case INT16:
+ case INT32:
+ case INT64:
+ case UINT8:
+ case UINT16:
+ case UINT32:
+ case UINT64:
+ return new com.yahoo.tensor.functions.Map(input, new AsInt());
+ case FLOAT:
+ case DOUBLE:
+ case FLOAT16:
+ return input;
+ case STRING:
+ throw new IllegalArgumentException("OnnxCast in " + name + ": Casting to string is not implemented.");
+ default:
+ throw new IllegalArgumentException("OnnxCast in " + name + ": Unknown or undefined cast: " + toType.name());
+ }
+ }
+
+ @Override
+ public OnnxCast withInputs(List<IntermediateOperation> inputs) {
+ return new OnnxCast(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "Cast"; }
+
+ private static class AsBool implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return operand != 0.0 ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a)(a!=0)"; }
+ }
+
+ private static class AsInt implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return operand < 0 ? Math.ceil(operand) : Math.floor(operand); }
+ @Override
+ public String toString() { return "f(a)(if (a < 0, ceil(a), floor(a)))"; }
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
new file mode 100644
index 00000000000..e5463291ef8
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -0,0 +1,203 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+/**
+ * Onnx slice operation.
+ *
+ * Opset 1 to 9 accepts starts, ends, and axes tensors as attributes
+ *
+ * Opset 10 and up accepts starts, ends, axes, and steps tensors as inputs. Here we assume these are
+ * constants, otherwise we can't import this model because that would mean we
+ * would not know the resulting tensor type until run-time, and that is currently
+ * not supported in Vespa.
+ */
+public class Slice extends IntermediateOperation {
+
+ private final AttributeMap attributes;
+
+ private int[] starts;
+ private int[] ends;
+ private int[] steps;
+
+ public Slice(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) {
+ super(modelName, nodeName, inputs);
+ this.attributes = attributes;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (inputs.size() < 1 || inputs.get(0).type().isEmpty()) {
+ return null;
+ }
+ OrderedTensorType dataType = inputs.get(0).type().get();
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ // Todo: only supports opsets 1-9, for >= get these from inputs
+ int[] startsInput = attributeListAsArray("starts", 0);
+ int[] endsInput = attributeListAsArray("ends", 0);
+ int[] stepsInput = new int[dataType.rank()]; Arrays.fill(stepsInput, 1); // Todo: get from input when opset >= 10
+
+ int[] axes;
+ if (attributes.getList("axes").isPresent()) {
+ axes = attributeListAsArray("axes", 0);
+ } else {
+ // infer axes: default is [0, 1, ..., len('starts')-1]
+ axes = new int[startsInput.length];
+ for (int i = 0; i < startsInput.length; ++i) {
+ axes[i] = i;
+ }
+ }
+
+ if (startsInput.length != endsInput.length) {
+ throw new IllegalArgumentException("Slice in " + name + ": 'starts' and 'ends' indexes are not of the same size.");
+ }
+ if (startsInput.length != axes.length) {
+ throw new IllegalArgumentException("Slice in " + name + ": 'axes' and 'starts' are not of same size.");
+ }
+
+ int[] dimensionSizes = new int [dataType.rank()];
+ for (int i = 0; i < dataType.rank(); ++i) {
+ dimensionSizes[i] = dataType.dimensions().get(i).size().get().intValue();
+ }
+
+ starts = new int[dataType.rank()]; Arrays.fill(starts, 0);
+ ends = new int[dataType.rank()];
+ steps = new int[dataType.rank()]; Arrays.fill(steps, 1);
+
+ for (int i = 0; i < axes.length; ++i) {
+ int axis = axes[i];
+ int start = startsInput[i];
+ int end = endsInput[i];
+ int step = stepsInput[i];
+
+ axis = Math.min(axis, dataType.rank() - 1);
+ axis = axis < 0 ? axis + dataType.rank() : axis;
+
+ start = Math.min(start, dimensionSizes[axis]);
+ start = start < 0 ? start + dimensionSizes[axis] : start;
+
+ end = Math.min(end, dimensionSizes[axis]);
+ end = end < 0 ? end + dimensionSizes[axis] : end;
+
+ // Todo: check negative values for step size
+
+ starts[axis] = start;
+ steps[axis] = step;
+
+ if (step == 0) {
+ throw new IllegalArgumentException("Slice in " + name + ": illegal step size of 0.");
+ }
+ if ((end - start) < 1) {
+ throw new IllegalArgumentException("Slice in " + name + ": illegal start (" + start + ") and end (" + end + ") index.");
+ }
+ dimensionSizes[axis] = (end - start) / step;
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < dataType.rank(); ++i) {
+ addDimension(i, dimensionSizes[i], typeBuilder);
+ }
+ return typeBuilder.build();
+ }
+
+ private int[] attributeListAsArray(String name, int defaultValue) {
+ if (attributes.getList(name).isEmpty()) {
+ throw new IllegalArgumentException("Slice in " + name + ": Required attribute '" + name + "' is missing.");
+ }
+ List<Value> list = attributes.getList(name).get();
+ int[] result = new int[list.size()]; Arrays.fill(result, defaultValue);
+ for (int i = 0; i < list.size(); ++i) {
+ result[i] = (int)list.get(i).asDouble();
+ }
+ return result;
+ }
+
+ private void addDimension(int dimensionIndex, long size, OrderedTensorType.Builder typeBuilder) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ typeBuilder.add(TensorType.Dimension.indexed(name, size));
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (inputs.size() < 1 || inputs.get(0).function().isEmpty()) {
+ return null;
+ }
+
+ IntermediateOperation data = inputs.get(0);
+ OrderedTensorType dataType = data.type().get();
+ String dataFunctionName = data.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ for (int axis = 0; axis < dataType.rank(); ++axis) {
+ int start = starts[axis];
+ int step = steps[axis];
+
+ String inputDimensionName = dataType.dimensions().get(axis).name();
+ String outputDimensionName = type.dimensions().get(axis).name();
+
+ ExpressionNode stepSize = new ConstantNode(new DoubleValue(step));
+ ExpressionNode startIndex = new ConstantNode(new DoubleValue(start));
+
+ // step * (d0 + start)
+ ExpressionNode reference = new ReferenceNode(outputDimensionName);
+ ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex));
+ ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus);
+
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul))));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(dataFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ return generate;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ // Todo: what to do?
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ renamer.addDimension(type.dimensions().get(i).name());
+ for (int j = i + 1; j < type.dimensions().size(); j++) {
+ renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
+ DimensionRenamer.Constraint.lessThan(), this);
+ }
+ }
+ }
+
+ @Override
+ public Slice withInputs(List<IntermediateOperation> inputs) {
+ return new Slice(modelName(), name(), inputs, attributes);
+ }
+
+ @Override
+ public String operationName() { return "Slice"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java
new file mode 100644
index 00000000000..0df09c21530
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Unsqueeze.java
@@ -0,0 +1,109 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+public class Unsqueeze extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private List<String> expandDimensions;
+
+ public Unsqueeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ if (attributeMap.getList("axes").isEmpty()) {
+ throw new IllegalArgumentException("Unsqueeze in " + name + ": Required attribute 'axes' is missing.");
+ }
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ Set<Integer> dimensionsToInsert = attributeMap.getList("axes").get().stream().
+ map(d -> (int)d.asDouble()).collect(Collectors.toSet());
+
+ // handle negative dimension indexes
+ int rank = inputType.rank() + dimensionsToInsert.size();
+ dimensionsToInsert = dimensionsToInsert.stream().map(d -> d < 0 ? rank + d : d).collect(Collectors.toSet());
+
+ expandDimensions = new ArrayList<>();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ int inputDimensionIndex = 0;
+ for (int expandedDimensionIndex = 0; expandedDimensionIndex < rank; ++expandedDimensionIndex) {
+ if (dimensionsToInsert.contains(expandedDimensionIndex)) {
+ addDimension(expandedDimensionIndex, typeBuilder);
+ } else {
+ typeBuilder.add(inputType.dimensions().get(inputDimensionIndex));
+ inputDimensionIndex++;
+ }
+ }
+ return typeBuilder.build();
+ }
+
+ private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ expandDimensions.add(name);
+ typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(1)) return null;
+
+ // multiply with a generated tensor created from the expanded dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
+ for (String name : expandDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(expandDimensions.size());
+ for (String name : expandDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (newName.isEmpty()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ expandDimensions = renamedDimensions;
+ }
+
+ @Override
+ public Unsqueeze withInputs(List<IntermediateOperation> inputs) {
+ return new Unsqueeze(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "Unsqueeze"; }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 6954abe5157..94c5577357b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -17,6 +17,7 @@ import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import onnx.Onnx;
+import org.junit.Ignore;
import org.junit.Test;
import java.util.ArrayList;
@@ -26,7 +27,9 @@ import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*;
import static onnx.Onnx.AttributeProto.AttributeType.FLOAT;
import static onnx.Onnx.AttributeProto.AttributeType.INT;
import static onnx.Onnx.AttributeProto.AttributeType.INTS;
+import static onnx.Onnx.AttributeProto.AttributeType.TENSOR;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
/**
* Unit tests for ONNX operators. The number on the test reflects the minimum
@@ -294,6 +297,27 @@ public class OnnxOperationsTestCase {
}
@Test
+ public void testUnsqueeze1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2]):[1, 2]");
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1]):[1, 2]"), createAttribute("axes", new int[] {1}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1]):[1, 2]"), createAttribute("axes", new int[] {-1}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {-2}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0,0}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1]):[1, 2]"), createAttribute("axes", new int[] {0,2}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1]):[1, 2]"), createAttribute("axes", new int[] {2,0}));
+
+ x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,1}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,2}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[3],d3[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,3}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[1],d2[1],d3[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {1,2}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[2],d1[3],d2[1],d3[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2,3}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3],d4[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0,2,4}));
+ assertEval("unsqueeze", x, evaluate("tensor(d0[1],d1[2],d2[1],d3[3],d4[1]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {4,2,0}));
+ }
+
+ @Test
public void testWhere9() throws ParseException {
Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]");
@@ -308,6 +332,109 @@ public class OnnxOperationsTestCase {
assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x);
}
+ @Test
+ public void testCast1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[4]):[-1.9, 0.0, 1.1, 2.9]");
+ assertEval("cast", x, evaluate("tensor(d0[4]):[1,0,1,1]"), createAttribute("to", 9)); // boolean
+ assertEval("cast", x, evaluate("tensor(d0[4]):[-1,0,1,2]"), createAttribute("to", 6)); // int32
+ assertEval("cast", x, evaluate("tensor(d0[4]):[-1,0,1,2]"), createAttribute("to", 12)); // uint32
+ assertEval("cast", x, evaluate("tensor(d0[4]):[-1.9,0,1.1,2.9]"), createAttribute("to", 1)); // float
+ try {
+ assertEval("cast", x, evaluate("tensor(d0[4]):[1,0,1,1]"), createAttribute("to", 8)); // string
+ fail();
+ } catch (IllegalArgumentException e) {
+ assertEquals(e.getMessage(), "OnnxCast in cast: Casting to string is not implemented.");
+ }
+ }
+
+ @Test
+ public void testGather1() throws ParseException {
+ // 1 dim input, 1 dim indices
+ Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ Tensor y = evaluate("tensor(d0[3]):[0,2,4]");
+ assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,3,5]"));
+
+ // 2 dim input, 1 dim indices - axis 0
+ x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]");
+ y = evaluate("tensor(d0[4]):[2, 1, 0, 2]");
+ assertEval("gather", x, y, evaluate("tensor(d0[4],d1[2]):[5, 6, 3, 4, 1, 2, 5, 6]"));
+
+ // 1 dim input, 2 dim indices - axis 0
+ x = evaluate("tensor(d0[6]):[1, 2, 3, 4, 5, 6]");
+ y = evaluate("tensor(d0[2],d1[2]):[0, 1, 3, 5]");
+ assertEval("gather", x, y, evaluate("tensor(d0[2],d1[2]):[1, 2, 4, 6]"));
+
+ // 2 dim input, 2 dim indices - axis 0
+ x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]");
+ y = evaluate("tensor(d0[2],d1[2]):[0, 1, 1, 2]");
+ assertEval("gather", x, y, evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"), createAttribute("axis", -2));
+
+ // 2 dim input, 1 dim indices - axis 1
+ x = evaluate("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]");
+ y = evaluate("tensor(d0[4]):[0, 1, 0, 1]");
+ assertEval("gather", x, y, evaluate("tensor(d0[3],d1[4]):[1,2,1,2,3,4,3,4,5,6,5,6]"), createAttribute("axis", 1));
+
+ // 2 dim input, 2 dim indices - axis 1
+ x = evaluate("tensor(d0[3],d1[3]):[1, 2, 3, 4, 5, 6, 7, 8, 9]");
+ y = evaluate("tensor(d0[1],d1[2]):[0, 2]");
+ assertEval("gather", x, y, evaluate("tensor(d0[3],d1[1],d2[2]):[1,3,4,6,7,9]"), createAttribute("axis", 1));
+
+ // 1 dim input, 1 dim indices - negative indices
+ x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ y = evaluate("tensor(d0[3]):[0,-2,-4]");
+ assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,5,3]"));
+ }
+
+ @Test
+ public void testSlice1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[4]):[ [1,2,3,4],[5,6,7,8] ]");
+ AttributeConverter attributes = createAttributes().
+ attr("starts", new int[] {1, 0}).
+ attr("ends", new int[] {2, 3}).
+ attr("axes", new int[] {0, 1}).build();
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {0, 1}).
+ attr("ends", new int[] {-1, 1000}).build();
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [2,3,4] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {0, 1}).
+ attr("ends", new int[] {3, 2}).
+ attr("axes", new int[] {1, 0}).build(); // axes are switched
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {1, 0}).
+ attr("ends", new int[] {2, 3}).
+ attr("axes", new int[] {0, -1}).build(); // negative axes
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[3]):[ [5,6,7] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {1}).
+ attr("ends", new int[] {2}).
+ attr("axes", new int[] {0}).build(); // axis 1 is not specified
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[4]):[ [5,6,7,8] ]"), attributes);
+
+ attributes = createAttributes().
+ attr("starts", new int[] {0}).
+ attr("ends", new int[] {1}).build();
+ assertEval("slice", x, evaluate("tensor(d0[1],d1[4]):[ [1,2,3,4] ]"), attributes);
+ }
+
+ @Ignore
+ @Test
+ public void testSlice10() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[4]):[ [1,2,3,4],[5,6,7,8] ]");
+ Tensor starts = evaluate("tensor(d0[2]):[1,0]");
+ Tensor ends = evaluate("tensor(d0[2]):[2,3]");
+ Tensor axes = evaluate("tensor(d0[2]):[0,1]");
+ Tensor steps = evaluate("tensor(d0[2]):[1,2]");
+ assertEval("slice", x, starts, ends, axes, steps, evaluate("tensor(d0[1],d1[2]):[ [5,7] ]"));
+
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -334,28 +461,40 @@ public class OnnxOperationsTestCase {
}
private void assertEval(String opName, Tensor x, Tensor expected) {
- assertEval(opName, x, null, null, expected, null);
+ assertEval(opName, x, null, null, null, null, expected, null);
}
private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, null, null, expected, attr);
+ assertEval(opName, x, null, null, null, null, expected, attr);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, y, null, expected, attr);
+ assertEval(opName, x, y, null, null, null, expected, attr);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) {
- assertEval(opName, x, y, null, expected, null);
+ assertEval(opName, x, y, null, null, null, expected, null);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) {
- assertEval(opName, x, y, z, expected, null);
+ assertEval(opName, x, y, z, null, null, expected, null);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) {
+ assertEval(opName, x, y, z, null, null, expected, attr);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) {
+ assertEval(opName, x, y, z, q, null, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) {
+ assertEval(opName, x, y, z, q, r, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) {
Context context = new MapContext(DoubleValue.NaN);
- List<IntermediateOperation> inputs = createInputs(context, x, y, z);
+ List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r);
IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build());
optimizeAndRename(opName, op);
Tensor result = evaluate(op);
@@ -363,11 +502,13 @@ public class OnnxOperationsTestCase {
assertEquals(expected.type(), result.type());
}
- private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) {
+ private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r) {
List<IntermediateOperation> inputs = new ArrayList<>();
addInput(inputs, context, x, "x");
addInput(inputs, context, y, "y");
addInput(inputs, context, z, "z");
+ addInput(inputs, context, q, "q");
+ addInput(inputs, context, r, "r");
return inputs;
}
@@ -451,6 +592,16 @@ public class OnnxOperationsTestCase {
return this;
}
+ Attributes attr(String name, Tensor tensor) {
+ Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder();
+ builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);;
+ tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get()));
+ tensor.valueIterator().forEachRemaining(builder::addDoubleData);
+ Onnx.TensorProto val = builder.build();
+ nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(TENSOR).setT(val).build());
+ return this;
+ }
+
AttributeConverter build() {
return AttributeConverter.convert(nodeBuilder.build());
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index d1dea730da5..9631bddd93d 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -3,8 +3,13 @@
package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -21,21 +26,48 @@ public class SimpleImportTestCase {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/simple.onnx");
MapContext context = new MapContext();
- context.put("query_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[4])")).
- cell(0.1, 0, 0).
- cell(0.2, 0, 1).
- cell(0.3, 0, 2).
- cell(0.4, 0, 3).build()));
- context.put("attribute_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[4],d1[1])")).
- cell(0.1, 0, 0).
- cell(0.2, 1, 0).
- cell(0.3, 2, 0).
- cell(0.4, 3, 0).build()));
- context.put("bias_tensor", new TensorValue(Tensor.Builder.of(TensorType.fromSpec("tensor(d0[1],d1[1])")).
- cell(1.0, 0, 0).build()));
+ context.put("query_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[4]):[0.1, 0.2, 0.3, 0.4]")));
+ context.put("attribute_tensor", new TensorValue(Tensor.from("tensor(d0[4],d1[1]):[0.1, 0.2, 0.3, 0.4]")));
+ context.put("bias_tensor", new TensorValue(Tensor.from("tensor(d0[1],d1[1]):[1.0]")));
Tensor result = model.expressions().get("output").evaluate(context).asTensor();
assertEquals(result, Tensor.from("tensor(d0[1],d1[1]):{{d0:0,d1:0}:1.3}"));
}
+ @Test
+ public void testGather() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx");
+
+ MapContext context = new MapContext();
+ context.put("data", new TensorValue(Tensor.from("tensor(d0[3],d1[2]):[1, 2, 3, 4, 5, 6]")));
+ context.put("indices", new TensorValue(Tensor.from("tensor(d0[2],d1[2]):[0, 1, 1, 2]")));
+
+ model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
+
+ Tensor result = model.expressions().get("y").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"));
+ }
+
+ private void evaluateFunction(Context context, ImportedModel model, String functionName) {
+ if (!context.names().contains(functionName)) {
+ RankingExpression e = RankingExpression.from(model.functions().get(functionName));
+ evaluateFunctionDependencies(context, model, e.getRoot());
+ context.put(functionName, new TensorValue(e.evaluate(context).asTensor()));
+ }
+ }
+
+ private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ String name = node.toString();
+ if (model.functions().containsKey(name)) {
+ evaluateFunction(context, model, name);
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children()) {
+ evaluateFunctionDependencies(context, model, child);
+ }
+ }
+ }
+
}
diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx
new file mode 100644
index 00000000000..62451ad953d
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/gather.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/simple/gather.py b/model-integration/src/test/models/onnx/simple/gather.py
new file mode 100755
index 00000000000..63a2103fd86
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/gather.py
@@ -0,0 +1,23 @@
+# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+import numpy as np
+from onnx import helper, TensorProto
+
+data_type = helper.make_tensor_value_info('data', TensorProto.FLOAT, [3,2])
+indices_type = helper.make_tensor_value_info('indices', TensorProto.FLOAT, [2,2])
+output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [2,2,2])
+
+node = onnx.helper.make_node(
+ 'Gather',
+ inputs=['data', 'indices'],
+ outputs=['y'],
+ axis=0,
+)
+graph_def = onnx.helper.make_graph(
+ nodes = [node],
+ name = 'gather_test',
+ inputs = [data_type, indices_type],
+ outputs = [output_type]
+)
+model_def = helper.make_model(graph_def, producer_name='gather.py')
+onnx.save(model_def, 'gather.onnx')
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 6a87e0c6d46..807eb3aa7ce 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -386,6 +386,56 @@ public class EvaluationTestCase {
// tensor result dimensions are given from argument dimensions, not the resulting values
tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }");
tester.assertEvaluates("tensor(x{},y{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{},y{}):{ {x:1,y:0}:1, {x:2,y:1}:1 }");
+
+ }
+
+ @Test
+ public void testTake() {
+ EvaluationTester tester = new EvaluationTester();
+
+ // numpy.take(a, indices, axis) with tensors.
+
+ // 1 dim input, 1 dim indices
+ tester.assertEvaluates("tensor(d0[3]):[1, 3, 5]",
+ "tensor(d0[3])(tensor0{a0:(tensor1{indices0:(d0)})})",
+ "tensor(a0[6]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[3]):[0, 2, 4]");
+
+ // 1 dim input, 1 dim indices - negative indices
+ tester.assertEvaluates("tensor(d0[3]):[1, 5, 3]",
+ "tensor(d0[3])(tensor0{a0:(fmod(6 + tensor1{indices0:(d0)}, 6) ) })",
+ "tensor(a0[6]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[3]):[0, -2, -4]");
+
+ // 2 dim input, 1 dim indices - axis 0
+ tester.assertEvaluates("tensor(d0[4],d1[2]):[5, 6, 3, 4, 1, 2, 5, 6]",
+ "tensor(d0[4],d1[2])(tensor0{a0:(tensor1{indices0:(d0)}),a1:(d1)})",
+ "tensor(a0[3],a1[2]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[4]):[2, 1, 0, 2]");
+
+ // 1 dim input, 2 dim indices - axis 0
+ tester.assertEvaluates("tensor(d0[2],d1[2]):[1, 2, 4, 6]",
+ "tensor(d0[2],d1[2])(tensor0{a0:(tensor1{indices0:(d0),indices1:(d1)}) })",
+ "tensor(a0[6]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[2],indices1[2]):[0, 1, 3, 5]");
+
+ // 2 dim input, 2 dim indices - axis 0
+ tester.assertEvaluates("tensor(d0[2],d1[2],d2[2]):[1,2,3,4,3,4,5,6]",
+ "tensor(d0[2],d1[2],d2[2])(tensor0{a0:(tensor1{indices0:(d0),indices1:(d1)}),a1:(d2)})",
+ "tensor(a0[3],a1[2]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[2],indices1[2]):[0, 1, 1, 2]");
+
+ // 2 dim input, 1 dim indices - axis 1
+ tester.assertEvaluates("tensor(d0[3],d1[4]):[1,2,1,2,3,4,3,4,5,6,5,6]",
+ "tensor(d0[3],d1[4])(tensor0{a0:(d0), a1:(tensor1{indices0:(d1)}) })",
+ "tensor(a0[3],a1[2]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[4]):[0, 1, 0, 1]");
+
+ // 2 dim input, 2 dim indices - axis 1
+ tester.assertEvaluates("tensor(d0[3],d1[1],d2[2]):[1,3,4,6,7,9]",
+ "tensor(d0[3],d1[1],d2[2])(tensor0{a0:(d0), a1:(tensor1{indices0:(d1),indices1:(d2)}) })", // can add an if
+ "tensor(a0[3],a1[3]):[1, 2, 3, 4, 5, 6, 7, 8, 9]",
+ "tensor(indices0[1],indices1[2]):[0, 2]");
}
@Test
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 4d3989b8782..bccd66acd31 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -230,7 +230,7 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
@Override
public String toString() {
- return toString(null);
+ return toString(ToStringContext.empty());
}
public String toString(ToStringContext context) {