diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-22 11:12:53 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-22 11:12:53 +0100 |
commit | 39a652ce439a42fb8db372c821c834d02c95b0f1 (patch) | |
tree | b246abdbddfcf8fd84f2f74590aeb56743f2dfa3 | |
parent | ed8ec5305f6838e31de94ef87ddd3a75390b59ed (diff) |
Tensor generate functions
9 files changed, 314 insertions, 36 deletions
diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java index 5cabe8a9ec6..ccfb858c473 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java @@ -33,7 +33,7 @@ import javax.annotation.concurrent.GuardedBy; * Note that this means that subclass handlers are synchronous - the request io can * continue after completion of the worker thread. * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public abstract class ThreadedRequestHandler extends AbstractRequestHandler { diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/Response.java b/jdisc_core/src/main/java/com/yahoo/jdisc/Response.java index 809805fdcc4..d624e070a2e 100644 --- a/jdisc_core/src/main/java/com/yahoo/jdisc/Response.java +++ b/jdisc_core/src/main/java/com/yahoo/jdisc/Response.java @@ -208,9 +208,9 @@ public class Response { } /** - * <p>This is a convenience method for creating a Response with status {@link Status#REQUEST_TIMEOUT} and passing + * This is a convenience method for creating a Response with status {@link Status#REQUEST_TIMEOUT} and passing * that to the given {@link ResponseHandler#handleResponse(Response)}. For trivial implementations of {@link - * RequestHandler#handleTimeout(Request, ResponseHandler)}, simply call this method.</p> + * RequestHandler#handleTimeout(Request, ResponseHandler)}, simply call this method. * * @param handler The handler to pass the timeout {@link Response} to. */ diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 564b2cd9801..fab80304f6d 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -123,6 +123,9 @@ TOKEN : <JOIN: "join"> | <RENAME: "rename"> | <TENSOR: "tensor"> | + <RANGE: "range"> | + <DIAG: "diag"> | + <RANDOM: "random"> | <L1_NORMALIZE: "l1_normalize"> | <L2_NORMALIZE: "l2_normalize"> | <MATMUL: "matmul"> | @@ -352,6 +355,9 @@ ExpressionNode tensorFunction() : tensorExpression = tensorJoin() | tensorExpression = tensorRename() | tensorExpression = tensorGenerate() | + tensorExpression = tensorRange() | + tensorExpression = tensorDiag() | + tensorExpression = tensorRandom() | tensorExpression = tensorL1Normalize() | tensorExpression = tensorL2Normalize() | tensorExpression = tensorMatmul() | @@ -426,31 +432,35 @@ ExpressionNode tensorGenerate() : ExpressionNode generator; } { - type = tensorType() <LBRACE> generator = expression() <RBRACE> + <TENSOR> type = tensorTypeArgument() <LBRACE> generator = expression() <RBRACE> { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asIntegerListToDoubleOperator())); } } -TensorType tensorType() : +ExpressionNode tensorRange() : { - TensorType.Builder builder = new TensorType.Builder(); + TensorType type; } { - <TENSOR> <LBRACE> - ( tensorTypeDimension(builder) ) ? - ( <COMMA> tensorTypeDimension(builder) ) * - <RBRACE> - { return builder.build(); } + <RANGE> type = tensorTypeArgument() + { return new TensorFunctionNode(new Range(type)); } } -// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need -void tensorTypeDimension(TensorType.Builder builder) : +ExpressionNode tensorDiag() : { - String name; - int size; + TensorType type; } { - name = identifier() <LSQUARE> size = integerNumber() <RSQUARE> - { builder.indexed(name, size); } + <DIAG> type = tensorTypeArgument() + { return new TensorFunctionNode(new Diag(type)); } +} + +ExpressionNode tensorRandom() : +{ + TensorType type; +} +{ + <RANDOM> type = tensorTypeArgument() + { return new TensorFunctionNode(new Random(type)); } } ExpressionNode tensorL1Normalize() : @@ -529,18 +539,49 @@ Reduce.Aggregator tensorReduceAggregator() : { return Reduce.Aggregator.valueOf(token.image); } } +TensorType tensorTypeArgument() : +{ + TensorType.Builder builder = new TensorType.Builder(); +} +{ + <LBRACE> + ( tensorTypeDimension(builder) ) ? + ( <COMMA> tensorTypeDimension(builder) ) * + <RBRACE> + { return builder.build(); } +} + +// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need +void tensorTypeDimension(TensorType.Builder builder) : +{ + String name; + int size; +} +{ + name = identifier() <LSQUARE> size = integerNumber() <RSQUARE> + { builder.indexed(name, size); } +} + // This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge String tensorFunctionName() : { Reduce.Aggregator aggregator; } { - ( <F> { return token.image; } ) | - ( <MAP> { return token.image; } ) | - ( <REDUCE> { return token.image; } ) | - ( <JOIN> { return token.image; } ) | - ( <RENAME> { return token.image; } ) | - ( <TENSOR> { return token.image; } ) | + ( <F> { return token.image; } ) | + ( <MAP> { return token.image; } ) | + ( <REDUCE> { return token.image; } ) | + ( <JOIN> { return token.image; } ) | + ( <RENAME> { return token.image; } ) | + ( <TENSOR> { return token.image; } ) | + ( <RANGE> { return token.image; } ) | + ( <DIAG> { return token.image; } ) | + ( <RANDOM> { return token.image; } ) | + ( <L1_NORMALIZE> { return token.image; } ) | + ( <L2_NORMALIZE> { return token.image; } ) | + ( <MATMUL> { return token.image; } ) | + ( <SOFTMAX> { return token.image; } ) | + ( <XW_PLUS_B> { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 55638c3687b..dc451b1dc5c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -5,6 +5,7 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.*; +import com.yahoo.tensor.Tensor; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -227,7 +228,7 @@ public class EvaluationTestCase { // argmin tester.assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0 }", "tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); - + // tensor rename tester.assertEvaluates("{ {newX:0,y:0}:3 }", "rename(tensor0, x, newX)", "{ {x:0,y:0}:3.0 }"); tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }", "rename(tensor0, (x, y), (y, x))", "{ {x:0,y:0}:3.0, {x:0,y:1}:5.0 }"); @@ -235,11 +236,11 @@ public class EvaluationTestCase { // tensor generate tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:0, {x:0,y:1}:1, {x:1,y:1}:0, {x:0,y:2}:0, {x:1,y:2}:1 }", "tensor(x[2],y[3])(x+1==y)"); tester.assertEvaluates("{ {y:0,x:0}:0, {y:1,x:0}:0, {y:0,x:1}:1, {y:1,x:1}:0, {y:0,x:2}:0, {y:1,x:2}:1 }", "tensor(y[2],x[3])(y+1==x)"); - // TODO - // range - // diag - // fill - // random + tester.assertEvaluates("{ {x:0,y:0,z:0}:1 }", "tensor(x[1],y[1],z[1])((x==y)*(y==z))"); + // - generate composites + tester.assertEvaluates("{ {x:0}:0, {x:1}:1, {x:2}:2 }", "range(x[3])"); + tester.assertEvaluates("{ {x:0,y:0,z:0}:1, {x:0,y:0,z:1}:0, {x:0,y:1,z:0}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:0}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:1, }", "diag(x[2],y[2],z[2])"); + tester.assertEvaluates("6", "reduce(random(x[2],y[3]), count)"); // composite functions tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }"); @@ -265,12 +266,6 @@ public class EvaluationTestCase { } @Test - public void testItz() { - EvaluationTester tester = new EvaluationTester(); - tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }"); - } - - @Test public void testProgrammaticBuildingAndPrecedence() { RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4)))); RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4))); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index c1a24abd878..e99e7da7415 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -443,7 +443,7 @@ public class IndexedTensor implements Tensor { public boolean equals(Object o) { if (o == this) return true; if ( ! ( o instanceof Map.Entry)) return false; - Map.Entry other = (Map.Entry)o; + Map.Entry<?,?> other = (Map.Entry)o; if ( ! this.getValue().equals(other.getValue())) return false; if ( ! this.getKey().equals(other.getKey())) return false; return true; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java new file mode 100644 index 00000000000..0bb92bc2a6f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -0,0 +1,56 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.CompositeTensorFunction; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere. + * + * @author bratseth + */ +public class Diag extends CompositeTensorFunction { + + private final TensorType type; + private final Function<List<Integer>, Double> diagFunction; + + public Diag(TensorType type) { + this.type = type; + this.diagFunction = ScalarFunctions.equalArguments(dimensionNames().collect(Collectors.toList())); + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Generate(type, diagFunction); + } + + @Override + public String toString(ToStringContext context) { + return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; + } + + private Stream<String> dimensionNames() { + return type.dimensions().stream().map(TensorType.Dimension::toString); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java new file mode 100644 index 00000000000..ba34c0d9748 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -0,0 +1,53 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.CompositeTensorFunction; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A tensor generator which returns a tensor of any dimension filled with random numbers between 0 and 1. + * + * @author bratseth + */ +public class Random extends CompositeTensorFunction { + + private final TensorType type; + + public Random(TensorType type) { + this.type = type; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Generate(type, ScalarFunctions.random()); + } + + @Override + public String toString(ToStringContext context) { + return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; + } + + private Stream<String> dimensionNames() { + return type.dimensions().stream().map(TensorType.Dimension::toString); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java new file mode 100644 index 00000000000..e18edd48127 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -0,0 +1,51 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; + +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor + * indexes of each position. + * + * @author bratseth + */ +public class Range extends CompositeTensorFunction { + + private final TensorType type; + private final Function<List<Integer>, Double> rangeFunction; + + public Range(TensorType type) { + this.type = type; + this.rangeFunction = ScalarFunctions.sumArguments(dimensionNames().collect(Collectors.toList())); + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Generate(type, rangeFunction); + } + + @Override + public String toString(ToStringContext context) { + return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; + } + + private Stream<String> dimensionNames() { + return type.dimensions().stream().map(TensorType.Dimension::toString); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index c1b1cb2243d..a0b60f53df3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -1,9 +1,15 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleSupplier; import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; +import java.util.stream.Collectors; /** * Factory of scalar Java functions. @@ -21,6 +27,13 @@ public class ScalarFunctions { public static DoubleUnaryOperator square() { return new Square(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator exp() { return new Exponent(); } + public static Function<List<Integer>, Double> random() { return new Random(); } + public static Function<List<Integer>, Double> equalArguments(List<String> argumentNames) { + return new EqualArguments(argumentNames); + } + public static Function<List<Integer>, Double> sumArguments(List<String> argumentNames) { + return new SumArguments(argumentNames); + } public static class Addition implements DoubleBinaryOperator { @@ -81,4 +94,73 @@ public class ScalarFunctions { } + public static class Random implements Function<List<Integer>, Double> { + + @Override + public Double apply(List<Integer> values) { + return ThreadLocalRandom.current().nextDouble(); + } + + @Override + public String toString() { return "random()"; } + + } + + public static class EqualArguments implements Function<List<Integer>, Double> { + + private final ImmutableList<String> argumentNames; + + private EqualArguments(List<String> argumentNames) { + this.argumentNames = ImmutableList.copyOf(argumentNames); + } + + @Override + public Double apply(List<Integer> values) { + if (values.isEmpty()) return 1.0; + for (Integer value : values) + if ( ! value.equals(values.get(0))) + return 0.0; + return 1.0; + } + + @Override + public String toString() { + if (argumentNames.size() == 0) return "(1)"; + if (argumentNames.size() == 1) return "(1)"; + if (argumentNames.size() == 2) return "(" + argumentNames.get(0) + "==" + argumentNames.get(1) + ")"; + + StringBuilder b = new StringBuilder("("); + for (int i = 0; i < argumentNames.size() -1; i++) { + b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")"); + if ( i < argumentNames.size() -2) + b.append("*"); + } + return b.toString(); + } + + } + + public static class SumArguments implements Function<List<Integer>, Double> { + + private final ImmutableList<String> argumentNames; + + private SumArguments(List<String> argumentNames) { + this.argumentNames = ImmutableList.copyOf(argumentNames); + } + + @Override + public Double apply(List<Integer> values) { + int sum = 0; + for (Integer value : values) + sum += value; + return (double)sum; + } + + @Override + public String toString() { + return "(" + argumentNames.stream().collect(Collectors.joining("+")) + ")"; + } + + } + } |