summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:16 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:16 +0100
commit1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (patch)
tree650307f35d321145410248f703943ef7525f94fb /searchlib
parent0606896d63cc8bbe4919c7c37126fb9bc3f6e34e (diff)
parent7e8f8da8f249cf3c529cec8ecdcf13b69c99da13 (diff)
Merge with master
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/pom.xml10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java25
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java31
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java41
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java10
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj39
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java64
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java138
-rw-r--r--searchlib/src/tests/features/constant/constant_test.cpp5
-rw-r--r--searchlib/src/tests/features/tensor/tensor_test.cpp2
-rw-r--r--searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp2
-rw-r--r--searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp2
-rw-r--r--searchlib/src/tests/postinglistbm/andstress.cpp2
-rw-r--r--searchlib/src/tests/rankingexpression/rankingexpressionlist4
-rw-r--r--searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/features/constant_tensor_executor.h11
-rw-r--r--searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h1
-rw-r--r--searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp13
-rw-r--r--searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h2
-rw-r--r--searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h4
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/multisearch.cpp26
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/multisearch.h1
31 files changed, 495 insertions, 77 deletions
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index bb305f460ca..8e15e0d425c 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -43,6 +43,16 @@
<artifactId>proto</artifactId>
<version>1.4.0</version>
</dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-core</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index dab89fe8955..ea750295423 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -39,6 +39,31 @@ public abstract class DoubleCompatibleValue extends Value {
}
@Override
+ public Value modulo(Value value) {
+ return new DoubleValue(asDouble() % value.asDouble());
+ }
+
+ @Override
+ public Value and(Value value) {
+ return new BooleanValue(asBoolean() && value.asBoolean());
+ }
+
+ @Override
+ public Value or(Value value) {
+ return new BooleanValue(asBoolean() || value.asBoolean());
+ }
+
+ @Override
+ public Value not() {
+ return new BooleanValue(!asBoolean());
+ }
+
+ @Override
+ public Value power(Value value) {
+ return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble()));
+ }
+
+ @Override
public Value compare(TruthOperator operator, Value value) {
return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 28272e58c91..17157ab385f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -98,6 +98,17 @@ public final class DoubleValue extends DoubleCompatibleValue {
}
@Override
+ public Value modulo(Value value) {
+ try {
+ return mutable(this.value % value.asDouble());
+ }
+ catch (UnsupportedOperationException e) {
+ throw unsupported("modulo",value);
+ }
+ }
+
+
+ @Override
public Value function(Function function, Value value) {
// use the tensor implementation of max and min if the argument is a tensor
if ( (function.equals(Function.min) || function.equals(Function.max)) && value instanceof TensorValue)
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index acf301f3b80..ac8aba6a617 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -54,17 +54,42 @@ public class StringValue extends Value {
@Override
public Value subtract(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') does not support subtraction");
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction");
}
@Override
public Value multiply(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') does not support multiplication");
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication");
}
@Override
public Value divide(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') does not support division");
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support division");
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo");
+ }
+
+ @Override
+ public Value and(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support and");
+ }
+
+ @Override
+ public Value or(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support or");
+ }
+
+ @Override
+ public Value not() {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support not");
+ }
+
+ @Override
+ public Value power(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support ^");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index 6cf15837da1..49c3ccb7b01 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -81,6 +81,43 @@ public class TensorValue extends Value {
return new TensorValue(value.map((value) -> value / argument.asDouble()));
}
+ @Override
+ public Value modulo(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.fmod(((TensorValue) argument).value));
+ else
+ return new TensorValue(value.map((value) -> value % argument.asDouble()));
+ }
+
+ @Override
+ public Value and(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value or(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value not() {
+ return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value power(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.pow(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> Math.pow(value, argument.asDouble())));
+ }
+
private Tensor asTensor(Value value, String operationName) {
if ( ! (value instanceof TensorValue))
throw new UnsupportedOperationException("Could not perform " + operationName +
@@ -103,6 +140,7 @@ public class TensorValue extends Value {
case SMALLEREQUAL: return value.smallerOrEqual(argument);
case EQUAL: return value.equal(argument);
case NOTEQUAL: return value.notEqual(argument);
+ case APPROX_EQUAL: return value.approxEqual(argument);
default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
}
}
@@ -120,6 +158,9 @@ public class TensorValue extends Value {
case min: return value.min(argument);
case max: return value.max(argument);
case atan2: return value.atan2(argument);
+ case pow: return value.pow(argument);
+ case fmod: return value.fmod(argument);
+ case ldexp: return value.ldexp(argument);
default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index a63387506a0..b2ccbe572d0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -41,6 +41,16 @@ public abstract class Value {
public abstract Value divide(Value value);
+ public abstract Value modulo(Value value);
+
+ public abstract Value and(Value value);
+
+ public abstract Value or(Value value);
+
+ public abstract Value not();
+
+ public abstract Value power(Value value);
+
/** Perform the comparison specified by the operator between this value and the given value */
public abstract Value compare(TruthOperator operator, Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
index 91d8abec1be..518a15bcc87 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
@@ -77,7 +77,7 @@ public final class ArithmeticNode extends CompositeNode {
Iterator<ExpressionNode> child = children.iterator();
Deque<ValueItem> stack = new ArrayDeque<>();
- stack.push(new ValueItem(ArithmeticOperator.PLUS, child.next().evaluate(context)));
+ stack.push(new ValueItem(ArithmeticOperator.OR, child.next().evaluate(context)));
for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
ArithmeticOperator op = it.next();
if (!stack.isEmpty()) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
index 5a5237c2608..a715490e95a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
@@ -14,17 +14,29 @@ import java.util.List;
*/
public enum ArithmeticOperator {
- PLUS(0, "+") { public Value evaluate(Value x, Value y) {
+ OR(0, "||") { public Value evaluate(Value x, Value y) {
+ return x.or(y);
+ }},
+ AND(1, "&&") { public Value evaluate(Value x, Value y) {
+ return x.and(y);
+ }},
+ PLUS(2, "+") { public Value evaluate(Value x, Value y) {
return x.add(y);
}},
- MINUS(1, "-") { public Value evaluate(Value x, Value y) {
+ MINUS(3, "-") { public Value evaluate(Value x, Value y) {
return x.subtract(y);
}},
- MULTIPLY(2, "*") { public Value evaluate(Value x, Value y) {
+ MULTIPLY(4, "*") { public Value evaluate(Value x, Value y) {
return x.multiply(y);
}},
- DIVIDE(3, "/") { public Value evaluate(Value x, Value y) {
+ DIVIDE(5, "/") { public Value evaluate(Value x, Value y) {
return x.divide(y);
+ }},
+ MODULO(6, "%") { public Value evaluate(Value x, Value y) {
+ return x.modulo(y);
+ }},
+ POWER(7, "^") { public Value evaluate(Value x, Value y) {
+ return x.power(y);
}};
/** A list of all the operators in this in order of decreasing precedence */
@@ -52,10 +64,14 @@ public enum ArithmeticOperator {
private static List<ArithmeticOperator> operatorsByPrecedence() {
List<ArithmeticOperator> operators = new ArrayList<>();
+ operators.add(POWER);
+ operators.add(MODULO);
operators.add(DIVIDE);
operators.add(MULTIPLY);
operators.add(MINUS);
operators.add(PLUS);
+ operators.add(AND);
+ operators.add(OR);
return Collections.unmodifiableList(operators);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index fc4a511b307..c3c1c371a68 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -39,7 +39,7 @@ public enum Function implements Serializable {
atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } },
fmod(2) { public double evaluate(double x, double y) { return x % y; } },
- ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } },
+ ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,(int)y); } },
max(2) { public double evaluate(double x, double y) { return max(x,y); } },
min(2) { public double evaluate(double x, double y) { return min(x,y); } },
pow(2) { public double evaluate(double x, double y) { return pow(x,y); } };
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
new file mode 100644
index 00000000000..8c459a032bd
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java
@@ -0,0 +1,50 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+
+/**
+ * A node which flips the logical value produced from the nested expression.
+ *
+ * @author lesters
+ */
+public class NotNode extends BooleanNode {
+
+ private final ExpressionNode value;
+
+ public NotNode(ExpressionNode value) {
+ this.value = value;
+ }
+
+ public ExpressionNode getValue() {
+ return value;
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(value);
+ }
+
+ @Override
+ public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ return "!" + value.toString(context, path, parent);
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ return value.evaluate(context).not();
+ }
+
+ @Override
+ public NotNode setChildren(List<ExpressionNode> children) {
+ if (children.size() != 1) throw new IllegalArgumentException("Expected 1 children but got " + children.size());
+ return new NotNode(children.get(0));
+ }
+
+}
+
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
index f8e44f1087c..f6b1a1a8979 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java
@@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableList;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.Tensor;
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Deque;
+import java.util.List;
+import java.util.function.Predicate;
/**
* A node which returns true or false depending on a set membership test
@@ -55,11 +60,30 @@ public class SetMembershipNode extends BooleanNode {
@Override
public Value evaluate(Context context) {
Value value = testValue.evaluate(context);
+ if (value instanceof TensorValue) {
+ return evaluateTensor(((TensorValue) value).asTensor(), context);
+ }
+ return evaluateValue(value, context);
+ }
+
+ private Value evaluateValue(Value value, Context context) {
+ return new BooleanValue(testMembership(value::equals, context));
+ }
+
+ private Value evaluateTensor(Tensor tensor, Context context) {
+ return new TensorValue(tensor.map((value) -> contains(value, context) ? 1.0 : 0.0));
+ }
+
+ private boolean contains(double value, Context context) {
+ return testMembership((setValue) -> setValue.asDouble() == value, context);
+ }
+
+ private boolean testMembership(Predicate<Value> test, Context context) {
for (ExpressionNode setValue : setValues) {
- if (setValue.evaluate(context).equals(value))
- return new BooleanValue(true);
+ if (test.test(setValue.evaluate(context)))
+ return true;
}
- return new BooleanValue(false);
+ return false;
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index ede7c861d98..ebad0d5c21f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -94,7 +94,7 @@ public class Simplifier extends ExpressionTransformer {
private ExpressionNode transformIf(IfNode node) {
if ( ! isConstant(node.getCondition())) return node;
- if (((BooleanValue)node.getCondition().evaluate(null)).asBoolean())
+ if ((node.getCondition().evaluate(null)).asBoolean())
return node.getTrueExpression();
else
return node.getFalseExpression();
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index c3b9235cc93..7821ab88b86 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -65,6 +65,8 @@ TOKEN :
<DIV: "/"> |
<MUL: "*"> |
<DOT: "."> |
+ <MOD: "%"> |
+ <POWOP: "^"> |
<DOLLAR: "$"> |
<COMMA: ","> |
@@ -85,6 +87,10 @@ TOKEN :
<IN: "in"> |
<F: "f"> |
+ <NOT: "!"> |
+ <AND: "&&"> |
+ <OR: "||"> |
+
<ABS: "abs"> |
<ACOS: "acos"> |
<ASIN: "asin"> |
@@ -199,10 +205,14 @@ ExpressionNode arithmeticExpression() :
ArithmeticOperator arithmetic() : { }
{
- ( <ADD> { return ArithmeticOperator.PLUS; } |
- <SUB> { return ArithmeticOperator.MINUS; } |
- <DIV> { return ArithmeticOperator.DIVIDE; } |
- <MUL> { return ArithmeticOperator.MULTIPLY; } )
+ ( <ADD> { return ArithmeticOperator.PLUS; } |
+ <SUB> { return ArithmeticOperator.MINUS; } |
+ <DIV> { return ArithmeticOperator.DIVIDE; } |
+ <MUL> { return ArithmeticOperator.MULTIPLY; } |
+ <MOD> { return ArithmeticOperator.MODULO; } |
+ <AND> { return ArithmeticOperator.AND; } |
+ <OR> { return ArithmeticOperator.OR; } |
+ <POWOP> { return ArithmeticOperator.POWER; } )
{ return null; }
}
@@ -222,16 +232,23 @@ ExpressionNode value() :
{
ExpressionNode ret;
boolean neg = false;
+ boolean not = false;
}
{
- ( [ LOOKAHEAD(2) <SUB> { neg = true; } ]
- ( ret = constantPrimitive() |
- LOOKAHEAD(2) ret = ifExpression() |
- LOOKAHEAD(4) ret = function() |
- ret = feature() |
- ret = queryFeature() |
+ (
+ [ <NOT> { not = true; } ]
+ [ LOOKAHEAD(2) <SUB> { neg = true; } ]
+ ( ret = constantPrimitive() |
+ LOOKAHEAD(2) ret = ifExpression() |
+ LOOKAHEAD(4) ret = function() |
+ ret = feature() |
+ ret = queryFeature() |
( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) )
- { return neg ? new NegativeNode(ret) : ret; }
+ {
+ ret = not ? new NotNode(ret) : ret;
+ ret = neg ? new NegativeNode(ret) : ret;
+ return ret;
+ }
}
IfNode ifExpression() :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 5d357777657..82e5d0cfe5b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -29,6 +29,7 @@ public class EvaluationTestCase {
tester.assertEvaluates(0.75, "0.5 + 0.25");
tester.assertEvaluates(0.75, "one_half + a_quarter");
tester.assertEvaluates(1.25, "0.5 - 0.25 + one");
+ tester.assertEvaluates(9.0, "3 ^ 2");
// String
tester.assertEvaluates(1, "if(\"a\"==\"a\",1,0)");
@@ -37,6 +38,9 @@ public class EvaluationTestCase {
tester.assertEvaluates(26, "2*3+4*5");
tester.assertEvaluates(1, "2/6+4/6");
tester.assertEvaluates(2 * 3 * 4 + 3 * 4 * 5 - 4 * 200 / 10, "2*3*4+3*4*5-4*200/10");
+ tester.assertEvaluates(3, "1 + 10 % 6 / 2");
+ tester.assertEvaluates(10.0, "3 ^ 2 + 1");
+ tester.assertEvaluates(18.0, "2 * 3 ^ 2");
// Conditionals
tester.assertEvaluates(2 * (3 * 4 + 3) * (4 * 5 - 4 * 200) / 10, "2*(3*4+3)*(4*5-4*200)/10");
@@ -89,6 +93,38 @@ public class EvaluationTestCase {
}
@Test
+ public void testBooleanEvaluation() {
+ EvaluationTester tester = new EvaluationTester();
+
+ // and
+ tester.assertEvaluates(false, "0 && 0");
+ tester.assertEvaluates(false, "0 && 1");
+ tester.assertEvaluates(false, "1 && 0");
+ tester.assertEvaluates(true, "1 && 1");
+ tester.assertEvaluates(true, "1 && 2");
+ tester.assertEvaluates(true, "1 && 0.1");
+
+ // or
+ tester.assertEvaluates(false, "0 || 0");
+ tester.assertEvaluates(true, "0 || 0.1");
+ tester.assertEvaluates(true, "0 || 1");
+ tester.assertEvaluates(true, "1 || 0");
+ tester.assertEvaluates(true, "1 || 1");
+
+ // not
+ tester.assertEvaluates(true, "!0");
+ tester.assertEvaluates(false, "!1");
+ tester.assertEvaluates(false, "!2");
+ tester.assertEvaluates(true, "!0 && 1");
+
+ // precedence
+ tester.assertEvaluates(0, "2 * (0 && 1)");
+ tester.assertEvaluates(2, "2 * (1 && 1)");
+ tester.assertEvaluates(true, "2 + 0 && 1");
+ tester.assertEvaluates(true, "1 && 0 + 2");
+ }
+
+ @Test
public void testTensorEvaluation() {
EvaluationTester tester = new EvaluationTester();
tester.assertEvaluates("{}", "tensor0", "{}");
@@ -107,6 +143,16 @@ public class EvaluationTestCase {
"min(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }");
tester.assertEvaluates("{ {d1:0}:0, {d1:1}:0, {d1:2 }:10 }",
"max(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }");
+ // operators
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
+ "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");
+
// -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }");
@@ -122,8 +168,9 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "isNan(tensor0)", "{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "log(tensor0)", "{ {x:0}:1, {x:1}:1 }");
tester.assertEvaluates("{ {x:0}:0, {x:1}:1 }", "log10(tensor0)", "{ {x:0}:1, {x:1}:10 }");
- tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)", "{ {x:0}:3, {x:1}:8 }");
+ tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)","{ {x:0}:3, {x:1}:8 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:8 }", "pow(tensor0, 3)", "{ {x:0}:1, {x:1}:2 }");
+ tester.assertEvaluates("{ {x:0}:8, {x:1}:16 }", "ldexp(tensor0,3.1)","{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "relu(tensor0)", "{ {x:0}:1, {x:1}:2 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "round(tensor0)", "{ {x:0}:1, {x:1}:1.8 }");
tester.assertEvaluates("{ {x:0}:0.5, {x:1}:0.5 }", "sigmoid(tensor0)","{ {x:0}:0, {x:1}:0 }");
@@ -201,6 +248,16 @@ public class EvaluationTestCase {
"max(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }",
"min(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }",
+ "pow(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }",
+ "tensor0 ^ tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }",
+ "fmod(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }",
+ "tensor0 % tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
+ tester.assertEvaluates("{ {x:0,y:0}:96, {x:1,y:0}:224 }",
+ "ldexp(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5.1 }");
tester.assertEvaluates("{ {x:0,y:0,z:0}:7, {x:0,y:0,z:1}:13, {x:1,y:0,z:0}:21, {x:1,y:0,z:1}:39, {x:0,y:1,z:0}:55, {x:0,y:1,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:0 }",
"tensor0 * tensor1", "{ {x:0,y:0}:1, {x:1,y:0}:3, {x:0,y:1}:5, {x:1,y:1}:0 }", "{ {y:0,z:0}:7, {y:1,z:0}:11, {y:0,z:1}:13, {y:1,z:1}:0 }");
tester.assertEvaluates("{ {x:0,y:1,z:0}:35, {x:0,y:1,z:1}:65 }",
@@ -225,8 +282,13 @@ public class EvaluationTestCase {
"tensor0 <= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }");
tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }",
"tensor0 == tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
+ tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }",
+ "tensor0 ~= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
tester.assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0 }",
"tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
+ tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }",
+ "tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }");
+
// TODO
// argmax
// argmin
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
index d67c9dfd9dc..ee2b1c147e3 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
@@ -58,10 +58,18 @@ public class EvaluationTester {
return assertEvaluates(value, expressionString, defaultContext);
}
+ public RankingExpression assertEvaluates(boolean value, String expressionString) {
+ return assertEvaluates(value, expressionString, defaultContext);
+ }
+
public RankingExpression assertEvaluates(double value, String expressionString, Context context) {
return assertEvaluates(new DoubleValue(value), expressionString, context, "");
}
+ public RankingExpression assertEvaluates(boolean value, String expressionString, Context context) {
+ return assertEvaluates(new BooleanValue(value), expressionString, context, "");
+ }
+
public RankingExpression assertEvaluates(Value value, String expressionString, Context context, String explanation) {
try {
RankingExpression expression = new RankingExpression(expressionString);
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
new file mode 100644
index 00000000000..dde9d4bf21e
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java
@@ -0,0 +1,138 @@
+package com.yahoo.searchlib.tensor;
+
+import com.fasterxml.jackson.databind.JsonNode;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+
+public class TensorConformanceTest {
+
+ private static String testPath = "eval/src/apps/tensor_conformance/test_spec.json";
+
+ @Test
+ public void testConformance() throws IOException {
+ File testSpec = new File(testPath);
+ if (!testSpec.exists()) {
+ testSpec = new File("../" + testPath);
+ }
+ int count = 0;
+ List<Integer> failList = new ArrayList<>();
+
+ try(BufferedReader br = new BufferedReader(new FileReader(testSpec))) {
+ String test = br.readLine();
+ while (test != null) {
+ boolean success = testCase(test, count);
+ if (!success) {
+ failList.add(count);
+ }
+ test = br.readLine();
+ count++;
+ }
+ }
+ assertEquals(failList.size() + " conformance test fails: " + failList, 0, failList.size());
+ }
+
+ private boolean testCase(String test, int count) {
+ try {
+ ObjectMapper mapper = new ObjectMapper();
+ JsonNode node = mapper.readTree(test);
+
+ if (node.has("num_tests")) {
+ Assert.assertEquals(node.get("num_tests").asInt(), count);
+ return true;
+ }
+ if (!node.has("expression")) {
+ return true; // ignore
+ }
+
+ String expression = node.get("expression").asText();
+ MapContext context = getInput(node.get("inputs"));
+ Tensor expect = getTensor(node.get("result").get("expect").asText());
+ Tensor result = evaluate(expression, context);
+ boolean equals = Tensor.equals(result, expect);
+ if (!equals) {
+ System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\"");
+ }
+ return equals;
+
+ } catch (Exception e) {
+ System.out.println(count + " : " + e.toString());
+ }
+ return false;
+ }
+
+ private Tensor evaluate(String expression, MapContext context) throws ParseException {
+ Value value = new RankingExpression(expression).evaluate(context);
+ if (!(value instanceof TensorValue)) {
+ throw new IllegalArgumentException("Result is not a tensor");
+ }
+ return ((TensorValue)value).asTensor();
+ }
+
+ private MapContext getInput(JsonNode inputs) {
+ MapContext context = new MapContext();
+ for (Iterator<String> i = inputs.fieldNames(); i.hasNext(); ) {
+ String name = i.next();
+ String value = inputs.get(name).asText();
+ Tensor tensor = getTensor(value);
+ context.put(name, new TensorValue(tensor));
+ }
+ return context;
+ }
+
+ private Tensor getTensor(String binaryRepresentation) {
+ byte[] bin = getBytes(binaryRepresentation);
+ return TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(bin));
+ }
+
+ private byte[] getBytes(String binaryRepresentation) {
+ return parseHexValue(binaryRepresentation.substring(2));
+ }
+
+ private byte[] parseHexValue(String s) {
+ final int len = s.length();
+ byte[] bytes = new byte[len/2];
+ for (int i = 0; i < len; i += 2) {
+ int c1 = hexValue(s.charAt(i)) << 4;
+ int c2 = hexValue(s.charAt(i + 1));
+ bytes[i/2] = (byte)(c1 + c2);
+ }
+ return bytes;
+ }
+
+ private int hexValue(Character c) {
+ if (c >= 'a' && c <= 'f') {
+ return c - 'a' + 10;
+ } else if (c >= 'A' && c <= 'F') {
+ return c - 'A' + 10;
+ } else if (c >= '0' && c <= '9') {
+ return c - '0';
+ }
+ throw new IllegalArgumentException("Hex contains illegal characters");
+ }
+
+}
+
diff --git a/searchlib/src/tests/features/constant/constant_test.cpp b/searchlib/src/tests/features/constant/constant_test.cpp
index a10f76e25ba..4a88fde58ce 100644
--- a/searchlib/src/tests/features/constant/constant_test.cpp
+++ b/searchlib/src/tests/features/constant/constant_test.cpp
@@ -19,7 +19,6 @@ using namespace search::features;
using vespalib::eval::Function;
using vespalib::eval::Value;
using vespalib::eval::DoubleValue;
-using vespalib::eval::TensorValue;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
using vespalib::tensor::DefaultTensorEngine;
@@ -39,7 +38,7 @@ Tensor::UP createTensor(const TensorCells &cells,
}
Tensor::UP make_tensor(const TensorSpec &spec) {
- auto tensor = DefaultTensorEngine::ref().create(spec);
+ auto tensor = DefaultTensorEngine::ref().from_spec(spec);
return Tensor::UP(dynamic_cast<Tensor*>(tensor.release()));
}
@@ -80,7 +79,7 @@ struct ExecFixture
ValueType type(tensor->getType());
test.getIndexEnv().addConstantValue(name,
std::move(type),
- std::make_unique<TensorValue>(std::move(tensor)));
+ std::move(tensor));
}
void addDouble(const vespalib::string &name, const double value) {
diff --git a/searchlib/src/tests/features/tensor/tensor_test.cpp b/searchlib/src/tests/features/tensor/tensor_test.cpp
index be7bb9defac..b097f27342d 100644
--- a/searchlib/src/tests/features/tensor/tensor_test.cpp
+++ b/searchlib/src/tests/features/tensor/tensor_test.cpp
@@ -54,7 +54,7 @@ Tensor::UP createTensor(const TensorCells &cells,
}
Tensor::UP make_tensor(const TensorSpec &spec) {
- auto tensor = DefaultTensorEngine::ref().create(spec);
+ auto tensor = DefaultTensorEngine::ref().from_spec(spec);
return Tensor::UP(dynamic_cast<Tensor*>(tensor.release()));
}
diff --git a/searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp b/searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp
index 0a900ad9ec8..1ac524b5d0b 100644
--- a/searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp
+++ b/searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp
@@ -36,7 +36,7 @@ typedef search::AttributeVector::SP AttributePtr;
typedef FtTestApp FTA;
Tensor::UP make_tensor(const TensorSpec &spec) {
- auto tensor = DefaultTensorEngine::ref().create(spec);
+ auto tensor = DefaultTensorEngine::ref().from_spec(spec);
return Tensor::UP(dynamic_cast<Tensor*>(tensor.release()));
}
diff --git a/searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp b/searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp
index cad0c56b0ca..e0eee954a53 100644
--- a/searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp
+++ b/searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp
@@ -37,7 +37,7 @@ typedef search::AttributeVector::SP AttributePtr;
typedef FtTestApp FTA;
Tensor::UP make_tensor(const TensorSpec &spec) {
- auto tensor = DefaultTensorEngine::ref().create(spec);
+ auto tensor = DefaultTensorEngine::ref().from_spec(spec);
return Tensor::UP(dynamic_cast<Tensor*>(tensor.release()));
}
diff --git a/searchlib/src/tests/postinglistbm/andstress.cpp b/searchlib/src/tests/postinglistbm/andstress.cpp
index 736d53508b4..40f919509e8 100644
--- a/searchlib/src/tests/postinglistbm/andstress.cpp
+++ b/searchlib/src/tests/postinglistbm/andstress.cpp
@@ -280,7 +280,7 @@ AndStressMaster::Task *
AndStressMaster::getTask()
{
Task *result = NULL;
- std::unique_lock<std::mutex> taskGuard(_taskLock);
+ std::lock_guard<std::mutex> taskGuard(_taskLock);
if (_taskIdx < _tasks.size()) {
result = &_tasks[_taskIdx];
++_taskIdx;
diff --git a/searchlib/src/tests/rankingexpression/rankingexpressionlist b/searchlib/src/tests/rankingexpression/rankingexpressionlist
index 327f2b161cd..77b2294c668 100644
--- a/searchlib/src/tests/rankingexpression/rankingexpressionlist
+++ b/searchlib/src/tests/rankingexpression/rankingexpressionlist
@@ -160,3 +160,7 @@ mysum ( mysum(4, 4), value( 4 ), value(4) ); mysum(mysum(4,4),value(4),value(4)
"1008\x1977"
"100819\x77"
if(1.09999~=1.1,2,3); if (1.09999 ~= 1.1, 2, 3)
+10 % 3
+1 && 0 || 1
+!a && (a || a)
+10 ^ 3
diff --git a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp
index 28b4ad3c4e4..2e88f0e90b0 100644
--- a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp
+++ b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp
@@ -21,7 +21,7 @@ using EntryRef = DenseTensorStore::EntryRef;
Tensor::UP
makeTensor(const TensorSpec &spec)
{
- auto tensor = DefaultTensorEngine::ref().create(spec);
+ auto tensor = DefaultTensorEngine::ref().from_spec(spec);
return Tensor::UP(dynamic_cast<Tensor *>(tensor.release()));
}
diff --git a/searchlib/src/vespa/searchlib/features/constant_tensor_executor.h b/searchlib/src/vespa/searchlib/features/constant_tensor_executor.h
index 1a0e425e0ef..43ce48282ee 100644
--- a/searchlib/src/vespa/searchlib/features/constant_tensor_executor.h
+++ b/searchlib/src/vespa/searchlib/features/constant_tensor_executor.h
@@ -18,10 +18,10 @@ namespace features {
class ConstantTensorExecutor : public fef::FeatureExecutor
{
private:
- const vespalib::eval::TensorValue::UP _tensor;
+ vespalib::eval::Value::UP _tensor;
public:
- ConstantTensorExecutor(vespalib::eval::TensorValue::UP tensor)
+ ConstantTensorExecutor(vespalib::eval::Value::UP tensor)
: _tensor(std::move(tensor))
{}
virtual bool isPure() override { return true; }
@@ -29,11 +29,12 @@ public:
outputs().set_object(0, *_tensor);
}
static fef::FeatureExecutor &create(std::unique_ptr<vespalib::eval::Tensor> tensor, vespalib::Stash &stash) {
- return stash.create<ConstantTensorExecutor>(std::make_unique<vespalib::eval::TensorValue>(std::move(tensor)));
+ return stash.create<ConstantTensorExecutor>(std::move(tensor));
}
static fef::FeatureExecutor &createEmpty(const vespalib::eval::ValueType &valueType, vespalib::Stash &stash) {
- return create(vespalib::tensor::DefaultTensorEngine::ref()
- .create(vespalib::eval::TensorSpec(valueType.to_spec())), stash);
+ const auto &engine = vespalib::tensor::DefaultTensorEngine::ref();
+ auto spec = vespalib::eval::TensorSpec(valueType.to_spec());
+ return stash.create<ConstantTensorExecutor>(engine.from_spec(spec));
}
static fef::FeatureExecutor &createEmpty(vespalib::Stash &stash) {
return createEmpty(vespalib::eval::ValueType::double_type(), stash);
diff --git a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp
index 76252486bf4..487bc724e07 100644
--- a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp
+++ b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp
@@ -5,7 +5,6 @@
using search::tensor::DenseTensorAttribute;
using vespalib::eval::Tensor;
-using vespalib::eval::TensorValue;
using vespalib::tensor::MutableDenseTensorView;
namespace search {
@@ -14,8 +13,7 @@ namespace features {
DenseTensorAttributeExecutor::
DenseTensorAttributeExecutor(const DenseTensorAttribute *attribute)
: _attribute(attribute),
- _tensorView(_attribute->getConfig().tensorType()),
- _tensor(_tensorView)
+ _tensorView(_attribute->getConfig().tensorType())
{
}
@@ -23,7 +21,7 @@ void
DenseTensorAttributeExecutor::execute(uint32_t docId)
{
_attribute->getTensor(docId, _tensorView);
- outputs().set_object(0, _tensor);
+ outputs().set_object(0, _tensorView);
}
} // namespace features
diff --git a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h
index 68042075942..ac3d327c12a 100644
--- a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h
+++ b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h
@@ -19,7 +19,6 @@ class DenseTensorAttributeExecutor : public fef::FeatureExecutor
private:
const search::tensor::DenseTensorAttribute *_attribute;
vespalib::tensor::MutableDenseTensorView _tensorView;
- vespalib::eval::TensorValue _tensor;
public:
DenseTensorAttributeExecutor(const search::tensor::DenseTensorAttribute *attribute);
diff --git a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp
index 6ee7664f2bb..03393d6f590 100644
--- a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp
+++ b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp
@@ -3,8 +3,6 @@
#include "tensor_attribute_executor.h"
#include <vespa/searchlib/tensor/tensor_attribute.h>
-using vespalib::eval::TensorValue;
-
namespace search {
namespace features {
@@ -12,20 +10,19 @@ TensorAttributeExecutor::
TensorAttributeExecutor(const search::tensor::TensorAttribute *attribute)
: _attribute(attribute),
_emptyTensor(attribute->getEmptyTensor()),
- _tensor(*_emptyTensor)
+ _tensor()
{
}
void
TensorAttributeExecutor::execute(uint32_t docId)
{
- auto tensor = _attribute->getTensor(docId);
- if (!tensor) {
- _tensor = TensorValue(*_emptyTensor);
+ _tensor = _attribute->getTensor(docId);
+ if (_tensor) {
+ outputs().set_object(0, *_tensor);
} else {
- _tensor = TensorValue(std::move(tensor));
+ outputs().set_object(0, *_emptyTensor);
}
- outputs().set_object(0, _tensor);
}
} // namespace features
diff --git a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h
index 198b03e3d1d..0f1e21c8cad 100644
--- a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h
+++ b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h
@@ -17,7 +17,7 @@ class TensorAttributeExecutor : public fef::FeatureExecutor
private:
const search::tensor::TensorAttribute *_attribute;
std::unique_ptr<vespalib::eval::Tensor> _emptyTensor;
- vespalib::eval::TensorValue _tensor;
+ std::unique_ptr<vespalib::eval::Tensor> _tensor;
public:
TensorAttributeExecutor(const search::tensor::TensorAttribute *attribute);
diff --git a/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h b/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h
index 31b92f89538..f102749f1b6 100644
--- a/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h
+++ b/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h
@@ -22,7 +22,7 @@ private:
const search::attribute::IAttributeVector *_attribute;
vespalib::string _dimension;
WeightedBufferType _attrBuffer;
- vespalib::eval::TensorValue::UP _tensor;
+ std::unique_ptr<vespalib::tensor::Tensor> _tensor;
public:
TensorFromAttributeExecutor(const search::attribute::IAttributeVector *attribute,
@@ -48,7 +48,7 @@ TensorFromAttributeExecutor<WeightedBufferType>::execute(uint32_t docId)
builder.add_label(dimensionEnum, vespalib::string(_attrBuffer[i].value()));
builder.add_cell(_attrBuffer[i].weight());
}
- _tensor = vespalib::eval::TensorValue::UP(new vespalib::eval::TensorValue(builder.build()));
+ _tensor = builder.build();
outputs().set_object(0, *_tensor);
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/multisearch.cpp b/searchlib/src/vespa/searchlib/queryeval/multisearch.cpp
index 19d744dfd28..b63a54785a4 100644
--- a/searchlib/src/vespa/searchlib/queryeval/multisearch.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/multisearch.cpp
@@ -27,32 +27,16 @@ MultiSearch::remove(size_t index)
void
MultiSearch::doUnpack(uint32_t docid)
{
- size_t sz(_children.size());
- for (size_t i = 0; i < sz; ) {
- if (__builtin_expect(_children[i]->getDocId() < docid, false)) {
- _children[i]->doSeek(docid);
- if (_children[i]->isAtEnd()) {
- sz = deactivate(i);
- continue;
- }
+ for (SearchIterator *child: _children) {
+ if (__builtin_expect(child->getDocId() < docid, false)) {
+ child->doSeek(docid);
}
- if (__builtin_expect(_children[i]->getDocId() == docid, false)) {
- _children[i]->doUnpack(docid);
+ if (__builtin_expect(child->getDocId() == docid, false)) {
+ child->doUnpack(docid);
}
- i++;
}
}
-size_t
-MultiSearch::deactivate(size_t idx)
-{
- assert(idx < _children.size());
- delete _children[idx];
- _children[idx] = _children.back();
- _children.pop_back();
- return _children.size();
-}
-
MultiSearch::MultiSearch(const Children & children)
: _children(children)
{
diff --git a/searchlib/src/vespa/searchlib/queryeval/multisearch.h b/searchlib/src/vespa/searchlib/queryeval/multisearch.h
index 16bbd5d4ecc..d67f895ddb5 100644
--- a/searchlib/src/vespa/searchlib/queryeval/multisearch.h
+++ b/searchlib/src/vespa/searchlib/queryeval/multisearch.h
@@ -54,7 +54,6 @@ private:
virtual void onInsert(size_t index) { (void) index; }
bool isMultiSearch() const override { return true; }
- size_t deactivate(size_t index);
Children _children;
};