aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-05 22:49:08 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-05 22:49:08 +0100
commited8c274dc76794efa692efba6cf509b058b13648 (patch)
treec1dcb9fbc70b851be5cfdb8c335089283715f698 /searchlib
parent64c5daa351557869e64786188afa75ed3b59991b (diff)
Literal tensors with value expressions
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java41
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj96
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java3
6 files changed, 106 insertions, 57 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 5ef3cd61366..1258601a2d1 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -909,11 +909,11 @@
"public final java.util.List tagCommaLeadingList()",
"public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode tensorValueBody(com.yahoo.tensor.TensorType)",
- "public final void mappedTensorValueBody(com.yahoo.tensor.Tensor$Builder)",
- "public final void indexedTensorValueBody(com.yahoo.tensor.Tensor$Builder)",
- "public final void tensorCell(com.yahoo.tensor.Tensor$Builder$CellBuilder)",
- "public final void labelAndDimension(com.yahoo.tensor.Tensor$Builder$CellBuilder)",
+ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final void tensorCell(com.yahoo.tensor.TensorType, java.util.Map)",
+ "public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)",
"public void <init>(java.io.InputStream)",
"public void <init>(java.io.InputStream, java.lang.String)",
"public void ReInit(java.io.InputStream)",
@@ -1612,7 +1612,9 @@
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrapArgument(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
+ "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public static java.util.Map wrap(java.util.Map)",
+ "public static java.util.List wrap(java.util.List)"
],
"fields": []
},
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index c1732aabf0b..e6e49e07c34 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -14,9 +15,13 @@ import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
+import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
import java.util.stream.Collectors;
/**
@@ -72,10 +77,44 @@ public class TensorFunctionNode extends CompositeNode {
return new TensorValue(function.evaluate(context));
}
- public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
+ public static TensorFunctionExpressionNode wrap(ExpressionNode node) {
return new TensorFunctionExpressionNode(node);
}
+ public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) {
+ Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>();
+ for (var entry : nodes.entrySet())
+ closures.put(entry.getKey(), new ExpressionClosure(entry.getValue()));
+ return closures;
+ }
+
+ public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) {
+ List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>();
+ for (var entry : nodes)
+ closures.add(new ExpressionClosure(entry));
+ return closures;
+ }
+
+ private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> {
+
+ private final ExpressionNode expression;
+
+ public ExpressionClosure(ExpressionNode expression) {
+ this.expression = expression;
+ }
+
+ @Override
+ public Double apply(EvaluationContext<?> context) {
+ return expression.evaluate((Context)context).asDouble();
+ }
+
+ @Override
+ public String toString() {
+ return expression.toString();
+ }
+
+ }
+
/**
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
index 979c5b0f88c..6d687b015f1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
@@ -83,7 +83,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E
ExpressionNode arg1 = node.children().get(0);
ExpressionNode arg2 = node.children().get(1);
- TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1);
+ TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrap(arg1);
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 8f411bf6593..47555d95e58 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -30,6 +30,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+@SuppressWarnings({"rawtypes", "unchecked"})
public class RankingExpressionParser {
}
@@ -401,8 +402,8 @@ ExpressionNode tensorMap() :
}
{
<MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
- doubleMapper.asDoubleUnaryOperator())); }
+ { return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor),
+ doubleMapper.asDoubleUnaryOperator())); }
}
ExpressionNode tensorReduce() :
@@ -413,7 +414,7 @@ ExpressionNode tensorReduce() :
}
{
<REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); }
}
ExpressionNode tensorReduceComposites() :
@@ -425,7 +426,7 @@ ExpressionNode tensorReduceComposites() :
{
aggregator = tensorReduceAggregator()
<LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); }
}
ExpressionNode tensorJoin() :
@@ -435,9 +436,9 @@ ExpressionNode tensorJoin() :
}
{
<JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- doubleJoiner.asDoubleBinaryOperator())); }
+ { return new TensorFunctionNode(new Join(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ doubleJoiner.asDoubleBinaryOperator())); }
}
ExpressionNode tensorRename() :
@@ -450,7 +451,7 @@ ExpressionNode tensorRename() :
fromDimensions = bracedIdentifierList() <COMMA>
toDimensions = bracedIdentifierList()
<RBRACE>
- { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
+ { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); }
}
ExpressionNode tensorConcat() :
@@ -460,8 +461,8 @@ ExpressionNode tensorConcat() :
}
{
<CONCAT> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = tag() <RBRACE>
- { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
+ { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
dimension)); }
}
@@ -522,7 +523,7 @@ ExpressionNode tensorL1Normalize() :
}
{
<L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorL2Normalize() :
@@ -532,7 +533,7 @@ ExpressionNode tensorL2Normalize() :
}
{
<L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorMatmul() :
@@ -542,9 +543,9 @@ ExpressionNode tensorMatmul() :
}
{
<MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- dimension)); }
+ { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ dimension)); }
}
ExpressionNode tensorSoftmax() :
@@ -554,7 +555,7 @@ ExpressionNode tensorSoftmax() :
}
{
<SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorXwPlusB() :
@@ -567,9 +568,9 @@ ExpressionNode tensorXwPlusB() :
tensor2 = expression() <COMMA>
tensor3 = expression() <COMMA>
dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- TensorFunctionNode.wrapArgument(tensor3),
+ { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ TensorFunctionNode.wrap(tensor3),
dimension)); }
}
@@ -580,7 +581,7 @@ ExpressionNode tensorArgmax() :
}
{
<ARGMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorArgmin() :
@@ -590,7 +591,7 @@ ExpressionNode tensorArgmin() :
}
{
<ARGMIN> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrap(tensor), dimension)); }
}
LambdaFunctionNode lambdaFunction() :
@@ -823,63 +824,62 @@ Value primitiveValue() :
{ return Value.parse(sign + token.image); }
}
-ConstantNode tensorValueBody(TensorType type) :
+ExpressionNode tensorValueBody(TensorType type) :
{
- Tensor.Builder builder = Tensor.Builder.of(type);
+ DynamicTensor dynamicTensor;
}
{
<COLON>
(
- mappedTensorValueBody(builder) |
- indexedTensorValueBody(builder)
+ dynamicTensor = mappedTensorValueBody(type) |
+ dynamicTensor = indexedTensorValueBody(type)
)
- { return new ConstantNode(new TensorValue(builder.build())); }
+ { return new TensorFunctionNode(dynamicTensor); }
}
-void mappedTensorValueBody(Tensor.Builder builder) : {}
+DynamicTensor mappedTensorValueBody(TensorType type) :
+{
+ java.util.Map cells = new LinkedHashMap();
+}
{
<LCURLY>
- ( tensorCell(builder.cell()))*
- ( <COMMA> tensorCell(builder.cell()))*
+ ( tensorCell(type, cells))*
+ ( <COMMA> tensorCell(type, cells))*
<RCURLY>
+ { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
}
-void indexedTensorValueBody(Tensor.Builder builder) :
+DynamicTensor indexedTensorValueBody(TensorType type) :
{
- IndexedTensor.BoundBuilder indexedBuilder;
- long index = 0;
- double value;
+ List cells = new ArrayList();
+ ExpressionNode value;
}
{
- {
- if ( ! (builder instanceof IndexedTensor.BoundBuilder))
- throw new IllegalArgumentException("The tensor short form [n, n, ...] can only be used for indexed " +
- "bound tensors, not " + builder.type());
- indexedBuilder = (IndexedTensor.BoundBuilder)builder;
- }
<LSQUARE>
- ( value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )*
- ( <COMMA> value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )*
+ ( value = expression() { cells.add(value); } )*
+ ( <COMMA> value = expression() { cells.add(value); } )*
<RSQUARE>
+ { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
}
-void tensorCell(Tensor.Builder.CellBuilder cellBuilder) :
+void tensorCell(TensorType type, java.util.Map cells) :
{
- double value;
+ ExpressionNode value;
+ TensorAddress.Builder addressBuilder = new TensorAddress.Builder(type);
}
{
<LCURLY>
- ( labelAndDimension(cellBuilder))*
- ( <COMMA> labelAndDimension(cellBuilder))*
+ ( labelAndDimension(addressBuilder))*
+ ( <COMMA> labelAndDimension(addressBuilder))*
<RCURLY>
- <COLON> value = doubleNumber() { cellBuilder.value(value); }
+ <COLON> value = expression() { cells.put(addressBuilder.build(), value); }
}
-void labelAndDimension(Tensor.Builder.CellBuilder cellBuilder) :
+void labelAndDimension(TensorAddress.Builder addressBuilder) :
{
String dimension, label;
}
{
dimension = identifier() <COLON> label = tag()
- { cellBuilder.label(dimension, label); }
+ { addressBuilder.add(dimension, label); }
} \ No newline at end of file
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 571e1f4d608..a41f24b3b8a 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -154,7 +154,12 @@ public class RankingExpressionTestCase {
"map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)");
assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))",
"xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)");
-
+ assertSerialization("tensor(x{}):{{x:a}:1 + 2 + 3,{x:b}:if (1 > 2, 3, 4),{x:c}:reduce(tensor0 * tensor1, sum)}",
+ "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }");
+ assertSerialization("tensor(x[3]):[1.0,2.0,3]",
+ "tensor(x[3]):[1.0, 2.0, 3]");
+ assertSerialization("tensor(x[3]):[1.0,reduce(tensor0 * tensor1, sum),3]",
+ "tensor(x[3]):[1.0, sum(tensor0*tensor1), 3]");
}
@Test
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 7aafb8efee7..e28daefdabf 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -368,6 +368,9 @@ public class EvaluationTestCase {
"tensor(x{}):{}");
tester.assertEvaluates("tensor():{{}:1}",
"tensor():{{}:1}");
+ tester.assertEvaluates("tensor(x{}):{ {x:a}:6.0, {x:b}:4.0, {x:c}:14.0 }",
+ "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }",
+ "{ {x:0}:7 }", "tensor(x{}):{ {x:0}:2 }");
}
@Test