aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-09-28 17:42:55 +0200
committerJon Bratseth <bratseth@gmail.com>2022-09-28 17:42:55 +0200
commitbcbb2009c44380055b2670e7cdefcad232f9ece4 (patch)
treeb1f1716e6dcb7cb79dcb9871af21758cf6c0a5c2
parent3d49f155fccfa4fc08882b01e7a6e3a982c55212 (diff)
Drop 'arithmetic' from name
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java40
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java20
-rw-r--r--config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java34
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java6
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java)42
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java)12
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java38
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java8
18 files changed, 150 insertions, 144 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
index 8fa4b469590..9ffc73e1863 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformer.java
@@ -2,8 +2,8 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -35,20 +35,20 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
if (node instanceof CompositeNode composite)
node = transformChildren(composite, context);
- if (node instanceof ArithmeticNode arithmetic)
+ if (node instanceof OperationNode arithmetic)
node = transformBooleanArithmetics(arithmetic);
return node;
}
- private ExpressionNode transformBooleanArithmetics(ArithmeticNode node) {
+ private ExpressionNode transformBooleanArithmetics(OperationNode node) {
Iterator<ExpressionNode> child = node.children().iterator();
// Transform in precedence order:
Deque<ChildNode> stack = new ArrayDeque<>();
stack.push(new ChildNode(null, child.next()));
- for (Iterator<ArithmeticOperator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) {
- ArithmeticOperator op = it.next();
+ for (Iterator<Operator> it = node.operators().iterator(); it.hasNext() && child.hasNext();) {
+ Operator op = it.next();
if ( ! stack.isEmpty()) {
while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
popStack(stack);
@@ -66,9 +66,9 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
ChildNode lhs = stack.peek();
ExpressionNode combination;
- if (rhs.op == ArithmeticOperator.AND)
+ if (rhs.op == Operator.AND)
combination = andByIfNode(lhs.child, rhs.child);
- else if (rhs.op == ArithmeticOperator.OR)
+ else if (rhs.op == Operator.OR)
combination = orByIfNode(lhs.child, rhs.child);
else {
combination = resolve(lhs, rhs);
@@ -77,28 +77,28 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
lhs.child = combination;
}
- private static ArithmeticNode resolve(ChildNode left, ChildNode right) {
- if ( ! (left.child instanceof ArithmeticNode) && ! (right.child instanceof ArithmeticNode))
- return new ArithmeticNode(left.child, right.op, right.child);
+ private static OperationNode resolve(ChildNode left, ChildNode right) {
+ if (! (left.child instanceof OperationNode) && ! (right.child instanceof OperationNode))
+ return new OperationNode(left.child, right.op, right.child);
// Collapse inserted ArithmeticNodes
- List<ArithmeticOperator> joinedOps = new ArrayList<>();
+ List<Operator> joinedOps = new ArrayList<>();
joinOps(left, joinedOps);
joinedOps.add(right.op);
joinOps(right, joinedOps);
List<ExpressionNode> joinedChildren = new ArrayList<>();
joinChildren(left, joinedChildren);
joinChildren(right, joinedChildren);
- return new ArithmeticNode(joinedChildren, joinedOps);
+ return new OperationNode(joinedChildren, joinedOps);
}
- private static void joinOps(ChildNode node, List<ArithmeticOperator> joinedOps) {
- if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode)
- joinedOps.addAll(arithmeticNode.operators());
+ private static void joinOps(ChildNode node, List<Operator> joinedOps) {
+ if (node.artificial && node.child instanceof OperationNode operationNode)
+ joinedOps.addAll(operationNode.operators());
}
private static void joinChildren(ChildNode node, List<ExpressionNode> joinedChildren) {
- if (node.artificial && node.child instanceof ArithmeticNode arithmeticNode)
- joinedChildren.addAll(arithmeticNode.children());
+ if (node.artificial && node.child instanceof OperationNode operationNode)
+ joinedChildren.addAll(operationNode.children());
else
joinedChildren.add(node.child);
}
@@ -115,11 +115,11 @@ public class BooleanExpressionTransformer extends ExpressionTransformer<Transfor
/** A child with the operator to be applied to it when combining it with the previous child. */
private static class ChildNode {
- final ArithmeticOperator op;
+ final Operator op;
ExpressionNode child;
boolean artificial;
- public ChildNode(ArithmeticOperator op, ExpressionNode child) {
+ public ChildNode(Operator op, ExpressionNode child) {
this.op = op;
this.child = child;
}
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
index 2695fa79588..dd714501a41 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java
@@ -3,8 +3,8 @@ package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -139,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence);
ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
ExpressionNode expr = new IfNode(
- new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, queryLengthExpr),
+ new OperationNode(new ReferenceNode("d1"), Operator.LESS, queryLengthExpr),
ZERO,
new IfNode(
- new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, restLengthExpr),
+ new OperationNode(new ReferenceNode("d1"), Operator.LESS, restLengthExpr),
ONE,
ZERO
)
@@ -174,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
List<ExpressionNode> tokenSequence = createTokenSequence(feature);
ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence);
- ArithmeticNode comparison = new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, lengthExpr);
+ OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr);
ExpressionNode expr = new IfNode(comparison, ONE, ZERO);
return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr)));
}
@@ -254,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*/
private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) {
ExpressionNode lengthExpr = createLengthExpr(iter, sequence);
- ArithmeticNode comparison = new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, lengthExpr);
+ OperationNode comparison = new OperationNode(new ReferenceNode("d1"), Operator.LESS, lengthExpr);
ExpressionNode trueExpr = sequence.get(iter);
if (sequence.get(iter) instanceof ReferenceNode) {
@@ -278,7 +278,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
*/
private ExpressionNode createLengthExpr(int iter, List<ExpressionNode> sequence) {
List<ExpressionNode> factors = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
+ List<Operator> operators = new ArrayList<>();
for (int i = 0; i < iter + 1; ++i) {
if (sequence.get(i) instanceof ConstantNode) {
factors.add(ONE);
@@ -286,10 +286,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
factors.add(new ReferenceNode(lengthFunctionName((ReferenceNode) sequence.get(i))));
}
if (i >= 1) {
- operators.add(ArithmeticOperator.PLUS);
+ operators.add(Operator.PLUS);
}
}
- return new ArithmeticNode(factors, operators);
+ return new OperationNode(factors, operators);
}
/**
@@ -299,7 +299,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
ExpressionNode expr;
if (iter >= 1) {
ExpressionNode lengthExpr = new EmbracedNode(createLengthExpr(iter - 1, sequence));
- expr = new EmbracedNode(new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.MINUS, lengthExpr));
+ expr = new EmbracedNode(new OperationNode(new ReferenceNode("d1"), Operator.MINUS, lengthExpr));
} else {
expr = new ReferenceNode("d1");
}
diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
index 65d12f1cdcf..d692b69d3c8 100644
--- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java
@@ -4,7 +4,7 @@ package com.yahoo.schema.expressiontransforms;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
import org.junit.jupiter.api.Test;
@@ -43,8 +43,8 @@ public class BooleanExpressionTransformerTestCase {
var expr = new BooleanExpressionTransformer()
.transform(new RankingExpression("a + b + c * d + e + f"),
new TransformContext(Map.of(), new MapTypeContext()));
- assertTrue(expr.getRoot() instanceof ArithmeticNode);
- ArithmeticNode root = (ArithmeticNode) expr.getRoot();
+ assertTrue(expr.getRoot() instanceof OperationNode);
+ OperationNode root = (OperationNode) expr.getRoot();
assertEquals(5, root.operators().size());
assertEquals(6, root.children().size());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index cd0c4da6d0f..c66022975c7 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -92,7 +92,7 @@ public class Gather extends IntermediateOperation {
ExpressionNode indexExpression = new ConstantNode(new DoubleValue(constantValue));
if (constantValue < 0) {
ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
- indexExpression = new EmbracedNode(new ArithmeticNode(indexExpression, ArithmeticOperator.PLUS, axisSize));
+ indexExpression = new EmbracedNode(new OperationNode(indexExpression, Operator.PLUS, axisSize));
}
addSliceDimension(dataSliceDimensions, dataType.dimensions().get(axis).name(), indexExpression);
} else {
@@ -125,8 +125,8 @@ public class Gather extends IntermediateOperation {
/** to support negative indexing */
private ExpressionNode createIndexExpression(OrderedTensorType dataType, ExpressionNode slice) {
ExpressionNode axisSize = new ConstantNode(new DoubleValue(dataType.dimensions().get(axis).size().get()));
- ExpressionNode plus = new EmbracedNode(new ArithmeticNode(slice, ArithmeticOperator.PLUS, axisSize));
- ExpressionNode mod = new ArithmeticNode(plus, ArithmeticOperator.MODULO, axisSize);
+ ExpressionNode plus = new EmbracedNode(new OperationNode(slice, Operator.PLUS, axisSize));
+ ExpressionNode mod = new OperationNode(plus, Operator.MODULO, axisSize);
return mod;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 1f447f2a575..81d633dea4b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.TensorType;
@@ -97,17 +97,17 @@ public class Gemm extends IntermediateOperation {
TensorFunction<Reference> AxB = new com.yahoo.tensor.functions.Matmul<>(aFunction.get(), bFunction.get(), joinDimension);
TensorFunction<Reference> alphaxAxB = new TensorFunctionNode.ExpressionTensorFunction(
- new ArithmeticNode(
+ new OperationNode(
new TensorFunctionNode(AxB),
- ArithmeticOperator.MULTIPLY,
+ Operator.MULTIPLY,
new ConstantNode(new DoubleValue(alpha))));
if (inputs.size() == 3) {
Optional<TensorFunction<Reference>> cFunction = inputs.get(2).function();
TensorFunction<Reference> betaxC = new TensorFunctionNode.ExpressionTensorFunction(
- new ArithmeticNode(
+ new OperationNode(
new TensorFunctionNode(cFunction.get()),
- ArithmeticOperator.MULTIPLY,
+ Operator.MULTIPLY,
new ConstantNode(new DoubleValue(beta))));
return new com.yahoo.tensor.functions.Join<>(alphaxAxB, betaxC, ScalarFunctions.add());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
index 5c4e8cd6cd0..66e810b954e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -65,8 +65,8 @@ public class Range extends IntermediateOperation {
ExpressionNode startExpr = new ConstantNode(new DoubleValue(start));
ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta));
ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName));
- ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr);
- ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr);
+ ExpressionNode stepExpr = new OperationNode(deltaExpr, Operator.MULTIPLY, dimExpr);
+ ExpressionNode addExpr = new OperationNode(startExpr, Operator.PLUS, stepExpr);
TensorFunction<Reference> function = Generate.bound(type.type(), wrapScalar(addExpr));
return function;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 7b675fa79af..ce93461bff3 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -6,13 +6,11 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.Function;
-import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
@@ -159,13 +157,13 @@ public class Reshape extends IntermediateOperation {
inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
} else if (dim == (inputType.rank() - 1)) {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
- ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size);
+ ExpressionNode div = new OperationNode(unrolled, Operator.MODULO, size);
inputDimensionExpression = new EmbracedNode(div);
} else {
ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
- ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize);
- ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size);
+ ExpressionNode mod = new OperationNode(unrolled, Operator.MODULO, previousSize);
+ ExpressionNode div = new OperationNode(new EmbracedNode(mod), Operator.DIVIDE, size);
inputDimensionExpression = new EmbracedNode(div);
}
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
@@ -183,21 +181,21 @@ public class Reshape extends IntermediateOperation {
return new ConstantNode(DoubleValue.zero);
List<ExpressionNode> children = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
+ List<Operator> operators = new ArrayList<>();
int size = 1;
for (int i = type.dimensions().size() - 1; i >= 0; --i) {
TensorType.Dimension dimension = type.dimensions().get(i);
children.add(0, new ReferenceNode(dimension.name()));
if (size > 1) {
- operators.add(0, ArithmeticOperator.MULTIPLY);
+ operators.add(0, Operator.MULTIPLY);
children.add(0, new ConstantNode(new DoubleValue(size)));
}
size *= OrderedTensorType.dimensionSize(dimension);
if (i > 0) {
- operators.add(0, ArithmeticOperator.PLUS);
+ operators.add(0, Operator.PLUS);
}
}
- return new ArithmeticNode(children, operators);
+ return new OperationNode(children, operators);
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
index 91b7064b19c..617e1f00c94 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -6,8 +6,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -166,8 +166,8 @@ public class Slice extends IntermediateOperation {
// step * (d0 + start)
ExpressionNode reference = new ReferenceNode(outputDimensionName);
- ExpressionNode plus = new EmbracedNode(new ArithmeticNode(reference, ArithmeticOperator.PLUS, startIndex));
- ExpressionNode mul = new ArithmeticNode(stepSize, ArithmeticOperator.MULTIPLY, plus);
+ ExpressionNode plus = new EmbracedNode(new OperationNode(reference, Operator.PLUS, startIndex));
+ ExpressionNode mul = new OperationNode(stepSize, Operator.MULTIPLY, plus);
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mul))));
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
index 6f720716adb..42901259821 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
@@ -5,8 +5,8 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -96,7 +96,7 @@ public class Split extends IntermediateOperation {
for (int i = 0; i < inputType.rank(); ++i) {
String inputDimensionName = inputType.dimensions().get(i).name();
ExpressionNode reference = new ReferenceNode(inputDimensionName);
- ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
+ ExpressionNode offset = new OperationNode(reference, Operator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset))));
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
index 4bfab284cc2..cd88c625d81 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -4,8 +4,8 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -77,7 +77,7 @@ public class Tile extends IntermediateOperation {
ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
ExpressionNode reference = new ReferenceNode(inputDimensionName);
- ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size);
+ ExpressionNode mod = new OperationNode(reference, Operator.MODULO, size);
dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod))));
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
index 6ab483800a7..78acd2e5af1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestOptimizer.java
@@ -5,8 +5,8 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -84,12 +84,12 @@ public class GBDTForestOptimizer extends Optimizer {
currentTreesOptimized++;
return true;
}
- if (!(node instanceof ArithmeticNode)) {
+ if (!(node instanceof OperationNode)) {
return false;
}
- ArithmeticNode aNode = (ArithmeticNode)node;
- for (ArithmeticOperator op : aNode.operators()) {
- if (op != ArithmeticOperator.PLUS) {
+ OperationNode aNode = (OperationNode)node;
+ for (Operator op : aNode.operators()) {
+ if (op != Operator.PLUS) {
return false;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
index cf4c35d94af..b3cbe252dfc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTOptimizer.java
@@ -61,13 +61,13 @@ public class GBDTOptimizer extends Optimizer {
* @return the optimized expression
*/
private ExpressionNode optimize(ExpressionNode node, ContextIndex context) {
- if (node instanceof ArithmeticNode) {
- Iterator<ExpressionNode> childIt = ((ArithmeticNode)node).children().iterator();
+ if (node instanceof OperationNode) {
+ Iterator<ExpressionNode> childIt = ((OperationNode)node).children().iterator();
ExpressionNode ret = optimize(childIt.next(), context);
- Iterator<ArithmeticOperator> operIt = ((ArithmeticNode)node).operators().iterator();
+ Iterator<Operator> operIt = ((OperationNode)node).operators().iterator();
while (childIt.hasNext() && operIt.hasNext()) {
- ret = ArithmeticNode.resolve(ret, operIt.next(), optimize(childIt.next(), context));
+ ret = OperationNode.resolve(ret, operIt.next(), optimize(childIt.next(), context));
}
return ret;
}
@@ -115,10 +115,10 @@ public class GBDTOptimizer extends Optimizer {
/** Consumes the if condition and return the size of the values resulting, for convenience */
private int consumeIfCondition(ExpressionNode condition, List<Double> values, ContextIndex context) {
if (isBinaryComparison(condition)) {
- ArithmeticNode comparison = (ArithmeticNode)condition;
- if (comparison.operators().get(0) == ArithmeticOperator.LESS)
+ OperationNode comparison = (OperationNode)condition;
+ if (comparison.operators().get(0) == Operator.LESS)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*0 + getVariableIndex(comparison.children().get(0), context));
- else if (comparison.operators().get(0) == ArithmeticOperator.EQUAL)
+ else if (comparison.operators().get(0) == Operator.EQUAL)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*1 + getVariableIndex(comparison.children().get(0), context));
else
throw new IllegalArgumentException("Cannot optimize other conditions than < and ==, encountered: " + comparison.operators().get(0));
@@ -134,8 +134,8 @@ public class GBDTOptimizer extends Optimizer {
else if (condition instanceof NotNode notNode) { // handle if inversion: !(a >= b)
if (notNode.children().size() == 1 && notNode.children().get(0) instanceof EmbracedNode embracedNode) {
if (embracedNode.children().size() == 1 && isBinaryComparison(embracedNode.children().get(0))) {
- ArithmeticNode comparison = (ArithmeticNode)embracedNode.children().get(0);
- if (comparison.operators().get(0) == ArithmeticOperator.GREATEREQUAL)
+ OperationNode comparison = (OperationNode)embracedNode.children().get(0);
+ if (comparison.operators().get(0) == Operator.GREATEREQUAL)
values.add(GBDTNode.MAX_LEAF_VALUE + GBDTNode.MAX_VARIABLES*3 + getVariableIndex(comparison.children().get(0), context));
else
throw new IllegalArgumentException("Cannot optimize other conditions than >=, encountered: " + comparison.operators().get(0));
@@ -151,15 +151,15 @@ public class GBDTOptimizer extends Optimizer {
}
private boolean isBinaryComparison(ExpressionNode condition) {
- if ( ! (condition instanceof ArithmeticNode binaryNode)) return false;
+ if ( ! (condition instanceof OperationNode binaryNode)) return false;
if (binaryNode.operators().size() != 1) return false;
- if (binaryNode.operators().get(0) == ArithmeticOperator.GREATEREQUAL) return true;
- if (binaryNode.operators().get(0) == ArithmeticOperator.GREATER) return true;
- if (binaryNode.operators().get(0) == ArithmeticOperator.LESSEQUAL) return true;
- if (binaryNode.operators().get(0) == ArithmeticOperator.LESS) return true;
- if (binaryNode.operators().get(0) == ArithmeticOperator.APPROX) return true;
- if (binaryNode.operators().get(0) == ArithmeticOperator.NOTEQUAL) return true;
- if (binaryNode.operators().get(0) == ArithmeticOperator.EQUAL) return true;
+ if (binaryNode.operators().get(0) == Operator.GREATEREQUAL) return true;
+ if (binaryNode.operators().get(0) == Operator.GREATER) return true;
+ if (binaryNode.operators().get(0) == Operator.LESSEQUAL) return true;
+ if (binaryNode.operators().get(0) == Operator.LESS) return true;
+ if (binaryNode.operators().get(0) == Operator.APPROX) return true;
+ if (binaryNode.operators().get(0) == Operator.NOTEQUAL) return true;
+ if (binaryNode.operators().get(0) == Operator.EQUAL) return true;
return false;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
index 9f07f146264..3c7f48aa38c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java
@@ -106,10 +106,10 @@ public class LambdaFunctionNode extends CompositeNode {
}
private Optional<DoubleBinaryOperator> getDirectEvaluator() {
- if ( ! (functionExpression instanceof ArithmeticNode)) {
+ if ( ! (functionExpression instanceof OperationNode)) {
return Optional.empty();
}
- ArithmeticNode node = (ArithmeticNode) functionExpression;
+ OperationNode node = (OperationNode) functionExpression;
if ( ! (node.children().get(0) instanceof ReferenceNode) || ! (node.children().get(1) instanceof ReferenceNode)) {
return Optional.empty();
}
@@ -121,7 +121,7 @@ public class LambdaFunctionNode extends CompositeNode {
if (node.operators().size() != 1) {
return Optional.empty();
}
- ArithmeticOperator operator = node.operators().get(0);
+ Operator operator = node.operators().get(0);
switch (operator) {
case OR: return asFunctionExpression((left, right) -> ((left != 0.0) || (right != 0.0)) ? 1.0 : 0.0);
case AND: return asFunctionExpression((left, right) -> ((left != 0.0) && (right != 0.0)) ? 1.0 : 0.0);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
index c3e39197316..392f42f6cbe 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/OperationNode.java
@@ -16,26 +16,26 @@ import java.util.List;
import java.util.Objects;
/**
- * A binary mathematical operation
+ * A sequence of binary operations.
*
* @author bratseth
*/
-public final class ArithmeticNode extends CompositeNode {
+public final class OperationNode extends CompositeNode {
private final List<ExpressionNode> children;
- private final List<ArithmeticOperator> operators;
+ private final List<Operator> operators;
- public ArithmeticNode(List<ExpressionNode> children, List<ArithmeticOperator> operators) {
+ public OperationNode(List<ExpressionNode> children, List<Operator> operators) {
this.children = List.copyOf(children);
this.operators = List.copyOf(operators);
}
- public ArithmeticNode(ExpressionNode leftExpression, ArithmeticOperator operator, ExpressionNode rightExpression) {
+ public OperationNode(ExpressionNode leftExpression, Operator operator, ExpressionNode rightExpression) {
this.children = List.of(leftExpression, rightExpression);
this.operators = List.of(operator);
}
- public List<ArithmeticOperator> operators() { return operators; }
+ public List<Operator> operators() { return operators; }
@Override
public List<ExpressionNode> children() { return children; }
@@ -50,7 +50,7 @@ public final class ArithmeticNode extends CompositeNode {
child.next().toString(string, context, path, this);
if (child.hasNext())
string.append(" ");
- for (Iterator<ArithmeticOperator> op = operators.iterator(); op.hasNext() && child.hasNext();) {
+ for (Iterator<Operator> op = operators.iterator(); op.hasNext() && child.hasNext();) {
string.append(op.next().toString()).append(" ");
child.next().toString(string, context, path, this);
if (op.hasNext())
@@ -68,7 +68,7 @@ public final class ArithmeticNode extends CompositeNode {
*/
private boolean nonDefaultPrecedence(CompositeNode parent) {
if ( parent == null) return false;
- if ( ! (parent instanceof ArithmeticNode arithmeticParent)) return false;
+ if ( ! (parent instanceof OperationNode arithmeticParent)) return false;
// The line below can only be correct in both only have one operator.
// Getting this correct is impossible without more work.
@@ -96,8 +96,8 @@ public final class ArithmeticNode extends CompositeNode {
// Apply in precedence order:
Deque<ValueItem> stack = new ArrayDeque<>();
stack.push(new ValueItem(null, child.next().evaluate(context)));
- for (Iterator<ArithmeticOperator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
- ArithmeticOperator op = it.next();
+ for (Iterator<Operator> it = operators.iterator(); it.hasNext() && child.hasNext();) {
+ Operator op = it.next();
if ( ! stack.isEmpty()) {
while (stack.size() > 1 && ! op.hasPrecedenceOver(stack.peek().op)) {
popStack(stack);
@@ -121,30 +121,38 @@ public final class ArithmeticNode extends CompositeNode {
public CompositeNode setChildren(List<ExpressionNode> newChildren) {
if (children.size() != newChildren.size())
throw new IllegalArgumentException("Expected " + children.size() + " children but got " + newChildren.size());
- return new ArithmeticNode(newChildren, operators);
+ return new OperationNode(newChildren, operators);
}
@Override
public int hashCode() { return Objects.hash(children, operators); }
- public static ArithmeticNode resolve(ExpressionNode left, ArithmeticOperator op, ExpressionNode right) {
- if ( ! (left instanceof ArithmeticNode leftArithmetic)) return new ArithmeticNode(left, op, right);
+ @Override
+ public boolean equals(Object o) {
+ if ( ! (o instanceof OperationNode other)) return false;
+ if ( ! this.children().equals(other.children())) return false;
+ if ( ! this.operators().equals(other.operators())) return false;
+ return true;
+ }
+
+ public static OperationNode resolve(ExpressionNode left, Operator op, ExpressionNode right) {
+ if ( ! (left instanceof OperationNode leftArithmetic)) return new OperationNode(left, op, right);
List<ExpressionNode> newChildren = new ArrayList<>(leftArithmetic.children());
newChildren.add(right);
- List<ArithmeticOperator> newOperators = new ArrayList<>(leftArithmetic.operators());
+ List<Operator> newOperators = new ArrayList<>(leftArithmetic.operators());
newOperators.add(op);
- return new ArithmeticNode(newChildren, newOperators);
+ return new OperationNode(newChildren, newOperators);
}
private static class ValueItem {
- final ArithmeticOperator op;
+ final Operator op;
Value value;
- public ValueItem(ArithmeticOperator op, Value value) {
+ public ValueItem(Operator op, Value value) {
this.op = op;
this.value = value;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
index 435c92ff7da..4ddbfa4ea9f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Operator.java
@@ -12,7 +12,7 @@ import java.util.function.BiFunction;
*
* @author bratseth
*/
-public enum ArithmeticOperator {
+public enum Operator {
// In order from lowest to highest precedence
OR("||", (x, y) -> x.or(y)),
@@ -31,25 +31,25 @@ public enum ArithmeticOperator {
MODULO("%", (x, y) -> x.modulo(y)),
POWER("^", true, (x, y) -> x.power(y));
- /** A list of all the operators in this in order of decreasing precedence */
- public static final List<ArithmeticOperator> operatorsByPrecedence = Arrays.stream(ArithmeticOperator.values()).toList();
+ /** A list of all the operators in this in order of increasing precedence */
+ public static final List<Operator> operatorsByPrecedence = Arrays.stream(Operator.values()).toList();
private final String image;
private final boolean bindsRight; // TODO: Implement
private final BiFunction<Value, Value, Value> function;
- ArithmeticOperator(String image, BiFunction<Value, Value, Value> function) {
+ Operator(String image, BiFunction<Value, Value, Value> function) {
this(image, false, function);
}
- ArithmeticOperator(String image, boolean bindsRight, BiFunction<Value, Value, Value> function) {
+ Operator(String image, boolean bindsRight, BiFunction<Value, Value, Value> function) {
this.image = image;
this.bindsRight = bindsRight;
this.function = function;
}
/** Returns true if this operator has precedence over the given operator */
- public boolean hasPrecedenceOver(ArithmeticOperator op) {
+ public boolean hasPrecedenceOver(Operator op) {
return operatorsByPrecedence.indexOf(this) > operatorsByPrecedence.indexOf(op);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index 7a34f5b7b03..04728966bc1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -4,8 +4,8 @@ package com.yahoo.searchlib.rankingexpression.transform;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
@@ -33,8 +33,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
node = transformIf((IfNode) node);
if (node instanceof EmbracedNode e && hasSingleUndividableChild(e))
node = e.children().get(0);
- if (node instanceof ArithmeticNode)
- node = transformArithmetic((ArithmeticNode) node);
+ if (node instanceof OperationNode)
+ node = transformArithmetic((OperationNode) node);
if (node instanceof NegativeNode)
node = transformNegativeNode((NegativeNode) node);
return node;
@@ -42,18 +42,18 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
private boolean hasSingleUndividableChild(EmbracedNode node) {
if (node.children().size() > 1) return false;
- if (node.children().get(0) instanceof ArithmeticNode) return false;
+ if (node.children().get(0) instanceof OperationNode) return false;
return true;
}
- private ExpressionNode transformArithmetic(ArithmeticNode node) {
+ private ExpressionNode transformArithmetic(OperationNode node) {
// Fold the subset of expressions that are constant (such that in "1 + 2 + var")
if (node.children().size() > 1) {
List<ExpressionNode> children = new ArrayList<>(node.children());
- List<ArithmeticOperator> operators = new ArrayList<>(node.operators());
- for (ArithmeticOperator operator : ArithmeticOperator.operatorsByPrecedence)
+ List<Operator> operators = new ArrayList<>(node.operators());
+ for (Operator operator : Operator.operatorsByPrecedence)
transform(operator, children, operators);
- node = new ArithmeticNode(children, operators);
+ node = new OperationNode(children, operators);
}
if (isConstant(node) && ! node.evaluate(null).isNaN())
@@ -64,8 +64,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return node;
}
- private void transform(ArithmeticOperator operatorToTransform,
- List<ExpressionNode> children, List<ArithmeticOperator> operators) {
+ private void transform(Operator operatorToTransform,
+ List<ExpressionNode> children, List<Operator> operators) {
int i = 0;
while (i < children.size()-1) {
boolean transformed = false;
@@ -73,7 +73,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
ExpressionNode child1 = children.get(i);
ExpressionNode child2 = children.get(i + 1);
if (isConstant(child1) && isConstant(child2) && hasPrecedence(operators, i)) {
- Value evaluated = new ArithmeticNode(child1, operators.get(i), child2).evaluate(null);
+ Value evaluated = new OperationNode(child1, operators.get(i), child2).evaluate(null);
if ( ! evaluated.isNaN()) { // Don't replace by NaN
operators.remove(i);
children.set(i, new ConstantNode(evaluated.freeze()));
@@ -92,7 +92,7 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
* This check works because we simplify by decreasing precedence, so neighbours will either be single constant values
* or a more complex expression that can't be simplified and hence also prevents the simplification in question here.
*/
- private boolean hasPrecedence(List<ArithmeticOperator> operators, int i) {
+ private boolean hasPrecedence(List<Operator> operators, int i) {
if (i > 0 && operators.get(i-1).hasPrecedenceOver(operators.get(i))) return false;
if (i < operators.size()-1 && operators.get(i+1).hasPrecedenceOver(operators.get(i))) return false;
return true;
@@ -114,14 +114,14 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return new ConstantNode(constant.getValue().negate() );
}
- private boolean allMultiplicationOrDivision(ArithmeticNode node) {
- for (ArithmeticOperator o : node.operators())
- if (o != ArithmeticOperator.MULTIPLY && o != ArithmeticOperator.DIVIDE)
+ private boolean allMultiplicationOrDivision(OperationNode node) {
+ for (Operator o : node.operators())
+ if (o != Operator.MULTIPLY && o != Operator.DIVIDE)
return false;
return true;
}
- private boolean hasZero(ArithmeticNode node) {
+ private boolean hasZero(OperationNode node) {
for (ExpressionNode child : node.children()) {
if (isZero(child))
return true;
@@ -129,9 +129,9 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return false;
}
- private boolean hasDivisionByZero(ArithmeticNode node) {
+ private boolean hasDivisionByZero(OperationNode node) {
for (int i = 1; i < node.children().size(); i++) {
- if ( node.operators().get(i - 1) == ArithmeticOperator.DIVIDE && isZero(node.children().get(i)))
+ if (node.operators().get(i - 1) == Operator.DIVIDE && isZero(node.children().get(i)))
return true;
}
return false;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 1ab9ee11252..83de8e04a7d 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -2,8 +2,8 @@
package com.yahoo.searchlib.rankingexpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -66,7 +66,7 @@ public class RankingExpressionTestCase {
public void testProgrammaticBuilding() throws ParseException {
ReferenceNode input = new ReferenceNode("input");
ReferenceNode constant = new ReferenceNode("constant");
- ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant);
+ OperationNode product = new OperationNode(input, Operator.MULTIPLY, constant);
Reduce<Reference> sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum);
RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index b1ac4b9e3ca..019f76521e9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -5,8 +5,8 @@ import com.yahoo.javacc.UnicodeUtilities;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.OperationNode;
+import com.yahoo.searchlib.rankingexpression.rule.Operator;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
@@ -779,8 +779,8 @@ public class EvaluationTestCase {
@Test
public void testProgrammaticBuildingAndPrecedence() {
- RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4))));
- RankingExpression oppositePrecedence = new RankingExpression(new ArithmeticNode(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, constant(3)), ArithmeticOperator.MULTIPLY, constant(4)));
+ RankingExpression standardPrecedence = new RankingExpression(new OperationNode(constant(2), Operator.PLUS, new OperationNode(constant(3), Operator.MULTIPLY, constant(4))));
+ RankingExpression oppositePrecedence = new RankingExpression(new OperationNode(new OperationNode(constant(2), Operator.PLUS, constant(3)), Operator.MULTIPLY, constant(4)));
assertEquals(14.0, standardPrecedence.evaluate(null).asDouble(), tolerance);
assertEquals(20.0, oppositePrecedence.evaluate(null).asDouble(), tolerance);
assertEquals("2.0 + 3.0 * 4.0", standardPrecedence.toString());