summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java36
1 files changed, 36 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index 3211a44fa68..2ebc7c3ddf6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -5,11 +5,14 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
public class Join extends IntermediateOperation {
@@ -54,6 +57,16 @@ public class Join extends IntermediateOperation {
if ( ! allInputTypesPresent(2)) return null;
if ( ! allInputFunctionsPresent(2)) return null;
+ // Optimization: if inputs are the same, replace with a map function.
+ if (inputs.get(0).equals(inputs.get(1))) {
+ Optional<DoubleUnaryOperator> mapOperator = operatorAsUnary(operator);
+ if (mapOperator.isPresent()) {
+ IntermediateOperation input = inputs.get(0);
+ input.removeDuplicateOutputsTo(this); // avoids unnecessary function export
+ return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get());
+ }
+ }
+
IntermediateOperation a = largestInput();
IntermediateOperation b = smallestInput();
@@ -126,4 +139,27 @@ public class Join extends IntermediateOperation {
@Override
public String operationName() { return "Join"; }
+ private Optional<DoubleUnaryOperator> operatorAsUnary(DoubleBinaryOperator op) {
+ String unaryRep;
+ if (op instanceof ScalarFunctions.Add) unaryRep = "f(a)(a + a)";
+ else if (op instanceof ScalarFunctions.Multiply) unaryRep = "f(a)(a * a)";
+ else if (op instanceof ScalarFunctions.Subtract) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Divide) unaryRep = "f(a)(1)";
+ else if (op instanceof ScalarFunctions.Equal) unaryRep = "f(a)(1)";
+ else if (op instanceof ScalarFunctions.Greater) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Less) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Max) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Min) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Mean) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Pow) unaryRep = "f(a)(pow(a,a))";
+ else if (op instanceof ScalarFunctions.SquaredDifference) unaryRep = "f(a)(0)";
+ else return Optional.empty();
+ return Optional.of(new DoubleUnaryOperator() {
+ @Override
+ public double applyAsDouble(double operand) { return op.applyAsDouble(operand, operand); }
+ @Override
+ public String toString() { return unaryRep; }
+ });
+ }
+
}