diff options
author | Lester Solbakken <lesters@oath.com> | 2020-04-26 10:01:15 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-04-26 10:01:15 +0200 |
commit | 12b09c46a3e32df3a90d326f30171bc789d1bcb8 (patch) | |
tree | 013d544eca009f2da2009372928c43c17d0387eb | |
parent | ee76bcb34de37f748d4377c0dc2d4899a0c41d6a (diff) |
Fix unit tests
5 files changed, 43 insertions, 36 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 5c1134f928c..e4ca83640e9 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -50,10 +50,10 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(sin).rankingScript,x * x)", - testRankProperties.get(0).toString()); assertEquals("(rankingExpression(sin@).rankingScript,2 * 2)", - censorBindingHash(testRankProperties.get(1).toString())); + censorBindingHash(testRankProperties.get(0).toString())); + assertEquals("(rankingExpression(sin).rankingScript,x * x)", + testRankProperties.get(1).toString()); assertEquals("(vespa.rank.firstphase,rankingExpression(sin@))", censorBindingHash(testRankProperties.get(2).toString())); } @@ -94,27 +94,26 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(s)).configProperties(); + assertEquals("(rankingExpression(tan@).rankingScript,2 * 2)", + censorBindingHash(testRankProperties.get(0).toString())); + assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))", + censorBindingHash(testRankProperties.get(1).toString())); + assertEquals("(rankingExpression(sin@).rankingScript,rankingExpression(cos@))", + censorBindingHash(testRankProperties.get(2).toString())); assertEquals("(rankingExpression(tan).rankingScript,x * x)", - testRankProperties.get(0).toString()); + testRankProperties.get(3).toString()); assertEquals("(rankingExpression(tan@).rankingScript,x * x)", - censorBindingHash(testRankProperties.get(1).toString())); - assertEquals("(rankingExpression(cos).rankingScript,rankingExpression(tan@))", - censorBindingHash(testRankProperties.get(2).toString())); - assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))", - censorBindingHash(testRankProperties.get(3).toString())); - assertEquals("(rankingExpression(sin).rankingScript,rankingExpression(cos@))", censorBindingHash(testRankProperties.get(4).toString())); - assertEquals("(rankingExpression(tan@).rankingScript,2 * 2)", + assertEquals("(rankingExpression(cos).rankingScript,rankingExpression(tan@))", censorBindingHash(testRankProperties.get(5).toString())); assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))", - censorBindingHash(testRankProperties.get(6).toString())); - assertEquals("(rankingExpression(sin@).rankingScript,rankingExpression(cos@))", + censorBindingHash(testRankProperties.get(6).toString())); + assertEquals("(rankingExpression(sin).rankingScript,rankingExpression(cos@))", censorBindingHash(testRankProperties.get(7).toString())); assertEquals("(vespa.rank.firstphase,rankingExpression(sin@))", censorBindingHash(testRankProperties.get(8).toString())); } - @Test public void testFunctionShadowingArguments() throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); @@ -144,12 +143,12 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(sin).rankingScript,x * x)", - testRankProperties.get(0).toString()); assertEquals("(rankingExpression(sin@).rankingScript,4.0 * 4.0)", - censorBindingHash(testRankProperties.get(1).toString())); + censorBindingHash(testRankProperties.get(0).toString())); assertEquals("(rankingExpression(sin@).rankingScript,cos(5.0) * cos(5.0))", - censorBindingHash(testRankProperties.get(2).toString())); + censorBindingHash(testRankProperties.get(1).toString())); + assertEquals("(rankingExpression(sin).rankingScript,x * x)", + testRankProperties.get(2).toString()); assertEquals("(vespa.rank.firstphase,rankingExpression(firstphase))", censorBindingHash(testRankProperties.get(3).toString())); assertEquals("(rankingExpression(firstphase).rankingScript,cos(rankingExpression(sin@)) + rankingExpression(sin@))", @@ -208,17 +207,17 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { queryProfiles, new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", - testRankProperties.get(0).toString()); assertEquals("(rankingExpression(relu@).rankingScript,max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))", - censorBindingHash(testRankProperties.get(1).toString())); + censorBindingHash(testRankProperties.get(0).toString())); assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))", - censorBindingHash(testRankProperties.get(2).toString())); + censorBindingHash(testRankProperties.get(1).toString())); assertEquals("(rankingExpression(hidden_layer).type,tensor(x[]))", - censorBindingHash(testRankProperties.get(3).toString())); + censorBindingHash(testRankProperties.get(2).toString())); assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", - testRankProperties.get(4).toString()); + testRankProperties.get(3).toString()); assertEquals("(rankingExpression(final_layer).type,tensor(x[]))", + testRankProperties.get(4).toString()); + assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(5).toString()); assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))", testRankProperties.get(6).toString()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 0cd6674751e..e6616ce0dd1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -41,6 +41,14 @@ class RankProfileSearchFixture { private Search search; private Map<String, RankProfile> compiledRankProfiles = new HashMap<>(); + public RankProfileRegistry getRankProfileRegistry() { + return rankProfileRegistry; + } + + public QueryProfileRegistry getQueryProfileRegistry() { + return queryProfileRegistry; + } + RankProfileSearchFixture(String rankProfiles) throws ParseException { this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index a64a964727c..680f2dd9659 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -319,7 +319,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testFunctionGeneration() { final String name = "mnist_saved"; - final String expression = "join(join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))"; final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))"; @@ -349,7 +349,7 @@ public class RankingExpressionWithTensorFlowTestCase { " rank-profile my_profile_child inherits my_profile {\n" + " }"; - final String expression = "join(join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))"; final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))"; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index b3eda9b7e13..1567a4c3b5e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -44,20 +44,20 @@ public class RankingExpressionsTestCase extends SchemaTestCase { new AttributeFields(search)).configProperties(); assertEquals(6, rankProperties.size()); - assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(0).getFirst()); - assertEquals("var1 * var2 + 890", rankProperties.get(0).getSecond()); + assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(2).getFirst()); + assertEquals("var1 * var2 + 890", rankProperties.get(2).getSecond()); - assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(1).getFirst()); - assertEquals("78 + closeness(distance)", rankProperties.get(1).getSecond()); + assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(3).getFirst()); + assertEquals("78 + closeness(distance)", rankProperties.get(3).getSecond()); assertEquals("rankingExpression(firstphase).rankingScript", rankProperties.get(5).getFirst()); assertEquals("0.8 + 0.2 * rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c) + 0.8 * rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6) * closeness(distance)", rankProperties.get(5).getSecond()); - assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(3).getFirst()); - assertEquals("7 * 8 + 890", rankProperties.get(3).getSecond()); + assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(1).getFirst()); + assertEquals("7 * 8 + 890", rankProperties.get(1).getSecond()); - assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(2).getFirst()); - assertEquals("4 * 5 + 890", rankProperties.get(2).getSecond()); + assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(0).getFirst()); + assertEquals("4 * 5 + 890", rankProperties.get(0).getSecond()); } @Test(expected = IllegalArgumentException.class) diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java index ca84eb5eed7..57dbf132883 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java @@ -64,10 +64,10 @@ public class MlModelsTest { private final String testProfile = "rankingExpression(input).rankingScript: attribute(argument)\n" + "rankingExpression(input).type: tensor<float>(d0[],d1[784])\n" + - "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" + - "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" + "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(rankingExpression(input), (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" + "rankingExpression(mnist_tensorflow).rankingScript: join(reduce(join(map(join(reduce(join(join(join(0.009999999776482582, rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.0507009873554805 * if (a >= 0, a, 1.6732632423543772 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" + + "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" + + "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" + "rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" + "rankingExpression(mnist_softmax_onnx).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))\n" + "rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (!(f56 >= -0.242398), 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (!(f60 >= -0.482947), if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" + |