diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-09-28 16:19:30 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-09-28 16:19:30 +0200 |
commit | 3d49f155fccfa4fc08882b01e7a6e3a982c55212 (patch) | |
tree | 865d6e301e5fcd3fba248807ff980bcc7e18d41f /config-model | |
parent | 7cfc4fa47828261ee1f839a27a437d8bc49eb26f (diff) |
Fold comparisons into the other operators
Diffstat (limited to 'config-model')
3 files changed, 7 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java index 30d9a3766b3..2695fa79588 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/TokenTransformer.java @@ -5,7 +5,6 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; @@ -13,7 +12,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Generate; @@ -141,10 +139,10 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform ExpressionNode queryLengthExpr = createLengthExpr(2, tokenSequence); ExpressionNode restLengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); ExpressionNode expr = new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, queryLengthExpr), + new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, queryLengthExpr), ZERO, new IfNode( - new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, restLengthExpr), + new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, restLengthExpr), ONE, ZERO ) @@ -176,7 +174,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform List<ExpressionNode> tokenSequence = createTokenSequence(feature); ExpressionNode lengthExpr = createLengthExpr(tokenSequence.size() - 1, tokenSequence); - ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + ArithmeticNode comparison = new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, lengthExpr); ExpressionNode expr = new IfNode(comparison, ONE, ZERO); return new TensorFunctionNode(Generate.bound(type, wrapScalar(expr))); } @@ -256,7 +254,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform */ private ExpressionNode createTokenSequenceExpr(int iter, List<ExpressionNode> sequence) { ExpressionNode lengthExpr = createLengthExpr(iter, sequence); - ComparisonNode comparison = new ComparisonNode(new ReferenceNode("d1"), TruthOperator.SMALLER, lengthExpr); + ArithmeticNode comparison = new ArithmeticNode(new ReferenceNode("d1"), ArithmeticOperator.LESS, lengthExpr); ExpressionNode trueExpr = sequence.get(iter); if (sequence.get(iter) instanceof ReferenceNode) { diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java index 39d9be905a5..13d21884c7d 100644 --- a/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionInliningTestCase.java @@ -68,7 +68,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase { Schema s = builder.getSchema(); RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); - assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + if (7.0 < attribute(a), 1, 2) == 0))", + assertEquals("7.0 * (3 + attribute(a) + attribute(b) * (attribute(a) * 3 + (if (7.0 < attribute(a), 1, 2) == 0)))", parent.getFirstPhaseRanking().getRoot().toString()); RankProfile child = rankProfileRegistry.get(s, "child").compile(new QueryProfileRegistry(), new ImportedMlModels()); assertEquals("7.0 * (9 + attribute(a))", @@ -97,7 +97,7 @@ public class RankingExpressionInliningTestCase extends AbstractSchemaTestCase { Schema s = builder.getSchema(); RankProfile parent = rankProfileRegistry.get(s, "parent").compile(new QueryProfileRegistry(), new ImportedMlModels()); - assertEquals("3 * ( query(test) > 2.0 )", + assertEquals("3 * (query(test) > 2.0)", parent.getFunctions().get("foo").function().getBody().getRoot().toString()); } diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java index 22681858fc3..e9b674a8c87 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java @@ -171,7 +171,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); - assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 0.0, if (d1 < 1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0, 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); + assertEquals("tensor<float>(d0[1],d1[10])((if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 0.0, if (d1 < (1.0 + rankingExpression(__token_length@-1993461420) + 1.0 + rankingExpression(__token_length@-1993461420) + 1.0), 1.0, 0.0))))", config.rankprofile(7).fef().property(1).value()); assertEquals("test_unbound_model", config.rankprofile(8).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); |