diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java | 36 |
1 files changed, 36 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index 3211a44fa68..2ebc7c3ddf6 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -5,11 +5,14 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; public class Join extends IntermediateOperation { @@ -54,6 +57,16 @@ public class Join extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; if ( ! allInputFunctionsPresent(2)) return null; + // Optimization: if inputs are the same, replace with a map function. + if (inputs.get(0).equals(inputs.get(1))) { + Optional<DoubleUnaryOperator> mapOperator = operatorAsUnary(operator); + if (mapOperator.isPresent()) { + IntermediateOperation input = inputs.get(0); + input.removeDuplicateOutputsTo(this); // avoids unnecessary function export + return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get()); + } + } + IntermediateOperation a = largestInput(); IntermediateOperation b = smallestInput(); @@ -126,4 +139,27 @@ public class Join extends IntermediateOperation { @Override public String operationName() { return "Join"; } + private Optional<DoubleUnaryOperator> operatorAsUnary(DoubleBinaryOperator op) { + String unaryRep; + if (op instanceof ScalarFunctions.Add) unaryRep = "f(a)(a + a)"; + else if (op instanceof ScalarFunctions.Multiply) unaryRep = "f(a)(a * a)"; + else if (op instanceof ScalarFunctions.Subtract) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Divide) unaryRep = "f(a)(1)"; + else if (op instanceof ScalarFunctions.Equal) unaryRep = "f(a)(1)"; + else if (op instanceof ScalarFunctions.Greater) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Less) unaryRep = "f(a)(0)"; + else if (op instanceof ScalarFunctions.Max) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Min) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Mean) unaryRep = "f(a)(a)"; + else if (op instanceof ScalarFunctions.Pow) unaryRep = "f(a)(pow(a,a))"; + else if (op instanceof ScalarFunctions.SquaredDifference) unaryRep = "f(a)(0)"; + else return Optional.empty(); + return Optional.of(new DoubleUnaryOperator() { + @Override + public double applyAsDouble(double operand) { return op.applyAsDouble(operand, operand); } + @Override + public String toString() { return unaryRep; } + }); + } + } |