summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 11:40:14 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 11:40:14 +0100
commit296340ac996edac09a4f53997ae1a8a803d302c1 (patch)
treef7abfd7b4c30529bbdbc03334f523fcbebdd27e2 /model-integration
parent69e4f6bf072d8ebfb12761c450f2bdacf86e226c (diff)
Add additional ONNX operations
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java69
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java21
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java58
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java183
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java105
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java121
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java35
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java7
12 files changed, 601 insertions, 19 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java
new file mode 100644
index 00000000000..8caa158e5be
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java
@@ -0,0 +1,69 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import onnx.Onnx;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * Converts Onnx node attributes to Vespa attribute values.
+ *
+ * @author lesters
+ */
+class AttributeConverter implements IntermediateOperation.AttributeMap {
+
+ private final Onnx.NodeProto node;
+
+ private AttributeConverter(Onnx.NodeProto node) {
+ this.node = node;
+ }
+
+ static AttributeConverter convert(Onnx.NodeProto node) {
+ return new AttributeConverter(node);
+ }
+
+ @Override
+ public Optional<Value> get(String name) {
+ for (Onnx.AttributeProto attr : node.getAttributeList()) {
+ if (attr.getName().equals(name)) {
+ switch (attr.getType()) {
+ case INT: return Optional.of(DoubleValue.frozen(attr.getI()));
+ case FLOAT: return Optional.of(DoubleValue.frozen(attr.getF()));
+ case STRING: return Optional.of(StringValue.frozen(attr.getS().toString()));
+ default:
+ return Optional.empty();
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<Value> get(String name, OrderedTensorType type) {
+ return Optional.empty();
+ }
+
+ @Override
+ public Optional<List<Value>> getList(String name) {
+ for (Onnx.AttributeProto attr : node.getAttributeList()) {
+ if (attr.getName().equals(name)) {
+ switch (attr.getType()) {
+ case INTS: return Optional.of(attr.getIntsList().stream().map(DoubleValue::new).collect(Collectors.toList()));
+ case FLOATS: return Optional.of(attr.getFloatsList().stream().map(DoubleValue::new).collect(Collectors.toList()));
+ case STRINGS: return Optional.of(attr.getStringsList().stream().map((s) -> StringValue.frozen(s.toString())).collect(Collectors.toList()));
+ default:
+ return Optional.empty();
+ }
+ }
+ }
+ return Optional.empty();
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index ccf3c2d8fb0..4fa6c09c636 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,11 +2,16 @@
package ai.vespa.rankingexpression.importer.onnx;
+import ai.vespa.rankingexpression.importer.operations.Gemm;
+import ai.vespa.rankingexpression.importer.operations.OnnxConcat;
+import ai.vespa.rankingexpression.importer.operations.Reduce;
+import ai.vespa.rankingexpression.importer.operations.Select;
+import ai.vespa.rankingexpression.importer.operations.Softmax;
+import ai.vespa.rankingexpression.importer.operations.Squeeze;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.operations.Argument;
-import ai.vespa.rankingexpression.importer.operations.ConcatV2;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.Identity;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
@@ -36,6 +41,7 @@ class GraphImporter {
IntermediateGraph graph) {
String modelName = graph.name();
String nodeName = getNodeName(node);
+ AttributeConverter attributes = AttributeConverter.convert(node);
switch (node.getOpType().toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
@@ -44,13 +50,14 @@ class GraphImporter {
case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
- case "concat": return new ConcatV2(modelName, nodeName, inputs);
+ case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes);
case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "gemm": return new Gemm(modelName, nodeName, inputs, attributes);
case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
case "identity": return new Identity(modelName, nodeName, inputs);
case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
@@ -63,15 +70,21 @@ class GraphImporter {
case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum);
+ case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg);
case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu());
case "shape": return new Shape(modelName, nodeName, inputs);
- case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
- case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
+ case "softmax": return new Softmax(modelName, nodeName, inputs);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
+ case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
+ case "where": return new Select(modelName, nodeName, inputs);
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index f3d87d89c27..69d18d0ffcb 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -10,7 +10,10 @@ import onnx.Onnx;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
/**
* Converts Onnx tensors into Vespa tensors.
@@ -31,11 +34,16 @@ class TensorConverter {
private static Values readValuesOf(Onnx.TensorProto tensorProto) {
if (tensorProto.hasRawData()) {
switch (tensorProto.getDataType()) {
+ case BOOL: return new RawBoolValues(tensorProto);
case FLOAT: return new RawFloatValues(tensorProto);
+ case DOUBLE: return new RawDoubleValues(tensorProto);
+ case INT64: return new RawLongValues(tensorProto);
}
} else {
switch (tensorProto.getDataType()) {
case FLOAT: return new FloatValues(tensorProto);
+ case DOUBLE: return new DoubleValues(tensorProto);
+ case INT64: return new LongValues(tensorProto);
}
}
throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
@@ -55,6 +63,17 @@ class TensorConverter {
}
}
+ private static class RawBoolValues extends RawValues {
+ private final IntBuffer values;
+ private final int size;
+ RawBoolValues(Onnx.TensorProto tensorProto) {
+ values = bytes(tensorProto).asIntBuffer();
+ size = values.remaining();
+ }
+ @Override double get(int i) { return values.get(i); }
+ @Override int size() { return size; }
+ }
+
private static class RawFloatValues extends RawValues {
private final FloatBuffer values;
private final int size;
@@ -66,6 +85,28 @@ class TensorConverter {
@Override int size() { return size; }
}
+ private static class RawDoubleValues extends RawValues {
+ private final DoubleBuffer values;
+ private final int size;
+ RawDoubleValues(Onnx.TensorProto tensorProto) {
+ values = bytes(tensorProto).asDoubleBuffer();
+ size = values.remaining();
+ }
+ @Override double get(int i) { return values.get(i); }
+ @Override int size() { return size; }
+ }
+
+ private static class RawLongValues extends RawValues {
+ private final LongBuffer values;
+ private final int size;
+ RawLongValues(Onnx.TensorProto tensorProto) {
+ values = bytes(tensorProto).asLongBuffer();
+ size = values.remaining();
+ }
+ @Override double get(int i) { return values.get(i); }
+ @Override int size() { return size; }
+ }
+
private static class FloatValues extends Values {
private final Onnx.TensorProto tensorProto;
FloatValues(Onnx.TensorProto tensorProto) {
@@ -75,5 +116,22 @@ class TensorConverter {
@Override int size() { return tensorProto.getFloatDataCount(); }
}
+ private static class DoubleValues extends Values {
+ private final Onnx.TensorProto tensorProto;
+ DoubleValues(Onnx.TensorProto tensorProto) {
+ this.tensorProto = tensorProto;
+ }
+ @Override double get(int i) { return tensorProto.getDoubleData(i); }
+ @Override int size() { return tensorProto.getDoubleDataCount(); }
+ }
+
+ private static class LongValues extends Values {
+ private final Onnx.TensorProto tensorProto;
+ LongValues(Onnx.TensorProto tensorProto) {
+ this.tensorProto = tensorProto;
+ }
+ @Override double get(int i) { return tensorProto.getInt64Data(i); }
+ @Override int size() { return tensorProto.getInt64DataCount(); }
+ }
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
index dad4508bc61..f68372ce4dd 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
@@ -3,7 +3,6 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
-import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
@@ -39,7 +38,16 @@ public class Argument extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- addConstraintsFrom(type, renamer);
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ renamer.addDimension(type.dimensions().get(i).name());
+
+ // Each dimension is distinct and ordered correctly:
+ for (int j = i + 1; j < type.dimensions().size(); j++) {
+ renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
+ DimensionRenamer.Constraint.lessThan(false),
+ this);
+ }
+ }
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index e6cc96d48ad..3487d889338 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -82,9 +82,7 @@ public class ExpandDims extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
+ addConstraintsFrom(type, renamer);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
new file mode 100644
index 00000000000..b116f18c7d1
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -0,0 +1,183 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.text.ExpressionFormatter;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Gemm extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private final float alpha, beta;
+ private final int transposeA, transposeB;
+
+ private final static DoubleValue zero = DoubleValue.frozen(0.0);
+ private final static DoubleValue one = DoubleValue.frozen(1.0);
+
+ public Gemm(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.alpha = (float) attributeMap.get("alpha").orElse(one).asDouble();
+ this.beta = (float) attributeMap.get("beta").orElse(one).asDouble();
+ this.transposeA = (int) attributeMap.get("transA").orElse(zero).asDouble();
+ this.transposeB = (int) attributeMap.get("transB").orElse(zero).asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! check2or3InputsPresent()) return null;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+
+ TensorType.Dimension dimA = inputs.get(0).type().get().dimensions().get(transposeA);
+ TensorType.Dimension dimB = inputs.get(1).type().get().dimensions().get(1 - transposeB);
+
+ typeBuilder.add(dimA);
+ typeBuilder.add(dimB);
+ OrderedTensorType result = typeBuilder.build();
+
+ // Input tensor C. The shape of C should be unidirectional "broadcastable" to (dimA, dimB).
+ if (inputs.size() == 3) {
+ List<TensorType.Dimension> cDimensions = inputs.get(2).type().get().dimensions();
+ if (cDimensions.size() == 2) {
+ TensorType.Dimension dimC0 = cDimensions.get(0);
+ TensorType.Dimension dimC1 = cDimensions.get(1);
+
+ if ( ! (dimA.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) {
+ throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
+ " is not compatible or not broadcastable to " + result.type());
+ }
+ if ( ! (dimB.size().get().equals(dimC1.size().get()) || dimC1.size().get() == 1) ) {
+ throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
+ " is not compatible or not broadcastable to " + result.type());
+ }
+
+ } else if (cDimensions.size() == 1) {
+ TensorType.Dimension dimC0 = cDimensions.get(0);
+ if ( ! (dimB.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) {
+ throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
+ " is not compatible or not broadcastable to " + result.type());
+ }
+ } else {
+ throw new IllegalArgumentException("GEMM: optional input C has no dimensions.");
+ }
+ }
+
+ return result;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! check2or3InputsPresent()) return null;
+
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+ if (aType.type().rank() != 2 || bType.type().rank() != 2)
+ throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2");
+
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (aFunction.isEmpty() || bFunction.isEmpty()) {
+ return null;
+ }
+
+ String joinDimension = aType.dimensions().get(1).name(); // TODO: check wrt transpose!
+
+ TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension);
+ TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
+ new ArithmeticNode(
+ new TensorFunctionNode(AxB),
+ ArithmeticOperator.MULTIPLY,
+ new ConstantNode(new DoubleValue(alpha))));
+
+ if (inputs.size() == 3) {
+ Optional<TensorFunction> cFunction = inputs.get(2).function();
+ TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction(
+ new ArithmeticNode(
+ new TensorFunctionNode(cFunction.get()),
+ ArithmeticOperator.MULTIPLY,
+ new ConstantNode(new DoubleValue(beta))));
+ return new com.yahoo.tensor.functions.Join(alphaxAxB, betaxC, ScalarFunctions.add());
+ }
+
+ return alphaxAxB;
+ }
+
+ private boolean check2or3InputsPresent() {
+ if (inputs.size() != 2 && inputs.size() != 3) {
+ throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size());
+ }
+ return inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if ( ! check2or3InputsPresent()) return;
+
+ List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+
+ assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
+ assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+
+ String aDim0 = aDimensions.get(transposeA).name();
+ String aDim1 = aDimensions.get(1 - transposeA).name();
+ String bDim0 = bDimensions.get(transposeB).name();
+ String bDim1 = bDimensions.get(1 - transposeB).name();
+
+ // The second dimension of a should have the same name as the first dimension of b
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this);
+
+ // The first dimension of a should have a different name than the second dimension of b
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this);
+
+ // If c is given, should be unidirectionally broadcastable to tensor a * b:
+ // Tensor A and B both have exactly the same shape.
+ // Tensor A and B all have the same number of dimensions and the length of each dimensions is either a common length or B's length is 1.
+ // Tensor B has too few dimensions, and B can have its shapes prepended with a dimension of length 1 to satisfy property 2.
+ if (inputs.size() == 3) {
+ List<TensorType.Dimension> cDimensions = inputs.get(2).type().get().dimensions();
+
+ if (cDimensions.size() == 2) {
+ String cDim0 = cDimensions.get(0).name();
+ String cDim1 = cDimensions.get(1).name();
+ renamer.addConstraint(aDim0, cDim0, DimensionRenamer.Constraint.equal(false), this);
+ renamer.addConstraint(bDim1, cDim1, DimensionRenamer.Constraint.equal(false), this);
+ } else if (cDimensions.size() == 1) {
+ String cDim0 = cDimensions.get(0).name();
+ renamer.addConstraint(bDim1, cDim0, DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
+
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this);
+ }
+
+ private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
+ if (dimensions.size() >= 2) return;
+ throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
+ " but got just " + dimensions + " from\n" +
+ ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
+ }
+
+ @Override
+ public Gemm withInputs(List<IntermediateOperation> inputs) {
+ return new Gemm(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "Gemm"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 434261c6077..6849e64641e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -73,8 +73,6 @@ public class MatMul extends IntermediateOperation {
private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
if (dimensions.size() >= 2) return;
-
-
throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
" but got just " + dimensions + " from\n" +
ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
new file mode 100644
index 00000000000..ded76db60fe
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
@@ -0,0 +1,105 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+
+public class OnnxConcat extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private String concatDimensionName;
+ private int concatDimensionIndex;
+
+ public OnnxConcat(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ if (attributeMap.get("axis").isEmpty())
+ throw new IllegalArgumentException("OnnxConcat in " + name + ": Required attribute 'axis' is missing.");
+ this.concatDimensionIndex = (int) attributeMap.get("axis").get().asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null;
+
+ OrderedTensorType aType = inputs.get(0).type().get();
+ long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L);
+
+ for (int i = 1; i < inputs.size(); ++i) {
+ OrderedTensorType bType = inputs.get(i).type().get();
+ if (bType.rank() != aType.rank())
+ throw new IllegalArgumentException("OnnxConcat in " + name + ": Inputs must have the same rank.");
+
+ for (int j = 0; j < aType.rank(); ++j) {
+ long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
+ long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
+ if (j == concatDimensionIndex) {
+ concatDimSize += dimSizeB;
+ } else if (dimSizeA != dimSizeB) {
+ throw new IllegalArgumentException("OnnxConcat in " + name + ": " +
+ "input dimension " + j + " differs in input tensors.");
+ }
+ }
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ int dimensionIndex = 0;
+ for (TensorType.Dimension dimension : aType.dimensions()) {
+ if (dimensionIndex == concatDimensionIndex) {
+ concatDimensionName = dimension.name();
+ typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize));
+ } else {
+ typeBuilder.add(dimension);
+ }
+ dimensionIndex++;
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
+ return null;
+ }
+ TensorFunction result = inputs.get(0).function().get();
+ for (int i = 1; i < inputs.size(); ++i) {
+ TensorFunction b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName);
+ }
+ return result;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
+ return;
+ }
+ OrderedTensorType a = inputs.get(0).type().get();
+ for (int i = 1; i < inputs.size(); ++i) {
+ OrderedTensorType b = inputs.get(i).type().get();
+ String bDim = b.dimensions().get(concatDimensionIndex).name();
+ String aDim = a.dimensions().get(concatDimensionIndex).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
+ }
+
+ @Override
+ public OnnxConcat withInputs(List<IntermediateOperation> inputs) {
+ return new OnnxConcat(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "ConcatV2"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
new file mode 100644
index 00000000000..1b2d9ac090e
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
@@ -0,0 +1,121 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * ONNX Reduce[Sum/Mean/etc] operation
+ */
+public class Reduce extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator;
+
+ private List<String> reduceDimensions;
+
+ public Reduce(String modelName, String nodeName,
+ List<IntermediateOperation> inputs,
+ AttributeMap attributeMap,
+ com.yahoo.tensor.functions.Reduce.Aggregator aggregator) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.aggregator = aggregator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ reduceDimensions = inputType.dimensionNames(); // default is to reduce all dimensions
+ if (attributeMap.getList("axes").isPresent()) {
+ reduceDimensions = new ArrayList<>();
+ for (Value i : attributeMap.getList("axes").get()) {
+ int dimensionIndex = (int) i.asDouble();
+ if (dimensionIndex < 0) {
+ dimensionIndex = inputType.dimensions().size() - dimensionIndex;
+ }
+ reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
+ }
+ }
+ return reducedType(inputType, shouldKeepDimensions());
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions);
+ if (shouldKeepDimensions()) {
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
+ for (String name : reduceDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ }
+ return output;
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size());
+ for (String name : reduceDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (newName.isEmpty()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ reduceDimensions = renamedDimensions;
+ }
+
+ @Override
+ public Reduce withInputs(List<IntermediateOperation> inputs) {
+ return new Reduce(modelName(), name(), inputs, attributeMap, aggregator);
+ }
+
+ @Override
+ public String operationName() { return "Reduce"; }
+
+ private boolean shouldKeepDimensions() {
+ Optional<Value> keepDims = attributeMap.get("keepdims");
+ return keepDims.isPresent() && keepDims.get().asBoolean();
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if ( ! reduceDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ } else if (keepDimensions) {
+ builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
+ }
+ }
+ return builder.build();
+ }
+
+
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index a210ed13f5d..c7accd00619 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -16,6 +17,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
@@ -35,7 +37,7 @@ public class Reshape extends IntermediateOperation {
if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation newShape = inputs.get(1);
- if ( ! newShape.getConstantValue().isPresent())
+ if (newShape.getConstantValue().isEmpty())
throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant.");
Tensor shape = newShape.getConstantValue().get().asTensor();
@@ -69,9 +71,7 @@ public class Reshape extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
+ addConstraintsFrom(type, renamer);
}
@Override
@@ -89,17 +89,40 @@ public class Reshape extends IntermediateOperation {
// the new shape. We have to introduce temporary dimension names and rename back if dimension names
// in the new and old tensor type overlap.
+ List<String> from = new ArrayList<>();
+ List<String> to = new ArrayList<>();
+ boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType);
+ if (dimensionNamesOverlap) {
+ TensorType.Builder builder = new TensorType.Builder(outputType.valueType());
+ for (int i = 0; i < outputType.rank(); ++i) {
+ TensorType.Dimension dim = outputType.dimensions().get(i);
+ from.add(dim.name());
+ to.add("temp_" + dim.name());
+ builder.dimension(dim.withName("temp_" + dim.name()));
+ }
+ outputType = builder.build();
+ }
+
ExpressionNode unrollFrom = unrollTensorExpression(inputType);
ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
+ ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo));
TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
Generate transformTensor = new Generate(transformationType,
new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
- return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
+ TensorFunction result = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
Reduce.Aggregator.sum,
inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+
+ if (dimensionNamesOverlap) {
+ result = new Rename(result, to, from);
+ }
+ return result;
+ }
+
+ private static boolean dimensionNamesOverlap(TensorType a, TensorType b) {
+ return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent());
}
private static ExpressionNode unrollTensorExpression(TensorType type) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index 35a1b6e2b0e..8696d0f1858 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -51,6 +51,9 @@ public class Select extends IntermediateOperation {
if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
}
+ if (condition.type().rank() == 2 && dimensionSize(condition.type().dimensions().get(0)) == 1 && dimensionSize(condition.type().dimensions().get(1)) == 1) {
+ return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
+ }
}
// The task is to select cells from 'x' or 'y' based on 'condition'.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index 56d9b542093..a9e3fc6a43a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -31,7 +31,10 @@ public class Squeeze extends IntermediateOperation {
squeezeDimensions = new ArrayList<>();
Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims");
- if ( ! squeezeDimsAttr.isPresent()) {
+ if (squeezeDimsAttr.isEmpty()) {
+ squeezeDimsAttr = attributeMap.getList("axes"); // ONNX
+ }
+ if (squeezeDimsAttr.isEmpty()) {
squeezeDimensions = inputType.type().dimensions().stream().
filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
@@ -62,7 +65,7 @@ public class Squeeze extends IntermediateOperation {
List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size());
for (String name : squeezeDimensions) {
Optional<String> newName = renamer.dimensionNameOf(name);
- if (!newName.isPresent()) {
+ if (newName.isEmpty()) {
return; // presumably, already renamed
}
renamedDimensions.add(newName.get());