aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-11-22 11:40:14 +0100
committerLester Solbakken <lesters@oath.com>2019-11-22 11:40:14 +0100
commit296340ac996edac09a4f53997ae1a8a803d302c1 (patch)
treef7abfd7b4c30529bbdbc03334f523fcbebdd27e2 /model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations
parent69e4f6bf072d8ebfb12761c450f2bdacf86e226c (diff)
Add additional ONNX operations
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java183
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java105
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java121
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java35
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java7
9 files changed, 457 insertions, 15 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
index dad4508bc61..f68372ce4dd 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
@@ -3,7 +3,6 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
-import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
@@ -39,7 +38,16 @@ public class Argument extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- addConstraintsFrom(type, renamer);
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ renamer.addDimension(type.dimensions().get(i).name());
+
+ // Each dimension is distinct and ordered correctly:
+ for (int j = i + 1; j < type.dimensions().size(); j++) {
+ renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
+ DimensionRenamer.Constraint.lessThan(false),
+ this);
+ }
+ }
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index e6cc96d48ad..3487d889338 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -82,9 +82,7 @@ public class ExpandDims extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
+ addConstraintsFrom(type, renamer);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
new file mode 100644
index 00000000000..b116f18c7d1
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -0,0 +1,183 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.text.ExpressionFormatter;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Gemm extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private final float alpha, beta;
+ private final int transposeA, transposeB;
+
+ private final static DoubleValue zero = DoubleValue.frozen(0.0);
+ private final static DoubleValue one = DoubleValue.frozen(1.0);
+
+ public Gemm(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.alpha = (float) attributeMap.get("alpha").orElse(one).asDouble();
+ this.beta = (float) attributeMap.get("beta").orElse(one).asDouble();
+ this.transposeA = (int) attributeMap.get("transA").orElse(zero).asDouble();
+ this.transposeB = (int) attributeMap.get("transB").orElse(zero).asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! check2or3InputsPresent()) return null;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+
+ TensorType.Dimension dimA = inputs.get(0).type().get().dimensions().get(transposeA);
+ TensorType.Dimension dimB = inputs.get(1).type().get().dimensions().get(1 - transposeB);
+
+ typeBuilder.add(dimA);
+ typeBuilder.add(dimB);
+ OrderedTensorType result = typeBuilder.build();
+
+ // Input tensor C. The shape of C should be unidirectional "broadcastable" to (dimA, dimB).
+ if (inputs.size() == 3) {
+ List<TensorType.Dimension> cDimensions = inputs.get(2).type().get().dimensions();
+ if (cDimensions.size() == 2) {
+ TensorType.Dimension dimC0 = cDimensions.get(0);
+ TensorType.Dimension dimC1 = cDimensions.get(1);
+
+ if ( ! (dimA.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) {
+ throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
+ " is not compatible or not broadcastable to " + result.type());
+ }
+ if ( ! (dimB.size().get().equals(dimC1.size().get()) || dimC1.size().get() == 1) ) {
+ throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
+ " is not compatible or not broadcastable to " + result.type());
+ }
+
+ } else if (cDimensions.size() == 1) {
+ TensorType.Dimension dimC0 = cDimensions.get(0);
+ if ( ! (dimB.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) {
+ throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
+ " is not compatible or not broadcastable to " + result.type());
+ }
+ } else {
+ throw new IllegalArgumentException("GEMM: optional input C has no dimensions.");
+ }
+ }
+
+ return result;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! check2or3InputsPresent()) return null;
+
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+ if (aType.type().rank() != 2 || bType.type().rank() != 2)
+ throw new IllegalArgumentException("Tensors in Gemm must have rank of exactly 2");
+
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (aFunction.isEmpty() || bFunction.isEmpty()) {
+ return null;
+ }
+
+ String joinDimension = aType.dimensions().get(1).name(); // TODO: check wrt transpose!
+
+ TensorFunction AxB = new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), joinDimension);
+ TensorFunction alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
+ new ArithmeticNode(
+ new TensorFunctionNode(AxB),
+ ArithmeticOperator.MULTIPLY,
+ new ConstantNode(new DoubleValue(alpha))));
+
+ if (inputs.size() == 3) {
+ Optional<TensorFunction> cFunction = inputs.get(2).function();
+ TensorFunction betaxC = new TensorFunctionNode.ExpressionTensorFunction(
+ new ArithmeticNode(
+ new TensorFunctionNode(cFunction.get()),
+ ArithmeticOperator.MULTIPLY,
+ new ConstantNode(new DoubleValue(beta))));
+ return new com.yahoo.tensor.functions.Join(alphaxAxB, betaxC, ScalarFunctions.add());
+ }
+
+ return alphaxAxB;
+ }
+
+ private boolean check2or3InputsPresent() {
+ if (inputs.size() != 2 && inputs.size() != 3) {
+ throw new IllegalArgumentException("Expected 2 or 3 inputs for '" + name + "', got " + inputs.size());
+ }
+ return inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if ( ! check2or3InputsPresent()) return;
+
+ List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+
+ assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
+ assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+
+ String aDim0 = aDimensions.get(transposeA).name();
+ String aDim1 = aDimensions.get(1 - transposeA).name();
+ String bDim0 = bDimensions.get(transposeB).name();
+ String bDim1 = bDimensions.get(1 - transposeB).name();
+
+ // The second dimension of a should have the same name as the first dimension of b
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this);
+
+ // The first dimension of a should have a different name than the second dimension of b
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this);
+
+ // If c is given, should be unidirectionally broadcastable to tensor a * b:
+ // Tensor A and B both have exactly the same shape.
+ // Tensor A and B all have the same number of dimensions and the length of each dimensions is either a common length or B's length is 1.
+ // Tensor B has too few dimensions, and B can have its shapes prepended with a dimension of length 1 to satisfy property 2.
+ if (inputs.size() == 3) {
+ List<TensorType.Dimension> cDimensions = inputs.get(2).type().get().dimensions();
+
+ if (cDimensions.size() == 2) {
+ String cDim0 = cDimensions.get(0).name();
+ String cDim1 = cDimensions.get(1).name();
+ renamer.addConstraint(aDim0, cDim0, DimensionRenamer.Constraint.equal(false), this);
+ renamer.addConstraint(bDim1, cDim1, DimensionRenamer.Constraint.equal(false), this);
+ } else if (cDimensions.size() == 1) {
+ String cDim0 = cDimensions.get(0).name();
+ renamer.addConstraint(bDim1, cDim0, DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
+
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this);
+ }
+
+ private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
+ if (dimensions.size() >= 2) return;
+ throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
+ " but got just " + dimensions + " from\n" +
+ ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
+ }
+
+ @Override
+ public Gemm withInputs(List<IntermediateOperation> inputs) {
+ return new Gemm(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "Gemm"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 434261c6077..6849e64641e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -73,8 +73,6 @@ public class MatMul extends IntermediateOperation {
private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
if (dimensions.size() >= 2) return;
-
-
throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
" but got just " + dimensions + " from\n" +
ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
new file mode 100644
index 00000000000..ded76db60fe
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConcat.java
@@ -0,0 +1,105 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+import java.util.Optional;
+
+public class OnnxConcat extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private String concatDimensionName;
+ private int concatDimensionIndex;
+
+ public OnnxConcat(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ if (attributeMap.get("axis").isEmpty())
+ throw new IllegalArgumentException("OnnxConcat in " + name + ": Required attribute 'axis' is missing.");
+ this.concatDimensionIndex = (int) attributeMap.get("axis").get().asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null;
+
+ OrderedTensorType aType = inputs.get(0).type().get();
+ long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L);
+
+ for (int i = 1; i < inputs.size(); ++i) {
+ OrderedTensorType bType = inputs.get(i).type().get();
+ if (bType.rank() != aType.rank())
+ throw new IllegalArgumentException("OnnxConcat in " + name + ": Inputs must have the same rank.");
+
+ for (int j = 0; j < aType.rank(); ++j) {
+ long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
+ long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
+ if (j == concatDimensionIndex) {
+ concatDimSize += dimSizeB;
+ } else if (dimSizeA != dimSizeB) {
+ throw new IllegalArgumentException("OnnxConcat in " + name + ": " +
+ "input dimension " + j + " differs in input tensors.");
+ }
+ }
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ int dimensionIndex = 0;
+ for (TensorType.Dimension dimension : aType.dimensions()) {
+ if (dimensionIndex == concatDimensionIndex) {
+ concatDimensionName = dimension.name();
+ typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize));
+ } else {
+ typeBuilder.add(dimension);
+ }
+ dimensionIndex++;
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
+ return null;
+ }
+ TensorFunction result = inputs.get(0).function().get();
+ for (int i = 1; i < inputs.size(); ++i) {
+ TensorFunction b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName);
+ }
+ return result;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
+ return;
+ }
+ OrderedTensorType a = inputs.get(0).type().get();
+ for (int i = 1; i < inputs.size(); ++i) {
+ OrderedTensorType b = inputs.get(i).type().get();
+ String bDim = b.dimensions().get(concatDimensionIndex).name();
+ String aDim = a.dimensions().get(concatDimensionIndex).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
+ }
+
+ @Override
+ public OnnxConcat withInputs(List<IntermediateOperation> inputs) {
+ return new OnnxConcat(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "ConcatV2"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
new file mode 100644
index 00000000000..1b2d9ac090e
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
@@ -0,0 +1,121 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * ONNX Reduce[Sum/Mean/etc] operation
+ */
+public class Reduce extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator;
+
+ private List<String> reduceDimensions;
+
+ public Reduce(String modelName, String nodeName,
+ List<IntermediateOperation> inputs,
+ AttributeMap attributeMap,
+ com.yahoo.tensor.functions.Reduce.Aggregator aggregator) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.aggregator = aggregator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ reduceDimensions = inputType.dimensionNames(); // default is to reduce all dimensions
+ if (attributeMap.getList("axes").isPresent()) {
+ reduceDimensions = new ArrayList<>();
+ for (Value i : attributeMap.getList("axes").get()) {
+ int dimensionIndex = (int) i.asDouble();
+ if (dimensionIndex < 0) {
+ dimensionIndex = inputType.dimensions().size() - dimensionIndex;
+ }
+ reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
+ }
+ }
+ return reducedType(inputType, shouldKeepDimensions());
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction output = new com.yahoo.tensor.functions.Reduce(inputFunction, aggregator, reduceDimensions);
+ if (shouldKeepDimensions()) {
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
+ for (String name : reduceDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ }
+ return output;
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size());
+ for (String name : reduceDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (newName.isEmpty()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ reduceDimensions = renamedDimensions;
+ }
+
+ @Override
+ public Reduce withInputs(List<IntermediateOperation> inputs) {
+ return new Reduce(modelName(), name(), inputs, attributeMap, aggregator);
+ }
+
+ @Override
+ public String operationName() { return "Reduce"; }
+
+ private boolean shouldKeepDimensions() {
+ Optional<Value> keepDims = attributeMap.get("keepdims");
+ return keepDims.isPresent() && keepDims.get().asBoolean();
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if ( ! reduceDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ } else if (keepDimensions) {
+ builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
+ }
+ }
+ return builder.build();
+ }
+
+
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index a210ed13f5d..c7accd00619 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
@@ -16,6 +17,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
@@ -35,7 +37,7 @@ public class Reshape extends IntermediateOperation {
if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation newShape = inputs.get(1);
- if ( ! newShape.getConstantValue().isPresent())
+ if (newShape.getConstantValue().isEmpty())
throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant.");
Tensor shape = newShape.getConstantValue().get().asTensor();
@@ -69,9 +71,7 @@ public class Reshape extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
+ addConstraintsFrom(type, renamer);
}
@Override
@@ -89,17 +89,40 @@ public class Reshape extends IntermediateOperation {
// the new shape. We have to introduce temporary dimension names and rename back if dimension names
// in the new and old tensor type overlap.
+ List<String> from = new ArrayList<>();
+ List<String> to = new ArrayList<>();
+ boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType);
+ if (dimensionNamesOverlap) {
+ TensorType.Builder builder = new TensorType.Builder(outputType.valueType());
+ for (int i = 0; i < outputType.rank(); ++i) {
+ TensorType.Dimension dim = outputType.dimensions().get(i);
+ from.add(dim.name());
+ to.add("temp_" + dim.name());
+ builder.dimension(dim.withName("temp_" + dim.name()));
+ }
+ outputType = builder.build();
+ }
+
ExpressionNode unrollFrom = unrollTensorExpression(inputType);
ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
+ ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, new EmbracedNode(unrollTo));
TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
Generate transformTensor = new Generate(transformationType,
new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
- return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
+ TensorFunction result = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
Reduce.Aggregator.sum,
inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+
+ if (dimensionNamesOverlap) {
+ result = new Rename(result, to, from);
+ }
+ return result;
+ }
+
+ private static boolean dimensionNamesOverlap(TensorType a, TensorType b) {
+ return a.dimensionNames().stream().anyMatch(d -> b.dimension(d).isPresent());
}
private static ExpressionNode unrollTensorExpression(TensorType type) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index 35a1b6e2b0e..8696d0f1858 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -51,6 +51,9 @@ public class Select extends IntermediateOperation {
if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
}
+ if (condition.type().rank() == 2 && dimensionSize(condition.type().dimensions().get(0)) == 1 && dimensionSize(condition.type().dimensions().get(1)) == 1) {
+ return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
+ }
}
// The task is to select cells from 'x' or 'y' based on 'condition'.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index 56d9b542093..a9e3fc6a43a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -31,7 +31,10 @@ public class Squeeze extends IntermediateOperation {
squeezeDimensions = new ArrayList<>();
Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims");
- if ( ! squeezeDimsAttr.isPresent()) {
+ if (squeezeDimsAttr.isEmpty()) {
+ squeezeDimsAttr = attributeMap.getList("axes"); // ONNX
+ }
+ if (squeezeDimsAttr.isEmpty()) {
squeezeDimensions = inputType.type().dimensions().stream().
filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
map(TensorType.Dimension::name).
@@ -62,7 +65,7 @@ public class Squeeze extends IntermediateOperation {
List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size());
for (String name : squeezeDimensions) {
Optional<String> newName = renamer.dimensionNameOf(name);
- if (!newName.isPresent()) {
+ if (newName.isEmpty()) {
return; // presumably, already renamed
}
renamedDimensions.add(newName.get());