From a9ede0792556b4ccb54af6a6367f4c7395e80f75 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 19 Jun 2020 15:06:04 +0200 Subject: Replace joins with same input with a map --- .../importer/operations/IntermediateOperation.java | 8 +++++ .../importer/operations/Join.java | 36 ++++++++++++++++++++++ 2 files changed, 44 insertions(+) (limited to 'model-integration/src/main') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 6e637c72d0f..6711b999940 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -312,6 +312,14 @@ public abstract class IntermediateOperation { return -1; } + /** Removes outputs if they point to the same operation */ + public void removeDuplicateOutputsTo(IntermediateOperation op) { + int last, first = outputs.indexOf(op); + while (first >= 0 && (last = outputs.lastIndexOf(op)) > first) { + outputs.remove(last); + } + } + /** * Returns the largest value type among the input value types. * This should only be called after it has been verified that input types are available. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 3211a44fa68..2ebc7c3ddf6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -5,11 +5,14 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; public class Join extends IntermediateOperation { @@ -54,6 +57,16 @@ public class Join extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; if ( ! allInputFunctionsPresent(2)) return null; + // Optimization: if inputs are the same, replace with a map function. + if (inputs.get(0).equals(inputs.get(1))) { + Optional mapOperator = operatorAsUnary(operator); + if (mapOperator.isPresent()) { + IntermediateOperation input = inputs.get(0); + input.removeDuplicateOutputsTo(this); // avoids unnecessary function export + return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get()); + } + } + IntermediateOperation a = largestInput(); IntermediateOperation b = smallestInput(); @@ -126,4 +139,27 @@ public class Join extends IntermediateOperation { @Override public String operationName() { return "Join"; } + private Optional operatorAsUnary(DoubleBinaryOperator op) { + String unaryRep; + if (op instanceof ScalarFunctions.Add) unaryRep = "f(a)(a + a)"; + else if (op instanceof ScalarFunctions.Multiply) unaryRep = "f(a)(a * a)"; + else if (op instanceof ScalarFunctions.Subtract) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Divide) unaryRep = "f(a)(1)"; + else if (op instanceof ScalarFunctions.Equal) unaryRep = "f(a)(1)"; + else if (op instanceof ScalarFunctions.Greater) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Less) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Max) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Min) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Mean) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Pow) unaryRep = "f(a)(pow(a,a))"; + else if (op instanceof ScalarFunctions.SquaredDifference) unaryRep = "f(a)(0)"; + else return Optional.empty(); + return Optional.of(new DoubleUnaryOperator() { + @Override + public double applyAsDouble(double operand) { return op.applyAsDouble(operand, operand); } + @Override + public String toString() { return unaryRep; } + }); + } + } -- cgit v1.2.3