diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-19 15:06:04 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-19 15:06:04 +0200 |
commit | a9ede0792556b4ccb54af6a6367f4c7395e80f75 (patch) | |
tree | bd2b7ea95972ac04b65362e715656ca76a2dd731 /model-integration | |
parent | 92b73034306ca58f6841f158149bd048bddb374f (diff) |
Replace joins with same input with a map
Diffstat (limited to 'model-integration')
3 files changed, 64 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 6e637c72d0f..6711b999940 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -312,6 +312,14 @@ public abstract class IntermediateOperation { return -1; } + /** Removes outputs if they point to the same operation */ + public void removeDuplicateOutputsTo(IntermediateOperation op) { + int last, first = outputs.indexOf(op); + while (first >= 0 && (last = outputs.lastIndexOf(op)) > first) { + outputs.remove(last); + } + } + /** * Returns the largest value type among the input value types. * This should only be called after it has been verified that input types are available. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 3211a44fa68..2ebc7c3ddf6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -5,11 +5,14 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; public class Join extends IntermediateOperation { @@ -54,6 +57,16 @@ public class Join extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; if ( ! allInputFunctionsPresent(2)) return null; + // Optimization: if inputs are the same, replace with a map function. + if (inputs.get(0).equals(inputs.get(1))) { + Optional<DoubleUnaryOperator> mapOperator = operatorAsUnary(operator); + if (mapOperator.isPresent()) { + IntermediateOperation input = inputs.get(0); + input.removeDuplicateOutputsTo(this); // avoids unnecessary function export + return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get()); + } + } + IntermediateOperation a = largestInput(); IntermediateOperation b = smallestInput(); @@ -126,4 +139,27 @@ public class Join extends IntermediateOperation { @Override public String operationName() { return "Join"; } + private Optional<DoubleUnaryOperator> operatorAsUnary(DoubleBinaryOperator op) { + String unaryRep; + if (op instanceof ScalarFunctions.Add) unaryRep = "f(a)(a + a)"; + else if (op instanceof ScalarFunctions.Multiply) unaryRep = "f(a)(a * a)"; + else if (op instanceof ScalarFunctions.Subtract) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Divide) unaryRep = "f(a)(1)"; + else if (op instanceof ScalarFunctions.Equal) unaryRep = "f(a)(1)"; + else if (op instanceof ScalarFunctions.Greater) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Less) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Max) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Min) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Mean) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Pow) unaryRep = "f(a)(pow(a,a))"; + else if (op instanceof ScalarFunctions.SquaredDifference) unaryRep = "f(a)(0)"; + else return Optional.empty(); + return Optional.of(new DoubleUnaryOperator() { + @Override + public double applyAsDouble(double operand) { return op.applyAsDouble(operand, operand); } + @Override + public String toString() { return unaryRep; } + }); + } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 7b9868d71f5..7a0488362a9 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -579,6 +579,23 @@ public class OnnxOperationsTestCase { assertEval("expand", input, shape, evaluate("tensor(d0[2],d1[3],d2[4]):[1,1,1,1,2,2,2,2,3,3,3,3,1,1,1,1,2,2,2,2,3,3,3,3]")); } + @Test + public void testJoinWithSameInput() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + String opName = "mul"; + + Context context = new MapContext(DoubleValue.NaN); + List<IntermediateOperation> inputs = new ArrayList<>(); + inputs.add(addInput(inputs, context, x, "x")); + IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, createAttributes().build(), 0); + optimizeAndRename(opName, op); + + Tensor result = evaluate(op); + Tensor expected = evaluate("tensor(d0[2],d1[3]):[1,4,9,16,25,36]"); + assertEquals(expected, result); + assertEquals(expected.type(), result.type()); + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -668,12 +685,13 @@ public class OnnxOperationsTestCase { return inputs; } - private void addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) { - if (x == null) return; + private IntermediateOperation addInput(List<IntermediateOperation> inputs, Context context, Tensor x, String name) { + if (x == null) return null; context.put(name, new TensorValue(x)); IntermediateOperation op = new Constant(modelName, name, OrderedTensorType.fromSpec(x.type().toString())); op.setConstantValueFunction(type -> new TensorValue(convertTypeAfterRename(x, type))); inputs.add(op); + return op; } Tensor convertTypeAfterRename(Tensor tensor, OrderedTensorType type) { |