aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 11:12:53 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 11:12:53 +0100
commit39a652ce439a42fb8db372c821c834d02c95b0f1 (patch)
treeb246abdbddfcf8fd84f2f74590aeb56743f2dfa3
parented8ec5305f6838e31de94ef87ddd3a75390b59ed (diff)
Tensor generate functions
-rw-r--r--container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java2
-rw-r--r--jdisc_core/src/main/java/com/yahoo/jdisc/Response.java4
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj81
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java19
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java56
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java53
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java51
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java82
9 files changed, 314 insertions, 36 deletions
diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java
index 5cabe8a9ec6..ccfb858c473 100644
--- a/container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java
+++ b/container-core/src/main/java/com/yahoo/container/jdisc/ThreadedRequestHandler.java
@@ -33,7 +33,7 @@ import javax.annotation.concurrent.GuardedBy;
* Note that this means that subclass handlers are synchronous - the request io can
* continue after completion of the worker thread.
*
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
*/
public abstract class ThreadedRequestHandler extends AbstractRequestHandler {
diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/Response.java b/jdisc_core/src/main/java/com/yahoo/jdisc/Response.java
index 809805fdcc4..d624e070a2e 100644
--- a/jdisc_core/src/main/java/com/yahoo/jdisc/Response.java
+++ b/jdisc_core/src/main/java/com/yahoo/jdisc/Response.java
@@ -208,9 +208,9 @@ public class Response {
}
/**
- * <p>This is a convenience method for creating a Response with status {@link Status#REQUEST_TIMEOUT} and passing
+ * This is a convenience method for creating a Response with status {@link Status#REQUEST_TIMEOUT} and passing
* that to the given {@link ResponseHandler#handleResponse(Response)}. For trivial implementations of {@link
- * RequestHandler#handleTimeout(Request, ResponseHandler)}, simply call this method.</p>
+ * RequestHandler#handleTimeout(Request, ResponseHandler)}, simply call this method.
*
* @param handler The handler to pass the timeout {@link Response} to.
*/
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 564b2cd9801..fab80304f6d 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -123,6 +123,9 @@ TOKEN :
<JOIN: "join"> |
<RENAME: "rename"> |
<TENSOR: "tensor"> |
+ <RANGE: "range"> |
+ <DIAG: "diag"> |
+ <RANDOM: "random"> |
<L1_NORMALIZE: "l1_normalize"> |
<L2_NORMALIZE: "l2_normalize"> |
<MATMUL: "matmul"> |
@@ -352,6 +355,9 @@ ExpressionNode tensorFunction() :
tensorExpression = tensorJoin() |
tensorExpression = tensorRename() |
tensorExpression = tensorGenerate() |
+ tensorExpression = tensorRange() |
+ tensorExpression = tensorDiag() |
+ tensorExpression = tensorRandom() |
tensorExpression = tensorL1Normalize() |
tensorExpression = tensorL2Normalize() |
tensorExpression = tensorMatmul() |
@@ -426,31 +432,35 @@ ExpressionNode tensorGenerate() :
ExpressionNode generator;
}
{
- type = tensorType() <LBRACE> generator = expression() <RBRACE>
+ <TENSOR> type = tensorTypeArgument() <LBRACE> generator = expression() <RBRACE>
{ return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asIntegerListToDoubleOperator())); }
}
-TensorType tensorType() :
+ExpressionNode tensorRange() :
{
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType type;
}
{
- <TENSOR> <LBRACE>
- ( tensorTypeDimension(builder) ) ?
- ( <COMMA> tensorTypeDimension(builder) ) *
- <RBRACE>
- { return builder.build(); }
+ <RANGE> type = tensorTypeArgument()
+ { return new TensorFunctionNode(new Range(type)); }
}
-// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need
-void tensorTypeDimension(TensorType.Builder builder) :
+ExpressionNode tensorDiag() :
{
- String name;
- int size;
+ TensorType type;
}
{
- name = identifier() <LSQUARE> size = integerNumber() <RSQUARE>
- { builder.indexed(name, size); }
+ <DIAG> type = tensorTypeArgument()
+ { return new TensorFunctionNode(new Diag(type)); }
+}
+
+ExpressionNode tensorRandom() :
+{
+ TensorType type;
+}
+{
+ <RANDOM> type = tensorTypeArgument()
+ { return new TensorFunctionNode(new Random(type)); }
}
ExpressionNode tensorL1Normalize() :
@@ -529,18 +539,49 @@ Reduce.Aggregator tensorReduceAggregator() :
{ return Reduce.Aggregator.valueOf(token.image); }
}
+TensorType tensorTypeArgument() :
+{
+ TensorType.Builder builder = new TensorType.Builder();
+}
+{
+ <LBRACE>
+ ( tensorTypeDimension(builder) ) ?
+ ( <COMMA> tensorTypeDimension(builder) ) *
+ <RBRACE>
+ { return builder.build(); }
+}
+
+// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need
+void tensorTypeDimension(TensorType.Builder builder) :
+{
+ String name;
+ int size;
+}
+{
+ name = identifier() <LSQUARE> size = integerNumber() <RSQUARE>
+ { builder.indexed(name, size); }
+}
+
// This is needed not to parse tensor functions but for the "reserved names as literals" workaround cludge
String tensorFunctionName() :
{
Reduce.Aggregator aggregator;
}
{
- ( <F> { return token.image; } ) |
- ( <MAP> { return token.image; } ) |
- ( <REDUCE> { return token.image; } ) |
- ( <JOIN> { return token.image; } ) |
- ( <RENAME> { return token.image; } ) |
- ( <TENSOR> { return token.image; } ) |
+ ( <F> { return token.image; } ) |
+ ( <MAP> { return token.image; } ) |
+ ( <REDUCE> { return token.image; } ) |
+ ( <JOIN> { return token.image; } ) |
+ ( <RENAME> { return token.image; } ) |
+ ( <TENSOR> { return token.image; } ) |
+ ( <RANGE> { return token.image; } ) |
+ ( <DIAG> { return token.image; } ) |
+ ( <RANDOM> { return token.image; } ) |
+ ( <L1_NORMALIZE> { return token.image; } ) |
+ ( <L2_NORMALIZE> { return token.image; } ) |
+ ( <MATMUL> { return token.image; } ) |
+ ( <SOFTMAX> { return token.image; } ) |
+ ( <XW_PLUS_B> { return token.image; } ) |
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 55638c3687b..dc451b1dc5c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -5,6 +5,7 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.*;
+import com.yahoo.tensor.Tensor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -227,7 +228,7 @@ public class EvaluationTestCase {
// argmin
tester.assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0 }",
"tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
-
+
// tensor rename
tester.assertEvaluates("{ {newX:0,y:0}:3 }", "rename(tensor0, x, newX)", "{ {x:0,y:0}:3.0 }");
tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }", "rename(tensor0, (x, y), (y, x))", "{ {x:0,y:0}:3.0, {x:0,y:1}:5.0 }");
@@ -235,11 +236,11 @@ public class EvaluationTestCase {
// tensor generate
tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:0, {x:0,y:1}:1, {x:1,y:1}:0, {x:0,y:2}:0, {x:1,y:2}:1 }", "tensor(x[2],y[3])(x+1==y)");
tester.assertEvaluates("{ {y:0,x:0}:0, {y:1,x:0}:0, {y:0,x:1}:1, {y:1,x:1}:0, {y:0,x:2}:0, {y:1,x:2}:1 }", "tensor(y[2],x[3])(y+1==x)");
- // TODO
- // range
- // diag
- // fill
- // random
+ tester.assertEvaluates("{ {x:0,y:0,z:0}:1 }", "tensor(x[1],y[1],z[1])((x==y)*(y==z))");
+ // - generate composites
+ tester.assertEvaluates("{ {x:0}:0, {x:1}:1, {x:2}:2 }", "range(x[3])");
+ tester.assertEvaluates("{ {x:0,y:0,z:0}:1, {x:0,y:0,z:1}:0, {x:0,y:1,z:0}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:0}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:1, }", "diag(x[2],y[2],z[2])");
+ tester.assertEvaluates("6", "reduce(random(x[2],y[3]), count)");
// composite functions
tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }");
@@ -265,12 +266,6 @@ public class EvaluationTestCase {
}
@Test
- public void testItz() {
- EvaluationTester tester = new EvaluationTester();
- tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }");
- }
-
- @Test
public void testProgrammaticBuildingAndPrecedence() {
RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4))));
RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4)));
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index c1a24abd878..e99e7da7415 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -443,7 +443,7 @@ public class IndexedTensor implements Tensor {
public boolean equals(Object o) {
if (o == this) return true;
if ( ! ( o instanceof Map.Entry)) return false;
- Map.Entry other = (Map.Entry)o;
+ Map.Entry<?,?> other = (Map.Entry)o;
if ( ! this.getValue().equals(other.getValue())) return false;
if ( ! this.getKey().equals(other.getKey())) return false;
return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
new file mode 100644
index 00000000000..0bb92bc2a6f
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -0,0 +1,56 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.CompositeTensorFunction;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.functions.ToStringContext;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere.
+ *
+ * @author bratseth
+ */
+public class Diag extends CompositeTensorFunction {
+
+ private final TensorType type;
+ private final Function<List<Integer>, Double> diagFunction;
+
+ public Diag(TensorType type) {
+ this.type = type;
+ this.diagFunction = ScalarFunctions.equalArguments(dimensionNames().collect(Collectors.toList()));
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Generate(type, diagFunction);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
+ }
+
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::toString);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
new file mode 100644
index 00000000000..ba34c0d9748
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -0,0 +1,53 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.CompositeTensorFunction;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.functions.ToStringContext;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * A tensor generator which returns a tensor of any dimension filled with random numbers between 0 and 1.
+ *
+ * @author bratseth
+ */
+public class Random extends CompositeTensorFunction {
+
+ private final TensorType type;
+
+ public Random(TensorType type) {
+ this.type = type;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Generate(type, ScalarFunctions.random());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
+ }
+
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::toString);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
new file mode 100644
index 00000000000..e18edd48127
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -0,0 +1,51 @@
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.TensorType;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+/**
+ * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor
+ * indexes of each position.
+ *
+ * @author bratseth
+ */
+public class Range extends CompositeTensorFunction {
+
+ private final TensorType type;
+ private final Function<List<Integer>, Double> rangeFunction;
+
+ public Range(TensorType type) {
+ this.type = type;
+ this.rangeFunction = ScalarFunctions.sumArguments(dimensionNames().collect(Collectors.toList()));
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Generate(type, rangeFunction);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
+ }
+
+ private Stream<String> dimensionNames() {
+ return type.dimensions().stream().map(TensorType.Dimension::toString);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index c1b1cb2243d..a0b60f53df3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -1,9 +1,15 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import java.util.List;
+import java.util.concurrent.ThreadLocalRandom;
import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleSupplier;
import java.util.function.DoubleUnaryOperator;
+import java.util.function.Function;
+import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
@@ -21,6 +27,13 @@ public class ScalarFunctions {
public static DoubleUnaryOperator square() { return new Square(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator exp() { return new Exponent(); }
+ public static Function<List<Integer>, Double> random() { return new Random(); }
+ public static Function<List<Integer>, Double> equalArguments(List<String> argumentNames) {
+ return new EqualArguments(argumentNames);
+ }
+ public static Function<List<Integer>, Double> sumArguments(List<String> argumentNames) {
+ return new SumArguments(argumentNames);
+ }
public static class Addition implements DoubleBinaryOperator {
@@ -81,4 +94,73 @@ public class ScalarFunctions {
}
+ public static class Random implements Function<List<Integer>, Double> {
+
+ @Override
+ public Double apply(List<Integer> values) {
+ return ThreadLocalRandom.current().nextDouble();
+ }
+
+ @Override
+ public String toString() { return "random()"; }
+
+ }
+
+ public static class EqualArguments implements Function<List<Integer>, Double> {
+
+ private final ImmutableList<String> argumentNames;
+
+ private EqualArguments(List<String> argumentNames) {
+ this.argumentNames = ImmutableList.copyOf(argumentNames);
+ }
+
+ @Override
+ public Double apply(List<Integer> values) {
+ if (values.isEmpty()) return 1.0;
+ for (Integer value : values)
+ if ( ! value.equals(values.get(0)))
+ return 0.0;
+ return 1.0;
+ }
+
+ @Override
+ public String toString() {
+ if (argumentNames.size() == 0) return "(1)";
+ if (argumentNames.size() == 1) return "(1)";
+ if (argumentNames.size() == 2) return "(" + argumentNames.get(0) + "==" + argumentNames.get(1) + ")";
+
+ StringBuilder b = new StringBuilder("(");
+ for (int i = 0; i < argumentNames.size() -1; i++) {
+ b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
+ if ( i < argumentNames.size() -2)
+ b.append("*");
+ }
+ return b.toString();
+ }
+
+ }
+
+ public static class SumArguments implements Function<List<Integer>, Double> {
+
+ private final ImmutableList<String> argumentNames;
+
+ private SumArguments(List<String> argumentNames) {
+ this.argumentNames = ImmutableList.copyOf(argumentNames);
+ }
+
+ @Override
+ public Double apply(List<Integer> values) {
+ int sum = 0;
+ for (Integer value : values)
+ sum += value;
+ return (double)sum;
+ }
+
+ @Override
+ public String toString() {
+ return "(" + argumentNames.stream().collect(Collectors.joining("+")) + ")";
+ }
+
+ }
+
}