summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-12-05 09:09:43 +0100
committerLester Solbakken <lesters@oath.com>2019-12-05 09:09:43 +0100
commitcd4e23a47c1993d5c9dbe17dfb23bdce3e037844 (patch)
tree7bf90e97261c246f4a3fe78b9401c50357fdac7f /model-integration
parent7cd2264c56253a1e9745cb063b8868a5589c6b51 (diff)
Add unit tests for ONNX operators (and fix some of the implementations)
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java50
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java78
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java27
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java103
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java34
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java4
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java460
11 files changed, 713 insertions, 56 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 6c583d960bd..14aa3ebf84e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -70,7 +70,7 @@ public class IntermediateGraph {
return operations;
}
- void optimize() {
+ public void optimize() {
renameDimensions();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index 280fe354149..63b04470d00 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -3,11 +3,13 @@
package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.operations.Gemm;
+import ai.vespa.rankingexpression.importer.operations.ConcatReduce;
import ai.vespa.rankingexpression.importer.operations.OnnxConcat;
import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Softmax;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
@@ -21,6 +23,8 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
@@ -36,24 +40,37 @@ import java.util.stream.Collectors;
*/
class GraphImporter {
+ private static final Value eluAlpha = DoubleValue.frozen(1.0);
+ private static final Value seluAlpha = DoubleValue.frozen(1.6732632423543772848170429916717);
+ private static final Value seluGamma = DoubleValue.frozen(1.0507009873554804934193349852946);
+ private static final Value leakyReluAlpha = DoubleValue.frozen(0.01);
+
private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
IntermediateGraph graph) {
+ String type = node.getOpType();
String modelName = graph.name();
String nodeName = getNodeName(node);
AttributeConverter attributes = AttributeConverter.convert(node);
+ return mapOperation(type, inputs, modelName, nodeName, attributes);
+ }
- switch (node.getOpType().toLowerCase()) {
+ static IntermediateOperation mapOperation(String opType,
+ List<IntermediateOperation> inputs,
+ String modelName,
+ String nodeName,
+ AttributeConverter attributes) {
+ switch (opType.toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
- case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes);
case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble()));
case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
@@ -63,23 +80,31 @@ class GraphImporter {
case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
case "matmul": return new MatMul(modelName, nodeName, inputs);
- case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
- case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
- case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
+ case "max": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.max);
+ case "min": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.min);
+ case "mean": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.avg);
case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
- case "reshape": return new Reshape(modelName, nodeName, inputs);
- case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum);
+ case "reshape": return new Reshape(modelName, nodeName, inputs, attributes);
+ case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null);
+ case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt());
+ case "reducelogsum":return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, null, ScalarFunctions.log());
+ case "reducelogsumexp": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.exp(), ScalarFunctions.log());
+ case "reducemax": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.max);
case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg);
+ case "reducemin": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.min);
+ case "reduceprod": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.prod);
+ case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum);
+ case "reducesumsquare": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), null);
case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
- case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu(attributes.get("gamma").orElse(seluGamma).asDouble(), attributes.get("alpha").orElse(seluAlpha).asDouble()));
+ case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu(attributes.get("alpha").orElse(leakyReluAlpha).asDouble()));
case "shape": return new Shape(modelName, nodeName, inputs);
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
- case "softmax": return new Softmax(modelName, nodeName, inputs);
+ case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
@@ -90,7 +115,7 @@ class GraphImporter {
}
IntermediateOperation op = new NoOp(modelName, nodeName, inputs);
- op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
+ op.warning("Operation '" + opType + "' is currently not implemented");
return op;
}
@@ -260,5 +285,4 @@ class GraphImporter {
"Either no explicit name given or no single output name.");
}
-
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
new file mode 100644
index 00000000000..497e7e7550d
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java
@@ -0,0 +1,78 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+
+public class ConcatReduce extends IntermediateOperation {
+
+ private final static String tmpDimensionName = "__concat_reduce_tmp_dimension_name__";
+ private final Reduce.Aggregator aggregator;
+
+ public ConcatReduce(String modelName, String nodeName, List<IntermediateOperation> inputs, Reduce.Aggregator aggregator) {
+ super(modelName, nodeName, inputs);
+ this.aggregator = aggregator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(inputs.size())) return null;
+ return inputs.get(0).type().get(); // todo, not necessarily so. Broadcasting etc?
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputFunctionsPresent(inputs.size())) return null;
+
+ TensorFunction result = inputs.get(0).function().get();
+ for (int i = 1; i < inputs.size(); ++i) {
+ TensorFunction b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat(result, b, tmpDimensionName);
+ }
+ return new com.yahoo.tensor.functions.Reduce(result, aggregator, tmpDimensionName);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if ( ! allInputTypesPresent(inputs.size())) return;
+
+ OrderedTensorType a = inputs.get(0).type().get();
+ for (int i = 1; i < inputs.size(); ++i) {
+ OrderedTensorType b = inputs.get(i).type().get();
+
+ OrderedTensorType largest = largestInput(a, b);
+ OrderedTensorType smallest = smallestInput(a, b);
+
+ int sizeDifference = largest.rank() - smallest.rank();
+ for (int j = 0; j < smallest.rank(); ++j) {
+ String bDim = smallest.dimensions().get(j).name();
+ String aDim = largest.dimensions().get(j + sizeDifference).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
+ }
+ a = b;
+ }
+ }
+
+ private OrderedTensorType largestInput(OrderedTensorType a, OrderedTensorType b) {
+ return a.rank() >= b.rank() ? a : b;
+ }
+
+ private OrderedTensorType smallestInput(OrderedTensorType a, OrderedTensorType b) {
+ return a.rank() < b.rank() ? a : b;
+ }
+
+
+ @Override
+ public ConcatReduce withInputs(List<IntermediateOperation> inputs) {
+ return new ConcatReduce(modelName(), name(), inputs, aggregator);
+ }
+
+ @Override
+ public String operationName() { return "ConcatReduce"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index f091ae165d1..3fba8680332 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -92,7 +92,7 @@ public class Gemm extends IntermediateOperation {
return null;
}
- String joinDimension = aType.dimensions().get(1).name(); // TODO: check wrt transpose!
+ String joinDimension = aType.dimensions().get(1 - transposeA).name();
TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension);
TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index bd302afa5c7..efd6f9d3339 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -199,7 +199,9 @@ public abstract class IntermediateOperation {
String constantName = "constant(" + vespaName() + ")";
Value result = context.get(constantName);
if (result == DoubleValue.NaN) {
- if (inputs.size() == 0) {
+ if (constantValue != null) {
+ result = constantValue;
+ } else if (inputs.size() == 0) {
if (getConstantValue().isEmpty()) {
throw new IllegalArgumentException("Error in evaluating constant for " + name);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
index ded76db60fe..5785621eed3 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
@@ -28,6 +28,9 @@ public class OnnxConcat extends IntermediateOperation {
if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null;
OrderedTensorType aType = inputs.get(0).type().get();
+ if (concatDimensionIndex < 0) {
+ concatDimensionIndex = aType.dimensions().size() + concatDimensionIndex;
+ }
long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L);
for (int i = 1; i < inputs.size(); ++i) {
@@ -92,7 +95,7 @@ public class OnnxConcat extends IntermediateOperation {
public void renameDimensions(DimensionRenamer renamer) {
super.renameDimensions(renamer);
concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
- }
+ }
@Override
public OnnxConcat withInputs(List<IntermediateOperation> inputs) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
index 1b2d9ac090e..b3fe1da931e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
@@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import java.util.function.DoubleUnaryOperator;
/**
* ONNX Reduce[Sum/Mean/etc] operation
@@ -24,6 +25,8 @@ public class Reduce extends IntermediateOperation {
private final AttributeMap attributeMap;
private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator;
+ private final DoubleUnaryOperator preOperator;
+ private final DoubleUnaryOperator postOperator;
private List<String> reduceDimensions;
@@ -31,11 +34,23 @@ public class Reduce extends IntermediateOperation {
List<IntermediateOperation> inputs,
AttributeMap attributeMap,
com.yahoo.tensor.functions.Reduce.Aggregator aggregator) {
+ this(modelName, nodeName, inputs, attributeMap, aggregator, null, null);
+ }
+
+ public Reduce(String modelName, String nodeName,
+ List<IntermediateOperation> inputs,
+ AttributeMap attributeMap,
+ com.yahoo.tensor.functions.Reduce.Aggregator aggregator,
+ DoubleUnaryOperator preOperator,
+ DoubleUnaryOperator postOperator) {
super(modelName, nodeName, inputs);
this.attributeMap = attributeMap;
this.aggregator = aggregator;
+ this.preOperator = preOperator;
+ this.postOperator = postOperator;
}
+
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
@@ -48,7 +63,7 @@ public class Reduce extends IntermediateOperation {
for (Value i : attributeMap.getList("axes").get()) {
int dimensionIndex = (int) i.asDouble();
if (dimensionIndex < 0) {
- dimensionIndex = inputType.dimensions().size() - dimensionIndex;
+ dimensionIndex = inputType.dimensions().size() - (-1 * dimensionIndex);
}
reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
}
@@ -61,6 +76,9 @@ public class Reduce extends IntermediateOperation {
if ( ! allInputTypesPresent(1)) return null;
TensorFunction inputFunction = inputs.get(0).function().get();
+ if (preOperator != null) {
+ inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator);
+ }
TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
@@ -74,6 +92,9 @@ public class Reduce extends IntermediateOperation {
new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
}
+ if (postOperator != null) {
+ output = new com.yahoo.tensor.functions.Map(output, postOperator);
+ }
return output;
}
@@ -93,7 +114,7 @@ public class Reduce extends IntermediateOperation {
@Override
public Reduce withInputs(List<IntermediateOperation> inputs) {
- return new Reduce(modelName(), name(), inputs, attributeMap, aggregator);
+ return new Reduce(modelName(), name(), inputs, attributeMap, aggregator, preOperator, postOperator);
}
@Override
@@ -101,7 +122,7 @@ public class Reduce extends IntermediateOperation {
private boolean shouldKeepDimensions() {
Optional<Value> keepDims = attributeMap.get("keepdims");
- return keepDims.isPresent() && keepDims.get().asBoolean();
+ return keepDims.isEmpty() || keepDims.get().asBoolean(); // default is 1
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index c7accd00619..1b72565b423 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -22,51 +23,97 @@ import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
-import java.util.Iterator;
import java.util.List;
+import java.util.Optional;
import java.util.stream.Collectors;
public class Reshape extends IntermediateOperation {
- public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ private final AttributeMap attributeMap;
+
+ public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
protected OrderedTensorType lazyGetType() {
- if ( ! allInputTypesPresent(2)) return null;
+ if (inputs.size() == 2) {
+ return typeWithShapeAsInput();
+ } else if (inputs.size() == 1) {
+ return typeWithShapeAsAttribute();
+ }
+ throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size());
+ }
+ private OrderedTensorType typeWithShapeAsInput() {
IntermediateOperation newShape = inputs.get(1);
if (newShape.getConstantValue().isEmpty())
- throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant.");
+ throw new IllegalArgumentException("Reshape " + name + ": Shape input must be a constant.");
+ OrderedTensorType inputType = inputs.get(0).type().get();
Tensor shape = newShape.getConstantValue().get().asTensor();
+ List<Integer> dimSizes = new ArrayList<>(shape.type().rank());
+ shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue()));
+
+ // first pass - set 0 values
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ if (dimSizes.get(i) == 0) {
+ if (i >= inputType.dimensions().size()) {
+ throw new IllegalArgumentException("Reshape " + name + ": 0 value for dimension not found in input");
+ }
+ dimSizes.set(i, inputType.dimensions().get(i).size().get().intValue());
+ }
+ }
+
+ // second pass - set any -1 values
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ if (dimSizes.get(i) < 0) {
+ int shapeSize = dimSizes.stream().reduce(1, (a, b) -> a * b);
+ int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue();
+ dimSizes.set(i, -1 * tensorSize / (shapeSize == 0 ? -1 : shapeSize));
+ }
+ }
+
+ return buildOutputType(dimSizes);
+ }
+
+ private OrderedTensorType typeWithShapeAsAttribute() {
+ if (attributeMap.getList("shape").isEmpty() || attributeMap.getList("shape").get().size() == 0)
+ throw new IllegalArgumentException("Reshape in " + name + ": Shape attribute is empty.");
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
- int dimensionIndex = 0;
- for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int size = cell.getValue().intValue();
+ List<Value> shape = attributeMap.getList("shape").get();
+ List<Integer> dimSizes = new ArrayList<>(shape.size());
+
+ for (Value v : shape) {
+ int size = (int) v.asDouble();
if (size < 0) {
- size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
- OrderedTensorType.tensorSize(inputType.type()).intValue();
+ int shapeSize = (int) shape.stream().mapToDouble(Value::asDouble).reduce(1, (a, b) -> a * b);
+ int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue();
+ size = -1 * shapeSize / tensorSize;
}
- outputTypeBuilder.add(TensorType.Dimension.indexed(
- String.format("%s_%d", vespaName(), dimensionIndex), size));
- dimensionIndex++;
+ dimSizes.add(size);
+ }
+ return buildOutputType(dimSizes);
+ }
+
+ private OrderedTensorType buildOutputType(List<Integer> dimSizes) {
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), i), dimSizes.get(i)));
}
return outputTypeBuilder.build();
}
@Override
protected TensorFunction lazyGetFunction() {
- if ( ! allInputTypesPresent(2)) return null;
- if ( ! allInputFunctionsPresent(2)) return null;
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null;
+ if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null;
OrderedTensorType inputType = inputs.get(0).type().get();
TensorFunction inputFunction = inputs.get(0).function().get();
- return reshape(inputFunction, inputType.type(), type.type());
+ return reshape(inputFunction, inputType, type);
}
@Override
@@ -76,11 +123,11 @@ public class Reshape extends IntermediateOperation {
@Override
public Reshape withInputs(List<IntermediateOperation> inputs) {
- return new Reshape(modelName(), name(), inputs);
+ return new Reshape(modelName(), name(), inputs, attributeMap);
}
- public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType)))
+ public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
+ if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type())))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
@@ -89,25 +136,27 @@ public class Reshape extends IntermediateOperation {
// the new shape. We have to introduce temporary dimension names and rename back if dimension names
// in the new and old tensor type overlap.
+ // Todo: change this to use tensor generate when available
+
List<String> from = new ArrayList<>();
List<String> to = new ArrayList<>();
boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType);
if (dimensionNamesOverlap) {
- TensorType.Builder builder = new TensorType.Builder(outputType.valueType());
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType());
for (int i = 0; i < outputType.rank(); ++i) {
TensorType.Dimension dim = outputType.dimensions().get(i);
from.add(dim.name());
to.add("temp_" + dim.name());
- builder.dimension(dim.withName("temp_" + dim.name()));
+ builder.add(dim.withName("temp_" + dim.name()));
}
outputType = builder.build();
}
ExpressionNode unrollFrom = unrollTensorExpression(inputType);
ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo));
+ ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo));
- TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
+ TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build();
Generate transformTensor = new Generate(transformationType,
new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
@@ -121,11 +170,11 @@ public class Reshape extends IntermediateOperation {
return result;
}
- private static boolean dimensionNamesOverlap(TensorType a, TensorType b) {
- return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent());
+ private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) {
+ return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent());
}
- private static ExpressionNode unrollTensorExpression(TensorType type) {
+ private static ExpressionNode unrollTensorExpression(OrderedTensorType type) {
if (type.rank() == 0)
return new ConstantNode(DoubleValue.zero);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 032ffb88a46..306387ad206 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -2,8 +2,13 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Map;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
+import java.util.ArrayList;
import java.util.List;
/**
@@ -13,8 +18,11 @@ import java.util.List;
*/
public class Softmax extends IntermediateOperation {
- public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ private final AttributeMap attributeMap;
+
+ public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
}
@Override
@@ -28,18 +36,30 @@ public class Softmax extends IntermediateOperation {
if ( ! allInputFunctionsPresent(1)) return null;
OrderedTensorType inputType = inputs.get(0).type().get();
- String dimension = inputType.dimensions().get(0).name();
- if (inputType.rank() == 2) {
- dimension = inputType.dimensions().get(1).name(); // assumption: first dimension is batch dimension
+
+ int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension, except if there's only one dimension
+ if (attributeMap.get("axis").isPresent()) {
+ axis = (int)attributeMap.get("axis").get().asDouble();
+ }
+ if (axis < 0) {
+ axis = inputType.rank() + axis;
}
+ List<String> reduceDimensions = new ArrayList<>();
+ for (int i = axis; i < inputType.rank(); ++i) {
+ reduceDimensions.add(inputType.dimensions().get(i).name()); // Do softmax over all dimensions except batch dimension
+ }
+
+ TensorFunction input = inputs.get(0).function().get();
+ TensorFunction exp = new Map(input, ScalarFunctions.exp());
+ TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions);
+ TensorFunction div = new Join(exp, sum, ScalarFunctions.divide());
- TensorFunction inputFunction = inputs.get(0).function().get();
- return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension);
+ return div;
}
@Override
public Softmax withInputs(List<IntermediateOperation> inputs) {
- return new Softmax(modelName(), name(), inputs);
+ return new Softmax(modelName(), name(), inputs, attributeMap);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
index 4f656d86929..0d2ba0cc714 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
@@ -64,7 +64,7 @@ class GraphImporter {
case "identity": return new Identity(modelName, nodeName, inputs);
case "placeholder": return new Argument(modelName, nodeName, nodeType);
case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
- case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs, attributes);
case "shape": return new Shape(modelName, nodeName, inputs);
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
@@ -113,7 +113,7 @@ class GraphImporter {
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
- case "softmax": return new Softmax(modelName, nodeName, inputs);
+ case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
// state ops
case "variable": return new Constant(modelName, nodeName, nodeType);
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
new file mode 100644
index 00000000000..6954abe5157
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -0,0 +1,460 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import ai.vespa.rankingexpression.importer.operations.Constant;
+import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.ConstantTensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*;
+import static onnx.Onnx.AttributeProto.AttributeType.FLOAT;
+import static onnx.Onnx.AttributeProto.AttributeType.INT;
+import static onnx.Onnx.AttributeProto.AttributeType.INTS;
+import static org.junit.Assert.assertEquals;
+
+/**
+ * Unit tests for ONNX operators. The number on the test reflects the minimum
+ * opset number for the operations tested.
+ *
+ * @author lesters
+ */
+public class OnnxOperationsTestCase {
+
+ private static final String modelName = "test_model";
+
+ @Test
+ public void testElementwiseOperators7() throws ParseException {
+ Tensor x = evaluate("tensor(d0[7]):[-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]");
+ assertEval("acos", x, evaluate("acos(x)", x));
+ assertEval("asin", x, evaluate("asin(x)", x));
+ assertEval("atan", x, evaluate("atan(x)", x));
+ assertEval("cos", x, evaluate("cos(x)", x));
+ assertEval("sin", x, evaluate("sin(x)", x));
+ assertEval("tan", x, evaluate("tan(x)", x));
+ assertEval("tanh", x, evaluate("tanh(x)", x));
+ assertEval("neg", x, evaluate("-x", x));
+ assertEval("sigmoid", x, evaluate("sigmoid(x)", x));
+ assertEval("exp", x, evaluate("exp(x)", x));
+ assertEval("floor", x, evaluate("floor(x)", x));
+ assertEval("ceil", x, evaluate("ceil(x)", x));
+ assertEval("abs", x, evaluate("abs(x)", x));
+
+ assertEval("relu", x, evaluate("max(0, x)", x));
+ assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 1.0 * (exp(a)-1), a)))", x));
+ assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 0.5 * (exp(a)-1), a)))", x), createAttribute("alpha", 0.5f));
+ assertEval("selu", x, evaluate("map(x, f(a)(1.050700987 * if(a >= 0, a, 1.673263242 * (exp(a) - 1))))", x));
+ assertEval("selu", x, evaluate("map(x, f(a)(1.0 * if(a >= 0, a, 1.5 * (exp(a) - 1))))", x), createAttributes().attr("gamma", 1.0f).attr("alpha", 1.5f).build());
+ assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x));
+ assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f));
+
+ x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]");
+ assertEval("log", x, evaluate("log(x)", x));
+ assertEval("sqrt", x, evaluate("sqrt(x)", x));
+ assertEval("reciprocal", x, evaluate("map(x, f(a)(1.0 / a))", x));
+ }
+
+ @Test
+ public void testJoinOperators7() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2]):[3, 4]");
+ Tensor y = evaluate("tensor(d0[2]):[1, 2]");
+ assertEval("add", x, y, evaluate("tensor(d0[2]):[4, 6]"));
+ assertEval("sub", x, y, evaluate("tensor(d0[2]):[2, 2]"));
+ assertEval("mul", x, y, evaluate("tensor(d0[2]):[3, 8]"));
+ assertEval("div", x, y, evaluate("tensor(d0[2]):[3, 2]"));
+ assertEval("greater", x, y, evaluate("tensor(d0[2]):[1, 1]"));
+ assertEval("less", x, y, evaluate("tensor(d0[2]):[0, 0]"));
+ assertEval("equal", x, y, evaluate("tensor(d0[2]):[0, 0]"));
+ assertEval("pow", x, y, evaluate("tensor(d0[2]):[3, 16]"));
+
+ x = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ y = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ assertEval("add", x, y, evaluate("x + y", x, y));
+ assertEval("sub", x, y, evaluate("x - y", x, y));
+ assertEval("mul", x, y, evaluate("x * y", x, y));
+ assertEval("div", x, y, evaluate("x / y", x, y));
+ assertEval("greater", x, y, evaluate("join(x, y, f(a,b)(a > b))", x, y));
+ assertEval("less", x, y, evaluate("join(x, y, f(a,b)(a < b))", x, y));
+ assertEval("equal", x, y, evaluate("join(x, y, f(a,b)(a == b))", x, y));
+ assertEval("pow", x, y, evaluate("join(x, y, f(a,b)(pow(a,b)))", x, y));
+
+ // broadcasting
+ x = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ y = evaluate("random(d0[4]) + 1");
+ assertEval("add", x, y, evaluate("x + rename(y, d0, d2)", x, y));
+ assertEval("sub", x, y, evaluate("x - rename(y, d0, d2)", x, y));
+ assertEval("mul", x, y, evaluate("x * rename(y, d0, d2)", x, y));
+ assertEval("div", x, y, evaluate("x / rename(y, d0, d2)", x, y));
+ assertEval("greater", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a > b))", x, y));
+ assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y));
+ assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y));
+ assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y));
+ }
+
+ @Test
+ public void testConcatReduce8() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2]):[3, 4]");
+ Tensor y = evaluate("tensor(d0[2]):[1, 2]");
+ Tensor z = evaluate("tensor(d0[2]):[5, 6]");
+ assertEval("max", x, y, z, evaluate("tensor(d0[2]):[5, 6]"));
+ assertEval("min", x, y, z, evaluate("tensor(d0[2]):[1, 2]"));
+ assertEval("mean", x, y, z, evaluate("tensor(d0[2]):[3, 4]"));
+
+ x = evaluate("random(d0[2],d1[3],d2[4])");
+ y = evaluate("random(d0[2],d1[3],d2[4])");
+ z = evaluate("random(d0[2],d1[3],d2[4])");
+ assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), max, tmp)", x, y, z));
+ assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), min, tmp)", x, y, z));
+ assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), avg, tmp)", x, y, z));
+
+ // broadcasting
+ x = evaluate("random(d0[2],d1[3],d2[4])");
+ y = evaluate("random(d0[3],d1[4])");
+ z = evaluate("random(d0[4])");
+ assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), max, tmp)", x, y, z));
+ assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), min, tmp)", x, y, z));
+ assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), avg, tmp)", x, y, z));
+ }
+
+ @Test
+ public void testConcat4() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2]):[1, 2]");
+ Tensor y = evaluate("tensor(d0[2]):[3, 4]");
+ Tensor expected = evaluate("tensor(d0[4]):[1,2,3,4]");
+ assertEval("concat", x, y, expected, createAttribute("axis", 0));
+ assertEval("concat", x, y, expected, createAttribute("axis", -1));
+
+ x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
+ y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]");
+ assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", 0));
+ assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", 1));
+ assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", -1));
+ assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", -2));
+
+ x = evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 5, 6, 7, 8]");
+ y = evaluate("tensor(d0[2],d1[2],d2[2]):[9,10,11,12,13,14,15,16]");
+ assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", 0));
+ assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", 1));
+ assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", 2));
+ assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", -1));
+ assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", -2));
+ assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", -3));
+ }
+
+ @Test
+ public void testGemm7() throws ParseException {
+ Tensor a = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
+ Tensor b = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]");
+ Tensor c = evaluate("tensor(d0[2],d1[2]):[0.1, 0.2, 0.3, 0.4]");
+
+ assertEval("gemm", a, b, evaluate("tensor(d0[2],d1[2]):[19, 22, 43, 50]"));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.3, 50.4]"));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[38.1, 44.2, 86.3, 100.4]"), createAttribute("alpha", 2.0f));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.2, 22.4, 43.6, 50.8]"), createAttribute("beta", 2.0f));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[26.1, 30.2, 38.3, 44.4]"), createAttribute("transA", 1));
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[17.1, 23.2, 39.3, 53.4]"), createAttribute("transB", 1));
+
+ // unidictional broadcasting for c
+ c = evaluate("tensor(d0[2]):[0.1, 0.2]");
+ assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.1, 50.2]"));
+ }
+
+ @Test
+ public void testIdentity1() throws ParseException {
+ Tensor x = evaluate("random(d0[2],d1[3],d2[4])");
+ assertEval("identity", x, x);
+ }
+
+ @Test
+ public void testMatMul1() throws ParseException {
+ Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]");
+ Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]"));
+ }
+
+ @Test
+ public void testReshape5() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]");
+ Tensor y = evaluate("tensor(d0[1]):[4]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[4]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[2]):[2,2]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[3]):[2,1,2]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[1],d2[2]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[2]):[2,-1]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[2]):[2,0]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"));
+
+ y = evaluate("tensor(d0[2]):[0,-1]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"));
+
+ x = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ y = evaluate("tensor(d0[2]):[3,2]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"));
+
+ y = evaluate("tensor(d0[4]):[3,2,-1,1]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]"));
+ }
+
+ @Test
+ public void testReduceOperators1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
+
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"));
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("axes", new int[] {0,1}));
+ assertEval("reducesum", x, evaluate("tensor():[10]"), createAttribute("keepdims", 0));
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("keepdims", 1));
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[]{0}));
+ assertEval("reducesum", x, evaluate("tensor(d0[2]):[4, 6]"), createAttributes().attr("axes", new int[]{0}).attr("keepdims", 0).build());
+ assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {1}));
+ assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[]{1}).attr("keepdims", 0).build());
+ assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[] {-2}));
+ assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {-1}));
+ assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[] {-1}).attr("keepdims", 0).build());
+
+ assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[1]):[24]"));
+ assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[2]):[3, 8]"), createAttribute("axes", new int[] {0}));
+
+ assertEval("reducemin", x, evaluate("tensor(d0[1],d1[1]):[1]"));
+ assertEval("reducemin", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0}));
+
+ assertEval("reducemax", x, evaluate("tensor(d0[1],d1[1]):[4]"));
+ assertEval("reducemax", x, evaluate("tensor(d0[1],d1[2]):[3, 4]"), createAttribute("axes", new int[] {0}));
+
+ assertEval("reducemean", x, evaluate("tensor():[2.5]"), createAttribute("keepdims", 0));
+ assertEval("reducemean", x, evaluate("tensor(d0[2]):[2, 3]"), createAttributes().attr("axes", new int[] {0}).attr("keepdims", 0).build());
+
+ assertEval("reducelogsum", x, evaluate("tensor():[log(10)]"), createAttribute("keepdims", 0));
+ assertEval("reducelogsumexp", x, evaluate("tensor():[log(exp(1)+exp(2)+exp(3)+exp(4))]"), createAttribute("keepdims", 0));
+ assertEval("reducesumsquare", x, evaluate("tensor():[1*1+2*2+3*3+4*4]"), createAttribute("keepdims", 0));
+
+ x = evaluate("tensor(d0[1],d1[5]):[-10, -5, 0, 5, 10]");
+ assertEval("reducel1", x, evaluate("tensor():[30]"), createAttribute("keepdims", 0));
+ assertEval("reducel2", x, evaluate("tensor():[sqrt(10*10 + 5*5 + 5*5 + 10*10)]"), createAttribute("keepdims", 0));
+ }
+
+ @Test
+ public void testShape1() throws ParseException {
+ Tensor x = evaluate("random(d0[2],d1[3],d2[4])");
+ assertEval("shape", x, evaluate("tensor(d0[3]):[2,3,4]"));
+ }
+
+ @Test
+ public void testSoftmax1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[1],d1[3]):[-1, 0, 1]");
+ assertEval("softmax", x, evaluate("tensor(d0[1],d1[3]):[0.09003058, 0.24472848, 0.66524094]"));
+
+ x = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 7]");
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", 0));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", 1)); // 1 is default
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", -1));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", -2));
+
+ x = evaluate("random(d0[2],d1[3],d2[4])");
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", 0));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", 1));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", 2));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", -1));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", -2));
+ assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", -3));
+ }
+
+ @Test
+ public void testSqueeze1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[1],d1[2]):[1, 2]");
+ assertEval("squeeze", x, evaluate("tensor(d0[2]):[1, 2]"));
+
+ x = evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]");
+ assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"));
+ assertEval("squeeze", x, evaluate("tensor(d0[2],d1[1],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0}));
+ assertEval("squeeze", x, evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2}));
+ assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0, 2}));
+ }
+
+ @Test
+ public void testWhere9() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]");
+ Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]");
+ Tensor condition = evaluate("tensor(d0[2],d1[2]):[0, 1, 0, 1]");
+ assertEval("where", condition, x, y, evaluate("tensor(d0[2],d1[2]):[5, 2, 7, 4]"));
+
+ assertEval("where", evaluate("tensor():[0]"), x, y, y);
+ assertEval("where", evaluate("tensor():[1]"), x, y, x);
+ assertEval("where", evaluate("tensor(d0[1]):[0]"), x, y, y);
+ assertEval("where", evaluate("tensor(d0[1]):[1]"), x, y, x);
+ assertEval("where", evaluate("tensor(d0[1],d1[1]):[0]"), x, y, y);
+ assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x);
+ }
+
+ private Tensor evaluate(String expr) throws ParseException {
+ return evaluate(expr, null, null, null);
+ }
+
+ private Tensor evaluate(String expr, Tensor x) throws ParseException {
+ return evaluate(expr, x, null, null);
+ }
+
+ private Tensor evaluate(String expr, Tensor x, Tensor y) throws ParseException {
+ return evaluate(expr, x, y, null);
+ }
+
+ private Tensor evaluate(String expr, Tensor x, Tensor y, Tensor z) throws ParseException {
+ Context context = new MapContext(DoubleValue.NaN);
+ if (x != null) context.put("x", new TensorValue(x));
+ if (y != null) context.put("y", new TensorValue(y));
+ if (z != null) context.put("z", new TensorValue(z));
+ return new RankingExpression(expr).evaluate(context).asTensor();
+ }
+
+ private Tensor evaluate(IntermediateOperation op) {
+ Tensor tensor = op.evaluateAsConstant(op.type().get()).asTensor();
+ return renameToStandardType(op, tensor);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected) {
+ assertEval(opName, x, null, null, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) {
+ assertEval(opName, x, null, null, expected, attr);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) {
+ assertEval(opName, x, y, null, expected, attr);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) {
+ assertEval(opName, x, y, null, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) {
+ assertEval(opName, x, y, z, expected, null);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) {
+ Context context = new MapContext(DoubleValue.NaN);
+ List<IntermediateOperation> inputs = createInputs(context, x, y, z);
+ IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build());
+ optimizeAndRename(opName, op);
+ Tensor result = evaluate(op);
+ assertEquals(expected, result);
+ assertEquals(expected.type(), result.type());
+ }
+
+ private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) {
+ List<IntermediateOperation> inputs = new ArrayList<>();
+ addInput(inputs, context, x, "x");
+ addInput(inputs, context, y, "y");
+ addInput(inputs, context, z, "z");
+ return inputs;
+ }
+
+ private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) {
+ if (x == null) return;
+ context.put(name, new TensorValue(x));
+ IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString()));
+ op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type)));
+ inputs.add(op);
+ }
+
+ Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) {
+ IndexedTensor indexedTensor = (IndexedTensor) tensor;
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
+ for (int i = 0; i < indexedTensor.size(); i++) {
+ builder.cellByDirectIndex(type.toDirectIndex(i), indexedTensor.get(i));
+ }
+ return builder.build();
+ }
+
+ private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) {
+ IntermediateGraph graph = new IntermediateGraph(modelName);
+ graph.put(opName, op);
+ graph.outputs(graph.defaultSignature()).put(opName, opName);
+ graph.optimize();
+ return op.function().get();
+ }
+
+ private Tensor renameToStandardType(IntermediateOperation op, Tensor tensor) {
+ OrderedTensorType operationType = op.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo);
+ return func.evaluate();
+ }
+ return tensor;
+ }
+
+ static AttributeConverter createAttribute(String name, int val) {
+ return new Attributes().attr(name, val).build();
+ }
+
+ static AttributeConverter createAttribute(String name, float val) {
+ return new Attributes().attr(name, val).build();
+ }
+
+ static AttributeConverter createAttribute(String name, int [] vals) {
+ return new Attributes().attr(name, vals).build();
+ }
+
+ static Attributes createAttributes() {
+ return new Attributes();
+ }
+
+ private static class Attributes {
+
+ Onnx.NodeProto.Builder nodeBuilder;
+
+ Attributes() {
+ this.nodeBuilder = Onnx.NodeProto.newBuilder();
+ }
+
+ Attributes attr(String name, int val) {
+ nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(INT).setI(val).build());
+ return this;
+ }
+
+ Attributes attr(String name, float val) {
+ nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(FLOAT).setF(val).build());
+ return this;
+ }
+
+ Attributes attr(String name, int [] vals) {
+ Onnx.AttributeProto.Builder builder = Onnx.AttributeProto.newBuilder();
+ for (int val : vals) {
+ builder.addInts(val);
+ }
+ nodeBuilder.addAttribute(builder.setName(name).setType(INTS).build());
+ return this;
+ }
+
+ AttributeConverter build() {
+ return AttributeConverter.convert(nodeBuilder.build());
+ }
+
+ }
+
+}