summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-26 16:51:50 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-26 16:51:50 +0200
commitf4203c3cc571722f08ee65047437c1290ed63f69 (patch)
tree7d06d17091a2e388e6771187a11cf4f4023a0c1e
parent316c941e90f39d2e9bc46f12b96ca0f87471d1bd (diff)
Allow bound functions in tensor generate
-rw-r--r--searchlib/abi-spec.json5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java46
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java3
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java5
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java5
10 files changed, 67 insertions, 28 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index d5970a4b69e..dcf42069373 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1615,8 +1615,9 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
- "public static java.util.Map wrap(java.util.Map)",
- "public static java.util.List wrap(java.util.List)"
+ "public static java.util.Map wrapScalars(java.util.Map)",
+ "public static java.util.List wrapScalars(java.util.List)",
+ "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
],
"fields": []
},
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index d68f8c85ad1..cf17c6465f3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -34,7 +34,7 @@ public abstract class Context implements EvaluationContext<Reference> {
@Override
public TensorType getType(String reference) {
- throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ throw new UnsupportedOperationException("Not able to parse general references from string form");
}
/** Returns a variable as a tensor */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 4ffd40f00f7..18f1fa8a78f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
+import java.sql.Ref;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
@@ -81,21 +82,22 @@ public class TensorFunctionNode extends CompositeNode {
return new ExpressionTensorFunction(node);
}
- public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) {
- Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>();
+ public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> nodes) {
+ Map<TensorAddress, ScalarFunction<Reference>> functions = new LinkedHashMap<>();
for (var entry : nodes.entrySet())
- functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue()));
+ functions.put(entry.getKey(), wrapScalar(entry.getValue()));
return functions;
}
- public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) {
- List<ScalarFunction> functions = new ArrayList<>();
- for (var entry : nodes)
- functions.add(new ExpressionScalarFunction(entry));
- return functions;
+ public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) {
+ return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList());
+ }
+
+ public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) {
+ return new ExpressionScalarFunction(node);
}
- private static class ExpressionScalarFunction implements ScalarFunction {
+ private static class ExpressionScalarFunction implements ScalarFunction<Reference> {
private final ExpressionNode expression;
@@ -104,8 +106,8 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public Double apply(EvaluationContext<?> context) {
- return expression.evaluate((Context)context).asDouble();
+ public Double apply(EvaluationContext<Reference> context) {
+ return expression.evaluate(new ContextWrapper(context)).asDouble();
}
@Override
@@ -209,4 +211,26 @@ public class TensorFunctionNode extends CompositeNode {
}
+ /** Turns an EvaluationContext into a Context */
+ // TODO: We should be able to change RankingExpression.evaluate to take an EvaluationContext and then get rid of this
+ private static class ContextWrapper extends Context {
+
+ private final EvaluationContext<Reference> delegate;
+
+ public ContextWrapper(EvaluationContext<Reference> delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public Value get(String name) {
+ return new TensorValue(delegate.getTensor(name));
+ }
+
+ @Override
+ public TensorType getType(Reference name) {
+ return delegate.getType(name);
+ }
+
+ }
+
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 3e9649cd9c6..beab722a1eb 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -487,7 +487,7 @@ TensorFunctionNode tensorGenerateBody(TensorType type) :
}
{
<LBRACE> generator = expression() <RBRACE>
- { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); }
+ { return new TensorFunctionNode(Generate.bound(type, TensorFunctionNode.wrapScalar(generator))); }
}
TensorFunctionNode tensorRange() :
@@ -847,7 +847,7 @@ DynamicTensor mappedTensorValueBody(TensorType type) :
( tensorCell(type, cells))*
( <COMMA> tensorCell(type, cells))*
<RCURLY>
- { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
DynamicTensor indexedTensorValueBody(TensorType type) :
@@ -860,7 +860,7 @@ DynamicTensor indexedTensorValueBody(TensorType type) :
( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
// <RSQUARE>
- { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
void tensorCell(TensorType type, java.util.Map cells) :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index a8afc230bde..05ad8c97c7f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -358,6 +358,9 @@ public class EvaluationTestCase {
tester.assertEvaluates("500",
"join(tensor0, tensor1, f(x,y) (x*y)){tag2}",
"tensor(tag{}):{{tag:tag1}:10, {tag:tag2}:20}", "{25}");
+ tester.assertEvaluates("tensor(j[3]):[3, 3, 3]",
+ "tensor(j[3])(tensor0[2])",
+ "tensor(values[5]):[1, 2, 3, 4, 5]");
// tensor result dimensions are given from argument dimensions, not the resulting values
tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }");
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 15cfce09793..8cba1ccdef8 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1620,7 +1620,8 @@
],
"methods": [
"public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)",
- "public void <init>(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)",
+ "public static com.yahoo.tensor.functions.Generate free(com.yahoo.tensor.TensorType, java.util.function.Function)",
+ "public static com.yahoo.tensor.functions.Generate bound(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index b8b644f8b49..a75e49c6402 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -73,6 +73,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
+ @SuppressWarnings("unchecked")
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
for (var cell : cells.entrySet())
@@ -114,6 +115,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
+ @SuppressWarnings("unchecked")
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
for (int i = 0; i < cells.size(); i++)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 28fc2c61426..52620814ecd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -27,28 +27,33 @@ public class Generate extends PrimitiveTensorFunction {
private final Function<List<Long>, Double> freeGenerator;
private final ScalarFunction boundGenerator;
+ /** The same as Generate.free */
+ public Generate(TensorType type, Function<List<Long>, Double> generator) {
+ this(type, Objects.requireNonNull(generator), null);
+ }
+
/**
- * Creates a generated tensor
+ * Creates a generated tensor from a free function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, Function<List<Long>, Double> generator) {
- this(type, Objects.requireNonNull(generator), null);
+ public static Generate free(TensorType type, Function<List<Long>, Double> generator) {
+ return new Generate(type, Objects.requireNonNull(generator), null);
}
/**
- * Creates a generated tensor
+ * Creates a generated tensor from a bound function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, ScalarFunction generator) {
- this(type, null, Objects.requireNonNull(generator));
+ public static Generate bound(TensorType type, ScalarFunction generator) {
+ return new Generate(type, null, Objects.requireNonNull(generator));
}
private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) {
@@ -127,6 +132,7 @@ public class Generate extends PrimitiveTensorFunction {
this.context = context;
}
+ @SuppressWarnings("unchecked")
double apply(IndexedTensor.Indexes indexes) {
if (freeGenerator != null) {
return freeGenerator.apply(indexes.toList());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
index c6a244b64df..70e08af16b6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.function.Function;
@@ -10,10 +11,10 @@ import java.util.function.Function;
*
* @author bratseth
*/
-public interface ScalarFunction extends Function<EvaluationContext<?>, Double> {
+public interface ScalarFunction<NAMETYPE extends TypeContext.Name> extends Function<EvaluationContext<NAMETYPE>, Double> {
@Override
- Double apply(EvaluationContext<?> context);
+ Double apply(EvaluationContext<NAMETYPE> context);
default String toString(ToStringContext context) {
return toString();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index 925da9d3c89..e1ae7f13c48 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -5,6 +5,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import org.junit.Test;
import java.util.Collections;
@@ -34,14 +35,14 @@ public class DynamicTensorTestCase {
assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
- private static class Constant implements ScalarFunction {
+ private static class Constant implements ScalarFunction<TypeContext.Name> {
private final double value;
public Constant(double value) { this.value = value; }
@Override
- public Double apply(EvaluationContext<?> evaluationContext) { return value; }
+ public Double apply(EvaluationContext<TypeContext.Name> evaluationContext) { return value; }
@Override
public String toString() { return String.valueOf(value); }