diff options
Diffstat (limited to 'model-integration/src')
11 files changed, 713 insertions, 56 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 6c583d960bd..14aa3ebf84e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -70,7 +70,7 @@ public class IntermediateGraph { return operations; } - void optimize() { + public void optimize() { renameDimensions(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index 280fe354149..63b04470d00 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -3,11 +3,13 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.operations.Gemm; +import ai.vespa.rankingexpression.importer.operations.ConcatReduce; import ai.vespa.rankingexpression.importer.operations.OnnxConcat; import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Softmax; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; @@ -21,6 +23,8 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.NoOp; import ai.vespa.rankingexpression.importer.operations.Reshape; import ai.vespa.rankingexpression.importer.operations.Shape; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; @@ -36,24 +40,37 @@ import java.util.stream.Collectors; */ class GraphImporter { + private static final Value eluAlpha = DoubleValue.frozen(1.0); + private static final Value seluAlpha = DoubleValue.frozen(1.6732632423543772848170429916717); + private static final Value seluGamma = DoubleValue.frozen(1.0507009873554804934193349852946); + private static final Value leakyReluAlpha = DoubleValue.frozen(0.01); + private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, IntermediateGraph graph) { + String type = node.getOpType(); String modelName = graph.name(); String nodeName = getNodeName(node); AttributeConverter attributes = AttributeConverter.convert(node); + return mapOperation(type, inputs, modelName, nodeName, attributes); + } - switch (node.getOpType().toLowerCase()) { + static IntermediateOperation mapOperation(String opType, + List<IntermediateOperation> inputs, + String modelName, + String nodeName, + AttributeConverter attributes) { + switch (opType.toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); - case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); - case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble())); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); @@ -63,23 +80,31 @@ class GraphImporter { case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); case "matmul": return new MatMul(modelName, nodeName, inputs); - case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); - case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); - case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean()); + case "max": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.max); + case "min": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.min); + case "mean": return new ConcatReduce(modelName, nodeName, inputs, com.yahoo.tensor.functions.Reduce.Aggregator.avg); case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); - case "reshape": return new Reshape(modelName, nodeName, inputs); - case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); + case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); + case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null); + case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt()); + case "reducelogsum":return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, null, ScalarFunctions.log()); + case "reducelogsumexp": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.exp(), ScalarFunctions.log()); + case "reducemax": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.max); case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg); + case "reducemin": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.min); + case "reduceprod": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.prod); + case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); + case "reducesumsquare": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), null); case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu(attributes.get("gamma").orElse(seluGamma).asDouble(), attributes.get("alpha").orElse(seluAlpha).asDouble())); + case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu(attributes.get("alpha").orElse(leakyReluAlpha).asDouble())); case "shape": return new Shape(modelName, nodeName, inputs); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); - case "softmax": return new Softmax(modelName, nodeName, inputs); + case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); @@ -90,7 +115,7 @@ class GraphImporter { } IntermediateOperation op = new NoOp(modelName, nodeName, inputs); - op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); + op.warning("Operation '" + opType + "' is currently not implemented"); return op; } @@ -260,5 +285,4 @@ class GraphImporter { "Either no explicit name given or no single output name."); } - } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java new file mode 100644 index 00000000000..497e7e7550d --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatReduce.java @@ -0,0 +1,78 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.Optional; + +public class ConcatReduce extends IntermediateOperation { + + private final static String tmpDimensionName = "__concat_reduce_tmp_dimension_name__"; + private final Reduce.Aggregator aggregator; + + public ConcatReduce(String modelName, String nodeName, List<IntermediateOperation> inputs, Reduce.Aggregator aggregator) { + super(modelName, nodeName, inputs); + this.aggregator = aggregator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(inputs.size())) return null; + return inputs.get(0).type().get(); // todo, not necessarily so. Broadcasting etc? + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputFunctionsPresent(inputs.size())) return null; + + TensorFunction result = inputs.get(0).function().get(); + for (int i = 1; i < inputs.size(); ++i) { + TensorFunction b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat(result, b, tmpDimensionName); + } + return new com.yahoo.tensor.functions.Reduce(result, aggregator, tmpDimensionName); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if ( ! allInputTypesPresent(inputs.size())) return; + + OrderedTensorType a = inputs.get(0).type().get(); + for (int i = 1; i < inputs.size(); ++i) { + OrderedTensorType b = inputs.get(i).type().get(); + + OrderedTensorType largest = largestInput(a, b); + OrderedTensorType smallest = smallestInput(a, b); + + int sizeDifference = largest.rank() - smallest.rank(); + for (int j = 0; j < smallest.rank(); ++j) { + String bDim = smallest.dimensions().get(j).name(); + String aDim = largest.dimensions().get(j + sizeDifference).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); + } + a = b; + } + } + + private OrderedTensorType largestInput(OrderedTensorType a, OrderedTensorType b) { + return a.rank() >= b.rank() ? a : b; + } + + private OrderedTensorType smallestInput(OrderedTensorType a, OrderedTensorType b) { + return a.rank() < b.rank() ? a : b; + } + + + @Override + public ConcatReduce withInputs(List<IntermediateOperation> inputs) { + return new ConcatReduce(modelName(), name(), inputs, aggregator); + } + + @Override + public String operationName() { return "ConcatReduce"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index f091ae165d1..3fba8680332 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -92,7 +92,7 @@ public class Gemm extends IntermediateOperation { return null; } - String joinDimension = aType.dimensions().get(1).name(); // TODO: check wrt transpose! + String joinDimension = aType.dimensions().get(1 - transposeA).name(); TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension); TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index bd302afa5c7..efd6f9d3339 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -199,7 +199,9 @@ public abstract class IntermediateOperation { String constantName = "constant(" + vespaName() + ")"; Value result = context.get(constantName); if (result == DoubleValue.NaN) { - if (inputs.size() == 0) { + if (constantValue != null) { + result = constantValue; + } else if (inputs.size() == 0) { if (getConstantValue().isEmpty()) { throw new IllegalArgumentException("Error in evaluating constant for " + name); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java index ded76db60fe..5785621eed3 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java @@ -28,6 +28,9 @@ public class OnnxConcat extends IntermediateOperation { if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null; OrderedTensorType aType = inputs.get(0).type().get(); + if (concatDimensionIndex < 0) { + concatDimensionIndex = aType.dimensions().size() + concatDimensionIndex; + } long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L); for (int i = 1; i < inputs.size(); ++i) { @@ -92,7 +95,7 @@ public class OnnxConcat extends IntermediateOperation { public void renameDimensions(DimensionRenamer renamer) { super.renameDimensions(renamer); concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); - } + } @Override public OnnxConcat withInputs(List<IntermediateOperation> inputs) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java index 1b2d9ac090e..b3fe1da931e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java @@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; import java.util.Optional; +import java.util.function.DoubleUnaryOperator; /** * ONNX Reduce[Sum/Mean/etc] operation @@ -24,6 +25,8 @@ public class Reduce extends IntermediateOperation { private final AttributeMap attributeMap; private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator; + private final DoubleUnaryOperator preOperator; + private final DoubleUnaryOperator postOperator; private List<String> reduceDimensions; @@ -31,11 +34,23 @@ public class Reduce extends IntermediateOperation { List<IntermediateOperation> inputs, AttributeMap attributeMap, com.yahoo.tensor.functions.Reduce.Aggregator aggregator) { + this(modelName, nodeName, inputs, attributeMap, aggregator, null, null); + } + + public Reduce(String modelName, String nodeName, + List<IntermediateOperation> inputs, + AttributeMap attributeMap, + com.yahoo.tensor.functions.Reduce.Aggregator aggregator, + DoubleUnaryOperator preOperator, + DoubleUnaryOperator postOperator) { super(modelName, nodeName, inputs); this.attributeMap = attributeMap; this.aggregator = aggregator; + this.preOperator = preOperator; + this.postOperator = postOperator; } + @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; @@ -48,7 +63,7 @@ public class Reduce extends IntermediateOperation { for (Value i : attributeMap.getList("axes").get()) { int dimensionIndex = (int) i.asDouble(); if (dimensionIndex < 0) { - dimensionIndex = inputType.dimensions().size() - dimensionIndex; + dimensionIndex = inputType.dimensions().size() - (-1 * dimensionIndex); } reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); } @@ -61,6 +76,9 @@ public class Reduce extends IntermediateOperation { if ( ! allInputTypesPresent(1)) return null; TensorFunction inputFunction = inputs.get(0).function().get(); + if (preOperator != null) { + inputFunction = new com.yahoo.tensor.functions.Map(inputFunction, preOperator); + } TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions @@ -74,6 +92,9 @@ public class Reduce extends IntermediateOperation { new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); } + if (postOperator != null) { + output = new com.yahoo.tensor.functions.Map(output, postOperator); + } return output; } @@ -93,7 +114,7 @@ public class Reduce extends IntermediateOperation { @Override public Reduce withInputs(List<IntermediateOperation> inputs) { - return new Reduce(modelName(), name(), inputs, attributeMap, aggregator); + return new Reduce(modelName(), name(), inputs, attributeMap, aggregator, preOperator, postOperator); } @Override @@ -101,7 +122,7 @@ public class Reduce extends IntermediateOperation { private boolean shouldKeepDimensions() { Optional<Value> keepDims = attributeMap.get("keepdims"); - return keepDims.isPresent() && keepDims.get().asBoolean(); + return keepDims.isEmpty() || keepDims.get().asBoolean(); // default is 1 } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index c7accd00619..1b72565b423 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -4,6 +4,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -22,51 +23,97 @@ import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; -import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.stream.Collectors; public class Reshape extends IntermediateOperation { - public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + private final AttributeMap attributeMap; + + public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override protected OrderedTensorType lazyGetType() { - if ( ! allInputTypesPresent(2)) return null; + if (inputs.size() == 2) { + return typeWithShapeAsInput(); + } else if (inputs.size() == 1) { + return typeWithShapeAsAttribute(); + } + throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size()); + } + private OrderedTensorType typeWithShapeAsInput() { IntermediateOperation newShape = inputs.get(1); if (newShape.getConstantValue().isEmpty()) - throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant."); + throw new IllegalArgumentException("Reshape " + name + ": Shape input must be a constant."); + OrderedTensorType inputType = inputs.get(0).type().get(); Tensor shape = newShape.getConstantValue().get().asTensor(); + List<Integer> dimSizes = new ArrayList<>(shape.type().rank()); + shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue())); + + // first pass - set 0 values + for (int i = 0; i < dimSizes.size(); ++i) { + if (dimSizes.get(i) == 0) { + if (i >= inputType.dimensions().size()) { + throw new IllegalArgumentException("Reshape " + name + ": 0 value for dimension not found in input"); + } + dimSizes.set(i, inputType.dimensions().get(i).size().get().intValue()); + } + } + + // second pass - set any -1 values + for (int i = 0; i < dimSizes.size(); ++i) { + if (dimSizes.get(i) < 0) { + int shapeSize = dimSizes.stream().reduce(1, (a, b) -> a * b); + int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue(); + dimSizes.set(i, -1 * tensorSize / (shapeSize == 0 ? -1 : shapeSize)); + } + } + + return buildOutputType(dimSizes); + } + + private OrderedTensorType typeWithShapeAsAttribute() { + if (attributeMap.getList("shape").isEmpty() || attributeMap.getList("shape").get().size() == 0) + throw new IllegalArgumentException("Reshape in " + name + ": Shape attribute is empty."); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); + List<Value> shape = attributeMap.getList("shape").get(); + List<Integer> dimSizes = new ArrayList<>(shape.size()); + + for (Value v : shape) { + int size = (int) v.asDouble(); if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / - OrderedTensorType.tensorSize(inputType.type()).intValue(); + int shapeSize = (int) shape.stream().mapToDouble(Value::asDouble).reduce(1, (a, b) -> a * b); + int tensorSize = OrderedTensorType.tensorSize(inputType.type()).intValue(); + size = -1 * shapeSize / tensorSize; } - outputTypeBuilder.add(TensorType.Dimension.indexed( - String.format("%s_%d", vespaName(), dimensionIndex), size)); - dimensionIndex++; + dimSizes.add(size); + } + return buildOutputType(dimSizes); + } + + private OrderedTensorType buildOutputType(List<Integer> dimSizes) { + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dimSizes.size(); ++i) { + outputTypeBuilder.add(TensorType.Dimension.indexed(String.format("%s_%d", vespaName(), i), dimSizes.get(i))); } return outputTypeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { - if ( ! allInputTypesPresent(2)) return null; - if ( ! allInputFunctionsPresent(2)) return null; + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent) ) return null; + if ( ! inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent) ) return null; OrderedTensorType inputType = inputs.get(0).type().get(); TensorFunction inputFunction = inputs.get(0).function().get(); - return reshape(inputFunction, inputType.type(), type.type()); + return reshape(inputFunction, inputType, type); } @Override @@ -76,11 +123,11 @@ public class Reshape extends IntermediateOperation { @Override public Reshape withInputs(List<IntermediateOperation> inputs) { - return new Reshape(modelName(), name(), inputs); + return new Reshape(modelName(), name(), inputs, attributeMap); } - public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) + public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, @@ -89,25 +136,27 @@ public class Reshape extends IntermediateOperation { // the new shape. We have to introduce temporary dimension names and rename back if dimension names // in the new and old tensor type overlap. + // Todo: change this to use tensor generate when available + List<String> from = new ArrayList<>(); List<String> to = new ArrayList<>(); boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType); if (dimensionNamesOverlap) { - TensorType.Builder builder = new TensorType.Builder(outputType.valueType()); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType()); for (int i = 0; i < outputType.rank(); ++i) { TensorType.Dimension dim = outputType.dimensions().get(i); from.add(dim.name()); to.add("temp_" + dim.name()); - builder.dimension(dim.withName("temp_" + dim.name())); + builder.add(dim.withName("temp_" + dim.name())); } outputType = builder.build(); } ExpressionNode unrollFrom = unrollTensorExpression(inputType); ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo)); + ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo)); - TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); + TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build(); Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); @@ -121,11 +170,11 @@ public class Reshape extends IntermediateOperation { return result; } - private static boolean dimensionNamesOverlap(TensorType a, TensorType b) { - return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent()); + private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) { + return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent()); } - private static ExpressionNode unrollTensorExpression(TensorType type) { + private static ExpressionNode unrollTensorExpression(OrderedTensorType type) { if (type.rank() == 0) return new ConstantNode(DoubleValue.zero); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 032ffb88a46..306387ad206 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -2,8 +2,13 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Map; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; +import java.util.ArrayList; import java.util.List; /** @@ -13,8 +18,11 @@ import java.util.List; */ public class Softmax extends IntermediateOperation { - public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs) { + private final AttributeMap attributeMap; + + public Softmax(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; } @Override @@ -28,18 +36,30 @@ public class Softmax extends IntermediateOperation { if ( ! allInputFunctionsPresent(1)) return null; OrderedTensorType inputType = inputs.get(0).type().get(); - String dimension = inputType.dimensions().get(0).name(); - if (inputType.rank() == 2) { - dimension = inputType.dimensions().get(1).name(); // assumption: first dimension is batch dimension + + int axis = inputType.rank() == 1 ? 0 : 1; // assumption: first dimension is batch dimension, except if there's only one dimension + if (attributeMap.get("axis").isPresent()) { + axis = (int)attributeMap.get("axis").get().asDouble(); + } + if (axis < 0) { + axis = inputType.rank() + axis; } + List<String> reduceDimensions = new ArrayList<>(); + for (int i = axis; i < inputType.rank(); ++i) { + reduceDimensions.add(inputType.dimensions().get(i).name()); // Do softmax over all dimensions except batch dimension + } + + TensorFunction input = inputs.get(0).function().get(); + TensorFunction exp = new Map(input, ScalarFunctions.exp()); + TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); + TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); - TensorFunction inputFunction = inputs.get(0).function().get(); - return new com.yahoo.tensor.functions.Softmax(inputFunction, dimension); + return div; } @Override public Softmax withInputs(List<IntermediateOperation> inputs) { - return new Softmax(modelName(), name(), inputs); + return new Softmax(modelName(), name(), inputs, attributeMap); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index 4f656d86929..0d2ba0cc714 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -64,7 +64,7 @@ class GraphImporter { case "identity": return new Identity(modelName, nodeName, inputs); case "placeholder": return new Argument(modelName, nodeName, nodeType); case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); - case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); case "shape": return new Shape(modelName, nodeName, inputs); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); @@ -113,7 +113,7 @@ class GraphImporter { case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "softmax": return new Softmax(modelName, nodeName, inputs); + case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); // state ops case "variable": return new Constant(modelName, nodeName, nodeType); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java new file mode 100644 index 00000000000..6954abe5157 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -0,0 +1,460 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static ai.vespa.rankingexpression.importer.onnx.GraphImporter.*; +import static onnx.Onnx.AttributeProto.AttributeType.FLOAT; +import static onnx.Onnx.AttributeProto.AttributeType.INT; +import static onnx.Onnx.AttributeProto.AttributeType.INTS; +import static org.junit.Assert.assertEquals; + +/** + * Unit tests for ONNX operators. The number on the test reflects the minimum + * opset number for the operations tested. + * + * @author lesters + */ +public class OnnxOperationsTestCase { + + private static final String modelName = "test_model"; + + @Test + public void testElementwiseOperators7() throws ParseException { + Tensor x = evaluate("tensor(d0[7]):[-1.0, -0.5, -0.1, 0.0, 0.1, 0.5, 1.0]"); + assertEval("acos", x, evaluate("acos(x)", x)); + assertEval("asin", x, evaluate("asin(x)", x)); + assertEval("atan", x, evaluate("atan(x)", x)); + assertEval("cos", x, evaluate("cos(x)", x)); + assertEval("sin", x, evaluate("sin(x)", x)); + assertEval("tan", x, evaluate("tan(x)", x)); + assertEval("tanh", x, evaluate("tanh(x)", x)); + assertEval("neg", x, evaluate("-x", x)); + assertEval("sigmoid", x, evaluate("sigmoid(x)", x)); + assertEval("exp", x, evaluate("exp(x)", x)); + assertEval("floor", x, evaluate("floor(x)", x)); + assertEval("ceil", x, evaluate("ceil(x)", x)); + assertEval("abs", x, evaluate("abs(x)", x)); + + assertEval("relu", x, evaluate("max(0, x)", x)); + assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 1.0 * (exp(a)-1), a)))", x)); + assertEval("elu", x, evaluate("map(x, f(a)(if(a < 0, 0.5 * (exp(a)-1), a)))", x), createAttribute("alpha", 0.5f)); + assertEval("selu", x, evaluate("map(x, f(a)(1.050700987 * if(a >= 0, a, 1.673263242 * (exp(a) - 1))))", x)); + assertEval("selu", x, evaluate("map(x, f(a)(1.0 * if(a >= 0, a, 1.5 * (exp(a) - 1))))", x), createAttributes().attr("gamma", 1.0f).attr("alpha", 1.5f).build()); + assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x)); + assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f)); + + x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]"); + assertEval("log", x, evaluate("log(x)", x)); + assertEval("sqrt", x, evaluate("sqrt(x)", x)); + assertEval("reciprocal", x, evaluate("map(x, f(a)(1.0 / a))", x)); + } + + @Test + public void testJoinOperators7() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[3, 4]"); + Tensor y = evaluate("tensor(d0[2]):[1, 2]"); + assertEval("add", x, y, evaluate("tensor(d0[2]):[4, 6]")); + assertEval("sub", x, y, evaluate("tensor(d0[2]):[2, 2]")); + assertEval("mul", x, y, evaluate("tensor(d0[2]):[3, 8]")); + assertEval("div", x, y, evaluate("tensor(d0[2]):[3, 2]")); + assertEval("greater", x, y, evaluate("tensor(d0[2]):[1, 1]")); + assertEval("less", x, y, evaluate("tensor(d0[2]):[0, 0]")); + assertEval("equal", x, y, evaluate("tensor(d0[2]):[0, 0]")); + assertEval("pow", x, y, evaluate("tensor(d0[2]):[3, 16]")); + + x = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + y = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + assertEval("add", x, y, evaluate("x + y", x, y)); + assertEval("sub", x, y, evaluate("x - y", x, y)); + assertEval("mul", x, y, evaluate("x * y", x, y)); + assertEval("div", x, y, evaluate("x / y", x, y)); + assertEval("greater", x, y, evaluate("join(x, y, f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(x, y, f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(x, y, f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(x, y, f(a,b)(pow(a,b)))", x, y)); + + // broadcasting + x = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + y = evaluate("random(d0[4]) + 1"); + assertEval("add", x, y, evaluate("x + rename(y, d0, d2)", x, y)); + assertEval("sub", x, y, evaluate("x - rename(y, d0, d2)", x, y)); + assertEval("mul", x, y, evaluate("x * rename(y, d0, d2)", x, y)); + assertEval("div", x, y, evaluate("x / rename(y, d0, d2)", x, y)); + assertEval("greater", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y)); + } + + @Test + public void testConcatReduce8() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[3, 4]"); + Tensor y = evaluate("tensor(d0[2]):[1, 2]"); + Tensor z = evaluate("tensor(d0[2]):[5, 6]"); + assertEval("max", x, y, z, evaluate("tensor(d0[2]):[5, 6]")); + assertEval("min", x, y, z, evaluate("tensor(d0[2]):[1, 2]")); + assertEval("mean", x, y, z, evaluate("tensor(d0[2]):[3, 4]")); + + x = evaluate("random(d0[2],d1[3],d2[4])"); + y = evaluate("random(d0[2],d1[3],d2[4])"); + z = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), max, tmp)", x, y, z)); + assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), min, tmp)", x, y, z)); + assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, y, tmp), z, tmp), avg, tmp)", x, y, z)); + + // broadcasting + x = evaluate("random(d0[2],d1[3],d2[4])"); + y = evaluate("random(d0[3],d1[4])"); + z = evaluate("random(d0[4])"); + assertEval("max", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), max, tmp)", x, y, z)); + assertEval("min", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), min, tmp)", x, y, z)); + assertEval("mean", x, y, z, evaluate("reduce(concat(concat(x, rename(y, (d0,d1), (d1,d2)), tmp), rename(z, d0, d2), tmp), avg, tmp)", x, y, z)); + } + + @Test + public void testConcat4() throws ParseException { + Tensor x = evaluate("tensor(d0[2]):[1, 2]"); + Tensor y = evaluate("tensor(d0[2]):[3, 4]"); + Tensor expected = evaluate("tensor(d0[4]):[1,2,3,4]"); + assertEval("concat", x, y, expected, createAttribute("axis", 0)); + assertEval("concat", x, y, expected, createAttribute("axis", -1)); + + x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", 0)); + assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", 1)); + assertEval("concat", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,5,6,3,4,7,8]"), createAttribute("axis", -1)); + assertEval("concat", x, y, evaluate("tensor(d0[4],d1[2]):[1,2,3,4,5,6,7,8]"), createAttribute("axis", -2)); + + x = evaluate("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 5, 6, 7, 8]"); + y = evaluate("tensor(d0[2],d1[2],d2[2]):[9,10,11,12,13,14,15,16]"); + assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", 0)); + assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", 1)); + assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", 2)); + assertEval("concat", x, y, evaluate("concat(x, y, d2)", x, y), createAttribute("axis", -1)); + assertEval("concat", x, y, evaluate("concat(x, y, d1)", x, y), createAttribute("axis", -2)); + assertEval("concat", x, y, evaluate("concat(x, y, d0)", x, y), createAttribute("axis", -3)); + } + + @Test + public void testGemm7() throws ParseException { + Tensor a = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + Tensor b = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + Tensor c = evaluate("tensor(d0[2],d1[2]):[0.1, 0.2, 0.3, 0.4]"); + + assertEval("gemm", a, b, evaluate("tensor(d0[2],d1[2]):[19, 22, 43, 50]")); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.3, 50.4]")); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[38.1, 44.2, 86.3, 100.4]"), createAttribute("alpha", 2.0f)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.2, 22.4, 43.6, 50.8]"), createAttribute("beta", 2.0f)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[26.1, 30.2, 38.3, 44.4]"), createAttribute("transA", 1)); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[17.1, 23.2, 39.3, 53.4]"), createAttribute("transB", 1)); + + // unidictional broadcasting for c + c = evaluate("tensor(d0[2]):[0.1, 0.2]"); + assertEval("gemm", a, b, c, evaluate("tensor(d0[2],d1[2]):[19.1, 22.2, 43.1, 50.2]")); + } + + @Test + public void testIdentity1() throws ParseException { + Tensor x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("identity", x, x); + } + + @Test + public void testMatMul1() throws ParseException { + Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]"); + Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]")); + } + + @Test + public void testReshape5() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + Tensor y = evaluate("tensor(d0[1]):[4]"); + assertEval("reshape", x, y, evaluate("tensor(d0[4]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[3]):[2,1,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[1],d2[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,-1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[2,0]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + y = evaluate("tensor(d0[2]):[0,-1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[2]):[1,2,3,4]")); + + x = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + y = evaluate("tensor(d0[2]):[3,2]"); + assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]")); + + y = evaluate("tensor(d0[4]):[3,2,-1,1]"); + assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]")); + } + + @Test + public void testReduceOperators1() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]")); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("axes", new int[] {0,1})); + assertEval("reducesum", x, evaluate("tensor():[10]"), createAttribute("keepdims", 0)); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[1]):[10]"), createAttribute("keepdims", 1)); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[]{0})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[4, 6]"), createAttributes().attr("axes", new int[]{0}).attr("keepdims", 0).build()); + assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {1})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[]{1}).attr("keepdims", 0).build()); + assertEval("reducesum", x, evaluate("tensor(d0[1],d1[2]):[4, 6]"), createAttribute("axes", new int[] {-2})); + assertEval("reducesum", x, evaluate("tensor(d0[2],d1[1]):[3, 7]"), createAttribute("axes", new int[] {-1})); + assertEval("reducesum", x, evaluate("tensor(d0[2]):[3, 7]"), createAttributes().attr("axes", new int[] {-1}).attr("keepdims", 0).build()); + + assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[1]):[24]")); + assertEval("reduceprod", x, evaluate("tensor(d0[1],d1[2]):[3, 8]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemin", x, evaluate("tensor(d0[1],d1[1]):[1]")); + assertEval("reducemin", x, evaluate("tensor(d0[1],d1[2]):[1, 2]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemax", x, evaluate("tensor(d0[1],d1[1]):[4]")); + assertEval("reducemax", x, evaluate("tensor(d0[1],d1[2]):[3, 4]"), createAttribute("axes", new int[] {0})); + + assertEval("reducemean", x, evaluate("tensor():[2.5]"), createAttribute("keepdims", 0)); + assertEval("reducemean", x, evaluate("tensor(d0[2]):[2, 3]"), createAttributes().attr("axes", new int[] {0}).attr("keepdims", 0).build()); + + assertEval("reducelogsum", x, evaluate("tensor():[log(10)]"), createAttribute("keepdims", 0)); + assertEval("reducelogsumexp", x, evaluate("tensor():[log(exp(1)+exp(2)+exp(3)+exp(4))]"), createAttribute("keepdims", 0)); + assertEval("reducesumsquare", x, evaluate("tensor():[1*1+2*2+3*3+4*4]"), createAttribute("keepdims", 0)); + + x = evaluate("tensor(d0[1],d1[5]):[-10, -5, 0, 5, 10]"); + assertEval("reducel1", x, evaluate("tensor():[30]"), createAttribute("keepdims", 0)); + assertEval("reducel2", x, evaluate("tensor():[sqrt(10*10 + 5*5 + 5*5 + 10*10)]"), createAttribute("keepdims", 0)); + } + + @Test + public void testShape1() throws ParseException { + Tensor x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("shape", x, evaluate("tensor(d0[3]):[2,3,4]")); + } + + @Test + public void testSoftmax1() throws ParseException { + Tensor x = evaluate("tensor(d0[1],d1[3]):[-1, 0, 1]"); + assertEval("softmax", x, evaluate("tensor(d0[1],d1[3]):[0.09003058, 0.24472848, 0.66524094]")); + + x = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 7]"); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", 0)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", 1)); // 1 is default + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1)", x), createAttribute("axis", -1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1)", x), createAttribute("axis", -2)); + + x = evaluate("random(d0[2],d1[3],d2[4])"); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", 0)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", 1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", 2)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d2)", x), createAttribute("axis", -1)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d1, d2)", x), createAttribute("axis", -2)); + assertEval("softmax", x, evaluate("exp(x) / sum(exp(x), d0, d1, d2)", x), createAttribute("axis", -3)); + } + + @Test + public void testSqueeze1() throws ParseException { + Tensor x = evaluate("tensor(d0[1],d1[2]):[1, 2]"); + assertEval("squeeze", x, evaluate("tensor(d0[2]):[1, 2]")); + + x = evaluate("tensor(d0[1],d1[2],d2[1],d3[3]):[1,2,3,4,5,6]"); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]")); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[1],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0})); + assertEval("squeeze", x, evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {2})); + assertEval("squeeze", x, evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"), createAttribute("axes", new int[] {0, 2})); + } + + @Test + public void testWhere9() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1, 2, 3, 4]"); + Tensor y = evaluate("tensor(d0[2],d1[2]):[5, 6, 7, 8]"); + Tensor condition = evaluate("tensor(d0[2],d1[2]):[0, 1, 0, 1]"); + assertEval("where", condition, x, y, evaluate("tensor(d0[2],d1[2]):[5, 2, 7, 4]")); + + assertEval("where", evaluate("tensor():[0]"), x, y, y); + assertEval("where", evaluate("tensor():[1]"), x, y, x); + assertEval("where", evaluate("tensor(d0[1]):[0]"), x, y, y); + assertEval("where", evaluate("tensor(d0[1]):[1]"), x, y, x); + assertEval("where", evaluate("tensor(d0[1],d1[1]):[0]"), x, y, y); + assertEval("where", evaluate("tensor(d0[1],d1[1]):[1]"), x, y, x); + } + + private Tensor evaluate(String expr) throws ParseException { + return evaluate(expr, null, null, null); + } + + private Tensor evaluate(String expr, Tensor x) throws ParseException { + return evaluate(expr, x, null, null); + } + + private Tensor evaluate(String expr, Tensor x, Tensor y) throws ParseException { + return evaluate(expr, x, y, null); + } + + private Tensor evaluate(String expr, Tensor x, Tensor y, Tensor z) throws ParseException { + Context context = new MapContext(DoubleValue.NaN); + if (x != null) context.put("x", new TensorValue(x)); + if (y != null) context.put("y", new TensorValue(y)); + if (z != null) context.put("z", new TensorValue(z)); + return new RankingExpression(expr).evaluate(context).asTensor(); + } + + private Tensor evaluate(IntermediateOperation op) { + Tensor tensor = op.evaluateAsConstant(op.type().get()).asTensor(); + return renameToStandardType(op, tensor); + } + + private void assertEval(String opName, Tensor x, Tensor expected) { + assertEval(opName, x, null, null, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) { + assertEval(opName, x, null, null, expected, attr); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) { + assertEval(opName, x, y, null, expected, attr); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) { + assertEval(opName, x, y, null, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) { + assertEval(opName, x, y, z, expected, null); + } + + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) { + Context context = new MapContext(DoubleValue.NaN); + List<IntermediateOperation> inputs = createInputs(context, x, y, z); + IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build()); + optimizeAndRename(opName, op); + Tensor result = evaluate(op); + assertEquals(expected, result); + assertEquals(expected.type(), result.type()); + } + + private List<IntermediateOperation> createInputs(Context context, Tensor x, Tensor y, Tensor z) { + List<IntermediateOperation> inputs = new ArrayList<>(); + addInput(inputs, context, x, "x"); + addInput(inputs, context, y, "y"); + addInput(inputs, context, z, "z"); + return inputs; + } + + private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) { + if (x == null) return; + context.put(name, new TensorValue(x)); + IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString())); + op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type))); + inputs.add(op); + } + + Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) { + IndexedTensor indexedTensor = (IndexedTensor) tensor; + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); + for (int i = 0; i < indexedTensor.size(); i++) { + builder.cellByDirectIndex(type.toDirectIndex(i), indexedTensor.get(i)); + } + return builder.build(); + } + + private TensorFunction optimizeAndRename(String opName, IntermediateOperation op) { + IntermediateGraph graph = new IntermediateGraph(modelName); + graph.put(opName, op); + graph.outputs(graph.defaultSignature()).put(opName, opName); + graph.optimize(); + return op.function().get(); + } + + private Tensor renameToStandardType(IntermediateOperation op, Tensor tensor) { + OrderedTensorType operationType = op.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + TensorFunction func = new Rename(new ConstantTensor(tensor), renameFrom, renameTo); + return func.evaluate(); + } + return tensor; + } + + static AttributeConverter createAttribute(String name, int val) { + return new Attributes().attr(name, val).build(); + } + + static AttributeConverter createAttribute(String name, float val) { + return new Attributes().attr(name, val).build(); + } + + static AttributeConverter createAttribute(String name, int [] vals) { + return new Attributes().attr(name, vals).build(); + } + + static Attributes createAttributes() { + return new Attributes(); + } + + private static class Attributes { + + Onnx.NodeProto.Builder nodeBuilder; + + Attributes() { + this.nodeBuilder = Onnx.NodeProto.newBuilder(); + } + + Attributes attr(String name, int val) { + nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(INT).setI(val).build()); + return this; + } + + Attributes attr(String name, float val) { + nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(FLOAT).setF(val).build()); + return this; + } + + Attributes attr(String name, int [] vals) { + Onnx.AttributeProto.Builder builder = Onnx.AttributeProto.newBuilder(); + for (int val : vals) { + builder.addInts(val); + } + nodeBuilder.addAttribute(builder.setName(name).setType(INTS).build()); + return this; + } + + AttributeConverter build() { + return AttributeConverter.convert(nodeBuilder.build()); + } + + } + +} |