diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-19 15:06:04 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-19 15:06:04 +0200 |
commit | a9ede0792556b4ccb54af6a6367f4c7395e80f75 (patch) | |
tree | bd2b7ea95972ac04b65362e715656ca76a2dd731 /model-integration/src/main | |
parent | 92b73034306ca58f6841f158149bd048bddb374f (diff) |
Replace joins with same input with a map
Diffstat (limited to 'model-integration/src/main')
2 files changed, 44 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 6e637c72d0f..6711b999940 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -312,6 +312,14 @@ public abstract class IntermediateOperation { return -1; } + /** Removes outputs if they point to the same operation */ + public void removeDuplicateOutputsTo(IntermediateOperation op) { + int last, first = outputs.indexOf(op); + while (first >= 0 && (last = outputs.lastIndexOf(op)) > first) { + outputs.remove(last); + } + } + /** * Returns the largest value type among the input value types. * This should only be called after it has been verified that input types are available. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 3211a44fa68..2ebc7c3ddf6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -5,11 +5,14 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; public class Join extends IntermediateOperation { @@ -54,6 +57,16 @@ public class Join extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; if ( ! allInputFunctionsPresent(2)) return null; + // Optimization: if inputs are the same, replace with a map function. + if (inputs.get(0).equals(inputs.get(1))) { + Optional<DoubleUnaryOperator> mapOperator = operatorAsUnary(operator); + if (mapOperator.isPresent()) { + IntermediateOperation input = inputs.get(0); + input.removeDuplicateOutputsTo(this); // avoids unnecessary function export + return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get()); + } + } + IntermediateOperation a = largestInput(); IntermediateOperation b = smallestInput(); @@ -126,4 +139,27 @@ public class Join extends IntermediateOperation { @Override public String operationName() { return "Join"; } + private Optional<DoubleUnaryOperator> operatorAsUnary(DoubleBinaryOperator op) { + String unaryRep; + if (op instanceof ScalarFunctions.Add) unaryRep = "f(a)(a + a)"; + else if (op instanceof ScalarFunctions.Multiply) unaryRep = "f(a)(a * a)"; + else if (op instanceof ScalarFunctions.Subtract) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Divide) unaryRep = "f(a)(1)"; + else if (op instanceof ScalarFunctions.Equal) unaryRep = "f(a)(1)"; + else if (op instanceof ScalarFunctions.Greater) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Less) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Max) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Min) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Mean) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Pow) unaryRep = "f(a)(pow(a,a))"; + else if (op instanceof ScalarFunctions.SquaredDifference) unaryRep = "f(a)(0)"; + else return Optional.empty(); + return Optional.of(new DoubleUnaryOperator() { + @Override + public double applyAsDouble(double operand) { return op.applyAsDouble(operand, operand); } + @Override + public String toString() { return unaryRep; } + }); + } + } |