summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java10
-rw-r--r--config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java71
-rw-r--r--searchlib/abi-spec.json130
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java81
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java73
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java139
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java27
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java53
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java69
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj57
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java4
16 files changed, 388 insertions, 408 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index 30d9a3766b3..2695fa79588 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -5,7 +5,6 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -13,7 +12,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
@@ -141,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
ExpressionNode expr = new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr),
+ new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, queryLengthExpr),
ZERO,
new IfNode(
- new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr),
+ new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, restLengthExpr),
ONE,
ZERO
)
@@ -176,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
List<ExpressionNode> tokenSequence = createTokenSequence(feature);
ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
- ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ ArithmeticNode comparison = new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, lengthExpr);
ExpressionNode expr = new IfNode(comparison, ONE, ZERO);
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}
@@ -256,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*/
private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) {
ExpressionNode lengthExpr = createLengthExpr(iter, sequence);
- ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr);
+ ArithmeticNode comparison = new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, lengthExpr);
ExpressionNode trueExpr = sequence.get(iter);
if (sequence.get(iter) instanceof ReferenceNode) {
diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
index 39d9be905a5..13d21884c7d 100644
--- a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java
@@ -68,7 +68,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
Schema s = builder.getSchema();
RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
- assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))",
+ assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + (if (7.0 < attribute(a), 1, 2) == 0)))",
parent.getFirstPhaseRanking().getRoot().toString());
RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels());
assertEquals("7.0 * (9 + attribute(a))",
@@ -97,7 +97,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase {
Schema s = builder.getSchema();
RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels());
- assertEquals("3 * ( query(test) > 2.0 )",
+ assertEquals("3 * (query(test) > 2.0)",
parent.getFunctions().get("foo").function().getBody().getRoot().toString());
}
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
index 22681858fc3..e9b674a8c87 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
- assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 0.0, if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
+ assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value());
assertEquals("test_unbound_model", config.rankprofile(8).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
index 8788eda572d..e0c99706e4a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java
@@ -4,7 +4,6 @@ package ai.vespa.models.evaluation;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -70,53 +69,83 @@ class LazyValue extends Value {
}
@Override
- public Value add(Value value) {
- return computedValue().add(value);
+ public Value not() {
+ return computedValue().not();
}
@Override
- public Value subtract(Value value) {
- return computedValue().subtract(value);
+ public Value or(Value value) {
+ return computedValue().or(value);
}
@Override
- public Value multiply(Value value) {
- return computedValue().multiply(value);
+ public Value and(Value value) {
+ return computedValue().and(value);
}
@Override
- public Value divide(Value value) {
- return computedValue().divide(value);
+ public Value greaterEqual(Value value) {
+ return computedValue().greaterEqual(value);
}
@Override
- public Value modulo(Value value) {
- return computedValue().modulo(value);
+ public Value greater(Value value) {
+ return computedValue().greater(value);
}
@Override
- public Value and(Value value) {
- return computedValue().and(value);
+ public Value lessEqual(Value value) {
+ return computedValue().lessEqual(value);
}
@Override
- public Value or(Value value) {
- return computedValue().or(value);
+ public Value less(Value value) {
+ return computedValue().less(value);
}
@Override
- public Value not() {
- return computedValue().not();
+ public Value approx(Value value) {
+ return computedValue().approx(value);
}
@Override
- public Value power(Value value) {
- return computedValue().power(value);
+ public Value notEqual(Value value) {
+ return computedValue().notEqual(value);
+ }
+
+ @Override
+ public Value equal(Value value) {
+ return computedValue().equal(value);
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- return computedValue().compare(operator, value);
+ public Value add(Value value) {
+ return computedValue().add(value);
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ return computedValue().subtract(value);
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ return computedValue().multiply(value);
+ }
+
+ @Override
+ public Value divide(Value value) {
+ return computedValue().divide(value);
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ return computedValue().modulo(value);
+ }
+
+ @Override
+ public Value power(Value value) {
+ return computedValue().power(value);
}
@Override
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 7caf5a06032..a4adf92210c 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -496,16 +496,22 @@
"public boolean hasDouble()",
"public com.yahoo.tensor.Tensor asTensor()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value greaterEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value greater(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value lessEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value less(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value approx(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)"
],
"fields": []
@@ -691,16 +697,22 @@
"public boolean hasDouble()",
"public boolean asBoolean()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value greaterEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value greater(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value lessEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value less(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value approx(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value asMutable()",
"public java.lang.String toString()",
@@ -723,17 +735,23 @@
"public boolean hasDouble()",
"public boolean asBoolean()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value greaterEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value greater(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value lessEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value less(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value approx(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.tensor.Tensor asTensor()",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value asMutable()",
"public java.lang.String toString()",
@@ -760,16 +778,22 @@
"public abstract boolean hasDouble()",
"public abstract boolean asBoolean()",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value greaterEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value greater(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value lessEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value less(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value approx(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value not()",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value freeze()",
"public final boolean isFrozen()",
@@ -893,7 +917,6 @@
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode expression()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode arithmeticExpression()",
"public final com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator arithmetic()",
- "public final com.yahoo.searchlib.rankingexpression.rule.TruthOperator comparator()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode value()",
"public final com.yahoo.searchlib.rankingexpression.rule.IfNode ifExpression()",
"public final com.yahoo.searchlib.rankingexpression.rule.ReferenceNode feature()",
@@ -1013,13 +1036,13 @@
"public static final int DOLLAR",
"public static final int COMMA",
"public static final int COLON",
- "public static final int LE",
- "public static final int LT",
- "public static final int EQ",
- "public static final int NQ",
- "public static final int AQ",
- "public static final int GE",
- "public static final int GT",
+ "public static final int GREATEREQUAL",
+ "public static final int GREATER",
+ "public static final int LESSEQUAL",
+ "public static final int LESS",
+ "public static final int APPROX",
+ "public static final int NOTEQUAL",
+ "public static final int EQUAL",
"public static final int STRING",
"public static final int IF",
"public static final int IN",
@@ -1246,19 +1269,26 @@
"interfaces": [],
"attributes": [
"public",
- "abstract",
+ "final",
"enum"
],
"methods": [
"public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator[] values()",
"public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator valueOf(java.lang.String)",
"public boolean hasPrecedenceOver(com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator)",
- "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Value, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
+ "public final com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Value, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public java.lang.String toString()"
],
"fields": [
"public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator OR",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator AND",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator GREATEREQUAL",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator GREATER",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator LESSEQUAL",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator LESS",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator APPROX",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator NOTEQUAL",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator EQUAL",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator PLUS",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator MINUS",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator MULTIPLY",
@@ -1280,27 +1310,6 @@
],
"fields": []
},
- "com.yahoo.searchlib.rankingexpression.rule.ComparisonNode": {
- "superClass": "com.yahoo.searchlib.rankingexpression.rule.BooleanNode",
- "interfaces": [],
- "attributes": [
- "public"
- ],
- "methods": [
- "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
- "public java.util.List children()",
- "public com.yahoo.searchlib.rankingexpression.rule.TruthOperator getOperator()",
- "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getLeftCondition()",
- "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getRightCondition()",
- "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
- "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public com.yahoo.searchlib.rankingexpression.rule.ComparisonNode setChildren(java.util.List)",
- "public int hashCode()",
- "public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)"
- ],
- "fields": []
- },
"com.yahoo.searchlib.rankingexpression.rule.CompositeNode": {
"superClass": "com.yahoo.searchlib.rankingexpression.rule.ExpressionNode",
"interfaces": [],
@@ -1693,32 +1702,5 @@
"public int hashCode()"
],
"fields": []
- },
- "com.yahoo.searchlib.rankingexpression.rule.TruthOperator": {
- "superClass": "java.lang.Enum",
- "interfaces": [
- "java.io.Serializable"
- ],
- "attributes": [
- "public",
- "abstract",
- "enum"
- ],
- "methods": [
- "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator[] values()",
- "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator valueOf(java.lang.String)",
- "public abstract boolean evaluate(double, double)",
- "public java.lang.String toString()",
- "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator fromString(java.lang.String)"
- ],
- "fields": [
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator SMALLER",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator SMALLEREQUAL",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator EQUAL",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator APPROX_EQUAL",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator LARGER",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator LARGEREQUAL",
- "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator NOTEQUAL"
- ]
}
} \ No newline at end of file
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
index afd263f1553..e1db6378fcf 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java
@@ -2,7 +2,6 @@
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -28,53 +27,83 @@ public abstract class DoubleCompatibleValue extends Value {
public Value negate() { return new DoubleValue(-asDouble()); }
@Override
- public Value add(Value value) {
- return new DoubleValue(asDouble() + value.asDouble());
+ public Value not() {
+ return new BooleanValue(!asBoolean());
}
@Override
- public Value subtract(Value value) {
- return new DoubleValue(asDouble() - value.asDouble());
+ public Value or(Value value) {
+ return new BooleanValue(asBoolean() || value.asBoolean());
}
@Override
- public Value multiply(Value value) {
- return new DoubleValue(asDouble() * value.asDouble());
+ public Value and(Value value) {
+ return new BooleanValue(asBoolean() && value.asBoolean());
}
@Override
- public Value divide(Value value) {
- return new DoubleValue(asDouble() / value.asDouble());
+ public Value greaterEqual(Value value) {
+ return new BooleanValue(this.asDouble() >= value.asDouble());
}
@Override
- public Value modulo(Value value) {
- return new DoubleValue(asDouble() % value.asDouble());
+ public Value greater(Value value) {
+ return new BooleanValue(this.asDouble() > value.asDouble());
}
@Override
- public Value and(Value value) {
- return new BooleanValue(asBoolean() && value.asBoolean());
+ public Value lessEqual(Value value) {
+ return new BooleanValue(this.asDouble() <= value.asDouble());
}
@Override
- public Value or(Value value) {
- return new BooleanValue(asBoolean() || value.asBoolean());
+ public Value less(Value value) {
+ return new BooleanValue(this.asDouble() < value.asDouble());
}
@Override
- public Value not() {
- return new BooleanValue(!asBoolean());
+ public Value approx(Value value) {
+ return new BooleanValue(approxEqual(this.asDouble(), value.asDouble()));
}
@Override
- public Value power(Value value) {
- return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble()));
+ public Value notEqual(Value value) {
+ return new BooleanValue(this.asDouble() != value.asDouble());
+ }
+
+ @Override
+ public Value equal(Value value) {
+ return new BooleanValue(this.asDouble() == value.asDouble());
+ }
+
+ @Override
+ public Value add(Value value) {
+ return new DoubleValue(asDouble() + value.asDouble());
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ return new DoubleValue(asDouble() - value.asDouble());
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ return new DoubleValue(asDouble() * value.asDouble());
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- return new BooleanValue(operator.evaluate(asDouble(), value.asDouble()));
+ public Value divide(Value value) {
+ return new DoubleValue(asDouble() / value.asDouble());
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ return new DoubleValue(asDouble() % value.asDouble());
+ }
+
+ @Override
+ public Value power(Value value) {
+ return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble()));
}
@Override
@@ -82,4 +111,14 @@ public abstract class DoubleCompatibleValue extends Value {
return new DoubleValue(function.evaluate(asDouble(),value.asDouble()));
}
+ static boolean approxEqual(double x, double y) {
+ if (y < -1.0 || y > 1.0) {
+ x = Math.nextAfter(x/y, 1.0);
+ y = 1.0;
+ } else {
+ x = Math.nextAfter(x, y);
+ }
+ return x == y;
+ }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
index 2c2d5eead05..3c09c644147 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java
@@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -57,55 +56,83 @@ public class StringValue extends Value {
}
@Override
- public Value add(Value value) {
- return new StringValue(value + value.toString());
+ public Value not() {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support not");
}
@Override
- public Value subtract(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction");
+ public Value or(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support or");
}
@Override
- public Value multiply(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication");
+ public Value and(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support and");
}
@Override
- public Value divide(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support division");
+ public Value greaterEqual(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual");
}
@Override
- public Value modulo(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo");
+ public Value greater(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support greater");
}
@Override
- public Value and(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support and");
+ public Value lessEqual(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual");
}
@Override
- public Value or(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support or");
+ public Value less(Value argument) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support less");
}
@Override
- public Value not() {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support not");
+ public Value approx(Value argument) {
+ return new BooleanValue(this.asDouble() == argument.asDouble());
}
@Override
- public Value power(Value value) {
- throw new UnsupportedOperationException("String values ('" + value + "') do not support ^");
+ public Value notEqual(Value argument) {
+ return new BooleanValue(this.asDouble() != argument.asDouble());
+ }
+
+ @Override
+ public Value equal(Value argument) {
+ return new BooleanValue(this.asDouble() == argument.asDouble());
}
@Override
- public Value compare(TruthOperator operator, Value value) {
- if (operator.equals(TruthOperator.EQUAL))
- return new BooleanValue(this.equals(value));
- throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='");
+ public Value add(Value value) {
+ return new StringValue(value + value.toString());
+ }
+
+ @Override
+ public Value subtract(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction");
+ }
+
+ @Override
+ public Value multiply(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication");
+ }
+
+ @Override
+ public Value divide(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support division");
+ }
+
+ @Override
+ public Value modulo(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo");
+ }
+
+ @Override
+ public Value power(Value value) {
+ throw new UnsupportedOperationException("String values ('" + value + "') do not support ^");
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
index b37bbb543eb..73ea0b23986 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java
@@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.api.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -52,6 +51,83 @@ public class TensorValue extends Value {
}
@Override
+ public Value not() {
+ return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value or(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value and(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
+ else
+ return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
+ }
+
+ @Override
+ public Value greaterEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.largerOrEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value greater(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.larger(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value lessEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value less(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.smaller(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value approx(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.approxEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value notEqual(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.notEqual(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
+ public Value equal(Value argument) {
+ if (argument instanceof TensorValue)
+ return new TensorValue(value.equal(((TensorValue)argument).value));
+ else
+ return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0));
+ }
+
+ @Override
public Value add(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.add(((TensorValue)argument).value));
@@ -92,27 +168,6 @@ public class TensorValue extends Value {
}
@Override
- public Value and(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0));
- }
-
- @Override
- public Value or(Value argument) {
- if (argument instanceof TensorValue)
- return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 ));
- else
- return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0));
- }
-
- @Override
- public Value not() {
- return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0));
- }
-
- @Override
public Value power(Value argument) {
if (argument instanceof TensorValue)
return new TensorValue(value.pow(((TensorValue)argument).value));
@@ -123,24 +178,6 @@ public class TensorValue extends Value {
public Tensor asTensor() { return value; }
@Override
- public Value compare(TruthOperator operator, Value argument) {
- return new TensorValue(compareTensor(operator, argument.asTensor()));
- }
-
- private Tensor compareTensor(TruthOperator operator, Tensor argument) {
- switch (operator) {
- case LARGER: return value.larger(argument);
- case LARGEREQUAL: return value.largerOrEqual(argument);
- case SMALLER: return value.smaller(argument);
- case SMALLEREQUAL: return value.smallerOrEqual(argument);
- case EQUAL: return value.equal(argument);
- case NOTEQUAL: return value.notEqual(argument);
- case APPROX_EQUAL: return value.approxEqual(argument);
- default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator);
- }
- }
-
- @Override
public Value function(Function function, Value arg) {
if (arg instanceof TensorValue)
return new TensorValue(functionOnTensor(function, arg.asTensor()));
@@ -149,17 +186,17 @@ public class TensorValue extends Value {
}
private Tensor functionOnTensor(Function function, Tensor argument) {
- switch (function) {
- case min: return value.min(argument);
- case max: return value.max(argument);
- case atan2: return value.atan2(argument);
- case pow: return value.pow(argument);
- case fmod: return value.fmod(argument);
- case ldexp: return value.ldexp(argument);
- case bit: return value.bit(argument);
- case hamming: return value.hamming(argument);
- default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
- }
+ return switch (function) {
+ case min -> value.min(argument);
+ case max -> value.max(argument);
+ case atan2 -> value.atan2(argument);
+ case pow -> value.pow(argument);
+ case fmod -> value.fmod(argument);
+ case ldexp -> value.ldexp(argument);
+ case bit -> value.bit(argument);
+ case hamming -> value.hamming(argument);
+ default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function);
+ };
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 207603c5038..99663fe8d0d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -1,9 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
-import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -51,27 +49,24 @@ public abstract class Value {
public abstract Value negate();
- public abstract Value add(Value value);
+ public abstract Value not();
+ public abstract Value or(Value value);
+ public abstract Value and(Value value);
+ public abstract Value greaterEqual(Value value);
+ public abstract Value greater(Value value);
+ public abstract Value lessEqual(Value value);
+ public abstract Value less(Value value);
+ public abstract Value approx(Value value);
+ public abstract Value notEqual(Value value);
+ public abstract Value equal(Value value);
+ public abstract Value add(Value value);
public abstract Value subtract(Value value);
-
public abstract Value multiply(Value value);
-
public abstract Value divide(Value value);
-
public abstract Value modulo(Value value);
-
- public abstract Value and(Value value);
-
- public abstract Value or(Value value);
-
- public abstract Value not();
-
public abstract Value power(Value value);
- /** Perform the comparison specified by the operator between this value and the given value */
- public abstract Value compare(TruthOperator operator, Value value);
-
/** Perform the given binary function on this value and the given value */
public abstract Value function(Function function, Value value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index 420f1f459f3..cf4c35d94af 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -114,15 +114,15 @@ public class GBDTOptimizer extends Optimizer {
/** Consumes the if condition and return the size of the values resulting, for convenience */
private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
- if (condition instanceof ComparisonNode) {
- ComparisonNode comparison = (ComparisonNode)condition;
- if (comparison.getOperator() == TruthOperator.SMALLER)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.getLeftCondition(), context));
- else if (comparison.getOperator() == TruthOperator.EQUAL)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.getLeftCondition(), context));
+ if (isBinaryComparison(condition)) {
+ ArithmeticNode comparison = (ArithmeticNode)condition;
+ if (comparison.operators().get(0) == ArithmeticOperator.LESS)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context));
+ else if (comparison.operators().get(0) == ArithmeticOperator.EQUAL)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context));
else
- throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator());
- values.add(toValue(comparison.getRightCondition()));
+ throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0));
+ values.add(toValue(comparison.children().get(1)));
}
else if (condition instanceof SetMembershipNode) {
SetMembershipNode setMembership = (SetMembershipNode)condition;
@@ -131,17 +131,15 @@ public class GBDTOptimizer extends Optimizer {
for (ExpressionNode setElementNode : setMembership.getSetValues())
values.add(toValue(setElementNode));
}
- else if (condition instanceof NotNode) { // handle if inversion: !(a >= b)
- NotNode notNode = (NotNode)condition;
- if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) {
- EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0);
- if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) {
- ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0);
- if (comparison.getOperator() == TruthOperator.LARGEREQUAL)
- values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context));
+ else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b)
+ if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) {
+ if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) {
+ ArithmeticNode comparison = (ArithmeticNode)embracedNode.children().get(0);
+ if (comparison.operators().get(0) == ArithmeticOperator.GREATEREQUAL)
+ values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context));
else
- throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator());
- values.add(toValue(comparison.getRightCondition()));
+ throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0));
+ values.add(toValue(comparison.children().get(1)));
}
}
}
@@ -152,12 +150,24 @@ public class GBDTOptimizer extends Optimizer {
return values.size();
}
+ private boolean isBinaryComparison(ExpressionNode condition) {
+ if ( ! (condition instanceof ArithmeticNode binaryNode)) return false;
+ if (binaryNode.operators().size() != 1) return false;
+ if (binaryNode.operators().get(0) == ArithmeticOperator.GREATEREQUAL) return true;
+ if (binaryNode.operators().get(0) == ArithmeticOperator.GREATER) return true;
+ if (binaryNode.operators().get(0) == ArithmeticOperator.LESSEQUAL) return true;
+ if (binaryNode.operators().get(0) == ArithmeticOperator.LESS) return true;
+ if (binaryNode.operators().get(0) == ArithmeticOperator.APPROX) return true;
+ if (binaryNode.operators().get(0) == ArithmeticOperator.NOTEQUAL) return true;
+ if (binaryNode.operators().get(0) == ArithmeticOperator.EQUAL) return true;
+ return false;
+ }
+
private double getVariableIndex(ExpressionNode node, ContextIndex context) {
- if (!(node instanceof ReferenceNode)) {
+ if (!(node instanceof ReferenceNode fNode)) {
throw new IllegalArgumentException("Contained a left-hand comparison expression " +
"which was not a feature value but was: " + node);
}
- ReferenceNode fNode = (ReferenceNode)node;
Integer index = context.getIndex(fNode.toString());
if (index == null) {
throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() +
@@ -177,8 +187,7 @@ public class GBDTOptimizer extends Optimizer {
value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node);
}
- if (node instanceof NegativeNode) {
- NegativeNode nNode = (NegativeNode)node;
+ if (node instanceof NegativeNode nNode) {
if (!(nNode.getValue() instanceof ConstantNode)) {
throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
index a2521398529..435c92ff7da 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
@@ -14,26 +14,16 @@ import java.util.function.BiFunction;
*/
public enum ArithmeticOperator {
-/*
-struct Sub : OperatorHelper<Sub> { Sub() : Helper("-", 101, LEFT) {}};
-struct Mul : OperatorHelper<Mul> { Mul() : Helper("*", 102, LEFT) {}};
-struct Div : OperatorHelper<Div> { Div() : Helper("/", 102, LEFT) {}};
-struct Mod : OperatorHelper<Mod> { Mod() : Helper("%", 102, LEFT) {}};
-struct Pow : OperatorHelper<Pow> { Pow() : Helper("^", 103, RIGHT) {}};
-struct Equal : OperatorHelper<Equal> { Equal() : Helper("==", 10, LEFT) {}};
-struct NotEqual : OperatorHelper<NotEqual> { NotEqual() : Helper("!=", 10, LEFT) {}};
-struct Approx : OperatorHelper<Approx> { Approx() : Helper("~=", 10, LEFT) {}};
-struct Less : OperatorHelper<Less> { Less() : Helper("<", 10, LEFT) {}};
-struct LessEqual : OperatorHelper<LessEqual> { LessEqual() : Helper("<=", 10, LEFT) {}};
-struct Greater : OperatorHelper<Greater> { Greater() : Helper(">", 10, LEFT) {}};
-struct GreaterEqual : OperatorHelper<GreaterEqual> { GreaterEqual() : Helper(">=", 10, LEFT) {}};
-struct And : OperatorHelper<And> { And() : Helper("&&", 2, LEFT) {}};
-struct Or : OperatorHelper<Or> { Or() : Helper("||", 1, LEFT) {}};
- */
-
// In order from lowest to highest precedence
OR("||", (x, y) -> x.or(y)),
AND("&&", (x, y) -> x.and(y)),
+ GREATEREQUAL(">=", (x, y) -> x.greaterEqual(y)),
+ GREATER(">", (x, y) -> x.greater(y)),
+ LESSEQUAL("<=", (x, y) -> x.lessEqual(y)),
+ LESS("<", (x, y) -> x.less(y)),
+ APPROX("~=", (x, y) -> x.approx(y)),
+ NOTEQUAL("!=", (x, y) -> x.notEqual(y)),
+ EQUAL("==", (x, y) -> x.equal(y)),
PLUS("+", (x, y) -> x.add(y)),
MINUS("-", (x, y) -> x.subtract(y)),
MULTIPLY("*", (x, y) -> x.multiply(y)),
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
deleted file mode 100644
index e726a351f74..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java
+++ /dev/null
@@ -1,69 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.Deque;
-import java.util.List;
-import java.util.Objects;
-
-/**
- * A node which returns the outcome of a comparison.
- *
- * @author bratseth
- */
-public class ComparisonNode extends BooleanNode {
-
- /** The operator string of this condition. */
- private final TruthOperator operator;
-
- private final List<ExpressionNode> conditions;
-
- public ComparisonNode(ExpressionNode leftCondition, TruthOperator operator, ExpressionNode rightCondition) {
- conditions = List.of(leftCondition, rightCondition);
- this.operator = operator;
- }
-
- @Override
- public List<ExpressionNode> children() {
- return conditions;
- }
-
- public TruthOperator getOperator() { return operator; }
-
- public ExpressionNode getLeftCondition() { return conditions.get(0); }
-
- public ExpressionNode getRightCondition() { return conditions.get(1); }
-
- @Override
- public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
- getLeftCondition().toString(string, context, path, this).append(' ').append(operator).append(' ');
- return getRightCondition().toString(string, context, path, this);
- }
-
- @Override
- public TensorType type(TypeContext<Reference> context) {
- return TensorType.empty; // by definition
- }
-
- @Override
- public Value evaluate(Context context) {
- Value leftValue = getLeftCondition().evaluate(context);
- Value rightValue = getRightCondition().evaluate(context);
- return leftValue.compare(operator,rightValue);
- }
-
- @Override
- public ComparisonNode setChildren(List<ExpressionNode> children) {
- if (children.size() != 2) throw new IllegalArgumentException("A comparison test must have 2 children");
- return new ComparisonNode(children.get(0), operator, children.get(1));
- }
-
- @Override
- public int hashCode() { return Objects.hash(operator, conditions); }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
deleted file mode 100644
index fc259867923..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.rule;
-
-import java.io.Serializable;
-
-/**
- * A mathematical operator
- *
- * @author bratseth
- */
-public enum TruthOperator implements Serializable {
-
- SMALLER("<") { public boolean evaluate(double x, double y) { return x<y; } },
- SMALLEREQUAL("<=") { public boolean evaluate(double x, double y) { return x<=y; } },
- EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } },
- APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } },
- LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } },
- LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } },
- NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } };
-
- private final String operatorString;
-
- TruthOperator(String operatorString) {
- this.operatorString = operatorString;
- }
-
- /** Perform the truth operation on the input */
- public abstract boolean evaluate(double x, double y);
-
- @Override
- public String toString() { return operatorString; }
-
- public static TruthOperator fromString(String string) {
- for (TruthOperator operator : values())
- if (operator.toString().equals(string))
- return operator;
- throw new IllegalArgumentException("Illegal truth operator '" + string + "'");
- }
-
- private static boolean approxEqual(double x,double y) {
- if (y < -1.0 || y > 1.0) {
- x = Math.nextAfter(x/y, 1.0);
- y = 1.0;
- } else {
- x = Math.nextAfter(x, y);
- }
- return x==y;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 90861e64164..7a34f5b7b03 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -6,7 +6,6 @@ import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -44,7 +43,6 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean hasSingleUndividableChild(EmbracedNode node) {
if (node.children().size() > 1) return false;
if (node.children().get(0) instanceof ArithmeticNode) return false;
- if (node.children().get(0) instanceof ComparisonNode) return false;
return true;
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index ebe1e048247..2261d39829c 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -72,13 +72,13 @@ TOKEN :
<COMMA: ","> |
<COLON: ":"> |
- <LE: "<="> |
- <LT: "<"> |
- <EQ: "=="> |
- <NQ: "!="> |
- <AQ: "~="> |
- <GE: ">="> |
- <GT: ">"> |
+ <GREATEREQUAL: ">="> |
+ <GREATER: ">"> |
+ <LESSEQUAL: "<="> |
+ <LESS: "<"> |
+ <APPROX: "~="> |
+ <NOTEQUAL: "!="> |
+ <EQUAL: "=="> |
<STRING: ("\"" (~["\""] | "\\\"")* "\"") |
("'" (~["'"] | "\\'")* "'")> |
@@ -188,14 +188,12 @@ ExpressionNode expression() :
{
ExpressionNode left, right;
List<ExpressionNode> rightList;
- TruthOperator comparatorOp;
}
{
( left = arithmeticExpression()
(
- ( comparatorOp = comparator() right = arithmeticExpression() { left = new ComparisonNode(left, comparatorOp, right); } ) |
( <IN> rightList = expressionList() { left = new SetMembershipNode(left, rightList); } )
- ) *
+ ) ?
)
{ return left; }
}
@@ -214,29 +212,26 @@ ExpressionNode arithmeticExpression() :
ArithmeticOperator arithmetic() : { }
{
- ( <ADD> { return ArithmeticOperator.PLUS; } |
- <SUB> { return ArithmeticOperator.MINUS; } |
- <DIV> { return ArithmeticOperator.DIVIDE; } |
- <MUL> { return ArithmeticOperator.MULTIPLY; } |
- <MOD> { return ArithmeticOperator.MODULO; } |
- <AND> { return ArithmeticOperator.AND; } |
- <OR> { return ArithmeticOperator.OR; } |
- <POWOP> { return ArithmeticOperator.POWER; } )
+ (
+ <OR> { return ArithmeticOperator.OR; } |
+ <AND> { return ArithmeticOperator.AND; } |
+ <GREATEREQUAL> { return ArithmeticOperator.GREATEREQUAL; } |
+ <GREATER> { return ArithmeticOperator.GREATER; } |
+ <LESSEQUAL> { return ArithmeticOperator.LESSEQUAL; } |
+ <LESS> { return ArithmeticOperator.LESS; } |
+ <APPROX> { return ArithmeticOperator.APPROX; } |
+ <NOTEQUAL> { return ArithmeticOperator.NOTEQUAL; } |
+ <EQUAL> { return ArithmeticOperator.EQUAL; } |
+ <ADD> { return ArithmeticOperator.PLUS; } |
+ <SUB> { return ArithmeticOperator.MINUS; } |
+ <DIV> { return ArithmeticOperator.DIVIDE; } |
+ <MUL> { return ArithmeticOperator.MULTIPLY; } |
+ <MOD> { return ArithmeticOperator.MODULO; } |
+ <POWOP> { return ArithmeticOperator.POWER; }
+ )
{ return null; }
}
-TruthOperator comparator() : { }
-{
- ( <LE> { return TruthOperator.SMALLEREQUAL; } |
- <LT> { return TruthOperator.SMALLER; } |
- <EQ> { return TruthOperator.EQUAL; } |
- <NQ> { return TruthOperator.NOTEQUAL; } |
- <AQ> { return TruthOperator.APPROX_EQUAL; } |
- <GE> { return TruthOperator.LARGEREQUAL; } |
- <GT> { return TruthOperator.LARGER; } )
- { return null; }
-}
-
ExpressionNode value() :
{
ExpressionNode value;
@@ -665,7 +660,7 @@ TensorType.Value optionalTensorValueTypeParameter() :
String valueType = "double";
}
{
- ( <LT> valueType = identifier() <GT> )?
+ ( <LESS> valueType = identifier() <GREATER> )?
{ return TensorType.Value.fromId(valueType); }
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index ad50a423eb9..b1ac4b9e3ca 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -198,9 +198,9 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
"tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
- "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ "(tensor0 || 1) == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
- "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
+ "(tensor0 && 1) == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }");
tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }",
"!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }");