summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-29 12:53:37 +0200
committerGitHub <noreply@github.com>2018-09-29 12:53:37 +0200
commitbfc4feb4f5b9b2cec76bd08027cdf3c2ca339f39 (patch)
tree3bd0250db894ae103c85062556b4b34c59e38e84
parent44313db093c7954c6fa979296832ed44dd28991d (diff)
parent35da647c2c32ea3a81b5c65bd440af1dc59b1b3c (diff)
Merge pull request #7145 from vespa-engine/lesters/add-java-reduce-join-optimization
Add reduce-join optimization in Java
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java85
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java43
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java116
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java97
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java68
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java372
11 files changed, 757 insertions, 70 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
index 7060cfc2132..84a90ee64c2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ExpressionOptimizer.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTForestOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization.TensorOptimizer;
/**
* This class will perform various optimizations on the ranking expressions. Clients using optimized expressions
@@ -32,8 +33,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization.GBDTOpt
public class ExpressionOptimizer {
private GBDTOptimizer gbdtOptimizer = new GBDTOptimizer();
-
private GBDTForestOptimizer gbdtForestOptimizer = new GBDTForestOptimizer();
+ private TensorOptimizer tensorOptimizer = new TensorOptimizer();
/** Gets an optimizer instance used by this by class name, or null if the optimizer is not known */
public Optimizer getOptimizer(Class<?> clazz) {
@@ -41,6 +42,8 @@ public class ExpressionOptimizer {
return gbdtOptimizer;
if (clazz == gbdtForestOptimizer.getClass())
return gbdtForestOptimizer;
+ if (clazz == tensorOptimizer.getClass())
+ return tensorOptimizer;
return null;
}
@@ -49,6 +52,7 @@ public class ExpressionOptimizer {
// Note: Order of optimizations matter
gbdtOptimizer.optimize(expression, contextIndex, report);
gbdtForestOptimizer.optimize(expression, contextIndex, report);
+ tensorOptimizer.optimize(expression, contextIndex, report);
return report;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
new file mode 100644
index 00000000000..63cea371d14
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
@@ -0,0 +1,85 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
+import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ReduceJoin;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Recognizes and optimizes tensor expressions.
+ *
+ * @author lesters
+ */
+public class TensorOptimizer extends Optimizer {
+
+ private OptimizationReport report;
+
+ @Override
+ public void optimize(RankingExpression expression, ContextIndex context, OptimizationReport report) {
+ if (!isEnabled()) return;
+ this.report = report;
+ expression.setRoot(optimize(expression.getRoot(), context));
+ report.note("Tensor expression optimization done");
+ }
+
+ private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
+ node = optimizeReduceJoin(node);
+ if (node instanceof CompositeNode) {
+ return optimizeChildren((CompositeNode)node, context);
+ }
+ return node;
+ }
+
+ private ExpressionNode optimizeChildren(CompositeNode node, ContextIndex context) {
+ List<ExpressionNode> children = node.children();
+ List<ExpressionNode> optimizedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children)
+ optimizedChildren.add(optimize(child, context));
+ return node.setChildren(optimizedChildren);
+ }
+
+ /**
+ * Recognized a reduce followed by a join. In many cases, chunking these
+ * two operations together is significantly more efficient than evaluating
+ * each on its own, avoiding the cost of a temporary tensor.
+ *
+ * Note that this does not guarantee that the optimization is performed.
+ * The ReduceJoin class determines whether or not the arguments are
+ * compatible with the optimization.
+ */
+ private ExpressionNode optimizeReduceJoin(ExpressionNode node) {
+ if ( ! (node instanceof TensorFunctionNode)) {
+ return node;
+ }
+ TensorFunction function = ((TensorFunctionNode) node).function();
+ if ( ! (function instanceof Reduce)) {
+ return node;
+ }
+ List<ExpressionNode> children = ((TensorFunctionNode) node).children();
+ if (children.size() != 1) {
+ return node;
+ }
+ ExpressionNode child = children.get(0);
+ if ( ! (child instanceof TensorFunctionNode)) {
+ return node;
+ }
+ TensorFunction argument = ((TensorFunctionNode) child).function();
+ if (argument instanceof Join) {
+ report.incMetric("Replaced reduce->join", 1);
+ return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument));
+ }
+ return node;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index de98b01287e..3a3410aeebb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.Deque;
import java.util.List;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
@@ -90,7 +91,47 @@ public class LambdaFunctionNode extends CompositeNode {
if (arguments.size() > 2)
throw new IllegalStateException("Cannot apply " + this + " as a DoubleBinaryOperator: " +
"Must have at most two argument " + " but has " + arguments);
- return new DoubleBinaryLambda();
+
+ // Optimization: if possible, calculate directly rather than creating a context and evaluating the expression
+ return getDirectEvaluator().orElseGet(DoubleBinaryLambda::new);
+ }
+
+ private Optional<DoubleBinaryOperator> getDirectEvaluator() {
+ if ( ! (functionExpression instanceof ArithmeticNode)) {
+ return Optional.empty();
+ }
+ ArithmeticNode node = (ArithmeticNode) functionExpression;
+ if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) {
+ return Optional.empty();
+ }
+ if (node.operators().size() != 1) {
+ return Optional.empty();
+ }
+ ArithmeticOperator operator = node.operators().get(0);
+ switch (operator) {
+ case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
+ case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
+ case PLUS: return asFunctionExpression((left, right) -> left + right);
+ case MINUS: return asFunctionExpression((left, right) -> left - right);
+ case MULTIPLY: return asFunctionExpression((left, right) -> left * right);
+ case DIVIDE: return asFunctionExpression((left, right) -> left / right);
+ case MODULO: return asFunctionExpression((left, right) -> left % right);
+ case POWER: return asFunctionExpression(Math::pow);
+ }
+ return Optional.empty();
+ }
+
+ private Optional<DoubleBinaryOperator> asFunctionExpression(DoubleBinaryOperator operator) {
+ return Optional.of(new DoubleBinaryOperator() {
+ @Override
+ public double applyAsDouble(double left, double right) {
+ return operator.applyAsDouble(left, right);
+ }
+ @Override
+ public String toString() {
+ return LambdaFunctionNode.this.toString();
+ }
+ });
}
private class DoubleUnaryLambda implements DoubleUnaryOperator {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java
new file mode 100644
index 00000000000..f29083bddc9
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizerTestCase.java
@@ -0,0 +1,116 @@
+package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ReduceJoin;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class TensorOptimizerTestCase {
+
+ @Test
+ public void testReduceJoinOptimization() throws ParseException {
+ assertWillOptimize("d0[3]", "d0[3]");
+ assertWillOptimize("d0[1]", "d0[1]", "d0");
+ assertWillOptimize("d0[2]", "d0[2]", "d0");
+ assertWillOptimize("d0[1]", "d0[3]", "d0");
+ assertWillOptimize("d0[3]", "d0[3]", "d0");
+ assertWillOptimize("d0[3]", "d0[3],d1[2]", "d0");
+ assertWillOptimize("d0[3],d1[2]", "d0[3]", "d0");
+ assertWillOptimize("d1[3]", "d0[2],d1[3]", "d1");
+ assertWillOptimize("d0[2],d1[3]", "d1[3]", "d1");
+ assertWillOptimize("d0[2],d2[2]", "d1[3],d2[2]", "d2");
+ assertWillOptimize("d1[2],d2[2]", "d0[3],d2[2]", "d2");
+ assertWillOptimize("d0[1],d2[2]", "d1[3],d2[4]", "d2");
+ assertWillOptimize("d0[2],d2[2]", "d1[3],d2[4]", "d2");
+ assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]");
+ assertWillOptimize("d0[1],d1[2]", "d0[2],d1[3]", "d0,d1");
+ assertWillOptimize("d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3");
+ assertWillOptimize("d0[1],d2[3],d3[4]", "d1[2],d2[3],d3[4]", "d2,d3");
+ assertWillOptimize("d0[1],d1[2],d2[3]", "d2[3],d3[4],d4[5]", "d2");
+ assertWillOptimize("d0[1],d1[2],d2[3]", "d1[2],d2[3],d4[4]", "d1,d2");
+ assertWillOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]");
+ assertWillOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]", "d0,d1,d2");
+
+ // Will not currently use reduce-join optimization
+ assertCantOptimize("d0[2],d1[3]", "d1[3]", "d0"); // reducing on a dimension not joining on
+ assertCantOptimize("d0[1],d1[2]", "d1[2],d2[3]", "d2"); // same
+ assertCantOptimize("d0[3]", "d0[3],d1[2]"); // reducing on more then we are combining
+ assertCantOptimize("d0[1],d2[3]", "d1[2],d2[3]"); // same
+ assertCantOptimize("d0[1],d1[2],d2[3]", "d0[1],d1[2],d2[3]", "d1,d2"); // reducing on less then joining on
+ }
+
+ private void assertWillOptimize(String aType, String bType) throws ParseException {
+ assertWillOptimize(aType, bType, "", "sum");
+ }
+
+ private void assertWillOptimize(String aType, String bType, String reduceDim) throws ParseException {
+ assertWillOptimize(aType, bType, reduceDim, "sum");
+ }
+
+ private void assertWillOptimize(String aType, String bType, String reduceDim, String aggregator) throws ParseException {
+ assertReduceJoin(aType, bType, reduceDim, aggregator, true);
+ }
+
+ private void assertCantOptimize(String aType, String bType) throws ParseException {
+ assertCantOptimize(aType, bType, "", "sum");
+ }
+
+ private void assertCantOptimize(String aType, String bType, String reduceDim) throws ParseException {
+ assertCantOptimize(aType, bType, reduceDim, "sum");
+ }
+
+ private void assertCantOptimize(String aType, String bType, String reduceDim, String aggregator) throws ParseException {
+ assertReduceJoin(aType, bType, reduceDim, aggregator, false);
+ }
+
+ private void assertReduceJoin(String aType, String bType, String reduceDim, String aggregator, boolean assertOptimize) throws ParseException {
+ Tensor a = generateRandomTensor(aType);
+ Tensor b = generateRandomTensor(bType);
+ RankingExpression expression = generateRankingExpression(reduceDim, aggregator);
+ assert ((TensorFunctionNode)expression.getRoot()).function() instanceof Reduce;
+
+ ArrayContext context = generateContext(a, b, expression);
+ Tensor result = expression.evaluate(context).asTensor();
+
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ OptimizationReport report = optimizer.optimize(expression, context);
+ assertEquals(1, report.getMetric("Replaced reduce->join"));
+ assert ((TensorFunctionNode)expression.getRoot()).function() instanceof ReduceJoin;
+
+ assertEquals(result, expression.evaluate(context).asTensor());
+ assertEquals(assertOptimize, ((ReduceJoin)((TensorFunctionNode)expression.getRoot()).function()).canOptimize(a, b));
+ }
+
+ private RankingExpression generateRankingExpression(String reduceDim, String aggregator) throws ParseException {
+ String dimensions = "";
+ if (reduceDim.length() > 0) {
+ dimensions = ", " + reduceDim;
+ }
+ return new RankingExpression("reduce(join(a, b, f(a,b)(a * b)), " + aggregator + dimensions + ")");
+ }
+
+ private ArrayContext generateContext(Tensor a, Tensor b, RankingExpression expression) {
+ ArrayContext context = new ArrayContext(expression);
+ context.put("a", new TensorValue(a));
+ context.put("b", new TensorValue(b));
+ return context;
+ }
+
+ private Tensor generateRandomTensor(String type) {
+ return Tensor.random(TensorType.fromSpec("tensor(" + type + ")"));
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
index 5447e5240f7..19b3329e8a4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.TensorConverter;
@@ -50,7 +52,11 @@ public class TestableTensorFlowModel {
model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
- Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
+ RankingExpression expression = model.expressions().get(operationName);
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+
+ Tensor vespaResult = expression.evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results",
tfResult.sum().asDouble(), vespaResult.sum().asDouble(), delta);
}
@@ -64,7 +70,11 @@ public class TestableTensorFlowModel {
model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
- Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
+ RankingExpression expression = model.expressions().get(operationName);
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+
+ Tensor vespaResult = expression.evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}
@@ -82,7 +92,7 @@ public class TestableTensorFlowModel {
}
private Context contextFrom(ImportedModel result) {
- MapContext context = new MapContext();
+ TestableModelContext context = new TestableModelContext();
result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
return context;
@@ -118,4 +128,15 @@ public class TestableTensorFlowModel {
}
}
+ private static class TestableModelContext extends MapContext implements ContextIndex {
+ @Override
+ public int size() {
+ return bindings().size();
+ }
+ @Override
+ public int getIndex(String name) {
+ throw new UnsupportedOperationException(this + " does not support index lookup by name");
+ }
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 38979411313..2d127eb86cf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -116,7 +116,14 @@ public class IndexedTensor implements Tensor {
}
}
- private double get(long valueIndex) { return values[(int)valueIndex]; }
+ /**
+ * Returns the value at the given index by direct lookup. Only use
+ * if you know the underlying data layout.
+ *
+ * @param valueIndex the direct index into the underlying data.
+ * @throws IndexOutOfBoundsException if index is out of bounds
+ */
+ public double get(long valueIndex) { return values[(int)valueIndex]; }
private static long toValueIndex(long[] indexes, DimensionSizes sizes) {
if (indexes.length == 1) return indexes[0]; // for speed
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index c846364df29..1d447ed3eed 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -462,7 +462,7 @@ public class TensorType {
return add(new MappedDimension(name));
}
- /** Adds the give dimension */
+ /** Adds the given dimension */
public Builder dimension(Dimension dimension) {
return add(dimension);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 2a52df20108..5dd2cc442aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -22,7 +22,7 @@ public abstract class CompositeTensorFunction extends TensorFunction {
/** Evaluates this by first converting it to a primitive function */
@Override
- public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
return toPrimitive().evaluate(context);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index be323313369..62ee471fcf4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -82,25 +82,29 @@ public class Join extends PrimitiveTensorFunction {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
+ return evaluate(a, b, joinedType, combinator);
+ }
+ static Tensor evaluate(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
// Choose join algorithm
if (hasSingleIndexedDimension(a) && hasSingleIndexedDimension(b) && a.type().dimensions().get(0).name().equals(b.type().dimensions().get(0).name()))
- return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType);
+ return indexedVectorJoin((IndexedTensor)a, (IndexedTensor)b, joinedType, combinator);
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
- return singleSpaceJoin(a, b, joinedType);
+ return singleSpaceJoin(a, b, joinedType, combinator);
else if (a.type().dimensions().containsAll(b.type().dimensions()))
- return subspaceJoin(b, a, joinedType, true);
+ return subspaceJoin(b, a, joinedType, true, combinator);
else if (b.type().dimensions().containsAll(a.type().dimensions()))
- return subspaceJoin(a, b, joinedType, false);
+ return subspaceJoin(a, b, joinedType, false, combinator);
else
- return generalJoin(a, b, joinedType);
+ return generalJoin(a, b, joinedType, combinator);
+
}
- private boolean hasSingleIndexedDimension(Tensor tensor) {
+ private static boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
- private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
+ private static Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type, DoubleBinaryOperator combinator) {
long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
@@ -111,7 +115,7 @@ public class Join extends PrimitiveTensorFunction {
}
/** When both tensors have the same dimensions, at most one cell matches a cell in the other tensor */
- private Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType) {
+ private static Tensor singleSpaceJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = a.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> aCell = i.next();
@@ -123,14 +127,14 @@ public class Join extends PrimitiveTensorFunction {
}
/** Join a tensor into a superspace */
- private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ private static Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
- return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder);
+ return indexedSubspaceJoin((IndexedTensor) subspace, (IndexedTensor) superspace, joinedType, reversedArgumentOrder, combinator);
else
- return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder);
+ return generalSubspaceJoin(subspace, superspace, joinedType, reversedArgumentOrder, combinator);
}
- private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ private static Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
@@ -145,16 +149,17 @@ public class Join extends PrimitiveTensorFunction {
for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
- subspaceInSuper, subspaceInSuper.size(),
- reversedArgumentOrder, builder);
+ subspaceInSuper, subspaceInSuper.size(),
+ reversedArgumentOrder, builder, combinator);
}
return builder.build();
}
- private void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
- Iterator<Tensor.Cell> superspace, long superspaceSize,
- boolean reversedArgumentOrder, IndexedTensor.Builder builder) {
+ private static void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
+ Iterator<Tensor.Cell> superspace, long superspaceSize,
+ boolean reversedArgumentOrder, IndexedTensor.Builder builder,
+ DoubleBinaryOperator combinator) {
long joinedLength = Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
@@ -169,7 +174,7 @@ public class Join extends PrimitiveTensorFunction {
}
}
- private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
+ private static DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size());
for (int i = 0; i < builder.dimensions(); i++) {
String dimensionName = joinedType.dimensions().get(i).name();
@@ -185,7 +190,7 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
+ private static Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder, DoubleBinaryOperator combinator) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Iterator<Tensor.Cell> i = superspace.cellIterator(); i.hasNext(); ) {
@@ -194,21 +199,21 @@ public class Join extends PrimitiveTensorFunction {
double subspaceValue = subspace.get(subaddress);
if ( ! Double.isNaN(subspaceValue))
builder.cell(supercell.getKey(),
- reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
- : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
+ : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
return builder.build();
}
/** Returns the indexes in the superspace type which should be retained to create the subspace type */
- private int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
+ private static int[] subspaceIndexes(TensorType supertype, TensorType subtype) {
int[] subspaceIndexes = new int[subtype.dimensions().size()];
for (int i = 0; i < subtype.dimensions().size(); i++)
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
- private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
+ private static TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
subspaceLabels[i] = superAddress.label(subspaceIndexes[i]);
@@ -216,25 +221,25 @@ public class Join extends PrimitiveTensorFunction {
}
/** Slow join which works for any two tensors */
- private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) {
+ private static Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
if (a instanceof IndexedTensor && b instanceof IndexedTensor)
- return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType);
+ return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType, combinator);
else
- return mappedHashJoin(a, b, joinedType);
+ return mappedHashJoin(a, b, joinedType, combinator);
}
- private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
+ private static Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
DimensionSizes joinedSize = joinedSize(joinedType, a, b);
Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize);
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
- joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder);
-// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
+ joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, builder, combinator);
return builder.build();
}
- private void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
- int[] aToIndexes, int[] bToIndexes, boolean reversedOrder, Tensor.Builder builder) {
+ private static void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
+ int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder,
+ DoubleBinaryOperator combinator) {
Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
@@ -252,15 +257,14 @@ public class Join extends PrimitiveTensorFunction {
for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
Tensor.Cell bCell = bSubspace.next();
TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
- double joinedValue = reversedOrder ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue())
- : combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
+ double joinedValue = combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
builder.cell(joinedAddress, joinedValue);
}
}
}
}
- private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
+ private static PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
if (retainDimensions.contains(addressType.dimensions().get(i).name()))
@@ -269,7 +273,7 @@ public class Join extends PrimitiveTensorFunction {
}
/** Returns the sizes from the joined sizes which are present in the type argument */
- private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
+ private static DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
int dimensionIndex = 0;
for (int i = 0; i < joinedType.dimensions().size(); i++) {
@@ -279,7 +283,7 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
+ private static Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
Tensor.Builder builder = Tensor.Builder.of(joinedType);
@@ -288,7 +292,7 @@ public class Join extends PrimitiveTensorFunction {
for (Iterator<Tensor.Cell> bIterator = b.cellIterator(); bIterator.hasNext(); ) {
Map.Entry<TensorAddress, Double> bCell = bIterator.next();
TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes,
- bCell.getKey(), bToIndexes, joinedType);
+ bCell.getKey(), bToIndexes, joinedType);
if (combinedAddress == null) continue; // not combinable
builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
}
@@ -296,10 +300,10 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType) {
+ private static Tensor mappedHashJoin(Tensor a, Tensor b, TensorType joinedType, DoubleBinaryOperator combinator) {
TensorType commonDimensionType = commonDimensions(a, b);
if (commonDimensionType.dimensions().isEmpty()) {
- return mappedGeneralJoin(a, b, joinedType); // fallback
+ return mappedGeneralJoin(a, b, joinedType, combinator); // fallback
}
boolean swapTensors = a.size() > b.size();
@@ -351,15 +355,15 @@ public class Join extends PrimitiveTensorFunction {
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
- private int[] mapIndexes(TensorType fromType, TensorType toType) {
+ static int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
toIndexes[i] = toType.indexOfDimension(fromType.dimensions().get(i).name()).orElse(-1);
return toIndexes;
}
- private TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType joinedType) {
+ private static TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
+ TensorType joinedType) {
String[] joinedLabels = new String[joinedType.dimensions().size()];
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
@@ -373,7 +377,7 @@ public class Join extends PrimitiveTensorFunction {
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
+ private static boolean mapContent(TensorAddress from, String[] to, int[] indexMap) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
if (to[toIndex] != null && ! to[toIndex].equals(from.label(i))) return false;
@@ -382,11 +386,10 @@ public class Join extends PrimitiveTensorFunction {
return true;
}
-
/**
* Returns common dimension of a and b as a new tensor type
*/
- private TensorType commonDimensions(Tensor a, Tensor b) {
+ private static TensorType commonDimensions(Tensor a, Tensor b) {
TensorType.Builder typeBuilder = new TensorType.Builder();
TensorType aType = a.type();
TensorType bType = b.type();
@@ -402,14 +405,14 @@ public class Join extends PrimitiveTensorFunction {
return typeBuilder.build();
}
- private TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
+ private static TensorAddress partialCommonAddress(Tensor.Cell cell, int[] indexMap) {
TensorAddress address = cell.getKey();
String[] labels = new String[indexMap.length];
for (int i = 0; i < labels.length; ++i) {
labels[i] = address.label(indexMap[i]);
}
return TensorAddress.of(labels);
-
}
}
+
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 332150f957d..54d7710c9dc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -32,7 +32,7 @@ public class Reduce extends PrimitiveTensorFunction {
private final List<String> dimensions;
private final Aggregator aggregator;
- /** Creates a reduce function reducing aLL dimensions */
+ /** Creates a reduce function reducing all dimensions */
public Reduce(TensorFunction argument, Aggregator aggregator) {
this(argument, aggregator, Collections.emptyList());
}
@@ -61,6 +61,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
+ if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all
TensorType.Builder b = new TensorType.Builder();
for (TensorType.Dimension dimension : inputType.dimensions()) {
if ( ! reduceDimensions.contains(dimension.name()))
@@ -71,6 +72,10 @@ public class Reduce extends PrimitiveTensorFunction {
public TensorFunction argument() { return argument; }
+ Aggregator aggregator() { return aggregator; }
+
+ List<String> dimensions() { return dimensions; }
+
@Override
public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
@@ -91,7 +96,7 @@ public class Reduce extends PrimitiveTensorFunction {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
- private String commaSeparated(List<String> list) {
+ static String commaSeparated(List<String> list) {
StringBuilder b = new StringBuilder();
for (String element : list)
b.append(", ").append(element);
@@ -100,10 +105,10 @@ public class Reduce extends PrimitiveTensorFunction {
@Override
public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
- return type(argument.type(context));
+ return type(argument.type(context), dimensions);
}
- private TensorType type(TensorType argumentType) {
+ private static TensorType type(TensorType argumentType, List<String> dimensions) {
if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argumentType.dimensions())
@@ -114,7 +119,10 @@ public class Reduce extends PrimitiveTensorFunction {
@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
- Tensor argument = this.argument.evaluate(context);
+ return evaluate(this.argument.evaluate(context), dimensions, aggregator);
+ }
+
+ static Tensor evaluate(Tensor argument, List<String> dimensions, Aggregator aggregator) {
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
dimensions + ": Not all those dimensions are present in this tensor");
@@ -122,17 +130,17 @@ public class Reduce extends PrimitiveTensorFunction {
// Special case: Reduce all
if (dimensions.isEmpty() || dimensions.size() == argument.type().dimensions().size())
if (argument.type().dimensions().size() == 1 && argument instanceof IndexedTensor)
- return reduceIndexedVector((IndexedTensor)argument);
+ return reduceIndexedVector((IndexedTensor)argument, aggregator);
else
- return reduceAllGeneral(argument);
+ return reduceAllGeneral(argument, aggregator);
- TensorType reducedType = type(argument.type());
+ TensorType reducedType = type(argument.type(), dimensions);
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
- TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType);
+ TensorAddress reducedAddress = reduceDimensions(cell.getKey(), argument.type(), reducedType, dimensions);
aggregatingCells.putIfAbsent(reducedAddress, ValueAggregator.ofType(aggregator));
aggregatingCells.get(reducedAddress).aggregate(cell.getValue());
}
@@ -141,11 +149,12 @@ public class Reduce extends PrimitiveTensorFunction {
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
return reducedBuilder.build();
+
}
- private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) {
+ private static TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType, List<String> dimensions) {
Set<Integer> indexesToRemove = new HashSet<>();
- for (String dimensionToRemove : this.dimensions)
+ for (String dimensionToRemove : dimensions)
indexesToRemove.add(argumentType.indexOfDimension(dimensionToRemove).get());
String[] reducedLabels = new String[reducedType.dimensions().size()];
@@ -156,23 +165,23 @@ public class Reduce extends PrimitiveTensorFunction {
return TensorAddress.of(reducedLabels);
}
- private Tensor reduceAllGeneral(Tensor argument) {
+ private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); )
valueAggregator.aggregate(i.next());
return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
- private Tensor reduceIndexedVector(IndexedTensor argument) {
+ private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (int i = 0; i < argument.dimensionSizes().size(0); i++)
valueAggregator.aggregate(argument.get(i));
return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
}
- private static abstract class ValueAggregator {
+ static abstract class ValueAggregator {
- private static ValueAggregator ofType(Aggregator aggregator) {
+ static ValueAggregator ofType(Aggregator aggregator) {
switch (aggregator) {
case avg : return new AvgAggregator();
case count : return new CountAggregator();
@@ -191,6 +200,9 @@ public class Reduce extends PrimitiveTensorFunction {
/** Returns the value aggregated by this */
public abstract double aggregatedValue();
+ /** Resets the aggregator */
+ public abstract void reset();
+
}
private static class AvgAggregator extends ValueAggregator {
@@ -209,6 +221,11 @@ public class Reduce extends PrimitiveTensorFunction {
return valueSum / valueCount;
}
+ @Override
+ public void reset() {
+ valueCount = 0;
+ valueSum = 0.0;
+ }
}
private static class CountAggregator extends ValueAggregator {
@@ -225,6 +242,10 @@ public class Reduce extends PrimitiveTensorFunction {
return valueCount;
}
+ @Override
+ public void reset() {
+ valueCount = 0;
+ }
}
private static class ProdAggregator extends ValueAggregator {
@@ -241,6 +262,10 @@ public class Reduce extends PrimitiveTensorFunction {
return valueProd;
}
+ @Override
+ public void reset() {
+ valueProd = 1.0;
+ }
}
private static class SumAggregator extends ValueAggregator {
@@ -257,6 +282,10 @@ public class Reduce extends PrimitiveTensorFunction {
return valueSum;
}
+ @Override
+ public void reset() {
+ valueSum = 0.0;
+ }
}
private static class MaxAggregator extends ValueAggregator {
@@ -274,6 +303,10 @@ public class Reduce extends PrimitiveTensorFunction {
return maxValue;
}
+ @Override
+ public void reset() {
+ maxValue = Double.MIN_VALUE;
+ }
}
private static class MinAggregator extends ValueAggregator {
@@ -291,6 +324,11 @@ public class Reduce extends PrimitiveTensorFunction {
return minValue;
}
+ @Override
+ public void reset() {
+ minValue = Double.MAX_VALUE;
+ }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
new file mode 100644
index 00000000000..b268e33b418
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -0,0 +1,372 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.DimensionSizes;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Arrays;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+import java.util.stream.Collectors;
+
+/**
+ * An optimization for tensor expressions where a join immediately follows a
+ * reduce. Evaluating this as one operation is significantly more efficient
+ * than evaluating each separately.
+ *
+ * This implementation optimizes the case where the reduce is done on the same
+ * dimensions as the join. A particularly efficient evaluation is done if there
+ * is one common dimension that is joined and reduced on, which is a common
+ * case as it covers vector and matrix like multiplications.
+ *
+ * @author lesters
+ */
+public class ReduceJoin extends CompositeTensorFunction {
+
+ private final TensorFunction argumentA, argumentB;
+ private final DoubleBinaryOperator combinator;
+ private final Reduce.Aggregator aggregator;
+ private final List<String> dimensions;
+
+ public ReduceJoin(Reduce reduce, Join join) {
+ this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions());
+ }
+
+ public ReduceJoin(TensorFunction argumentA,
+ TensorFunction argumentB,
+ DoubleBinaryOperator combinator,
+ Reduce.Aggregator aggregator,
+ List<String> dimensions) {
+ this.argumentA = argumentA;
+ this.argumentB = argumentB;
+ this.combinator = combinator;
+ this.aggregator = aggregator;
+ this.dimensions = ImmutableList.copyOf(dimensions);
+ }
+
+ @Override
+ public List<TensorFunction> arguments() {
+ return ImmutableList.of(argumentA, argumentB);
+ }
+
+ @Override
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 2)
+ throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size());
+ return new ReduceJoin(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions);
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ Join join = new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
+ return new Reduce(join, aggregator, dimensions);
+ }
+
+ @Override
+ public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
+ TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
+
+ if (canOptimize(a, b)) {
+ return evaluate((IndexedTensor)a, (IndexedTensor)b, joinedType);
+ }
+ return Reduce.evaluate(Join.evaluate(a, b, joinedType, combinator), dimensions, aggregator);
+ }
+
+ /**
+ * Tests whether or not the reduce is over the join dimensions. The
+ * remaining logic in this class assumes this to be true.
+ *
+ * If no dimensions are given, the join must be on all tensor dimensions.
+ *
+ * @return {@code true} if the implementation can optimize evaluation
+ * given the two tensors.
+ */
+ public boolean canOptimize(Tensor a, Tensor b) {
+ if (a.type().dimensions().isEmpty() || b.type().dimensions().isEmpty()) // TODO: support scalars
+ return false;
+ if ( ! (a instanceof IndexedTensor))
+ return false;
+ if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)))
+ return false;
+ if ( ! (b instanceof IndexedTensor))
+ return false;
+ if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)))
+ return false;
+
+ TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b);
+ if (dimensions.isEmpty()) {
+ if (a.type().dimensions().size() != commonDimensions.dimensions().size())
+ return false;
+ if (b.type().dimensions().size() != commonDimensions.dimensions().size())
+ return false;
+ } else {
+ for (TensorType.Dimension dimension : commonDimensions.dimensions()) {
+ if (!dimensions.contains(dimension.name()))
+ return false;
+ }
+ }
+ return true;
+ }
+
+ /**
+ * Evaluates the reduce-join. Special handling for common cases where the
+ * reduce dimension is the innermost dimension in both tensors.
+ */
+ private Tensor evaluate(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
+ TensorType reducedType = Reduce.outputType(joinedType, dimensions);
+
+ if (reduceDimensionIsInnermost(a, b)) {
+ if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 1) {
+ return vectorVectorProduct(a, b, reducedType);
+ }
+ if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 2) {
+ return vectorMatrixProduct(a, b, reducedType, false);
+ }
+ if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 1) {
+ return vectorMatrixProduct(b, a, reducedType, true);
+ }
+ if (a.type().dimensions().size() == 2 && b.type().dimensions().size() == 2) {
+ return matrixMatrixProduct(a, b, reducedType);
+ }
+ }
+ return evaluateGeneral(a, b, reducedType);
+ }
+
+ private Tensor vectorVectorProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
+ if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 1) {
+ throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-vector product");
+ }
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
+ long commonSize = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+
+ Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
+ for (int ic = 0; ic < commonSize; ++ic) {
+ double va = a.get(ic);
+ double vb = b.get(ic);
+ agg.aggregate(combinator.applyAsDouble(va, vb));
+ }
+ builder.cellByDirectIndex(0, agg.aggregatedValue());
+ return builder.build();
+ }
+
+ private Tensor vectorMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType, boolean swapped) {
+ if ( a.type().dimensions().size() != 1 || b.type().dimensions().size() != 2) {
+ throw new IllegalArgumentException("Wrong dimension sizes for tensors for vector-matrix product");
+ }
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
+ DimensionSizes sizesA = a.dimensionSizes();
+ DimensionSizes sizesB = b.dimensionSizes();
+
+ Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
+ for (int ib = 0; ib < sizesB.size(0); ++ib) {
+ agg.reset();
+ for (int ic = 0; ic < Math.min(sizesA.size(0), sizesB.size(1)); ++ic) {
+ double va = a.get(ic);
+ double vb = b.get(ib * sizesB.size(1) + ic);
+ double result = swapped ? combinator.applyAsDouble(vb, va) : combinator.applyAsDouble(va, vb);
+ agg.aggregate(result);
+ }
+ builder.cellByDirectIndex(ib, agg.aggregatedValue());
+ }
+ return builder.build();
+ }
+
+ private Tensor matrixMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
+ if ( a.type().dimensions().size() != 2 || b.type().dimensions().size() != 2) {
+ throw new IllegalArgumentException("Wrong dimension sizes for tensors for matrix-matrix product");
+ }
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
+ DimensionSizes sizesA = a.dimensionSizes();
+ DimensionSizes sizesB = b.dimensionSizes();
+ int iaToReduced = reducedType.indexOfDimension(a.type().dimensions().get(0).name()).get();
+ int ibToReduced = reducedType.indexOfDimension(b.type().dimensions().get(0).name()).get();
+ long strideA = iaToReduced < ibToReduced ? sizesB.size(0) : 1;
+ long strideB = ibToReduced < iaToReduced ? sizesA.size(0) : 1;
+
+ Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
+ for (int ia = 0; ia < sizesA.size(0); ++ia) {
+ for (int ib = 0; ib < sizesB.size(0); ++ib) {
+ agg.reset();
+ for (int ic = 0; ic < Math.min(sizesA.size(1), sizesB.size(1)); ++ic) {
+ double va = a.get(ia * sizesA.size(1) + ic);
+ double vb = b.get(ib * sizesB.size(1) + ic);
+ agg.aggregate(combinator.applyAsDouble(va, vb));
+ }
+ builder.cellByDirectIndex(ia * strideA + ib * strideB, agg.aggregatedValue());
+ }
+ }
+ return builder.build();
+ }
+
+ private Tensor evaluateGeneral(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
+ TensorType onlyInA = Reduce.outputType(a.type(), dimensions);
+ TensorType onlyInB = Reduce.outputType(b.type(), dimensions);
+ TensorType common = dimensionsInCommon(a, b);
+
+ // pre-calculate strides for each index position
+ long[] stridesA = strides(a.type());
+ long[] stridesB = strides(b.type());
+ long[] stridesResult = strides(reducedType);
+
+ // mapping of dimension indexes
+ int[] mapOnlyAToA = Join.mapIndexes(onlyInA, a.type());
+ int[] mapCommonToA = Join.mapIndexes(common, a.type());
+ int[] mapOnlyBToB = Join.mapIndexes(onlyInB, b.type());
+ int[] mapCommonToB = Join.mapIndexes(common, b.type());
+ int[] mapOnlyAToResult = Join.mapIndexes(onlyInA, reducedType);
+ int[] mapOnlyBToResult = Join.mapIndexes(onlyInB, reducedType);
+
+ // TODO: refactor with code in IndexedTensor and Join
+
+ MultiDimensionIterator ic = new MultiDimensionIterator(common);
+ Reduce.ValueAggregator agg = Reduce.ValueAggregator.ofType(aggregator);
+ for (MultiDimensionIterator ia = new MultiDimensionIterator(onlyInA); ia.hasNext(); ia.next()) {
+ for (MultiDimensionIterator ib = new MultiDimensionIterator(onlyInB); ib.hasNext(); ib.next()) {
+ agg.reset();
+ for (ic.reset(); ic.hasNext(); ic.next()) {
+ double va = a.get(toDirectIndex(ia, ic, stridesA, mapOnlyAToA, mapCommonToA));
+ double vb = b.get(toDirectIndex(ib, ic, stridesB, mapOnlyBToB, mapCommonToB));
+ agg.aggregate(combinator.applyAsDouble(va, vb));
+ }
+ builder.cellByDirectIndex(toDirectIndex(ia, ib, stridesResult, mapOnlyAToResult, mapOnlyBToResult),
+ agg.aggregatedValue());
+ }
+ }
+ return builder.build();
+ }
+
+ private long toDirectIndex(MultiDimensionIterator iter, MultiDimensionIterator common, long[] strides, int[] map, int[] commonmap) {
+ long directIndex = 0;
+ for (int i = 0; i < iter.length(); ++i) {
+ directIndex += strides[map[i]] * iter.iterator[i];
+ }
+ for (int i = 0; i < common.length(); ++i) {
+ directIndex += strides[commonmap[i]] * common.iterator[i];
+ }
+ return directIndex;
+ }
+
+ private long[] strides(TensorType type) {
+ long[] strides = new long[type.dimensions().size()];
+ if (strides.length > 0) {
+ long previous = 1;
+ strides[strides.length - 1] = previous;
+ for (int i = strides.length - 2; i >= 0; --i) {
+ strides[i] = previous * type.dimensions().get(i + 1).size().get();
+ previous = strides[i];
+ }
+ }
+ return strides;
+ }
+
+ private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension aDim : a.type().dimensions()) {
+ for (TensorType.Dimension bDim : b.type().dimensions()) {
+ if (aDim.name().equals(bDim.name())) {
+ if ( ! aDim.size().isPresent()) {
+ builder.set(aDim);
+ } else if ( ! bDim.size().isPresent()) {
+ builder.set(bDim);
+ } else {
+ builder.set(aDim.size().get() < bDim.size().get() ? aDim : bDim); // minimum size of dimension
+ }
+ }
+ }
+ }
+ return builder.build();
+ }
+
+ /**
+ * Tests if there is exactly one reduce dimension and it is the innermost
+ * dimension in both tensors.
+ */
+ private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) {
+ List<String> reducingDimensions = dimensions;
+ if (reducingDimensions.isEmpty()) {
+ reducingDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b).dimensions().stream()
+ .map(TensorType.Dimension::name)
+ .collect(Collectors.toList());
+ }
+ if (reducingDimensions.size() != 1) {
+ return false;
+ }
+ String dimension = reducingDimensions.get(0);
+ int indexInA = a.type().indexOfDimension(dimension).orElseThrow(() ->
+ new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor A."));
+ if (indexInA != (a.type().dimensions().size() - 1)) {
+ return false;
+ }
+ int indexInB = b.type().indexOfDimension(dimension).orElseThrow(() ->
+ new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor B."));
+ if (indexInB < (b.type().dimensions().size() - 1)) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "reduce_join(" + argumentA.toString(context) + ", " +
+ argumentB.toString(context) + ", " +
+ combinator + ", " +
+ aggregator +
+ Reduce.commaSeparated(dimensions) + ")";
+ }
+
+ private static class MultiDimensionIterator {
+
+ private long[] bounds;
+ private long[] iterator;
+ private int remaining;
+
+ MultiDimensionIterator(TensorType type) {
+ bounds = new long[type.dimensions().size()];
+ iterator = new long[type.dimensions().size()];
+ for (int i = 0; i < bounds.length; ++i) {
+ bounds[i] = type.dimensions().get(i).size().get();
+ }
+ reset();
+ }
+
+ public int length() {
+ return iterator.length;
+ }
+
+ public boolean hasNext() {
+ return remaining > 0;
+ }
+
+ public void reset() {
+ remaining = 1;
+ for (int i = iterator.length - 1; i >= 0; --i) {
+ iterator[i] = 0;
+ remaining *= bounds[i];
+ }
+ }
+
+ public void next() {
+ for (int i = iterator.length - 1; i >= 0; --i) {
+ iterator[i] += 1;
+ if (iterator[i] < bounds[i]) {
+ break;
+ }
+ iterator[i] = 0;
+ }
+ remaining -= 1;
+ }
+
+ public String toString() {
+ return Arrays.toString(iterator);
+ }
+ }
+
+}