From 067c45bdff39f88b7c7ce586c03f217532272ba5 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Thu, 29 Sep 2022 13:18:05 +0200 Subject: Only optimize boolean expressions on primitives Only transform a && b to if(a, b, false) etc.0 if both a and b are primitives, not tensors, as if requires both branches to return the same type. --- .../BooleanExpressionTransformerTestCase.java | 68 +++++++++++++--------- .../RankingExpressionWithOnnxTestCase.java | 51 ++++++++++------ 2 files changed, 76 insertions(+), 43 deletions(-) (limited to 'config-model/src/test/java/com/yahoo/schema') diff --git a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java index d692b69d3c8..d06573f7bae 100644 --- a/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/expressiontransforms/BooleanExpressionTransformerTestCase.java @@ -2,10 +2,13 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext; import com.yahoo.searchlib.rankingexpression.rule.OperationNode; import com.yahoo.searchlib.rankingexpression.transform.TransformContext; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import org.junit.jupiter.api.Test; import java.util.Map; @@ -20,7 +23,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; public class BooleanExpressionTransformerTestCase { @Test - public void testTransformer() throws Exception { + public void booleanTransformation() throws Exception { assertTransformed("if (a, b, false)", "a && b"); assertTransformed("if (a, true, b)", "a || b"); assertTransformed("if (a, true, b + c)", "a || b + c"); @@ -33,16 +36,17 @@ public class BooleanExpressionTransformerTestCase { } @Test - public void testIt() throws Exception { - assertTransformed("if(1 - 1, true, 1 - 1)", "1 - 1 || 1 - 1"); + public void noTransformationOnTensorTypes() throws Exception { + var typeContext = new MapTypeContext(); + typeContext.setType(Reference.fromIdentifier("tensorA"), TensorType.fromSpec("tensor(x{})")); + typeContext.setType(Reference.fromIdentifier("tensorB"), TensorType.fromSpec("tensor(x{})")); + assertUntransformed("tensorA && tensorB", typeContext); + assertTransformed("a && (tensorA * tensorB)","a && ( tensorA * tensorB)", typeContext); } @Test public void testNotSkewingNonBoolean() throws Exception { - assertTransformed("a + b + c * d + e + f", "a + b + c * d + e + f"); - var expr = new BooleanExpressionTransformer() - .transform(new RankingExpression("a + b + c * d + e + f"), - new TransformContext(Map.of(), new MapTypeContext())); + var expr = assertTransformed("a + b + c * d + e + f", "a + b + c * d + e + f"); assertTrue(expr.getRoot() instanceof OperationNode); OperationNode root = (OperationNode) expr.getRoot(); assertEquals(5, root.operators().size()); @@ -51,41 +55,53 @@ public class BooleanExpressionTransformerTestCase { @Test public void testTransformPreservesPrecedence() throws Exception { - assertUnTransformed("a"); - assertUnTransformed("a + b"); - assertUnTransformed("a + b + c"); - assertUnTransformed("a * b"); - assertUnTransformed("a + b * c + d"); - assertUnTransformed("a + b + c * d + e + f"); - assertUnTransformed("a * b + c + d + e * f"); - assertUnTransformed("(a * b) + c + d + e * f"); - assertUnTransformed("(a * b + c) + d + e * f"); - assertUnTransformed("a * (b + c) + d + e * f"); - assertUnTransformed("(a * b) + (c + (d + e)) * f"); + assertUntransformed("a"); + assertUntransformed("a + b"); + assertUntransformed("a + b + c"); + assertUntransformed("a * b"); + assertUntransformed("a + b * c + d"); + assertUntransformed("a + b + c * d + e + f"); + assertUntransformed("a * b + c + d + e * f"); + assertUntransformed("(a * b) + c + d + e * f"); + assertUntransformed("(a * b + c) + d + e * f"); + assertUntransformed("a * (b + c) + d + e * f"); + assertUntransformed("(a * b) + (c + (d + e)) * f"); + } + + private void assertUntransformed(String input) throws Exception { + assertUntransformed(input, new MapTypeContext()); + } + + private void assertUntransformed(String input, MapTypeContext typeContext) throws Exception { + assertTransformed(input, input, typeContext); } - private void assertUnTransformed(String input) throws Exception { - assertTransformed(input, input); + private RankingExpression assertTransformed(String expected, String input) throws Exception { + return assertTransformed(expected, input, new MapTypeContext()); } - private void assertTransformed(String expected, String input) throws Exception { + private RankingExpression assertTransformed(String expected, String input, MapTypeContext typeContext) throws Exception { + MapContext context = contextWithSingleLetterVariables(typeContext); var transformedExpression = new BooleanExpressionTransformer() .transform(new RankingExpression(input), - new TransformContext(Map.of(), new MapTypeContext())); + new TransformContext(Map.of(), typeContext)); assertEquals(new RankingExpression(expected), transformedExpression, "Transformed as expected"); - MapContext context = contextWithSingleLetterVariables(); var inputExpression = new RankingExpression(input); assertEquals(inputExpression.evaluate(context).asBoolean(), transformedExpression.evaluate(context).asBoolean(), "Transform and original input are equivalent"); + return transformedExpression; } - private MapContext contextWithSingleLetterVariables() { + private MapContext contextWithSingleLetterVariables(MapTypeContext typeContext) { var context = new MapContext(); - for (int i = 0; i < 26; i++) - context.put(Character.toString(i + 97), Math.floorMod(i, 2)); + for (int i = 0; i < 26; i++) { + String name = Character.toString(i + 97); + typeContext.setType(Reference.fromIdentifier(name), TensorType.empty); + context.put(name, Math.floorMod(i, 2)); + } return context; } diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java index 83d19b010bb..2f53dba7bb4 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxTestCase.java @@ -31,6 +31,8 @@ public class RankingExpressionWithOnnxTestCase { private final static String name = "mnist_softmax"; private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_layer_Variable), f(a,b)(a * b)), sum, d2) * 1.0, constant(mnist_softmax_layer_Variable_1) * 1.0, f(a,b)(a + b))"; + private final static String vespaExpressionConstants = "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor(d0[1],d1[784]) }\n" + + "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor(d0[1],d1[784]) }\n"; @AfterEach public void removeGeneratedModelFiles() { @@ -41,7 +43,7 @@ public class RankingExpressionWithOnnxTestCase { void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx_vespa('mnist_softmax.onnx')", - "constant mytensor { file: ignored\ntype: tensor(d0[1],d1[784]) }", + vespaExpressionConstants + "constant mytensor { file: ignored\ntype: tensor(d0[1],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -58,7 +60,7 @@ public class RankingExpressionWithOnnxTestCase { queryProfileType); RankProfileSearchFixture search = fixtureWith("query(mytensor)", "onnx_vespa('mnist_softmax.onnx')", - null, + vespaExpressionConstants, null, "Placeholder", application); @@ -70,7 +72,7 @@ public class RankingExpressionWithOnnxTestCase { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", "onnx_vespa('mnist_softmax.onnx')", - null, + vespaExpressionConstants, "field mytensor type tensor(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); @@ -88,7 +90,7 @@ public class RankingExpressionWithOnnxTestCase { StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", "onnx_vespa('mnist_softmax.onnx')", - "constant mytensor { file: ignored\ntype: tensor(d0[1],d1[784]) }", + vespaExpressionConstants + "constant mytensor { file: ignored\ntype: tensor(d0[1],d1[784]) }", "field mytensor type tensor(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); @@ -99,21 +101,24 @@ public class RankingExpressionWithOnnxTestCase { @Test void testNestedOnnxReference() { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "5 + sum(onnx_vespa('mnist_softmax.onnx'))"); + "5 + sum(onnx_vespa('mnist_softmax.onnx'))", + vespaExpressionConstants); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @Test void testOnnxReferenceWithSpecifiedOutput() { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx', 'layer_add')"); + "onnx_vespa('mnist_softmax.onnx', 'layer_add')", + vespaExpressionConstants); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test void testOnnxReferenceWithSpecifiedOutputAndSignature() { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')"); + "onnx_vespa('mnist_softmax.onnx', 'default.layer_add')", + vespaExpressionConstants); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -177,7 +182,8 @@ public class RankingExpressionWithOnnxTestCase { @Test void testImportingFromStoredExpressions() throws IOException { RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "onnx_vespa(\"mnist_softmax.onnx\")"); + "onnx_vespa(\"mnist_softmax.onnx\")", + vespaExpressionConstants); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); // At this point the expression is stored - copy application to another location which do not have a models dir @@ -187,12 +193,14 @@ public class RankingExpressionWithOnnxTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + String constants = "constant mnist_softmax_layer_Variable { file: ignored\ntype: tensor(d0[2],d1[784]) }\n" + + "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor(d0[2],d1[784]) }\n"; RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", - "onnx_vespa('mnist_softmax.onnx')", - null, - null, - "Placeholder", - storedApplication); + "onnx_vespa('mnist_softmax.onnx')", + constants, + null, + "Placeholder", + storedApplication); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test @@ -221,7 +229,8 @@ public class RankingExpressionWithOnnxTestCase { String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), " + name + "_layer_Variable, f(a,b)(a * b)), sum, d2) * 1.0, constant(" + name + "_layer_Variable_1) * 1.0, f(a,b)(a + b))"; - RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + String constant = "constant mnist_softmax_layer_Variable_1 { file: ignored\ntype: tensor(d0[1],d1[10]) }\n"; + RankProfileSearchFixture search = uncompiledFixtureWith(rankProfile, new StoringApplicationPackage(applicationDir), constant); search.compileRankProfile("my_profile", applicationDir.append("models")); search.compileRankProfile("my_profile_child", applicationDir.append("models")); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); @@ -237,7 +246,7 @@ public class RankingExpressionWithOnnxTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication); + RankProfileSearchFixture searchFromStored = uncompiledFixtureWith(rankProfile, storedApplication, constant); searchFromStored.compileRankProfile("my_profile", applicationDir.append("models")); searchFromStored.compileRankProfile("my_profile_child", applicationDir.append("models")); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); @@ -326,7 +335,11 @@ public class RankingExpressionWithOnnxTestCase { } private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { - return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder", + return fixtureWith(placeholderExpression, firstPhaseExpression, null); + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, String constant) { + return fixtureWith(placeholderExpression, firstPhaseExpression, constant, null, "Placeholder", new StoringApplicationPackage(applicationDir)); } @@ -337,9 +350,13 @@ public class RankingExpressionWithOnnxTestCase { } private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) { + return uncompiledFixtureWith(rankProfile, application, null); + } + + private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application, String constant) { try { return new RankProfileSearchFixture(application, application.getQueryProfiles(), - rankProfile, null, null); + rankProfile, constant, null); } catch (ParseException e) { throw new IllegalArgumentException(e); -- cgit v1.2.3