summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java120
1 files changed, 120 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java
new file mode 100644
index 00000000000..c98bcb43331
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java
@@ -0,0 +1,120 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+
+public class Join extends IntermediateOperation {
+
+ private final DoubleBinaryOperator operator;
+
+ public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
+ super(modelName, nodeName, inputs);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < a.rank(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ long size = aDim.size().orElse(-1L);
+
+ if (i - sizeDifference >= 0) {
+ TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
+ size = Math.max(size, bDim.size().orElse(-1L));
+ }
+
+ if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name(), size));
+ } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name()));
+ } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
+ builder.add(TensorType.Dimension.mapped(aDim.name()));
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+
+ IntermediateOperation a = largestInput();
+ IntermediateOperation b = smallestInput();
+
+ List<String> aDimensionsToReduce = new ArrayList<>();
+ List<String> bDimensionsToReduce = new ArrayList<>();
+ int sizeDifference = a.type().get().rank() - b.type().get().rank();
+ for (int i = 0; i < b.type().get().rank(); ++i) {
+ TensorType.Dimension bDim = b.type().get().dimensions().get(i);
+ TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
+ long bSize = bDim.size().orElse(-1L);
+ long aSize = aDim.size().orElse(-1L);
+ if (bSize == 1L && aSize != 1L) {
+ bDimensionsToReduce.add(bDim.name());
+ }
+ if (aSize == 1L && bSize != 1L) {
+ aDimensionsToReduce.add(bDim.name());
+ }
+ }
+
+ TensorFunction aReducedFunction = a.function().get();
+ if (aDimensionsToReduce.size() > 0) {
+ aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
+ }
+ TensorFunction bReducedFunction = b.function().get();
+ if (bDimensionsToReduce.size() > 0) {
+ bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
+ }
+
+ return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < b.rank(); ++i) {
+ String bDim = b.dimensions().get(i).name();
+ String aDim = a.dimensions().get(i + sizeDifference).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ }
+ }
+
+ private IntermediateOperation largestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+ private IntermediateOperation smallestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+}