aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-04-26 10:01:15 +0200
committerLester Solbakken <lesters@oath.com>2020-04-26 10:01:15 +0200
commit12b09c46a3e32df3a90d326f30171bc789d1bcb8 (patch)
tree013d544eca009f2da2009372928c43c17d0387eb
parentee76bcb34de37f748d4377c0dc2d4899a0c41d6a (diff)
Fix unit tests
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java47
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java16
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java4
5 files changed, 43 insertions, 36 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 5c1134f928c..e4ca83640e9 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -50,10 +50,10 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
new QueryProfileRegistry(),
new ImportedMlModels(),
new AttributeFields(s)).configProperties();
- assertEquals("(rankingExpression(sin).rankingScript,x * x)",
- testRankProperties.get(0).toString());
assertEquals("(rankingExpression(sin@).rankingScript,2 * 2)",
- censorBindingHash(testRankProperties.get(1).toString()));
+ censorBindingHash(testRankProperties.get(0).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript,x * x)",
+ testRankProperties.get(1).toString());
assertEquals("(vespa.rank.firstphase,rankingExpression(sin@))",
censorBindingHash(testRankProperties.get(2).toString()));
}
@@ -94,27 +94,26 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
new QueryProfileRegistry(),
new ImportedMlModels(),
new AttributeFields(s)).configProperties();
+ assertEquals("(rankingExpression(tan@).rankingScript,2 * 2)",
+ censorBindingHash(testRankProperties.get(0).toString()));
+ assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))",
+ censorBindingHash(testRankProperties.get(1).toString()));
+ assertEquals("(rankingExpression(sin@).rankingScript,rankingExpression(cos@))",
+ censorBindingHash(testRankProperties.get(2).toString()));
assertEquals("(rankingExpression(tan).rankingScript,x * x)",
- testRankProperties.get(0).toString());
+ testRankProperties.get(3).toString());
assertEquals("(rankingExpression(tan@).rankingScript,x * x)",
- censorBindingHash(testRankProperties.get(1).toString()));
- assertEquals("(rankingExpression(cos).rankingScript,rankingExpression(tan@))",
- censorBindingHash(testRankProperties.get(2).toString()));
- assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))",
- censorBindingHash(testRankProperties.get(3).toString()));
- assertEquals("(rankingExpression(sin).rankingScript,rankingExpression(cos@))",
censorBindingHash(testRankProperties.get(4).toString()));
- assertEquals("(rankingExpression(tan@).rankingScript,2 * 2)",
+ assertEquals("(rankingExpression(cos).rankingScript,rankingExpression(tan@))",
censorBindingHash(testRankProperties.get(5).toString()));
assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))",
- censorBindingHash(testRankProperties.get(6).toString()));
- assertEquals("(rankingExpression(sin@).rankingScript,rankingExpression(cos@))",
+ censorBindingHash(testRankProperties.get(6).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript,rankingExpression(cos@))",
censorBindingHash(testRankProperties.get(7).toString()));
assertEquals("(vespa.rank.firstphase,rankingExpression(sin@))",
censorBindingHash(testRankProperties.get(8).toString()));
}
-
@Test
public void testFunctionShadowingArguments() throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
@@ -144,12 +143,12 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
new QueryProfileRegistry(),
new ImportedMlModels(),
new AttributeFields(s)).configProperties();
- assertEquals("(rankingExpression(sin).rankingScript,x * x)",
- testRankProperties.get(0).toString());
assertEquals("(rankingExpression(sin@).rankingScript,4.0 * 4.0)",
- censorBindingHash(testRankProperties.get(1).toString()));
+ censorBindingHash(testRankProperties.get(0).toString()));
assertEquals("(rankingExpression(sin@).rankingScript,cos(5.0) * cos(5.0))",
- censorBindingHash(testRankProperties.get(2).toString()));
+ censorBindingHash(testRankProperties.get(1).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript,x * x)",
+ testRankProperties.get(2).toString());
assertEquals("(vespa.rank.firstphase,rankingExpression(firstphase))",
censorBindingHash(testRankProperties.get(3).toString()));
assertEquals("(rankingExpression(firstphase).rankingScript,cos(rankingExpression(sin@)) + rankingExpression(sin@))",
@@ -208,17 +207,17 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
queryProfiles,
new ImportedMlModels(),
new AttributeFields(s)).configProperties();
- assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
- testRankProperties.get(0).toString());
assertEquals("(rankingExpression(relu@).rankingScript,max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))",
- censorBindingHash(testRankProperties.get(1).toString()));
+ censorBindingHash(testRankProperties.get(0).toString()));
assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))",
- censorBindingHash(testRankProperties.get(2).toString()));
+ censorBindingHash(testRankProperties.get(1).toString()));
assertEquals("(rankingExpression(hidden_layer).type,tensor(x[]))",
- censorBindingHash(testRankProperties.get(3).toString()));
+ censorBindingHash(testRankProperties.get(2).toString()));
assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))",
- testRankProperties.get(4).toString());
+ testRankProperties.get(3).toString());
assertEquals("(rankingExpression(final_layer).type,tensor(x[]))",
+ testRankProperties.get(4).toString());
+ assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
testRankProperties.get(5).toString());
assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))",
testRankProperties.get(6).toString());
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 0cd6674751e..e6616ce0dd1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -41,6 +41,14 @@ class RankProfileSearchFixture {
private Search search;
private Map<String, RankProfile> compiledRankProfiles = new HashMap<>();
+ public RankProfileRegistry getRankProfileRegistry() {
+ return rankProfileRegistry;
+ }
+
+ public QueryProfileRegistry getQueryProfileRegistry() {
+ return queryProfileRegistry;
+ }
+
RankProfileSearchFixture(String rankProfiles) throws ParseException {
this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles);
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index a64a964727c..680f2dd9659 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -319,7 +319,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testFunctionGeneration() {
final String name = "mnist_saved";
- final String expression = "join(join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))";
final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))";
final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))";
@@ -349,7 +349,7 @@ public class RankingExpressionWithTensorFlowTestCase {
" rank-profile my_profile_child inherits my_profile {\n" +
" }";
- final String expression = "join(join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))";
final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))";
final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))";
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
index b3eda9b7e13..1567a4c3b5e 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
@@ -44,20 +44,20 @@ public class RankingExpressionsTestCase extends SchemaTestCase {
new AttributeFields(search)).configProperties();
assertEquals(6, rankProperties.size());
- assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(0).getFirst());
- assertEquals("var1 * var2 + 890", rankProperties.get(0).getSecond());
+ assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(2).getFirst());
+ assertEquals("var1 * var2 + 890", rankProperties.get(2).getSecond());
- assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(1).getFirst());
- assertEquals("78 + closeness(distance)", rankProperties.get(1).getSecond());
+ assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(3).getFirst());
+ assertEquals("78 + closeness(distance)", rankProperties.get(3).getSecond());
assertEquals("rankingExpression(firstphase).rankingScript", rankProperties.get(5).getFirst());
assertEquals("0.8 + 0.2 * rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c) + 0.8 * rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6) * closeness(distance)", rankProperties.get(5).getSecond());
- assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(3).getFirst());
- assertEquals("7 * 8 + 890", rankProperties.get(3).getSecond());
+ assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(1).getFirst());
+ assertEquals("7 * 8 + 890", rankProperties.get(1).getSecond());
- assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(2).getFirst());
- assertEquals("4 * 5 + 890", rankProperties.get(2).getSecond());
+ assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(0).getFirst());
+ assertEquals("4 * 5 + 890", rankProperties.get(0).getSecond());
}
@Test(expected = IllegalArgumentException.class)
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
index ca84eb5eed7..57dbf132883 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
@@ -64,10 +64,10 @@ public class MlModelsTest {
private final String testProfile =
"rankingExpression(input).rankingScript: attribute(argument)\n" +
"rankingExpression(input).type: tensor<float>(d0[],d1[784])\n" +
- "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" +
- "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" +
"rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(rankingExpression(input), (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
"rankingExpression(mnist_tensorflow).rankingScript: join(reduce(join(map(join(reduce(join(join(join(0.009999999776482582, rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.0507009873554805 * if (a >= 0, a, 1.6732632423543772 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" +
+ "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" +
+ "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" +
"rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" +
"rankingExpression(mnist_softmax_onnx).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))\n" +
"rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (!(f56 >= -0.242398), 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (!(f60 >= -0.482947), if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" +