From 296340ac996edac09a4f53997ae1a8a803d302c1 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 22 Nov 2019 11:40:14 +0100 Subject: Add additional ONNX operations --- .../importer/onnx/AttributeConverter.java | 69 ++++++++ .../importer/onnx/GraphImporter.java | 21 ++- .../importer/onnx/TensorConverter.java | 58 +++++++ .../importer/operations/Argument.java | 12 +- .../importer/operations/ExpandDims.java | 4 +- .../importer/operations/Gemm.java | 183 +++++++++++++++++++++ .../importer/operations/MatMul.java | 2 - .../importer/operations/OnnxConcat.java | 105 ++++++++++++ .../importer/operations/Reduce.java | 121 ++++++++++++++ .../importer/operations/Reshape.java | 35 +++- .../importer/operations/Select.java | 3 + .../importer/operations/Squeeze.java | 7 +- 12 files changed, 601 insertions(+), 19 deletions(-) create mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java create mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java create mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java create mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java (limited to 'model-integration/src') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java new file mode 100644 index 00000000000..8caa158e5be --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/AttributeConverter.java @@ -0,0 +1,69 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import onnx.Onnx; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Converts Onnx node attributes to Vespa attribute values. + * + * @author lesters + */ +class AttributeConverter implements IntermediateOperation.AttributeMap { + + private final Onnx.NodeProto node; + + private AttributeConverter(Onnx.NodeProto node) { + this.node = node; + } + + static AttributeConverter convert(Onnx.NodeProto node) { + return new AttributeConverter(node); + } + + @Override + public Optional get(String name) { + for (Onnx.AttributeProto attr : node.getAttributeList()) { + if (attr.getName().equals(name)) { + switch (attr.getType()) { + case INT: return Optional.of(DoubleValue.frozen(attr.getI())); + case FLOAT: return Optional.of(DoubleValue.frozen(attr.getF())); + case STRING: return Optional.of(StringValue.frozen(attr.getS().toString())); + default: + return Optional.empty(); + } + } + } + return Optional.empty(); + } + + @Override + public Optional get(String name, OrderedTensorType type) { + return Optional.empty(); + } + + @Override + public Optional> getList(String name) { + for (Onnx.AttributeProto attr : node.getAttributeList()) { + if (attr.getName().equals(name)) { + switch (attr.getType()) { + case INTS: return Optional.of(attr.getIntsList().stream().map(DoubleValue::new).collect(Collectors.toList())); + case FLOATS: return Optional.of(attr.getFloatsList().stream().map(DoubleValue::new).collect(Collectors.toList())); + case STRINGS: return Optional.of(attr.getStringsList().stream().map((s) -> StringValue.frozen(s.toString())).collect(Collectors.toList())); + default: + return Optional.empty(); + } + } + } + return Optional.empty(); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index ccf3c2d8fb0..4fa6c09c636 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,11 +2,16 @@ package ai.vespa.rankingexpression.importer.onnx; +import ai.vespa.rankingexpression.importer.operations.Gemm; +import ai.vespa.rankingexpression.importer.operations.OnnxConcat; +import ai.vespa.rankingexpression.importer.operations.Reduce; +import ai.vespa.rankingexpression.importer.operations.Select; +import ai.vespa.rankingexpression.importer.operations.Softmax; +import ai.vespa.rankingexpression.importer.operations.Squeeze; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.operations.Argument; -import ai.vespa.rankingexpression.importer.operations.ConcatV2; import ai.vespa.rankingexpression.importer.operations.Constant; import ai.vespa.rankingexpression.importer.operations.Identity; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; @@ -36,6 +41,7 @@ class GraphImporter { IntermediateGraph graph) { String modelName = graph.name(); String nodeName = getNodeName(node); + AttributeConverter attributes = AttributeConverter.convert(node); switch (node.getOpType().toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); @@ -44,13 +50,14 @@ class GraphImporter { case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); - case "concat": return new ConcatV2(modelName, nodeName, inputs); + case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "gemm": return new Gemm(modelName, nodeName, inputs, attributes); case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); case "identity": return new Identity(modelName, nodeName, inputs); case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); @@ -63,15 +70,21 @@ class GraphImporter { case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reducesum": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum); + case "reducemean": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.avg); case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "leakyrelu": return new Map(modelName, nodeName, inputs, ScalarFunctions.leakyrelu()); case "shape": return new Shape(modelName, nodeName, inputs); - case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); - case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); + case "softmax": return new Softmax(modelName, nodeName, inputs); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); + case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); + case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index f3d87d89c27..69d18d0ffcb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -10,7 +10,10 @@ import onnx.Onnx; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.DoubleBuffer; import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; /** * Converts Onnx tensors into Vespa tensors. @@ -31,11 +34,16 @@ class TensorConverter { private static Values readValuesOf(Onnx.TensorProto tensorProto) { if (tensorProto.hasRawData()) { switch (tensorProto.getDataType()) { + case BOOL: return new RawBoolValues(tensorProto); case FLOAT: return new RawFloatValues(tensorProto); + case DOUBLE: return new RawDoubleValues(tensorProto); + case INT64: return new RawLongValues(tensorProto); } } else { switch (tensorProto.getDataType()) { case FLOAT: return new FloatValues(tensorProto); + case DOUBLE: return new DoubleValues(tensorProto); + case INT64: return new LongValues(tensorProto); } } throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + @@ -55,6 +63,17 @@ class TensorConverter { } } + private static class RawBoolValues extends RawValues { + private final IntBuffer values; + private final int size; + RawBoolValues(Onnx.TensorProto tensorProto) { + values = bytes(tensorProto).asIntBuffer(); + size = values.remaining(); + } + @Override double get(int i) { return values.get(i); } + @Override int size() { return size; } + } + private static class RawFloatValues extends RawValues { private final FloatBuffer values; private final int size; @@ -66,6 +85,28 @@ class TensorConverter { @Override int size() { return size; } } + private static class RawDoubleValues extends RawValues { + private final DoubleBuffer values; + private final int size; + RawDoubleValues(Onnx.TensorProto tensorProto) { + values = bytes(tensorProto).asDoubleBuffer(); + size = values.remaining(); + } + @Override double get(int i) { return values.get(i); } + @Override int size() { return size; } + } + + private static class RawLongValues extends RawValues { + private final LongBuffer values; + private final int size; + RawLongValues(Onnx.TensorProto tensorProto) { + values = bytes(tensorProto).asLongBuffer(); + size = values.remaining(); + } + @Override double get(int i) { return values.get(i); } + @Override int size() { return size; } + } + private static class FloatValues extends Values { private final Onnx.TensorProto tensorProto; FloatValues(Onnx.TensorProto tensorProto) { @@ -75,5 +116,22 @@ class TensorConverter { @Override int size() { return tensorProto.getFloatDataCount(); } } + private static class DoubleValues extends Values { + private final Onnx.TensorProto tensorProto; + DoubleValues(Onnx.TensorProto tensorProto) { + this.tensorProto = tensorProto; + } + @Override double get(int i) { return tensorProto.getDoubleData(i); } + @Override int size() { return tensorProto.getDoubleDataCount(); } + } + + private static class LongValues extends Values { + private final Onnx.TensorProto tensorProto; + LongValues(Onnx.TensorProto tensorProto) { + this.tensorProto = tensorProto; + } + @Override double get(int i) { return tensorProto.getInt64Data(i); } + @Override int size() { return tensorProto.getInt64DataCount(); } + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index dad4508bc61..f68372ce4dd 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -3,7 +3,6 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; -import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; @@ -39,7 +38,16 @@ public class Argument extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - addConstraintsFrom(type, renamer); + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + + // Each dimension is distinct and ordered correctly: + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.lessThan(false), + this); + } + } } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index e6cc96d48ad..3487d889338 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -82,9 +82,7 @@ public class ExpandDims extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } + addConstraintsFrom(type, renamer); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java new file mode 100644 index 00000000000..b116f18c7d1 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -0,0 +1,183 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ExpressionFormatter; + +import java.util.List; +import java.util.Optional; + +public class Gemm extends IntermediateOperation { + + private final AttributeMap attributeMap; + private final float alpha, beta; + private final int transposeA, transposeB; + + private final static DoubleValue zero = DoubleValue.frozen(0.0); + private final static DoubleValue one = DoubleValue.frozen(1.0); + + public Gemm(String modelName, String nodeName, List inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + this.alpha = (float) attributeMap.get("alpha").orElse(one).asDouble(); + this.beta = (float) attributeMap.get("beta").orElse(one).asDouble(); + this.transposeA = (int) attributeMap.get("transA").orElse(zero).asDouble(); + this.transposeB = (int) attributeMap.get("transB").orElse(zero).asDouble(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! check2or3InputsPresent()) return null; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + + TensorType.Dimension dimA = inputs.get(0).type().get().dimensions().get(transposeA); + TensorType.Dimension dimB = inputs.get(1).type().get().dimensions().get(1 - transposeB); + + typeBuilder.add(dimA); + typeBuilder.add(dimB); + OrderedTensorType result = typeBuilder.build(); + + // Input tensor C. The shape of C should be unidirectional "broadcastable" to (dimA, dimB). + if (inputs.size() == 3) { + List cDimensions = inputs.get(2).type().get().dimensions(); + if (cDimensions.size() == 2) { + TensorType.Dimension dimC0 = cDimensions.get(0); + TensorType.Dimension dimC1 = cDimensions.get(1); + + if ( ! (dimA.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) { + throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + + " is not compatible or not broadcastable to " + result.type()); + } + if ( ! (dimB.size().get().equals(dimC1.size().get()) || dimC1.size().get() == 1) ) { + throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + + " is not compatible or not broadcastable to " + result.type()); + } + + } else if (cDimensions.size() == 1) { + TensorType.Dimension dimC0 = cDimensions.get(0); + if ( ! (dimB.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) { + throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + + " is not compatible or not broadcastable to " + result.type()); + } + } else { + throw new IllegalArgumentException("GEMM: optional input C has no dimensions."); + } + } + + return result; + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! check2or3InputsPresent()) return null; + + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + if (aType.type().rank() != 2 || bType.type().rank() != 2) + throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2"); + + Optional aFunction = inputs.get(0).function(); + Optional bFunction = inputs.get(1).function(); + if (aFunction.isEmpty() || bFunction.isEmpty()) { + return null; + } + + String joinDimension = aType.dimensions().get(1).name(); // TODO: check wrt transpose! + + TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension); + TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( + new ArithmeticNode( + new TensorFunctionNode(AxB), + ArithmeticOperator.MULTIPLY, + new ConstantNode(new DoubleValue(alpha)))); + + if (inputs.size() == 3) { + Optional cFunction = inputs.get(2).function(); + TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction( + new ArithmeticNode( + new TensorFunctionNode(cFunction.get()), + ArithmeticOperator.MULTIPLY, + new ConstantNode(new DoubleValue(beta)))); + return new com.yahoo.tensor.functions.Join(alphaxAxB, betaxC, ScalarFunctions.add()); + } + + return alphaxAxB; + } + + private boolean check2or3InputsPresent() { + if (inputs.size() != 2 && inputs.size() != 3) { + throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size()); + } + return inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if ( ! check2or3InputsPresent()) return; + + List aDimensions = inputs.get(0).type().get().dimensions(); + List bDimensions = inputs.get(1).type().get().dimensions(); + + assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); + assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + + String aDim0 = aDimensions.get(transposeA).name(); + String aDim1 = aDimensions.get(1 - transposeA).name(); + String bDim0 = bDimensions.get(transposeB).name(); + String bDim1 = bDimensions.get(1 - transposeB).name(); + + // The second dimension of a should have the same name as the first dimension of b + renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); + + // The first dimension of a should have a different name than the second dimension of b + renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); + + // If c is given, should be unidirectionally broadcastable to tensor a * b: + // Tensor A and B both have exactly the same shape. + // Tensor A and B all have the same number of dimensions and the length of each dimensions is either a common length or B's length is 1. + // Tensor B has too few dimensions, and B can have its shapes prepended with a dimension of length 1 to satisfy property 2. + if (inputs.size() == 3) { + List cDimensions = inputs.get(2).type().get().dimensions(); + + if (cDimensions.size() == 2) { + String cDim0 = cDimensions.get(0).name(); + String cDim1 = cDimensions.get(1).name(); + renamer.addConstraint(aDim0, cDim0, DimensionRenamer.Constraint.equal(false), this); + renamer.addConstraint(bDim1, cDim1, DimensionRenamer.Constraint.equal(false), this); + } else if (cDimensions.size() == 1) { + String cDim0 = cDimensions.get(0).name(); + renamer.addConstraint(bDim1, cDim0, DimensionRenamer.Constraint.equal(false), this); + } + } + + // For efficiency, the dimensions to join over should be innermost - soft constraint + renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); + } + + private void assertTwoDimensions(List dimensions, IntermediateOperation supplier, String inputDescription) { + if (dimensions.size() >= 2) return; + throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + + " but got just " + dimensions + " from\n" + + ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); + } + + @Override + public Gemm withInputs(List inputs) { + return new Gemm(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Gemm"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 434261c6077..6849e64641e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -73,8 +73,6 @@ public class MatMul extends IntermediateOperation { private void assertTwoDimensions(List dimensions, IntermediateOperation supplier, String inputDescription) { if (dimensions.size() >= 2) return; - - throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + " but got just " + dimensions + " from\n" + ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java new file mode 100644 index 00000000000..ded76db60fe --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java @@ -0,0 +1,105 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.Optional; + +public class OnnxConcat extends IntermediateOperation { + + private final AttributeMap attributeMap; + private String concatDimensionName; + private int concatDimensionIndex; + + public OnnxConcat(String modelName, String nodeName, List inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + if (attributeMap.get("axis").isEmpty()) + throw new IllegalArgumentException("OnnxConcat in " + name + ": Required attribute 'axis' is missing."); + this.concatDimensionIndex = (int) attributeMap.get("axis").get().asDouble(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null; + + OrderedTensorType aType = inputs.get(0).type().get(); + long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L); + + for (int i = 1; i < inputs.size(); ++i) { + OrderedTensorType bType = inputs.get(i).type().get(); + if (bType.rank() != aType.rank()) + throw new IllegalArgumentException("OnnxConcat in " + name + ": Inputs must have the same rank."); + + for (int j = 0; j < aType.rank(); ++j) { + long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); + long dimSizeB = bType.dimensions().get(j).size().orElse(-1L); + if (j == concatDimensionIndex) { + concatDimSize += dimSizeB; + } else if (dimSizeA != dimSizeB) { + throw new IllegalArgumentException("OnnxConcat in " + name + ": " + + "input dimension " + j + " differs in input tensors."); + } + } + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + int dimensionIndex = 0; + for (TensorType.Dimension dimension : aType.dimensions()) { + if (dimensionIndex == concatDimensionIndex) { + concatDimensionName = dimension.name(); + typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize)); + } else { + typeBuilder.add(dimension); + } + dimensionIndex++; + } + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { + return null; + } + TensorFunction result = inputs.get(0).function().get(); + for (int i = 1; i < inputs.size(); ++i) { + TensorFunction b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName); + } + return result; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { + return; + } + OrderedTensorType a = inputs.get(0).type().get(); + for (int i = 1; i < inputs.size(); ++i) { + OrderedTensorType b = inputs.get(i).type().get(); + String bDim = b.dimensions().get(concatDimensionIndex).name(); + String aDim = a.dimensions().get(concatDimensionIndex).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); + } + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); + } + + @Override + public OnnxConcat withInputs(List inputs) { + return new OnnxConcat(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "ConcatV2"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java new file mode 100644 index 00000000000..1b2d9ac090e --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java @@ -0,0 +1,121 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * ONNX Reduce[Sum/Mean/etc] operation + */ +public class Reduce extends IntermediateOperation { + + private final AttributeMap attributeMap; + private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator; + + private List reduceDimensions; + + public Reduce(String modelName, String nodeName, + List inputs, + AttributeMap attributeMap, + com.yahoo.tensor.functions.Reduce.Aggregator aggregator) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + this.aggregator = aggregator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().get(); + + reduceDimensions = inputType.dimensionNames(); // default is to reduce all dimensions + if (attributeMap.getList("axes").isPresent()) { + reduceDimensions = new ArrayList<>(); + for (Value i : attributeMap.getList("axes").get()) { + int dimensionIndex = (int) i.asDouble(); + if (dimensionIndex < 0) { + dimensionIndex = inputType.dimensions().size() - dimensionIndex; + } + reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); + } + } + return reducedType(inputType, shouldKeepDimensions()); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputTypesPresent(1)) return null; + + TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions); + if (shouldKeepDimensions()) { + // multiply with a generated tensor created from the reduced dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); + for (String name : reduceDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + } + return output; + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List renamedDimensions = new ArrayList<>(reduceDimensions.size()); + for (String name : reduceDimensions) { + Optional newName = renamer.dimensionNameOf(name); + if (newName.isEmpty()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + reduceDimensions = renamedDimensions; + } + + @Override + public Reduce withInputs(List inputs) { + return new Reduce(modelName(), name(), inputs, attributeMap, aggregator); + } + + @Override + public String operationName() { return "Reduce"; } + + private boolean shouldKeepDimensions() { + Optional keepDims = attributeMap.get("keepdims"); + return keepDims.isPresent() && keepDims.get().asBoolean(); + } + + private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); + for (TensorType.Dimension dimension: inputType.type().dimensions()) { + if ( ! reduceDimensions.contains(dimension.name())) { + builder.add(dimension); + } else if (keepDimensions) { + builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); + } + } + return builder.build(); + } + + + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index a210ed13f5d..c7accd00619 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -16,6 +17,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; @@ -35,7 +37,7 @@ public class Reshape extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; IntermediateOperation newShape = inputs.get(1); - if ( ! newShape.getConstantValue().isPresent()) + if (newShape.getConstantValue().isEmpty()) throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant."); Tensor shape = newShape.getConstantValue().get().asTensor(); @@ -69,9 +71,7 @@ public class Reshape extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } + addConstraintsFrom(type, renamer); } @Override @@ -89,17 +89,40 @@ public class Reshape extends IntermediateOperation { // the new shape. We have to introduce temporary dimension names and rename back if dimension names // in the new and old tensor type overlap. + List from = new ArrayList<>(); + List to = new ArrayList<>(); + boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType); + if (dimensionNamesOverlap) { + TensorType.Builder builder = new TensorType.Builder(outputType.valueType()); + for (int i = 0; i < outputType.rank(); ++i) { + TensorType.Dimension dim = outputType.dimensions().get(i); + from.add(dim.name()); + to.add("temp_" + dim.name()); + builder.dimension(dim.withName("temp_" + dim.name())); + } + outputType = builder.build(); + } + ExpressionNode unrollFrom = unrollTensorExpression(inputType); ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo); + ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo)); TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); Generate transformTensor = new Generate(transformationType, new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), + TensorFunction result = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); + + if (dimensionNamesOverlap) { + result = new Rename(result, to, from); + } + return result; + } + + private static boolean dimensionNamesOverlap(TensorType a, TensorType b) { + return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent()); } private static ExpressionNode unrollTensorExpression(TensorType type) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index 35a1b6e2b0e..8696d0f1858 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -51,6 +51,9 @@ public class Select extends IntermediateOperation { if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { return condition.cellIterator().next().getValue().intValue() == 0 ? b : a; } + if (condition.type().rank() == 2 && dimensionSize(condition.type().dimensions().get(0)) == 1 && dimensionSize(condition.type().dimensions().get(1)) == 1) { + return condition.cellIterator().next().getValue().intValue() == 0 ? b : a; + } } // The task is to select cells from 'x' or 'y' based on 'condition'. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 56d9b542093..a9e3fc6a43a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -31,7 +31,10 @@ public class Squeeze extends IntermediateOperation { squeezeDimensions = new ArrayList<>(); Optional> squeezeDimsAttr = attributeMap.getList("squeeze_dims"); - if ( ! squeezeDimsAttr.isPresent()) { + if (squeezeDimsAttr.isEmpty()) { + squeezeDimsAttr = attributeMap.getList("axes"); // ONNX + } + if (squeezeDimsAttr.isEmpty()) { squeezeDimensions = inputType.type().dimensions().stream(). filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). map(TensorType.Dimension::name). @@ -62,7 +65,7 @@ public class Squeeze extends IntermediateOperation { List renamedDimensions = new ArrayList<>(squeezeDimensions.size()); for (String name : squeezeDimensions) { Optional newName = renamer.dimensionNameOf(name); - if (!newName.isPresent()) { + if (newName.isEmpty()) { return; // presumably, already renamed } renamedDimensions.add(newName.get()); -- cgit v1.2.3