aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 13:55:10 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 13:55:10 +0100
commita7209cf3f8f11e916d70c4eb5db0bf13f181ef1f (patch)
tree9a1a9aabed30711341afae11ebee23e6fd054b67 /searchlib/src
parentdda0f64dafcb2696d04960b73c1d1a3148a0315c (diff)
Add tensor generate functions
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/CompositeNode.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java18
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java2
3 files changed, 18 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/CompositeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/CompositeNode.java
index d181c29b516..43658fcfa59 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/CompositeNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/CompositeNode.java
@@ -4,7 +4,7 @@ package com.yahoo.searchlib.rankingexpression.rule;
import java.util.List;
/**
- * <p>The parent of all node types which contains child nodes.</p>
+ * The parent of all node types which contains child nodes.
*
* @author bratseth
*/
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 15ad6ba647a..1947b00ac16 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -46,7 +46,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
// Serialize as primitive
return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
}
@@ -105,9 +105,19 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
+ public String toString() {
+ return toString(ExpressionNodeToStringContext.empty);
+ }
+
+ @Override
public String toString(ToStringContext c) {
- ExpressionNodeToStringContext context = (ExpressionNodeToStringContext)c;
- return expression.toString(context.context, context.path, context.parent);
+ if (c instanceof ExpressionNodeToStringContext) {
+ ExpressionNodeToStringContext context = (ExpressionNodeToStringContext) c;
+ return expression.toString(context.context, context.path, context.parent);
+ }
+ else {
+ return expression.toString();
+ }
}
}
@@ -119,6 +129,8 @@ public class TensorFunctionNode extends CompositeNode {
final Deque<String> path;
final CompositeNode parent;
+ public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null);
+
public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
this.context = context;
this.path = path;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index dc451b1dc5c..26d19bcec37 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -250,6 +250,8 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {x:0,y:0}:81.0, {x:1,y:0}:88.0 }", "xw_plus_b(tensor0, tensor1, tensor2, x)", "{ {x:0}:15, {x:1}:12 }", "{ {y:0}:3 }", "{ {x:0}:0, {x:1}:7 }");
// expressions combining functions
+ tester.assertEvaluates("tensor(y{}):{{y:6}:0}}", "matmul(tensor0, diag(x[5],y[7]), x)", "tensor(x{},y{}):{{x:4,y:6}:1})");
+ tester.assertEvaluates("tensor(y{}):{{y:6}:10}}", "matmul(tensor0, range(x[5],y[7]), x)", "tensor(x{},y{}):{{x:4,y:6}:1})");
tester.assertEvaluates(String.valueOf(7.5 + 45 + 1.7),
"sum( " + // model computation:
" tensor0 * tensor1 * tensor2 " + // - feature combinations