aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-19 15:06:04 +0200
committerLester Solbakken <lesters@oath.com>2020-06-19 15:06:04 +0200
commita9ede0792556b4ccb54af6a6367f4c7395e80f75 (patch)
treebd2b7ea95972ac04b65362e715656ca76a2dd731 /model-integration
parent92b73034306ca58f6841f158149bd048bddb374f (diff)
Replace joins with same input with a map
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java36
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java22
3 files changed, 64 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 6e637c72d0f..6711b999940 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -312,6 +312,14 @@ public abstract class IntermediateOperation {
return -1;
}
+ /** Removes outputs if they point to the same operation */
+ public void removeDuplicateOutputsTo(IntermediateOperation op) {
+ int last, first = outputs.indexOf(op);
+ while (first >= 0 && (last = outputs.lastIndexOf(op)) > first) {
+ outputs.remove(last);
+ }
+ }
+
/**
* Returns the largest value type among the input value types.
* This should only be called after it has been verified that input types are available.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index 3211a44fa68..2ebc7c3ddf6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -5,11 +5,14 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
public class Join extends IntermediateOperation {
@@ -54,6 +57,16 @@ public class Join extends IntermediateOperation {
if ( ! allInputTypesPresent(2)) return null;
if ( ! allInputFunctionsPresent(2)) return null;
+ // Optimization: if inputs are the same, replace with a map function.
+ if (inputs.get(0).equals(inputs.get(1))) {
+ Optional<DoubleUnaryOperator> mapOperator = operatorAsUnary(operator);
+ if (mapOperator.isPresent()) {
+ IntermediateOperation input = inputs.get(0);
+ input.removeDuplicateOutputsTo(this); // avoids unnecessary function export
+ return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get());
+ }
+ }
+
IntermediateOperation a = largestInput();
IntermediateOperation b = smallestInput();
@@ -126,4 +139,27 @@ public class Join extends IntermediateOperation {
@Override
public String operationName() { return "Join"; }
+ private Optional<DoubleUnaryOperator> operatorAsUnary(DoubleBinaryOperator op) {
+ String unaryRep;
+ if (op instanceof ScalarFunctions.Add) unaryRep = "f(a)(a + a)";
+ else if (op instanceof ScalarFunctions.Multiply) unaryRep = "f(a)(a * a)";
+ else if (op instanceof ScalarFunctions.Subtract) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Divide) unaryRep = "f(a)(1)";
+ else if (op instanceof ScalarFunctions.Equal) unaryRep = "f(a)(1)";
+ else if (op instanceof ScalarFunctions.Greater) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Less) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Max) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Min) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Mean) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Pow) unaryRep = "f(a)(pow(a,a))";
+ else if (op instanceof ScalarFunctions.SquaredDifference) unaryRep = "f(a)(0)";
+ else return Optional.empty();
+ return Optional.of(new DoubleUnaryOperator() {
+ @Override
+ public double applyAsDouble(double operand) { return op.applyAsDouble(operand, operand); }
+ @Override
+ public String toString() { return unaryRep; }
+ });
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 7b9868d71f5..7a0488362a9 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -579,6 +579,23 @@ public class OnnxOperationsTestCase {
assertEval("expand", input, shape, evaluate("tensor(d0[2],d1[3],d2[4]):[1,1,1,1,2,2,2,2,3,3,3,3,1,1,1,1,2,2,2,2,3,3,3,3]"));
}
+ @Test
+ public void testJoinWithSameInput() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ String opName = "mul";
+
+ Context context = new MapContext(DoubleValue.NaN);
+ List<IntermediateOperation> inputs = new ArrayList<>();
+ inputs.add(addInput(inputs, context, x, "x"));
+ IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, createAttributes().build(), 0);
+ optimizeAndRename(opName, op);
+
+ Tensor result = evaluate(op);
+ Tensor expected = evaluate("tensor(d0[2],d1[3]):[1,4,9,16,25,36]");
+ assertEquals(expected, result);
+ assertEquals(expected.type(), result.type());
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -668,12 +685,13 @@ public class OnnxOperationsTestCase {
return inputs;
}
- private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) {
- if (x == null) return;
+ private IntermediateOperation addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) {
+ if (x == null) return null;
context.put(name, new TensorValue(x));
IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString()));
op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type)));
inputs.add(op);
+ return op;
}
Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) {