aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-19 15:06:04 +0200
committerLester Solbakken <lesters@oath.com>2020-06-19 15:06:04 +0200
commita9ede0792556b4ccb54af6a6367f4c7395e80f75 (patch)
treebd2b7ea95972ac04b65362e715656ca76a2dd731 /model-integration/src/main/java/ai
parent92b73034306ca58f6841f158149bd048bddb374f (diff)
Replace joins with same input with a map
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java36
2 files changed, 44 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 6e637c72d0f..6711b999940 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -312,6 +312,14 @@ public abstract class IntermediateOperation {
return -1;
}
+ /** Removes outputs if they point to the same operation */
+ public void removeDuplicateOutputsTo(IntermediateOperation op) {
+ int last, first = outputs.indexOf(op);
+ while (first >= 0 && (last = outputs.lastIndexOf(op)) > first) {
+ outputs.remove(last);
+ }
+ }
+
/**
* Returns the largest value type among the input value types.
* This should only be called after it has been verified that input types are available.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index 3211a44fa68..2ebc7c3ddf6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -5,11 +5,14 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
public class Join extends IntermediateOperation {
@@ -54,6 +57,16 @@ public class Join extends IntermediateOperation {
if ( ! allInputTypesPresent(2)) return null;
if ( ! allInputFunctionsPresent(2)) return null;
+ // Optimization: if inputs are the same, replace with a map function.
+ if (inputs.get(0).equals(inputs.get(1))) {
+ Optional<DoubleUnaryOperator> mapOperator = operatorAsUnary(operator);
+ if (mapOperator.isPresent()) {
+ IntermediateOperation input = inputs.get(0);
+ input.removeDuplicateOutputsTo(this); // avoids unnecessary function export
+ return new com.yahoo.tensor.functions.Map(input.function().get(), mapOperator.get());
+ }
+ }
+
IntermediateOperation a = largestInput();
IntermediateOperation b = smallestInput();
@@ -126,4 +139,27 @@ public class Join extends IntermediateOperation {
@Override
public String operationName() { return "Join"; }
+ private Optional<DoubleUnaryOperator> operatorAsUnary(DoubleBinaryOperator op) {
+ String unaryRep;
+ if (op instanceof ScalarFunctions.Add) unaryRep = "f(a)(a + a)";
+ else if (op instanceof ScalarFunctions.Multiply) unaryRep = "f(a)(a * a)";
+ else if (op instanceof ScalarFunctions.Subtract) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Divide) unaryRep = "f(a)(1)";
+ else if (op instanceof ScalarFunctions.Equal) unaryRep = "f(a)(1)";
+ else if (op instanceof ScalarFunctions.Greater) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Less) unaryRep = "f(a)(0)";
+ else if (op instanceof ScalarFunctions.Max) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Min) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Mean) unaryRep = "f(a)(a)";
+ else if (op instanceof ScalarFunctions.Pow) unaryRep = "f(a)(pow(a,a))";
+ else if (op instanceof ScalarFunctions.SquaredDifference) unaryRep = "f(a)(0)";
+ else return Optional.empty();
+ return Optional.of(new DoubleUnaryOperator() {
+ @Override
+ public double applyAsDouble(double operand) { return op.applyAsDouble(operand, operand); }
+ @Override
+ public String toString() { return unaryRep; }
+ });
+ }
+
}