aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2021-11-04 00:25:45 +0100
committerGitHub <noreply@github.com>2021-11-04 00:25:45 +0100
commit19113f6da3db1abbe8a3e36f081a0ec03f878f12 (patch)
treeecf3063c4fbde0ef6d2ece92255527aeffa27e53
parent7e43b15f2e427ac08af82e2292e8649328f729e5 (diff)
parent599946ff2d1838914ffbab7d74fdbe6055187189 (diff)
Merge pull request #19855 from vespa-engine/balder/optimize-negative-constantsv7.495.22
Avoid intermediate NegativeNode in the leaf nodes, adding approximate…
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java10
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java4
-rw-r--r--searchlib/abi-spec.json7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/LongValue.java4
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java14
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj41
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java18
13 files changed, 78 insertions, 40 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
index faf3c8085d8..6d9f4cdec92 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxModelTransformer.java
@@ -135,7 +135,7 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
public static String asString(ExpressionNode node) {
if ( ! (node instanceof ConstantNode))
throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ return stripQuotes(node.toString());
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index cd56de18527..dbb32d88ef6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -43,11 +43,11 @@ import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrap
*/
public class TokenTransformer extends ExpressionTransformer<RankProfileTransformContext> {
- static private final ConstantNode ZERO = new ConstantNode(new DoubleValue(0.0), "0");
- static private final ConstantNode ONE = new ConstantNode(new DoubleValue(1.0), "1");
- static private final ConstantNode TWO = new ConstantNode(new DoubleValue(2.0), "2");
- static private final ConstantNode CLS = new ConstantNode(new DoubleValue(101), "101");
- static private final ConstantNode SEP = new ConstantNode(new DoubleValue(102), "102");
+ static private final ConstantNode ZERO = new ConstantNode(new DoubleValue(0.0));
+ static private final ConstantNode ONE = new ConstantNode(new DoubleValue(1.0));
+ static private final ConstantNode TWO = new ConstantNode(new DoubleValue(2.0));
+ static private final ConstantNode CLS = new ConstantNode(new DoubleValue(101));
+ static private final ConstantNode SEP = new ConstantNode(new DoubleValue(102));
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
index c7c1db3b862..fdbde08d926 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/OnnxModelConfigGenerator.java
@@ -74,7 +74,7 @@ public class OnnxModelConfigGenerator extends Processor {
if (feature.getArguments().size() > 0) {
if (feature.getArguments().expressions().get(0) instanceof ConstantNode) {
ConstantNode node = (ConstantNode) feature.getArguments().expressions().get(0);
- String path = OnnxModelTransformer.stripQuotes(node.sourceString());
+ String path = OnnxModelTransformer.stripQuotes(node.toString());
String modelConfigName = OnnxModelTransformer.asValidIdentifier(path);
// Only add the configuration if the model can actually be found.
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java
index 372b355b2d3..b52fd060a1c 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java
@@ -57,7 +57,7 @@ public class FeatureArguments {
private static String asString(ExpressionNode node) {
if ( ! (node instanceof ConstantNode))
throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
+ return stripQuotes(node.toString());
}
private static String stripQuotes(String s) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 40a0a1be5fc..d97235d11d2 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -293,8 +293,8 @@ public final class LazyArrayContext extends Context implements ContextIndex {
ReferenceNode reference = (ReferenceNode) node;
if (reference.getArguments().size() > 0) {
if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
- ConstantNode constantNode = (ConstantNode) reference.getArguments().expressions().get(0);
- return Optional.of(stripQuotes(constantNode.sourceString()));
+ ExpressionNode constantNode = reference.getArguments().expressions().get(0);
+ return Optional.of(stripQuotes(constantNode.toString()));
}
if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0);
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index d3ce502fe75..e5611324254 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -594,14 +594,13 @@
"public java.lang.String toString()",
"public boolean equals(java.lang.Object)",
"public int hashCode()",
- "public com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue negate()",
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value negate()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value add(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value subtract(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value multiply(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value divide(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value modulo(com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)",
- "public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.Value negate()"
+ "public com.yahoo.searchlib.rankingexpression.evaluation.Value function(com.yahoo.searchlib.rankingexpression.rule.Function, com.yahoo.searchlib.rankingexpression.evaluation.Value)"
],
"fields": []
},
@@ -943,7 +942,7 @@
"public final java.util.List bracedIdentifierList()",
"public final java.lang.String tag()",
"public final java.util.List tagCommaLeadingList()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode constantPrimitive(boolean)",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType, java.util.List)",
"public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/LongValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/LongValue.java
index bb4af7d71f3..b9323e1ccb8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/LongValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/LongValue.java
@@ -47,8 +47,8 @@ public class LongValue extends DoubleCompatibleValue {
}
@Override
- public DoubleValue negate() {
- return new DoubleValue(-value);
+ public Value negate() {
+ return new LongValue(-value);
}
private UnsupportedOperationException unsupported(String operation, Value value) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
index 68134cc85b4..ffbeec37c78 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java
@@ -19,7 +19,8 @@ public final class ConstantNode extends ExpressionNode {
private final Value value;
public ConstantNode(Value value) {
- this(value,null);
+ value.freeze();
+ this.value = value;
}
/**
@@ -28,19 +29,20 @@ public final class ConstantNode extends ExpressionNode {
* @param value the value. Ownership of this value is transferred to this.
* @param sourceImage the source string image producing this value
*/
+ @Deprecated
public ConstantNode(Value value, String sourceImage) {
- value.freeze();
- this.value = value;
+ this(value);
}
public Value getValue() { return value; }
@Override
public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
- return string.append(sourceString());
+ return string.append(value.toString());
}
/** Returns the string which created this, or the value.toString() if not known */
+ @Deprecated
public String sourceString() {
return value.toString();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
index 57b349fdc2e..9516f38a155 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java
@@ -30,7 +30,7 @@ public class NegativeNode extends CompositeNode {
@Override
public List<ExpressionNode> children() {
- return Collections.singletonList(value);
+ return List.of(value);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
index a56106e8f9d..b48303ae98b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java
@@ -1,7 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.transform;
-import com.yahoo.document.update.ArithmeticValueUpdate;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
@@ -12,9 +12,9 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
+import com.yahoo.searchlib.rankingexpression.rule.NegativeNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.functions.TensorFunction;
import java.util.ArrayList;
import java.util.List;
@@ -36,6 +36,8 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
node = ((EmbracedNode)node).children().get(0);
if (node instanceof ArithmeticNode)
node = transformArithmetic((ArithmeticNode) node);
+ if (node instanceof NegativeNode)
+ node = transformNegativeNode((NegativeNode) node);
return node;
}
@@ -107,6 +109,14 @@ public class Simplifier extends ExpressionTransformer<TransformContext> {
return node.getFalseExpression();
}
+ private ExpressionNode transformNegativeNode(NegativeNode node) {
+ if ( ! (node.getValue() instanceof ConstantNode) ) return node;
+
+ ConstantNode constant = (ConstantNode) node.getValue();
+ if ( ! (constant.getValue() instanceof DoubleCompatibleValue)) return node;
+ return new ConstantNode(constant.getValue().negate() );
+ }
+
private boolean allMultiplicationOrDivision(ArithmeticNode node) {
for (ArithmeticOperator o : node.operators())
if (o != ArithmeticOperator.MULTIPLY && o != ArithmeticOperator.DIVIDE)
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 61e35647b89..0d46ab4ddb6 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -246,18 +246,20 @@ ExpressionNode value() :
(
[ <NOT> { not = true; } ]
[ LOOKAHEAD(2) <SUB> { neg = true; } ]
- ( value = constantPrimitive() |
- LOOKAHEAD(2) value = ifExpression() |
- LOOKAHEAD(4) value = function() |
- value = feature() |
- value = legacyQueryFeature() |
- ( <LBRACE> value = expression() <RBRACE> { value = new EmbracedNode(value); } ) )
-
+ ( value = constantPrimitive(neg) |
+ (
+ LOOKAHEAD(2) value = ifExpression() |
+ LOOKAHEAD(4) value = function() |
+ value = feature() |
+ value = legacyQueryFeature() |
+ ( <LBRACE> value = expression() <RBRACE> { value = new EmbracedNode(value); } )
+ ) { value = neg ? new NegativeNode(value) : value; }
)
- [ LOOKAHEAD(2) valueAddress = valueAddress() { value = new TensorFunctionNode(new Slice(TensorFunctionNode.wrap(value), valueAddress)); } ]
+
+ )
+ [ LOOKAHEAD(2) valueAddress = valueAddress() { value = new TensorFunctionNode(new Slice(TensorFunctionNode.wrap(value), valueAddress)); } ]
{
value = not ? new NotNode(value) : value;
- value = neg ? new NegativeNode(value) : value;
return value;
}
}
@@ -841,17 +843,24 @@ List<String> tagCommaLeadingList() :
{ return list; }
}
-ConstantNode constantPrimitive() :
+ExpressionNode constantPrimitive(boolean negate) :
{
- String sign = "";
String value;
+ ExpressionNode node;
}
{
- ( <SUB> { sign = "-";} ) ?
- ( <INTEGER> { value = token.image; } |
- <FLOAT> { value = token.image; } |
- <STRING> { value = token.image; } )
- { return new ConstantNode(Value.parse(sign + value)); }
+ ( <SUB> { negate = !negate; } ) ?
+ (
+ ( <INTEGER> { value = token.image; } |
+ <FLOAT> { value = token.image; }
+ ) { node = new ConstantNode(Value.parse(negate ? ("-" + value) : value)); } |
+ <STRING> {
+ value = token.image;
+ node = new ConstantNode(Value.parse(value));
+ if (negate) node = new NegativeNode(node);
+ }
+ )
+ { return node; }
}
Value primitiveValue() :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 13e19ff4d35..efa98fba2eb 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -815,7 +815,7 @@ public class EvaluationTestCase {
try {
ExpressionNode e = arguments.expressions().get(index);
if (e instanceof ConstantNode) {
- return new DoubleValue(new RankingExpression(UnicodeUtilities.unquote(((ConstantNode)e).sourceString())).evaluate(this).asDouble());
+ return new DoubleValue(new RankingExpression(UnicodeUtilities.unquote(e.toString())).evaluate(this).asDouble());
}
return e.evaluate(this);
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
index f7735b4f5ca..c93830abda9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
@@ -5,8 +5,11 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.NegativeNode;
import org.junit.Test;
import java.util.Collections;
@@ -83,4 +86,19 @@ public class SimplifierTestCase {
assertEquals("a + (b + c) / 1.0E8", transformed.toString());
}
+ @Test
+ public void testOptimizingNegativeConstants() throws ParseException {
+ Simplifier s = new Simplifier();
+ TransformContext c = new TransformContext(Collections.emptyMap(), new MapTypeContext());
+ assertEquals("-3", s.transform(new RankingExpression("-3"), c).toString());
+ assertEquals("-9.0", s.transform(new RankingExpression("-3 + -6"), c).toString());
+ assertEquals("-a", s.transform(new RankingExpression("-a"), c).toString());
+ assertEquals("-\"a\"", s.transform(new RankingExpression("-'a'"), c).toString());
+
+ RankingExpression r = new RankingExpression(new NegativeNode(new ConstantNode(Value.parse("3"))));
+ assertTrue(r.getRoot() instanceof NegativeNode);
+ r = s.transform(r, c);
+ assertTrue(r.getRoot() instanceof ConstantNode);
+ }
+
}