summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-09-29 12:29:29 +0200
committerLester Solbakken <lesters@oath.com>2018-09-29 12:29:29 +0200
commit35da647c2c32ea3a81b5c65bd440af1dc59b1b3c (patch)
treea2e29cad34de2da864b9898245e802f7ab6cd0f9
parented3bc0556f0ec8b867c8f5d80b26f8e79b6a8b27 (diff)
Non-functional changes
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java106
2 files changed, 65 insertions, 59 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index ab71c3f64dc..3a3410aeebb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -109,19 +109,19 @@ public class LambdaFunctionNode extends CompositeNode {
}
ArithmeticOperator operator = node.operators().get(0);
switch (operator) {
- case OR: return asDoubleBinaryOperator((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
- case AND: return asDoubleBinaryOperator((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
- case PLUS: return asDoubleBinaryOperator((left, right) -> left + right);
- case MINUS: return asDoubleBinaryOperator((left, right) -> left - right);
- case MULTIPLY: return asDoubleBinaryOperator((left, right) -> left * right);
- case DIVIDE: return asDoubleBinaryOperator((left, right) -> left / right);
- case MODULO: return asDoubleBinaryOperator((left, right) -> left % right);
- case POWER: return asDoubleBinaryOperator(Math::pow);
+ case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
+ case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
+ case PLUS: return asFunctionExpression((left, right) -> left + right);
+ case MINUS: return asFunctionExpression((left, right) -> left - right);
+ case MULTIPLY: return asFunctionExpression((left, right) -> left * right);
+ case DIVIDE: return asFunctionExpression((left, right) -> left / right);
+ case MODULO: return asFunctionExpression((left, right) -> left % right);
+ case POWER: return asFunctionExpression(Math::pow);
}
return Optional.empty();
}
- private Optional<DoubleBinaryOperator> asDoubleBinaryOperator(DoubleBinaryOperator operator) {
+ private Optional<DoubleBinaryOperator> asFunctionExpression(DoubleBinaryOperator operator) {
return Optional.of(new DoubleBinaryOperator() {
@Override
public double applyAsDouble(double left, double right) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 9f53ced7719..b268e33b418 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -16,7 +16,7 @@ import java.util.stream.Collectors;
/**
* An optimization for tensor expressions where a join immediately follows a
- * reduce. Evaluating this as one operations is significantly more efficient
+ * reduce. Evaluating this as one operation is significantly more efficient
* than evaluating each separately.
*
* This implementation optimizes the case where the reduce is done on the same
@@ -71,30 +71,33 @@ public class ReduceJoin extends CompositeTensorFunction {
public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType joinType = new TensorType.Builder(a.type(), b.type()).build();
+ TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
if (canOptimize(a, b)) {
- return evaluate((IndexedTensor)a, (IndexedTensor)b, joinType);
+ return evaluate((IndexedTensor)a, (IndexedTensor)b, joinedType);
}
- return Reduce.evaluate(Join.evaluate(a, b, joinType, combinator), dimensions, aggregator);
+ return Reduce.evaluate(Join.evaluate(a, b, joinedType, combinator), dimensions, aggregator);
}
/**
* Tests whether or not the reduce is over the join dimensions. The
- * remaining logic in this class assumes this to return true.
+ * remaining logic in this class assumes this to be true.
*
* If no dimensions are given, the join must be on all tensor dimensions.
+ *
+ * @return {@code true} if the implementation can optimize evaluation
+ * given the two tensors.
*/
public boolean canOptimize(Tensor a, Tensor b) {
- if (a.type().dimensions().size() == 0 || b.type().dimensions().size() == 0) // TODO: support scalars
+ if (a.type().dimensions().isEmpty() || b.type().dimensions().isEmpty()) // TODO: support scalars
return false;
if ( ! (a instanceof IndexedTensor))
return false;
- if ( ! (a.type().dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)))
+ if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)))
return false;
if ( ! (b instanceof IndexedTensor))
return false;
- if ( ! (b.type().dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)))
+ if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)))
return false;
TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b);
@@ -113,36 +116,11 @@ public class ReduceJoin extends CompositeTensorFunction {
}
/**
- * Tests if the reduce dimension is the innermost dimension in both tensors
- */
- private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) {
- List<String> reducingDimensions = dimensions;
- if (reducingDimensions.isEmpty()) {
- reducingDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b).dimensions().stream()
- .map(TensorType.Dimension::name)
- .collect(Collectors.toList());
- }
- if (reducingDimensions.size() != 1) {
- return false;
- }
- String dimension = reducingDimensions.get(0);
- int indexInA = a.type().indexOfDimension(dimension).orElseThrow(() -> new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor A."));
- if (indexInA != (a.type().dimensions().size() - 1)) {
- return false;
- }
- int indexInB = b.type().indexOfDimension(dimension).orElseThrow(() -> new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor B."));
- if (indexInB < (b.type().dimensions().size() - 1)) {
- return false;
- }
- return true;
- }
-
- /**
* Evaluates the reduce-join. Special handling for common cases where the
* reduce dimension is the innermost dimension in both tensors.
*/
- private Tensor evaluate(IndexedTensor a, IndexedTensor b, TensorType joinType) {
- TensorType reducedType = Reduce.outputType(joinType, dimensions);
+ private Tensor evaluate(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
+ TensorType reducedType = Reduce.outputType(joinedType, dimensions);
if (reduceDimensionIsInnermost(a, b)) {
if (a.type().dimensions().size() == 1 && b.type().dimensions().size() == 1) {
@@ -200,15 +178,15 @@ public class ReduceJoin extends CompositeTensorFunction {
return builder.build();
}
- private Tensor matrixMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reduceType) {
+ private Tensor matrixMatrixProduct(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
if ( a.type().dimensions().size() != 2 || b.type().dimensions().size() != 2) {
throw new IllegalArgumentException("Wrong dimension sizes for tensors for matrix-matrix product");
}
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reduceType);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
DimensionSizes sizesA = a.dimensionSizes();
DimensionSizes sizesB = b.dimensionSizes();
- int iaToReduced = reduceType.indexOfDimension(a.type().dimensions().get(0).name()).get();
- int ibToReduced = reduceType.indexOfDimension(b.type().dimensions().get(0).name()).get();
+ int iaToReduced = reducedType.indexOfDimension(a.type().dimensions().get(0).name()).get();
+ int ibToReduced = reducedType.indexOfDimension(b.type().dimensions().get(0).name()).get();
long strideA = iaToReduced < ibToReduced ? sizesB.size(0) : 1;
long strideB = ibToReduced < iaToReduced ? sizesA.size(0) : 1;
@@ -227,8 +205,8 @@ public class ReduceJoin extends CompositeTensorFunction {
return builder.build();
}
- private Tensor evaluateGeneral(IndexedTensor a, IndexedTensor b, TensorType reduceType) {
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reduceType);
+ private Tensor evaluateGeneral(IndexedTensor a, IndexedTensor b, TensorType reducedType) {
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(reducedType);
TensorType onlyInA = Reduce.outputType(a.type(), dimensions);
TensorType onlyInB = Reduce.outputType(b.type(), dimensions);
TensorType common = dimensionsInCommon(a, b);
@@ -236,15 +214,15 @@ public class ReduceJoin extends CompositeTensorFunction {
// pre-calculate strides for each index position
long[] stridesA = strides(a.type());
long[] stridesB = strides(b.type());
- long[] stridesResult = strides(reduceType);
+ long[] stridesResult = strides(reducedType);
// mapping of dimension indexes
int[] mapOnlyAToA = Join.mapIndexes(onlyInA, a.type());
int[] mapCommonToA = Join.mapIndexes(common, a.type());
int[] mapOnlyBToB = Join.mapIndexes(onlyInB, b.type());
int[] mapCommonToB = Join.mapIndexes(common, b.type());
- int[] mapOnlyAToResult = Join.mapIndexes(onlyInA, reduceType);
- int[] mapOnlyBToResult = Join.mapIndexes(onlyInB, reduceType);
+ int[] mapOnlyAToResult = Join.mapIndexes(onlyInA, reducedType);
+ int[] mapOnlyBToResult = Join.mapIndexes(onlyInB, reducedType);
// TODO: refactor with code in IndexedTensor and Join
@@ -307,6 +285,34 @@ public class ReduceJoin extends CompositeTensorFunction {
return builder.build();
}
+ /**
+ * Tests if there is exactly one reduce dimension and it is the innermost
+ * dimension in both tensors.
+ */
+ private boolean reduceDimensionIsInnermost(Tensor a, Tensor b) {
+ List<String> reducingDimensions = dimensions;
+ if (reducingDimensions.isEmpty()) {
+ reducingDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b).dimensions().stream()
+ .map(TensorType.Dimension::name)
+ .collect(Collectors.toList());
+ }
+ if (reducingDimensions.size() != 1) {
+ return false;
+ }
+ String dimension = reducingDimensions.get(0);
+ int indexInA = a.type().indexOfDimension(dimension).orElseThrow(() ->
+ new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor A."));
+ if (indexInA != (a.type().dimensions().size() - 1)) {
+ return false;
+ }
+ int indexInB = b.type().indexOfDimension(dimension).orElseThrow(() ->
+ new IllegalArgumentException("Reduce-Join dimension '" + dimension + "' missing in tensor B."));
+ if (indexInB < (b.type().dimensions().size() - 1)) {
+ return false;
+ }
+ return true;
+ }
+
@Override
public String toString(ToStringContext context) {
return "reduce_join(" + argumentA.toString(context) + ", " +
@@ -318,15 +324,15 @@ public class ReduceJoin extends CompositeTensorFunction {
private static class MultiDimensionIterator {
- private long[] max;
+ private long[] bounds;
private long[] iterator;
private int remaining;
MultiDimensionIterator(TensorType type) {
- max = new long[type.dimensions().size()];
+ bounds = new long[type.dimensions().size()];
iterator = new long[type.dimensions().size()];
- for (int i = 0; i < max.length; ++i) {
- max[i] = type.dimensions().get(i).size().get();
+ for (int i = 0; i < bounds.length; ++i) {
+ bounds[i] = type.dimensions().get(i).size().get();
}
reset();
}
@@ -343,14 +349,14 @@ public class ReduceJoin extends CompositeTensorFunction {
remaining = 1;
for (int i = iterator.length - 1; i >= 0; --i) {
iterator[i] = 0;
- remaining *= max[i];
+ remaining *= bounds[i];
}
}
public void next() {
for (int i = iterator.length - 1; i >= 0; --i) {
iterator[i] += 1;
- if (iterator[i] < max[i]) {
+ if (iterator[i] < bounds[i]) {
break;
}
iterator[i] = 0;