diff options
30 files changed, 711 insertions, 638 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java index 8fa4b469590..ad050d4ca63 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java @@ -2,8 +2,8 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -35,20 +35,20 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor if (node instanceof CompositeNode composite) node = transformChildren(composite, context); - if (node instanceof ArithmeticNode arithmetic) + if (node instanceof OperationNode arithmetic) node = transformBooleanArithmetics(arithmetic); return node; } - private ExpressionNode transformBooleanArithmetics(ArithmeticNode node) { + private ExpressionNode transformBooleanArithmetics(OperationNode node) { Iterator<ExpressionNode> child = node.children().iterator(); // Transform in precedence order: Deque<ChildNode> stack = new ArrayDeque<>(); stack.push(new ChildNode(null, child.next())); - for (Iterator<ArithmeticOperator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) { - ArithmeticOperator op = it.next(); + for (Iterator<Operator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) { + Operator op = it.next(); if ( ! stack.isEmpty()) { while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) { popStack(stack); @@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor ChildNode lhs = stack.peek(); ExpressionNode combination; - if (rhs.op == ArithmeticOperator.AND) + if (rhs.op == Operator.and) combination = andByIfNode(lhs.child, rhs.child); - else if (rhs.op == ArithmeticOperator.OR) + else if (rhs.op == Operator.or) combination = orByIfNode(lhs.child, rhs.child); else { combination = resolve(lhs, rhs); @@ -77,28 +77,28 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor lhs.child = combination; } - private static ArithmeticNode resolve(ChildNode left, ChildNode right) { - if ( ! (left.child instanceof ArithmeticNode) && ! (right.child instanceof ArithmeticNode)) - return new ArithmeticNode(left.child, right.op, right.child); + private static OperationNode resolve(ChildNode left, ChildNode right) { + if (! (left.child instanceof OperationNode) && ! (right.child instanceof OperationNode)) + return new OperationNode(left.child, right.op, right.child); // Collapse inserted ArithmeticNodes - List<ArithmeticOperator> joinedOps = new ArrayList<>(); + List<Operator> joinedOps = new ArrayList<>(); joinOps(left, joinedOps); joinedOps.add(right.op); joinOps(right, joinedOps); List<ExpressionNode> joinedChildren = new ArrayList<>(); joinChildren(left, joinedChildren); joinChildren(right, joinedChildren); - return new ArithmeticNode(joinedChildren, joinedOps); + return new OperationNode(joinedChildren, joinedOps); } - private static void joinOps(ChildNode node, List<ArithmeticOperator> joinedOps) { - if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode) - joinedOps.addAll(arithmeticNode.operators()); + private static void joinOps(ChildNode node, List<Operator> joinedOps) { + if (node.artificial && node.child instanceof OperationNode operationNode) + joinedOps.addAll(operationNode.operators()); } private static void joinChildren(ChildNode node, List<ExpressionNode> joinedChildren) { - if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode) - joinedChildren.addAll(arithmeticNode.children()); + if (node.artificial && node.child instanceof OperationNode operationNode) + joinedChildren.addAll(operationNode.children()); else joinedChildren.add(node.child); } @@ -115,11 +115,11 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor /** A child with the operator to be applied to it when combining it with the previous child. */ private static class ChildNode { - final ArithmeticOperator op; + final Operator op; ExpressionNode child; boolean artificial; - public ChildNode(ArithmeticOperator op, ExpressionNode child) { + public ChildNode(Operator op, ExpressionNode child) { this.op = op; this.child = child; } diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index 30d9a3766b3..cf354a05a93 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -3,9 +3,8 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; @@ -13,7 +12,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Generate; @@ -141,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); ExpressionNode expr = new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr), + new OperationNode(new ReferenceNode("d1"), Operator.smaller, queryLengthExpr), ZERO, new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr), + new OperationNode(new ReferenceNode("d1"), Operator.smaller, restLengthExpr), ONE, ZERO ) @@ -176,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform List<ExpressionNode> tokenSequence = createTokenSequence(feature); ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); - ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode expr = new IfNode(comparison, ONE, ZERO); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } @@ -256,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform */ private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) { ExpressionNode lengthExpr = createLengthExpr(iter, sequence); - ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.smaller, lengthExpr); ExpressionNode trueExpr = sequence.get(iter); if (sequence.get(iter) instanceof ReferenceNode) { @@ -280,7 +278,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform */ private ExpressionNode createLengthExpr(int iter, List<ExpressionNode> sequence) { List<ExpressionNode> factors = new ArrayList<>(); - List<ArithmeticOperator> operators = new ArrayList<>(); + List<Operator> operators = new ArrayList<>(); for (int i = 0; i < iter + 1; ++i) { if (sequence.get(i) instanceof ConstantNode) { factors.add(ONE); @@ -288,10 +286,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i)))); } if (i >= 1) { - operators.add(ArithmeticOperator.PLUS); + operators.add(Operator.plus); } } - return new ArithmeticNode(factors, operators); + return new OperationNode(factors, operators); } /** @@ -301,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform ExpressionNode expr; if (iter >= 1) { ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence)); - expr = new EmbracedNode(new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.MINUS, lengthExpr)); + expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.minus, lengthExpr)); } else { expr = new ReferenceNode("d1"); } diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java index 5eecee516ec..13d21884c7d 100644 --- a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java @@ -68,7 +68,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase { Schema s = builder.getSchema(); RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); - assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))", + assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + (if (7.0 < attribute(a), 1, 2) == 0)))", parent.getFirstPhaseRanking().getRoot().toString()); RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("7.0 * (9 + attribute(a))", @@ -76,6 +76,33 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase { } @Test + void testInlinedComparison() throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry); + builder.addSchema("search test {\n" + + " document test { \n" + + " }\n" + + " \n" + + " rank-profile parent {\n" + + "function foo() {\n" + + " expression: 3 * bar\n" + + "}\n" + + "\n" + + "function inline bar() {\n" + + " expression: query(test) > 2.0\n" + + "}\n" + + "}\n" + + "}\n"); + builder.build(true); + Schema s = builder.getSchema(); + + RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); + assertEquals("3 * (query(test) > 2.0)", + parent.getFunctions().get("foo").function().getBody().getRoot().toString()); + + } + + @Test void testConstants() throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry); diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java index 65d12f1cdcf..d692b69d3c8 100644 --- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java @@ -4,7 +4,7 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; import org.junit.jupiter.api.Test; @@ -43,8 +43,8 @@ public class BooleanExpressionTransformerTestCase { var expr = new BooleanExpressionTransformer() .transform(new RankingExpression("a + b + c * d + e + f"), new TransformContext(Map.of(), new MapTypeContext())); - assertTrue(expr.getRoot() instanceof ArithmeticNode); - ArithmeticNode root = (ArithmeticNode) expr.getRoot(); + assertTrue(expr.getRoot() instanceof OperationNode); + OperationNode root = (OperationNode) expr.getRoot(); assertEquals(5, root.operators().size()); assertEquals(6, root.children().size()); } diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java index 22681858fc3..e9b674a8c87 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java @@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); - assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 0.0, if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); + assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); assertEquals("test_unbound_model", config.rankprofile(8).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java index 8788eda572d..3aa60e0b9a2 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyValue.java @@ -4,7 +4,6 @@ package ai.vespa.models.evaluation; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -70,53 +69,83 @@ class LazyValue extends Value { } @Override - public Value add(Value value) { - return computedValue().add(value); + public Value not() { + return computedValue().not(); } @Override - public Value subtract(Value value) { - return computedValue().subtract(value); + public Value or(Value value) { + return computedValue().or(value); } @Override - public Value multiply(Value value) { - return computedValue().multiply(value); + public Value and(Value value) { + return computedValue().and(value); } @Override - public Value divide(Value value) { - return computedValue().divide(value); + public Value largerOrEqual(Value value) { + return computedValue().largerOrEqual(value); } @Override - public Value modulo(Value value) { - return computedValue().modulo(value); + public Value larger(Value value) { + return computedValue().larger(value); } @Override - public Value and(Value value) { - return computedValue().and(value); + public Value smallerOrEqual(Value value) { + return computedValue().smallerOrEqual(value); } @Override - public Value or(Value value) { - return computedValue().or(value); + public Value smaller(Value value) { + return computedValue().smaller(value); } @Override - public Value not() { - return computedValue().not(); + public Value approxEqual(Value value) { + return computedValue().approxEqual(value); } @Override - public Value power(Value value) { - return computedValue().power(value); + public Value notEqual(Value value) { + return computedValue().notEqual(value); + } + + @Override + public Value equal(Value value) { + return computedValue().equal(value); } @Override - public Value compare(TruthOperator operator, Value value) { - return computedValue().compare(operator, value); + public Value add(Value value) { + return computedValue().add(value); + } + + @Override + public Value subtract(Value value) { + return computedValue().subtract(value); + } + + @Override + public Value multiply(Value value) { + return computedValue().multiply(value); + } + + @Override + public Value divide(Value value) { + return computedValue().divide(value); + } + + @Override + public Value modulo(Value value) { + return computedValue().modulo(value); + } + + @Override + public Value power(Value value) { + return computedValue().power(value); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index cd0c4da6d0f..7ac2f4bff84 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation { ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue)); if (constantValue < 0) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - indexExpression = new EmbracedNode(new ArithmeticNode(indexExpression, ArithmeticOperator.PLUS, axisSize)); + indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.plus, axisSize)); } addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression); } else { @@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation { /** to support negative indexing */ private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) { ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get())); - ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize)); - ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize); + ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.plus, axisSize)); + ExpressionNode mod = new OperationNode(plus, Operator.modulo, axisSize); return mod; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 1f447f2a575..97bfdda385e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.TensorType; @@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation { TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension); TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(AxB), - ArithmeticOperator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(alpha)))); if (inputs.size() == 3) { Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function(); TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction( - new ArithmeticNode( + new OperationNode( new TensorFunctionNode(cFunction.get()), - ArithmeticOperator.MULTIPLY, + Operator.multiply, new ConstantNode(new DoubleValue(beta)))); return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java index 5c4e8cd6cd0..9a38ab9dfde 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -65,8 +65,8 @@ public class Range extends IntermediateOperation { ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta)); ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); - ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr); - ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr); + ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.multiply, dimExpr); + ExpressionNode addExpr = new OperationNode(startExpr, Operator.plus, stepExpr); TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr)); return function; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 7b675fa79af..bc94fc6aa76 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -6,13 +6,11 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; @@ -159,13 +157,13 @@ public class Reshape extends IntermediateOperation { inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); } else if (dim == (inputType.rank() - 1)) { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); - ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size); + ExpressionNode div = new OperationNode(unrolled, Operator.modulo, size); inputDimensionExpression = new EmbracedNode(div); } else { ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); - ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize); - ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size); + ExpressionNode mod = new OperationNode(unrolled, Operator.modulo, previousSize); + ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.divide, size); inputDimensionExpression = new EmbracedNode(div); } dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); @@ -183,21 +181,21 @@ public class Reshape extends IntermediateOperation { return new ConstantNode(DoubleValue.zero); List<ExpressionNode> children = new ArrayList<>(); - List<ArithmeticOperator> operators = new ArrayList<>(); + List<Operator> operators = new ArrayList<>(); int size = 1; for (int i = type.dimensions().size() - 1; i >= 0; --i) { TensorType.Dimension dimension = type.dimensions().get(i); children.add(0, new ReferenceNode(dimension.name())); if (size > 1) { - operators.add(0, ArithmeticOperator.MULTIPLY); + operators.add(0, Operator.multiply); children.add(0, new ConstantNode(new DoubleValue(size))); } size *= OrderedTensorType.dimensionSize(dimension); if (i > 0) { - operators.add(0, ArithmeticOperator.PLUS); + operators.add(0, Operator.plus); } } - return new ArithmeticNode(children, operators); + return new OperationNode(children, operators); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index 91b7064b19c..916e9980131 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -6,8 +6,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation { // step * (d0 + start) ExpressionNode reference = new ReferenceNode(outputDimensionName); - ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex)); - ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus); + ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.plus, startIndex)); + ExpressionNode mul = new OperationNode(stepSize, Operator.multiply, plus); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java index 6f720716adb..ef46d222941 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -96,7 +96,7 @@ public class Split extends IntermediateOperation { for (int i = 0; i < inputType.rank(); ++i) { String inputDimensionName = inputType.dimensions().get(i).name(); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); + ExpressionNode offset = new OperationNode(reference, Operator.plus, new ConstantNode(new DoubleValue(i == axis ? start : 0))); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset)))); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index 4bfab284cc2..a880bff87be 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -4,8 +4,8 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation { ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); ExpressionNode reference = new ReferenceNode(inputDimensionName); - ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size); + ExpressionNode mod = new OperationNode(reference, Operator.modulo, size); dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod)))); } diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 7caf5a06032..a0517f408bf 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -496,16 +496,22 @@ "public boolean hasDouble()", "public com.yahoo.tensor.Tensor asTensor()", "public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value largerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value larger(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value smallerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value smaller(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value approxEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()", "public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)" ], "fields": [] @@ -691,16 +697,22 @@ "public boolean hasDouble()", "public boolean asBoolean()", "public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value largerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value larger(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value smallerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value smaller(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value approxEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()", "public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value asMutable()", "public java.lang.String toString()", @@ -723,17 +735,23 @@ "public boolean hasDouble()", "public boolean asBoolean()", "public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value largerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value larger(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value smallerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value smaller(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value approxEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value not()", "public com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.tensor.Tensor asTensor()", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value asMutable()", "public java.lang.String toString()", @@ -760,16 +778,22 @@ "public abstract boolean hasDouble()", "public abstract boolean asBoolean()", "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value negate()", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value not()", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value largerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value larger(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value smallerOrEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value smaller(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value approxEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value notEqual(com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value equal(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value and(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value or(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value not()", "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value power(com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value compare(com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value freeze()", "public final boolean isFrozen()", @@ -891,9 +915,8 @@ "public final java.util.List featureList()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode rankingExpression()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode expression()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode arithmeticExpression()", - "public final com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator arithmetic()", - "public final com.yahoo.searchlib.rankingexpression.rule.TruthOperator comparator()", + "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode operationExpression()", + "public final com.yahoo.searchlib.rankingexpression.rule.Operator binaryOperator()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode value()", "public final com.yahoo.searchlib.rankingexpression.rule.IfNode ifExpression()", "public final com.yahoo.searchlib.rankingexpression.rule.ReferenceNode feature()", @@ -1013,13 +1036,13 @@ "public static final int DOLLAR", "public static final int COMMA", "public static final int COLON", - "public static final int LE", - "public static final int LT", - "public static final int EQ", - "public static final int NQ", - "public static final int AQ", - "public static final int GE", - "public static final int GT", + "public static final int GREATEREQUAL", + "public static final int GREATER", + "public static final int LESSEQUAL", + "public static final int LESS", + "public static final int APPROX", + "public static final int NOTEQUAL", + "public static final int EQUAL", "public static final int STRING", "public static final int IF", "public static final int IN", @@ -1220,54 +1243,6 @@ "public static final com.yahoo.searchlib.rankingexpression.rule.Arguments EMPTY" ] }, - "com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode": { - "superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode", - "interfaces": [], - "attributes": [ - "public", - "final" - ], - "methods": [ - "public void <init>(java.util.List, java.util.List)", - "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", - "public java.util.List operators()", - "public java.util.List children()", - "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", - "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", - "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)", - "public int hashCode()", - "public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode resolve(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)" - ], - "fields": [] - }, - "com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator": { - "superClass": "java.lang.Enum", - "interfaces": [], - "attributes": [ - "public", - "abstract", - "enum" - ], - "methods": [ - "public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator[] values()", - "public static com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator valueOf(java.lang.String)", - "public boolean hasPrecedenceOver(com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator)", - "public abstract com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Value, com.yahoo.searchlib.rankingexpression.evaluation.Value)", - "public java.lang.String toString()" - ], - "fields": [ - "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator OR", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator AND", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator PLUS", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator MINUS", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator MULTIPLY", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator DIVIDE", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator MODULO", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator POWER", - "public static final java.util.List operatorsByPrecedence" - ] - }, "com.yahoo.searchlib.rankingexpression.rule.BooleanNode": { "superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode", "interfaces": [], @@ -1280,27 +1255,6 @@ ], "fields": [] }, - "com.yahoo.searchlib.rankingexpression.rule.ComparisonNode": { - "superClass": "com.yahoo.searchlib.rankingexpression.rule.BooleanNode", - "interfaces": [], - "attributes": [ - "public" - ], - "methods": [ - "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.TruthOperator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", - "public java.util.List children()", - "public com.yahoo.searchlib.rankingexpression.rule.TruthOperator getOperator()", - "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getLeftCondition()", - "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getRightCondition()", - "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", - "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", - "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", - "public com.yahoo.searchlib.rankingexpression.rule.ComparisonNode setChildren(java.util.List)", - "public int hashCode()", - "public bridge synthetic com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)" - ], - "fields": [] - }, "com.yahoo.searchlib.rankingexpression.rule.CompositeNode": { "superClass": "com.yahoo.searchlib.rankingexpression.rule.ExpressionNode", "interfaces": [], @@ -1581,6 +1535,61 @@ ], "fields": [] }, + "com.yahoo.searchlib.rankingexpression.rule.OperationNode": { + "superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode", + "interfaces": [], + "attributes": [ + "public", + "final" + ], + "methods": [ + "public void <init>(java.util.List, java.util.List)", + "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.Operator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", + "public java.util.List operators()", + "public java.util.List children()", + "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", + "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)", + "public int hashCode()", + "public static com.yahoo.searchlib.rankingexpression.rule.OperationNode resolve(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.searchlib.rankingexpression.rule.Operator, com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)" + ], + "fields": [] + }, + "com.yahoo.searchlib.rankingexpression.rule.Operator": { + "superClass": "java.lang.Enum", + "interfaces": [], + "attributes": [ + "public", + "final", + "enum" + ], + "methods": [ + "public static com.yahoo.searchlib.rankingexpression.rule.Operator[] values()", + "public static com.yahoo.searchlib.rankingexpression.rule.Operator valueOf(java.lang.String)", + "public boolean hasPrecedenceOver(com.yahoo.searchlib.rankingexpression.rule.Operator)", + "public final com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Value, com.yahoo.searchlib.rankingexpression.evaluation.Value)", + "public java.lang.String toString()" + ], + "fields": [ + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator or", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator and", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator largerOrEqual", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator larger", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator smallerOrEqual", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator smaller", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator approxEqual", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator notEqual", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator equal", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator plus", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator minus", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator multiply", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator divide", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator modulo", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Operator power", + "public static final java.util.List operatorsByPrecedence" + ] + }, "com.yahoo.searchlib.rankingexpression.rule.ReferenceNode": { "superClass": "com.yahoo.searchlib.rankingexpression.rule.CompositeNode", "interfaces": [], @@ -1693,32 +1702,5 @@ "public int hashCode()" ], "fields": [] - }, - "com.yahoo.searchlib.rankingexpression.rule.TruthOperator": { - "superClass": "java.lang.Enum", - "interfaces": [ - "java.io.Serializable" - ], - "attributes": [ - "public", - "abstract", - "enum" - ], - "methods": [ - "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator[] values()", - "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator valueOf(java.lang.String)", - "public abstract boolean evaluate(double, double)", - "public java.lang.String toString()", - "public static com.yahoo.searchlib.rankingexpression.rule.TruthOperator fromString(java.lang.String)" - ], - "fields": [ - "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator SMALLER", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator SMALLEREQUAL", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator EQUAL", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator APPROX_EQUAL", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator LARGER", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator LARGEREQUAL", - "public static final enum com.yahoo.searchlib.rankingexpression.rule.TruthOperator NOTEQUAL" - ] } }
\ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index afd263f1553..2405d1bd528 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -2,7 +2,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -28,58 +27,146 @@ public abstract class DoubleCompatibleValue extends Value { public Value negate() { return new DoubleValue(-asDouble()); } @Override - public Value add(Value value) { - return new DoubleValue(asDouble() + value.asDouble()); + public Value not() { + return new BooleanValue(!asBoolean()); } @Override - public Value subtract(Value value) { - return new DoubleValue(asDouble() - value.asDouble()); + public Value or(Value value) { + if (value instanceof TensorValue tensor) + return tensor.or(this); + else + return new BooleanValue(asBoolean() || value.asBoolean()); } @Override - public Value multiply(Value value) { - return new DoubleValue(asDouble() * value.asDouble()); + public Value and(Value value) { + if (value instanceof TensorValue tensor) + return tensor.and(this); + else + return new BooleanValue(asBoolean() && value.asBoolean()); } @Override - public Value divide(Value value) { - return new DoubleValue(asDouble() / value.asDouble()); + public Value largerOrEqual(Value value) { + if (value instanceof TensorValue tensor) + return tensor.largerOrEqual(this); + else + return new BooleanValue(this.asDouble() >= value.asDouble()); } @Override - public Value modulo(Value value) { - return new DoubleValue(asDouble() % value.asDouble()); + public Value larger(Value value) { + if (value instanceof TensorValue tensor) + return tensor.larger(this); + else + return new BooleanValue(this.asDouble() > value.asDouble()); } @Override - public Value and(Value value) { - return new BooleanValue(asBoolean() && value.asBoolean()); + public Value smallerOrEqual(Value value) { + if (value instanceof TensorValue tensor) + return tensor.smallerOrEqual(this); + else + return new BooleanValue(this.asDouble() <= value.asDouble()); } @Override - public Value or(Value value) { - return new BooleanValue(asBoolean() || value.asBoolean()); + public Value smaller(Value value) { + if (value instanceof TensorValue tensor) + return tensor.smaller(this); + else + return new BooleanValue(this.asDouble() < value.asDouble()); } @Override - public Value not() { - return new BooleanValue(!asBoolean()); + public Value approxEqual(Value value) { + if (value instanceof TensorValue tensor) + return tensor.approxEqual(this); + else + return new BooleanValue(approxEqual(this.asDouble(), value.asDouble())); } @Override - public Value power(Value value) { - return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble())); + public Value notEqual(Value value) { + if (value instanceof TensorValue tensor) + return tensor.notEqual(this); + else + return new BooleanValue(this.asDouble() != value.asDouble()); + } + + @Override + public Value equal(Value value) { + if (value instanceof TensorValue tensor) + return tensor.equal(this); + else + return new BooleanValue(this.asDouble() == value.asDouble()); + } + + @Override + public Value add(Value value) { + if (value instanceof TensorValue tensor) + return tensor.add(this); + else + return new DoubleValue(asDouble() + value.asDouble()); + } + + @Override + public Value subtract(Value value) { + if (value instanceof TensorValue tensor) + return tensor.subtract(this); + else + return new DoubleValue(asDouble() - value.asDouble()); + } + + @Override + public Value multiply(Value value) { + if (value instanceof TensorValue tensor) + return tensor.multiply(this); + else + return new DoubleValue(asDouble() * value.asDouble()); + } + + @Override + public Value divide(Value value) { + if (value instanceof TensorValue tensor) + return tensor.divide(this); + else + return new DoubleValue(asDouble() / value.asDouble()); + } + + @Override + public Value modulo(Value value) { + if (value instanceof TensorValue tensor) + return tensor.modulo(this); + else + return new DoubleValue(asDouble() % value.asDouble()); } @Override - public Value compare(TruthOperator operator, Value value) { - return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); + public Value power(Value value) { + if (value instanceof TensorValue tensor) + return tensor.power(this); + else + return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble())); } @Override public Value function(Function function, Value value) { - return new DoubleValue(function.evaluate(asDouble(),value.asDouble())); + if (value instanceof TensorValue tensor) + return tensor.function(function, this); + else + return new DoubleValue(function.evaluate(asDouble(), value.asDouble())); + } + + static boolean approxEqual(double x, double y) { + if (y < -1.0 || y > 1.0) { + x = Math.nextAfter(x/y, 1.0); + y = 1.0; + } else { + x = Math.nextAfter(x, y); + } + return x == y; } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 2c2d5eead05..a585c989954 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -57,55 +56,83 @@ public class StringValue extends Value { } @Override - public Value add(Value value) { - return new StringValue(value + value.toString()); + public Value not() { + throw new UnsupportedOperationException("String values ('" + value + "') do not support not"); } @Override - public Value subtract(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction"); + public Value or(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support or"); } @Override - public Value multiply(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication"); + public Value and(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support and"); } @Override - public Value divide(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support division"); + public Value largerOrEqual(Value argument) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support greaterEqual"); } @Override - public Value modulo(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo"); + public Value larger(Value argument) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support greater"); } @Override - public Value and(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support and"); + public Value smallerOrEqual(Value argument) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support lessEqual"); } @Override - public Value or(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support or"); + public Value smaller(Value argument) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support less"); } @Override - public Value not() { - throw new UnsupportedOperationException("String values ('" + value + "') do not support not"); + public Value approxEqual(Value argument) { + return new BooleanValue(this.asDouble() == argument.asDouble()); } @Override - public Value power(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') do not support ^"); + public Value notEqual(Value argument) { + return new BooleanValue(this.asDouble() != argument.asDouble()); + } + + @Override + public Value equal(Value argument) { + return new BooleanValue(this.asDouble() == argument.asDouble()); } @Override - public Value compare(TruthOperator operator, Value value) { - if (operator.equals(TruthOperator.EQUAL)) - return new BooleanValue(this.equals(value)); - throw new UnsupportedOperationException("String values ('" + value + "') cannot be compared except with '=='"); + public Value add(Value value) { + return new StringValue(value + value.toString()); + } + + @Override + public Value subtract(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction"); + } + + @Override + public Value multiply(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication"); + } + + @Override + public Value divide(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support division"); + } + + @Override + public Value modulo(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo"); + } + + @Override + public Value power(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support ^"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index b37bbb543eb..25e03c75376 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -3,7 +3,6 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.api.annotations.Beta; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -52,6 +51,83 @@ public class TensorValue extends Value { } @Override + public Value not() { + return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); + } + + @Override + public Value or(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value and(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value largerOrEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.largerOrEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value >= argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value larger(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.larger(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value > argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value smallerOrEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.smallerOrEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value <= argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value smaller(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.smaller(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value < argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value approxEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.approxEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> DoubleCompatibleValue.approxEqual(value, argument.asDouble()) ? 1.0 : 0.0)); + } + + @Override + public Value notEqual(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.notEqual(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value != argument.asDouble() ? 1.0 : 0.0)); + } + + @Override + public Value equal(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.equal(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> value == argument.asDouble() ? 1.0 : 0.0)); + } + + @Override public Value add(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.add(((TensorValue)argument).value)); @@ -92,27 +168,6 @@ public class TensorValue extends Value { } @Override - public Value and(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); - else - return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); - } - - @Override - public Value or(Value argument) { - if (argument instanceof TensorValue) - return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); - else - return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); - } - - @Override - public Value not() { - return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); - } - - @Override public Value power(Value argument) { if (argument instanceof TensorValue) return new TensorValue(value.pow(((TensorValue)argument).value)); @@ -123,24 +178,6 @@ public class TensorValue extends Value { public Tensor asTensor() { return value; } @Override - public Value compare(TruthOperator operator, Value argument) { - return new TensorValue(compareTensor(operator, argument.asTensor())); - } - - private Tensor compareTensor(TruthOperator operator, Tensor argument) { - switch (operator) { - case LARGER: return value.larger(argument); - case LARGEREQUAL: return value.largerOrEqual(argument); - case SMALLER: return value.smaller(argument); - case SMALLEREQUAL: return value.smallerOrEqual(argument); - case EQUAL: return value.equal(argument); - case NOTEQUAL: return value.notEqual(argument); - case APPROX_EQUAL: return value.approxEqual(argument); - default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); - } - } - - @Override public Value function(Function function, Value arg) { if (arg instanceof TensorValue) return new TensorValue(functionOnTensor(function, arg.asTensor())); @@ -149,17 +186,17 @@ public class TensorValue extends Value { } private Tensor functionOnTensor(Function function, Tensor argument) { - switch (function) { - case min: return value.min(argument); - case max: return value.max(argument); - case atan2: return value.atan2(argument); - case pow: return value.pow(argument); - case fmod: return value.fmod(argument); - case ldexp: return value.ldexp(argument); - case bit: return value.bit(argument); - case hamming: return value.hamming(argument); - default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); - } + return switch (function) { + case min -> value.min(argument); + case max -> value.max(argument); + case atan2 -> value.atan2(argument); + case pow -> value.pow(argument); + case fmod -> value.fmod(argument); + case ldexp -> value.ldexp(argument); + case bit -> value.bit(argument); + case hamming -> value.hamming(argument); + default -> throw new UnsupportedOperationException("Cannot combine two tensors using " + function); + }; } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 207603c5038..5de2138147e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -1,9 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; -import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.rule.Function; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -51,27 +49,24 @@ public abstract class Value { public abstract Value negate(); - public abstract Value add(Value value); + public abstract Value not(); + public abstract Value or(Value value); + public abstract Value and(Value value); + public abstract Value largerOrEqual(Value value); + public abstract Value larger(Value value); + public abstract Value smallerOrEqual(Value value); + public abstract Value smaller(Value value); + public abstract Value approxEqual(Value value); + public abstract Value notEqual(Value value); + public abstract Value equal(Value value); + public abstract Value add(Value value); public abstract Value subtract(Value value); - public abstract Value multiply(Value value); - public abstract Value divide(Value value); - public abstract Value modulo(Value value); - - public abstract Value and(Value value); - - public abstract Value or(Value value); - - public abstract Value not(); - public abstract Value power(Value value); - /** Perform the comparison specified by the operator between this value and the given value */ - public abstract Value compare(TruthOperator operator, Value value); - /** Perform the given binary function on this value and the given value */ public abstract Value function(Function function, Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java index 6ab483800a7..a3fc6aae9ac 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java @@ -5,8 +5,8 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -84,12 +84,12 @@ public class GBDTForestOptimizer extends Optimizer { currentTreesOptimized++; return true; } - if (!(node instanceof ArithmeticNode)) { + if (!(node instanceof OperationNode)) { return false; } - ArithmeticNode aNode = (ArithmeticNode)node; - for (ArithmeticOperator op : aNode.operators()) { - if (op != ArithmeticOperator.PLUS) { + OperationNode aNode = (OperationNode)node; + for (Operator op : aNode.operators()) { + if (op != Operator.plus) { return false; } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java index 420f1f459f3..7ba671e62eb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java @@ -61,13 +61,13 @@ public class GBDTOptimizer extends Optimizer { * @return the optimized expression */ private ExpressionNode optimize(ExpressionNode node, ContextIndex context) { - if (node instanceof ArithmeticNode) { - Iterator<ExpressionNode> childIt = ((ArithmeticNode)node).children().iterator(); + if (node instanceof OperationNode) { + Iterator<ExpressionNode> childIt = ((OperationNode)node).children().iterator(); ExpressionNode ret = optimize(childIt.next(), context); - Iterator<ArithmeticOperator> operIt = ((ArithmeticNode)node).operators().iterator(); + Iterator<Operator> operIt = ((OperationNode)node).operators().iterator(); while (childIt.hasNext() && operIt.hasNext()) { - ret = ArithmeticNode.resolve(ret, operIt.next(), optimize(childIt.next(), context)); + ret = OperationNode.resolve(ret, operIt.next(), optimize(childIt.next(), context)); } return ret; } @@ -114,15 +114,15 @@ public class GBDTOptimizer extends Optimizer { /** Consumes the if condition and return the size of the values resulting, for convenience */ private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) { - if (condition instanceof ComparisonNode) { - ComparisonNode comparison = (ComparisonNode)condition; - if (comparison.getOperator() == TruthOperator.SMALLER) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.getLeftCondition(), context)); - else if (comparison.getOperator() == TruthOperator.EQUAL) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.getLeftCondition(), context)); + if (isBinaryComparison(condition)) { + OperationNode comparison = (OperationNode)condition; + if (comparison.operators().get(0) == Operator.smaller) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context)); + else if (comparison.operators().get(0) == Operator.equal) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context)); else - throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.getOperator()); - values.add(toValue(comparison.getRightCondition())); + throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0)); + values.add(toValue(comparison.children().get(1))); } else if (condition instanceof SetMembershipNode) { SetMembershipNode setMembership = (SetMembershipNode)condition; @@ -131,17 +131,15 @@ public class GBDTOptimizer extends Optimizer { for (ExpressionNode setElementNode : setMembership.getSetValues()) values.add(toValue(setElementNode)); } - else if (condition instanceof NotNode) { // handle if inversion: !(a >= b) - NotNode notNode = (NotNode)condition; - if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode) { - EmbracedNode embracedNode = (EmbracedNode)notNode.children().get(0); - if (embracedNode.children().size() == 1 && embracedNode.children().get(0) instanceof ComparisonNode) { - ComparisonNode comparison = (ComparisonNode)embracedNode.children().get(0); - if (comparison.getOperator() == TruthOperator.LARGEREQUAL) - values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.getLeftCondition(), context)); + else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b) + if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) { + if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) { + OperationNode comparison = (OperationNode)embracedNode.children().get(0); + if (comparison.operators().get(0) == Operator.largerOrEqual) + values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context)); else - throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.getOperator()); - values.add(toValue(comparison.getRightCondition())); + throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0)); + values.add(toValue(comparison.children().get(1))); } } } @@ -152,12 +150,24 @@ public class GBDTOptimizer extends Optimizer { return values.size(); } + private boolean isBinaryComparison(ExpressionNode condition) { + if ( ! (condition instanceof OperationNode binaryNode)) return false; + if (binaryNode.operators().size() != 1) return false; + if (binaryNode.operators().get(0) == Operator.largerOrEqual) return true; + if (binaryNode.operators().get(0) == Operator.larger) return true; + if (binaryNode.operators().get(0) == Operator.smallerOrEqual) return true; + if (binaryNode.operators().get(0) == Operator.smaller) return true; + if (binaryNode.operators().get(0) == Operator.approxEqual) return true; + if (binaryNode.operators().get(0) == Operator.notEqual) return true; + if (binaryNode.operators().get(0) == Operator.equal) return true; + return false; + } + private double getVariableIndex(ExpressionNode node, ContextIndex context) { - if (!(node instanceof ReferenceNode)) { + if (!(node instanceof ReferenceNode fNode)) { throw new IllegalArgumentException("Contained a left-hand comparison expression " + "which was not a feature value but was: " + node); } - ReferenceNode fNode = (ReferenceNode)node; Integer index = context.getIndex(fNode.toString()); if (index == null) { throw new IllegalStateException("The ranking expression contained feature '" + fNode.getName() + @@ -177,8 +187,7 @@ public class GBDTOptimizer extends Optimizer { value.getClass().getSimpleName() + " (" + value + ") in a set test: " + node); } - if (node instanceof NegativeNode) { - NegativeNode nNode = (NegativeNode)node; + if (node instanceof NegativeNode nNode) { if (!(nNode.getValue() instanceof ConstantNode)) { throw new IllegalArgumentException("Contained a negation of a non-number: " + nNode.getValue()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java deleted file mode 100644 index 959045a63a0..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java +++ /dev/null @@ -1,63 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.yahoo.searchlib.rankingexpression.evaluation.Value; - -import java.util.List; - -/** - * A mathematical operator - * - * @author bratseth - */ -public enum ArithmeticOperator { - - OR(0, "||") { public Value evaluate(Value x, Value y) { - return x.or(y); - }}, - AND(1, "&&") { public Value evaluate(Value x, Value y) { - return x.and(y); - }}, - PLUS(2, "+") { public Value evaluate(Value x, Value y) { - return x.add(y); - }}, - MINUS(3, "-") { public Value evaluate(Value x, Value y) { - return x.subtract(y); - }}, - MULTIPLY(4, "*") { public Value evaluate(Value x, Value y) { - return x.multiply(y); - }}, - DIVIDE(5, "/") { public Value evaluate(Value x, Value y) { - return x.divide(y); - }}, - MODULO(6, "%") { public Value evaluate(Value x, Value y) { - return x.modulo(y); - }}, - POWER(7, "^") { public Value evaluate(Value x, Value y) { - return x.power(y); - }}; - - /** A list of all the operators in this in order of decreasing precedence */ - public static final List<ArithmeticOperator> operatorsByPrecedence = List.of(POWER, MODULO, DIVIDE, MULTIPLY, MINUS, PLUS, AND, OR); - - private final int precedence; - private final String image; - - private ArithmeticOperator(int precedence, String image) { - this.precedence = precedence; - this.image = image; - } - - /** Returns true if this operator has precedence over the given operator */ - public boolean hasPrecedenceOver(ArithmeticOperator op) { - return precedence > op.precedence; - } - - public abstract Value evaluate(Value x, Value y); - - @Override - public String toString() { - return image; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java deleted file mode 100644 index e726a351f74..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ /dev/null @@ -1,69 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.Deque; -import java.util.List; -import java.util.Objects; - -/** - * A node which returns the outcome of a comparison. - * - * @author bratseth - */ -public class ComparisonNode extends BooleanNode { - - /** The operator string of this condition. */ - private final TruthOperator operator; - - private final List<ExpressionNode> conditions; - - public ComparisonNode(ExpressionNode leftCondition, TruthOperator operator, ExpressionNode rightCondition) { - conditions = List.of(leftCondition, rightCondition); - this.operator = operator; - } - - @Override - public List<ExpressionNode> children() { - return conditions; - } - - public TruthOperator getOperator() { return operator; } - - public ExpressionNode getLeftCondition() { return conditions.get(0); } - - public ExpressionNode getRightCondition() { return conditions.get(1); } - - @Override - public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { - getLeftCondition().toString(string, context, path, this).append(' ').append(operator).append(' '); - return getRightCondition().toString(string, context, path, this); - } - - @Override - public TensorType type(TypeContext<Reference> context) { - return TensorType.empty; // by definition - } - - @Override - public Value evaluate(Context context) { - Value leftValue = getLeftCondition().evaluate(context); - Value rightValue = getRightCondition().evaluate(context); - return leftValue.compare(operator,rightValue); - } - - @Override - public ComparisonNode setChildren(List<ExpressionNode> children) { - if (children.size() != 2) throw new IllegalArgumentException("A comparison test must have 2 children"); - return new ComparisonNode(children.get(0), operator, children.get(1)); - } - - @Override - public int hashCode() { return Objects.hash(operator, conditions); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index 9f07f146264..97e9a74f9c8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -106,10 +106,10 @@ public class LambdaFunctionNode extends CompositeNode { } private Optional<DoubleBinaryOperator> getDirectEvaluator() { - if ( ! (functionExpression instanceof ArithmeticNode)) { + if ( ! (functionExpression instanceof OperationNode)) { return Optional.empty(); } - ArithmeticNode node = (ArithmeticNode) functionExpression; + OperationNode node = (OperationNode) functionExpression; if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) { return Optional.empty(); } @@ -121,16 +121,16 @@ public class LambdaFunctionNode extends CompositeNode { if (node.operators().size() != 1) { return Optional.empty(); } - ArithmeticOperator operator = node.operators().get(0); + Operator operator = node.operators().get(0); switch (operator) { - case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); - case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); - case PLUS: return asFunctionExpression((left, right) -> left + right); - case MINUS: return asFunctionExpression((left, right) -> left - right); - case MULTIPLY: return asFunctionExpression((left, right) -> left * right); - case DIVIDE: return asFunctionExpression((left, right) -> left / right); - case MODULO: return asFunctionExpression((left, right) -> left % right); - case POWER: return asFunctionExpression(Math::pow); + case or: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0); + case and: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0); + case plus: return asFunctionExpression((left, right) -> left + right); + case minus: return asFunctionExpression((left, right) -> left - right); + case multiply: return asFunctionExpression((left, right) -> left * right); + case divide: return asFunctionExpression((left, right) -> left / right); + case modulo: return asFunctionExpression((left, right) -> left % right); + case power: return asFunctionExpression(Math::pow); } return Optional.empty(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java index c3e39197316..0512e1dad2f 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java @@ -16,26 +16,26 @@ import java.util.List; import java.util.Objects; /** - * A binary mathematical operation + * A sequence of binary operations. * * @author bratseth */ -public final class ArithmeticNode extends CompositeNode { +public final class OperationNode extends CompositeNode { private final List<ExpressionNode> children; - private final List<ArithmeticOperator> operators; + private final List<Operator> operators; - public ArithmeticNode(List<ExpressionNode> children, List<ArithmeticOperator> operators) { + public OperationNode(List<ExpressionNode> children, List<Operator> operators) { this.children = List.copyOf(children); this.operators = List.copyOf(operators); } - public ArithmeticNode(ExpressionNode leftExpression, ArithmeticOperator operator, ExpressionNode rightExpression) { + public OperationNode(ExpressionNode leftExpression, Operator operator, ExpressionNode rightExpression) { this.children = List.of(leftExpression, rightExpression); this.operators = List.of(operator); } - public List<ArithmeticOperator> operators() { return operators; } + public List<Operator> operators() { return operators; } @Override public List<ExpressionNode> children() { return children; } @@ -50,7 +50,7 @@ public final class ArithmeticNode extends CompositeNode { child.next().toString(string, context, path, this); if (child.hasNext()) string.append(" "); - for (Iterator<ArithmeticOperator> op = operators.iterator(); op.hasNext() && child.hasNext();) { + for (Iterator<Operator> op = operators.iterator(); op.hasNext() && child.hasNext();) { string.append(op.next().toString()).append(" "); child.next().toString(string, context, path, this); if (op.hasNext()) @@ -68,14 +68,14 @@ public final class ArithmeticNode extends CompositeNode { */ private boolean nonDefaultPrecedence(CompositeNode parent) { if ( parent == null) return false; - if ( ! (parent instanceof ArithmeticNode arithmeticParent)) return false; + if ( ! (parent instanceof OperationNode operationParent)) return false; // The line below can only be correct in both only have one operator. // Getting this correct is impossible without more work. // So for now we only handle the simple case correctly, and use a safe approach by adding // extra parenthesis just in case.... - return arithmeticParent.operators.get(0).hasPrecedenceOver(this.operators.get(0)) - || ((arithmeticParent.operators.size() > 1) || (operators.size() > 1)); + return operationParent.operators.get(0).hasPrecedenceOver(this.operators.get(0)) + || ((operationParent.operators.size() > 1) || (operators.size() > 1)); } @Override @@ -96,8 +96,8 @@ public final class ArithmeticNode extends CompositeNode { // Apply in precedence order: Deque<ValueItem> stack = new ArrayDeque<>(); stack.push(new ValueItem(null, child.next().evaluate(context))); - for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) { - ArithmeticOperator op = it.next(); + for (Iterator<Operator> it = operators.iterator(); it.hasNext() && child.hasNext();) { + Operator op = it.next(); if ( ! stack.isEmpty()) { while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) { popStack(stack); @@ -121,30 +121,30 @@ public final class ArithmeticNode extends CompositeNode { public CompositeNode setChildren(List<ExpressionNode> newChildren) { if (children.size() != newChildren.size()) throw new IllegalArgumentException("Expected " + children.size() + " children but got " + newChildren.size()); - return new ArithmeticNode(newChildren, operators); + return new OperationNode(newChildren, operators); } @Override public int hashCode() { return Objects.hash(children, operators); } - public static ArithmeticNode resolve(ExpressionNode left, ArithmeticOperator op, ExpressionNode right) { - if ( ! (left instanceof ArithmeticNode leftArithmetic)) return new ArithmeticNode(left, op, right); + public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) { + if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right); List<ExpressionNode> newChildren = new ArrayList<>(leftArithmetic.children()); newChildren.add(right); - List<ArithmeticOperator> newOperators = new ArrayList<>(leftArithmetic.operators()); + List<Operator> newOperators = new ArrayList<>(leftArithmetic.operators()); newOperators.add(op); - return new ArithmeticNode(newChildren, newOperators); + return new OperationNode(newChildren, newOperators); } private static class ValueItem { - final ArithmeticOperator op; + final Operator op; Value value; - public ValueItem(ArithmeticOperator op, Value value) { + public ValueItem(Operator op, Value value) { this.op = op; this.value = value; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java new file mode 100644 index 00000000000..02af88b2c58 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java @@ -0,0 +1,67 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Arrays; +import java.util.List; +import java.util.function.BiFunction; + +/** + * A mathematical operator + * + * @author bratseth + */ +public enum Operator { + + // In order from lowest to highest precedence + or("||", (x, y) -> x.or(y)), + and("&&", (x, y) -> x.and(y)), + largerOrEqual(">=", (x, y) -> x.largerOrEqual(y)), + larger(">", (x, y) -> x.larger(y)), + smallerOrEqual("<=", (x, y) -> x.smallerOrEqual(y)), + smaller("<", (x, y) -> x.smaller(y)), + approxEqual("~=", (x, y) -> x.approxEqual(y)), + notEqual("!=", (x, y) -> x.notEqual(y)), + equal("==", (x, y) -> x.equal(y)), + plus("+", (x, y) -> x.add(y)), + minus("-", (x, y) -> x.subtract(y)), + multiply("*", (x, y) -> x.multiply(y)), + divide("/", (x, y) -> x.divide(y)), + modulo("%", (x, y) -> x.modulo(y)), + power("^", true, (x, y) -> x.power(y)); + + /** A list of all the operators in this in order of increasing precedence */ + public static final List<Operator> operatorsByPrecedence = Arrays.stream(Operator.values()).toList(); + + private final String image; + private final boolean rightPrecedence; + private final BiFunction<Value, Value, Value> function; + + Operator(String image, BiFunction<Value, Value, Value> function) { + this(image, false, function); + } + + Operator(String image, boolean rightPrecedence, BiFunction<Value, Value, Value> function) { + this.image = image; + this.rightPrecedence = rightPrecedence; + this.function = function; + } + + /** Returns true if this operator has precedence over the given operator */ + public boolean hasPrecedenceOver(Operator other) { + if (operatorsByPrecedence.indexOf(this) == operatorsByPrecedence.indexOf(other)) + return rightPrecedence; + return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(other); + } + + public final Value evaluate(Value x, Value y) { + return function.apply(x, y); + } + + @Override + public String toString() { + return image; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java deleted file mode 100644 index fc259867923..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TruthOperator.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.rule; - -import java.io.Serializable; - -/** - * A mathematical operator - * - * @author bratseth - */ -public enum TruthOperator implements Serializable { - - SMALLER("<") { public boolean evaluate(double x, double y) { return x<y; } }, - SMALLEREQUAL("<=") { public boolean evaluate(double x, double y) { return x<=y; } }, - EQUAL("==") { public boolean evaluate(double x, double y) { return x==y; } }, - APPROX_EQUAL("~=") { public boolean evaluate(double x, double y) { return approxEqual(x,y); } }, - LARGER(">") { public boolean evaluate(double x, double y) { return x>y; } }, - LARGEREQUAL(">=") { public boolean evaluate(double x, double y) { return x>=y; } }, - NOTEQUAL("!=") { public boolean evaluate(double x, double y) { return x!=y; } }; - - private final String operatorString; - - TruthOperator(String operatorString) { - this.operatorString = operatorString; - } - - /** Perform the truth operation on the input */ - public abstract boolean evaluate(double x, double y); - - @Override - public String toString() { return operatorString; } - - public static TruthOperator fromString(String string) { - for (TruthOperator operator : values()) - if (operator.toString().equals(string)) - return operator; - throw new IllegalArgumentException("Illegal truth operator '" + string + "'"); - } - - private static boolean approxEqual(double x,double y) { - if (y < -1.0 || y > 1.0) { - x = Math.nextAfter(x/y, 1.0); - y = 1.0; - } else { - x = Math.nextAfter(x, y); - } - return x==y; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index 90861e64164..1c1f7509ce8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -4,9 +4,8 @@ package com.yahoo.searchlib.rankingexpression.transform; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; @@ -34,8 +33,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { node = transformIf((IfNode) node); if (node instanceof EmbracedNode e && hasSingleUndividableChild(e)) node = e.children().get(0); - if (node instanceof ArithmeticNode) - node = transformArithmetic((ArithmeticNode) node); + if (node instanceof OperationNode) + node = transformArithmetic((OperationNode) node); if (node instanceof NegativeNode) node = transformNegativeNode((NegativeNode) node); return node; @@ -43,19 +42,18 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { private boolean hasSingleUndividableChild(EmbracedNode node) { if (node.children().size() > 1) return false; - if (node.children().get(0) instanceof ArithmeticNode) return false; - if (node.children().get(0) instanceof ComparisonNode) return false; + if (node.children().get(0) instanceof OperationNode) return false; return true; } - private ExpressionNode transformArithmetic(ArithmeticNode node) { + private ExpressionNode transformArithmetic(OperationNode node) { // Fold the subset of expressions that are constant (such that in "1 + 2 + var") if (node.children().size() > 1) { List<ExpressionNode> children = new ArrayList<>(node.children()); - List<ArithmeticOperator> operators = new ArrayList<>(node.operators()); - for (ArithmeticOperator operator : ArithmeticOperator.operatorsByPrecedence) + List<Operator> operators = new ArrayList<>(node.operators()); + for (Operator operator : Operator.operatorsByPrecedence) transform(operator, children, operators); - node = new ArithmeticNode(children, operators); + node = new OperationNode(children, operators); } if (isConstant(node) && ! node.evaluate(null).isNaN()) @@ -66,8 +64,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { return node; } - private void transform(ArithmeticOperator operatorToTransform, - List<ExpressionNode> children, List<ArithmeticOperator> operators) { + private void transform(Operator operatorToTransform, + List<ExpressionNode> children, List<Operator> operators) { int i = 0; while (i < children.size()-1) { boolean transformed = false; @@ -75,7 +73,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { ExpressionNode child1 = children.get(i); ExpressionNode child2 = children.get(i + 1); if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) { - Value evaluated = new ArithmeticNode(child1, operators.get(i), child2).evaluate(null); + Value evaluated = new OperationNode(child1, operators.get(i), child2).evaluate(null); if ( ! evaluated.isNaN()) { // Don't replace by NaN operators.remove(i); children.set(i, new ConstantNode(evaluated.freeze())); @@ -94,7 +92,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { * This check works because we simplify by decreasing precedence, so neighbours will either be single constant values * or a more complex expression that can't be simplified and hence also prevents the simplification in question here. */ - private boolean hasPrecedence(List<ArithmeticOperator> operators, int i) { + private boolean hasPrecedence(List<Operator> operators, int i) { if (i > 0 && operators.get(i-1).hasPrecedenceOver(operators.get(i))) return false; if (i < operators.size()-1 && operators.get(i+1).hasPrecedenceOver(operators.get(i))) return false; return true; @@ -116,14 +114,14 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { return new ConstantNode(constant.getValue().negate() ); } - private boolean allMultiplicationOrDivision(ArithmeticNode node) { - for (ArithmeticOperator o : node.operators()) - if (o != ArithmeticOperator.MULTIPLY && o != ArithmeticOperator.DIVIDE) + private boolean allMultiplicationOrDivision(OperationNode node) { + for (Operator o : node.operators()) + if (o != Operator.multiply && o != Operator.divide) return false; return true; } - private boolean hasZero(ArithmeticNode node) { + private boolean hasZero(OperationNode node) { for (ExpressionNode child : node.children()) { if (isZero(child)) return true; @@ -131,9 +129,9 @@ public class Simplifier extends ExpressionTransformer<TransformContext> { return false; } - private boolean hasDivisionByZero(ArithmeticNode node) { + private boolean hasDivisionByZero(OperationNode node) { for (int i = 1; i < node.children().size(); i++) { - if ( node.operators().get(i - 1) == ArithmeticOperator.DIVIDE && isZero(node.children().get(i))) + if (node.operators().get(i - 1) == Operator.divide && isZero(node.children().get(i))) return true; } return false; diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index ebe1e048247..42b5f2c191a 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -72,13 +72,13 @@ TOKEN : <COMMA: ","> | <COLON: ":"> | - <LE: "<="> | - <LT: "<"> | - <EQ: "=="> | - <NQ: "!="> | - <AQ: "~="> | - <GE: ">="> | - <GT: ">"> | + <GREATEREQUAL: ">="> | + <GREATER: ">"> | + <LESSEQUAL: "<="> | + <LESS: "<"> | + <APPROX: "~="> | + <NOTEQUAL: "!="> | + <EQUAL: "=="> | <STRING: ("\"" (~["\""] | "\\\"")* "\"") | ("'" (~["'"] | "\\'")* "'")> | @@ -188,55 +188,50 @@ ExpressionNode expression() : { ExpressionNode left, right; List<ExpressionNode> rightList; - TruthOperator comparatorOp; } { - ( left = arithmeticExpression() + ( left = operationExpression() ( - ( comparatorOp = comparator() right = arithmeticExpression() { left = new ComparisonNode(left, comparatorOp, right); } ) | ( <IN> rightList = expressionList() { left = new SetMembershipNode(left, rightList); } ) - ) * + ) ? ) { return left; } } -ExpressionNode arithmeticExpression() : +ExpressionNode operationExpression() : { ExpressionNode left, right = null; - ArithmeticOperator arithmeticOp; + Operator operator; } { ( left = value() - ( arithmeticOp = arithmetic() right = value() { left = ArithmeticNode.resolve(left, arithmeticOp, right); } ) * + ( operator = binaryOperator() right = value() { left = OperationNode.resolve(left, operator, right); } ) * ) { return left; } } -ArithmeticOperator arithmetic() : { } +Operator binaryOperator() : { } { - ( <ADD> { return ArithmeticOperator.PLUS; } | - <SUB> { return ArithmeticOperator.MINUS; } | - <DIV> { return ArithmeticOperator.DIVIDE; } | - <MUL> { return ArithmeticOperator.MULTIPLY; } | - <MOD> { return ArithmeticOperator.MODULO; } | - <AND> { return ArithmeticOperator.AND; } | - <OR> { return ArithmeticOperator.OR; } | - <POWOP> { return ArithmeticOperator.POWER; } ) + ( + <OR> { return Operator.or; } | + <AND> { return Operator.and; } | + <GREATEREQUAL> { return Operator.largerOrEqual; } | + <GREATER> { return Operator.larger; } | + <LESSEQUAL> { return Operator.smallerOrEqual; } | + <LESS> { return Operator.smaller; } | + <APPROX> { return Operator.approxEqual; } | + <NOTEQUAL> { return Operator.notEqual; } | + <EQUAL> { return Operator.equal; } | + <ADD> { return Operator.plus; } | + <SUB> { return Operator.minus; } | + <DIV> { return Operator.divide; } | + <MUL> { return Operator.multiply; } | + <MOD> { return Operator.modulo; } | + <POWOP> { return Operator.power; } + ) { return null; } } -TruthOperator comparator() : { } -{ - ( <LE> { return TruthOperator.SMALLEREQUAL; } | - <LT> { return TruthOperator.SMALLER; } | - <EQ> { return TruthOperator.EQUAL; } | - <NQ> { return TruthOperator.NOTEQUAL; } | - <AQ> { return TruthOperator.APPROX_EQUAL; } | - <GE> { return TruthOperator.LARGEREQUAL; } | - <GT> { return TruthOperator.LARGER; } ) - { return null; } -} - ExpressionNode value() : { ExpressionNode value; @@ -665,7 +660,7 @@ TensorType.Value optionalTensorValueTypeParameter() : String valueType = "double"; } { - ( <LT> valueType = identifier() <GT> )? + ( <LESS> valueType = identifier() <GREATER> )? { return TensorType.Value.fromId(valueType); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 1ab9ee11252..f18240c3222 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -2,8 +2,8 @@ package com.yahoo.searchlib.rankingexpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -66,7 +66,7 @@ public class RankingExpressionTestCase { public void testProgrammaticBuilding() throws ParseException { ReferenceNode input = new ReferenceNode("input"); ReferenceNode constant = new ReferenceNode("constant"); - ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant); + OperationNode product = new OperationNode(input, Operator.multiply, constant); Reduce<Reference> sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum); RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index ad50a423eb9..acab7bb38c6 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -5,8 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.Arguments; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.OperationNode; +import com.yahoo.searchlib.rankingexpression.rule.Operator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; @@ -88,6 +88,7 @@ public class EvaluationTestCase { tester.assertEvaluates(10.0, "3 ^ 2 + 1"); tester.assertEvaluates(18.0, "2 * 3 ^ 2"); tester.assertEvaluates(-4, "1 - 2 - 3"); // Means 1 + -2 + -3 + tester.assertEvaluates(Math.pow(4, 9), "4^3^2"); // Right precedence, by 51% majority // Conditionals tester.assertEvaluates(2 * (3 * 4 + 3) * (4 * 5 - 4 * 200) / 10, "2*(3*4+3)*(4*5-4*200)/10"); @@ -185,6 +186,11 @@ public class EvaluationTestCase { "map(tensor0, f(x) (log10(x)))", "{ {d1:0}:10, {d1:1}:100, {d1:2}:1000 }"); tester.assertEvaluates("{ {d1:0}:4, {d1:1}:9, {d1:2 }:16 }", "map(tensor0, f(x) (x * x))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + // -- tensor map shorthands + tester.assertEvaluates("{ {d1:0}:0, {d1:1}:1, {d1:2 }:0 }", + "tensor0 == 3", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:0, {d1:1}:1, {d1:2 }:0 }", + "3 == tensor0", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); // -- tensor map composites tester.assertEvaluates("{ {d1:0}:1, {d1:1}:2, {d1:2 }:3 }", "log10(tensor0)", "{ {d1:0}:10, {d1:1}:100, {d1:2}:1000 }"); @@ -198,9 +204,9 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", "tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", - "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + "(tensor0 || 1) == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", - "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + "(tensor0 && 1) == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); @@ -779,8 +785,8 @@ public class EvaluationTestCase { @Test public void testProgrammaticBuildingAndPrecedence() { - RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4)))); - RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4))); + RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.plus, new OperationNode(constant(3), Operator.multiply, constant(4)))); + RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.plus, constant(3)), Operator.multiply, constant(4))); assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance); assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance); assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString()); |