aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/schema/RankingExpressionShadowingTestCase.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-05-19 12:03:06 +0200
committerJon Bratseth <bratseth@gmail.com>2022-05-19 12:03:06 +0200
commit5c24dc5c9642a8d9ed70aee4c950fd0678a1ebec (patch)
treebd9b74bf00c832456f0b83c1b2cd7010be387d68 /config-model/src/test/java/com/yahoo/schema/RankingExpressionShadowingTestCase.java
parentf17c4fe7de4c55f5c4ee61897eab8c2f588d8405 (diff)
Rename the 'searchdefinition' package to 'schema'
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/RankingExpressionShadowingTestCase.java')
-rw-r--r--config-model/src/test/java/com/yahoo/schema/RankingExpressionShadowingTestCase.java251
1 files changed, 251 insertions, 0 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/schema/RankingExpressionShadowingTestCase.java
new file mode 100644
index 00000000000..250879b1570
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/schema/RankingExpressionShadowingTestCase.java
@@ -0,0 +1,251 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.schema;
+
+import com.yahoo.collections.Pair;
+import com.yahoo.config.model.application.provider.MockFileRegistry;
+import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.search.query.profile.QueryProfile;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.types.FieldDescription;
+import com.yahoo.search.query.profile.types.QueryProfileType;
+import com.yahoo.schema.derived.AttributeFields;
+import com.yahoo.schema.derived.RawRankProfile;
+import com.yahoo.schema.parser.ParseException;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class RankingExpressionShadowingTestCase extends AbstractSchemaTestCase {
+
+ @Test
+ public void testBasicFunctionShadowing() throws ParseException {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
+ builder.addSchema(
+ "search test {\n" +
+ " document test { \n" +
+ " field a type string { \n" +
+ " indexing: index \n" +
+ " }\n" +
+ " }\n" +
+ " \n" +
+ " rank-profile test {\n" +
+ " function sin(x) {\n" +
+ " expression: x * x\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: sin(2)\n" +
+ " }\n" +
+ " }\n" +
+ "\n" +
+ "}\n");
+ builder.build(true);
+ Schema s = builder.getSchema();
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
+ List<Pair<String, String>> testRankProperties = createRawRankProfile(test, new QueryProfileRegistry(), s).configProperties();
+ assertEquals("(rankingExpression(sin@).rankingScript, 2 * 2)",
+ censorBindingHash(testRankProperties.get(0).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript, x * x)",
+ testRankProperties.get(1).toString());
+ assertEquals("(vespa.rank.firstphase, rankingExpression(sin@))",
+ censorBindingHash(testRankProperties.get(2).toString()));
+ }
+
+
+ @Test
+ public void testMultiLevelFunctionShadowing() throws ParseException {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
+ builder.addSchema(
+ "search test {\n" +
+ " document test { \n" +
+ " field a type string { \n" +
+ " indexing: index \n" +
+ " }\n" +
+ " }\n" +
+ " \n" +
+ " rank-profile test {\n" +
+ " function tan(x) {\n" +
+ " expression: x * x\n" +
+ " }\n" +
+ " function cos(x) {\n" +
+ " expression: tan(x)\n" +
+ " }\n" +
+ " function sin(x) {\n" +
+ " expression: cos(x)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: sin(2)\n" +
+ " }\n" +
+ " }\n" +
+ "\n" +
+ "}\n");
+ builder.build(true);
+ Schema s = builder.getSchema();
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
+ List<Pair<String, String>> testRankProperties = createRawRankProfile(test, new QueryProfileRegistry(), s).configProperties();
+ assertEquals("(rankingExpression(tan@).rankingScript, 2 * 2)",
+ censorBindingHash(testRankProperties.get(0).toString()));
+ assertEquals("(rankingExpression(cos@).rankingScript, rankingExpression(tan@))",
+ censorBindingHash(testRankProperties.get(1).toString()));
+ assertEquals("(rankingExpression(sin@).rankingScript, rankingExpression(cos@))",
+ censorBindingHash(testRankProperties.get(2).toString()));
+ assertEquals("(rankingExpression(tan).rankingScript, x * x)",
+ testRankProperties.get(3).toString());
+ assertEquals("(rankingExpression(tan@).rankingScript, x * x)",
+ censorBindingHash(testRankProperties.get(4).toString()));
+ assertEquals("(rankingExpression(cos).rankingScript, rankingExpression(tan@))",
+ censorBindingHash(testRankProperties.get(5).toString()));
+ assertEquals("(rankingExpression(cos@).rankingScript, rankingExpression(tan@))",
+ censorBindingHash(testRankProperties.get(6).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript, rankingExpression(cos@))",
+ censorBindingHash(testRankProperties.get(7).toString()));
+ assertEquals("(vespa.rank.firstphase, rankingExpression(sin@))",
+ censorBindingHash(testRankProperties.get(8).toString()));
+ }
+
+ @Test
+ public void testFunctionShadowingArguments() throws ParseException {
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry);
+ builder.addSchema(
+ "search test {\n" +
+ " document test { \n" +
+ " field a type string { \n" +
+ " indexing: index \n" +
+ " }\n" +
+ " }\n" +
+ " \n" +
+ " rank-profile test {\n" +
+ " function sin(x) {\n" +
+ " expression: x * x\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: cos(sin(2*2)) + sin(cos(1+4))\n" +
+ " }\n" +
+ " }\n" +
+ "\n" +
+ "}\n");
+ builder.build(true);
+ Schema s = builder.getSchema();
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels());
+ List<Pair<String, String>> testRankProperties = createRawRankProfile(test, new QueryProfileRegistry(), s).configProperties();
+ assertEquals("(rankingExpression(sin@).rankingScript, 4.0 * 4.0)",
+ censorBindingHash(testRankProperties.get(0).toString()));
+ assertEquals("(rankingExpression(sin@).rankingScript, cos(5.0) * cos(5.0))",
+ censorBindingHash(testRankProperties.get(1).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript, x * x)",
+ testRankProperties.get(2).toString());
+ assertEquals("(vespa.rank.firstphase, rankingExpression(firstphase))",
+ censorBindingHash(testRankProperties.get(3).toString()));
+ assertEquals("(rankingExpression(firstphase).rankingScript, cos(rankingExpression(sin@)) + rankingExpression(sin@))",
+ censorBindingHash(testRankProperties.get(4).toString()));
+ }
+
+ @Test
+ public void testNeuralNetworkSetup() throws ParseException {
+ // Note: the type assigned to query profile and constant tensors here is not the correct type
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(input[1])");
+ ApplicationBuilder builder = new ApplicationBuilder(rankProfileRegistry, queryProfiles);
+ builder.addSchema(
+ "search test {\n" +
+ " document test { \n" +
+ " field a type string { \n" +
+ " indexing: index \n" +
+ " }\n" +
+ " }\n" +
+ " \n" +
+ " rank-profile test {\n" +
+ " function relu(x) {\n" + // relu is a built in function, redefined here
+ " expression: max(1.0, x)\n" +
+ " }\n" +
+ " function hidden_layer() {\n" +
+ " expression: relu(sum(query(q) * constant(W_hidden), input) + constant(b_input))\n" +
+ " }\n" +
+ " function final_layer() {\n" +
+ " expression: sigmoid(sum(hidden_layer * constant(W_final), hidden) + constant(b_final))\n" +
+ " }\n" +
+ " second-phase {\n" +
+ " expression: sum(final_layer)\n" +
+ " }\n" +
+ " }\n" +
+ " constant W_hidden {\n" +
+ " type: tensor(hidden[1])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant b_input {\n" +
+ " type: tensor(hidden[1])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant W_final {\n" +
+ " type: tensor(final[1])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ " constant b_final {\n" +
+ " type: tensor(final[1])\n" +
+ " file: ignored.json\n" +
+ " }\n" +
+ "}\n");
+ builder.build(true);
+ Schema s = builder.getSchema();
+ RankProfile test = rankProfileRegistry.get(s, "test").compile(queryProfiles, new ImportedMlModels());
+ List<Pair<String, String>> testRankProperties = createRawRankProfile(test, queryProfiles, s).configProperties();
+ assertEquals("(rankingExpression(autogenerated_ranking_feature@).rankingScript, reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input))",
+ censorBindingHash(testRankProperties.get(0).toString()));
+ assertEquals("(rankingExpression(relu@).rankingScript, max(1.0,rankingExpression(autogenerated_ranking_feature@)))",
+ censorBindingHash(testRankProperties.get(1).toString()));
+ assertEquals("(rankingExpression(hidden_layer).rankingScript, rankingExpression(relu@))",
+ censorBindingHash(testRankProperties.get(2).toString()));
+ assertEquals("(rankingExpression(final_layer).rankingScript, sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))",
+ testRankProperties.get(4).toString());
+ assertEquals("(rankingExpression(relu).rankingScript, max(1.0,x))",
+ testRankProperties.get(6).toString());
+ assertEquals("(vespa.rank.secondphase, rankingExpression(secondphase))",
+ testRankProperties.get(7).toString());
+ assertEquals("(rankingExpression(secondphase).rankingScript, reduce(rankingExpression(final_layer), sum))",
+ testRankProperties.get(8).toString());
+ }
+
+ private static RawRankProfile createRawRankProfile(RankProfile profile, QueryProfileRegistry queryProfiles, Schema schema) {
+ return new RawRankProfile(profile,
+ new LargeRankExpressions(new MockFileRegistry()),
+ queryProfiles,
+ new ImportedMlModels(),
+ new AttributeFields(schema),
+ new TestProperties());
+ }
+
+ private QueryProfileRegistry queryProfileWith(String field, String type) {
+ QueryProfileType queryProfileType = new QueryProfileType("root");
+ queryProfileType.addField(new FieldDescription(field, type));
+ QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry();
+ queryProfileRegistry.getTypeRegistry().register(queryProfileType);
+ QueryProfile profile = new QueryProfile("default");
+ profile.setType(queryProfileType);
+ queryProfileRegistry.register(profile);
+ return queryProfileRegistry;
+ }
+
+ private String censorBindingHash(String s) {
+ StringBuilder b = new StringBuilder();
+ boolean areInHash = false;
+ for (int i = 0; i < s.length() ; i++) {
+ char current = s.charAt(i);
+ if ( ! Character.isLetterOrDigit(current)) // end of hash
+ areInHash = false;
+ if ( ! areInHash)
+ b.append(current);
+ if (current == '@') // start of hash
+ areInHash = true;
+ }
+ return b.toString();
+ }
+
+}