aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/search/query
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-03-30 17:22:08 +0200
committerJon Bratseth <bratseth@gmail.com>2022-03-30 17:22:08 +0200
commitc0cd66e75bd5a8ff5cd94a72529cb951a1046746 (patch)
tree739856a9eb2ba3069f6427b9aa562ff6ad8376df /container-search/src/test/java/com/yahoo/search/query
parent211fdb61ecf7379c8113e0da413a8cc16f72494d (diff)
Support substitutions in tensors
Diffstat (limited to 'container-search/src/test/java/com/yahoo/search/query')
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java66
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml2
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml1
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml1
4 files changed, 62 insertions, 8 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java
index c47e1e5b23c..ad85b5e37a1 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java
@@ -3,7 +3,12 @@ package com.yahoo.search.query.profile.config.test;
import com.yahoo.jdisc.http.HttpRequest.Method;
import com.yahoo.container.jdisc.HttpRequest;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
import com.yahoo.processing.request.CompoundName;
+import com.yahoo.search.query.profile.types.test.QueryProfileTypeTestCase;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
import com.yahoo.search.Query;
import com.yahoo.search.query.Properties;
@@ -17,6 +22,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType;
import org.junit.Test;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@@ -474,20 +480,64 @@ public class XmlReadingTestCase {
CompiledQueryProfileRegistry registry = new QueryProfileXMLReader().read("src/test/java/com/yahoo/search/query/profile/config/test/tensortypes").compile();
QueryProfileType type1 = registry.getTypeRegistry().getComponent("type1");
- assertEquals("tensor<float>(x[1])", type1.getFieldType(new CompoundName("ranking.features.query(tensor_1)")).stringValue());
+ assertEquals(TensorType.fromSpec("tensor<float>(x[1])"),
+ type1.getFieldType(new CompoundName("ranking.features.query(tensor_1)")).asTensorType());
assertNull(type1.getFieldType(new CompoundName("ranking.features.query(tensor_2)")));
assertNull(type1.getFieldType(new CompoundName("ranking.features.query(tensor_3)")));
+ assertEquals(TensorType.fromSpec("tensor(key{})"),
+ type1.getFieldType(new CompoundName("ranking.features.query(tensor_4)")).asTensorType());
QueryProfileType type2 = registry.getTypeRegistry().getComponent("type2");
assertNull(type2.getFieldType(new CompoundName("ranking.features.query(tensor_1)")));
- assertEquals("tensor<float>(x[2])", type2.getFieldType(new CompoundName("ranking.features.query(tensor_2)")).stringValue());
- assertEquals("tensor<float>(x[3])", type2.getFieldType(new CompoundName("ranking.features.query(tensor_3)")).stringValue());
+ assertEquals(TensorType.fromSpec("tensor<float>(x[2])"),
+ type2.getFieldType(new CompoundName("ranking.features.query(tensor_2)")).asTensorType());
+ assertEquals(TensorType.fromSpec("tensor<float>(x[3])"),
+ type2.getFieldType(new CompoundName("ranking.features.query(tensor_3)")).asTensorType());
+
+ Query queryProfile1 = new Query.Builder().setQueryProfile(registry.getComponent("profile1"))
+ .setRequest("?query=test&ranking.features.query(tensor_1)=[1.200]")
+ .build();
+ assertEquals("tensor_1 received as a tensor tensor",
+ Tensor.from("tensor<float>(x[1]):[1.2]"),
+ queryProfile1.properties().get("ranking.features.query(tensor_1)"));
+ assertEquals("tensor_4 contained in the profile is a tensor",
+ Tensor.from("tensor(key{}):{key1:1.0}"),
+ queryProfile1.properties().get("ranking.features.query(tensor_4)"));
+
+ Query queryProfile2 = new Query.Builder().setQueryProfile(registry.getComponent("profile2"))
+ .setEmbedder(new MockEmbedder("text-to-embed",
+ Tensor.from("tensor(x[3]):[1, 2, 3]")))
+ .setRequest("?query=test&ranking.features.query(tensor_1)=[1.200]")
+ .build();
+ assertEquals("tensor_1 received as a string as it is not in type2",
+ "[1.200]",
+ queryProfile2.properties().get("ranking.features.query(tensor_1)"));
+ //assertEquals(Tensor.from("tensor(x[3]):[1, 2, 3]"),
+ // queryProfile2.properties().get("ranking.features.query(tensor_3)"));
+ }
- Query queryProfile1 = new Query("?query=test&ranking.features.query(tensor_1)=[1.200]", registry.getComponent("profile1"));
- assertEquals("Is received as a tensor tensor", "tensor<float>(x[1]):[1.2]", queryProfile1.properties().get("ranking.features.query(tensor_1)").toString());
+ private static final class MockEmbedder implements Embedder {
- Query queryProfile2 = new Query("?query=test&ranking.features.query(tensor_1)=[1.200]", registry.getComponent("profile2"));
- assertEquals("Is received as a string", "[1.200]", queryProfile2.properties().get("ranking.features.query(tensor_1)").toString());
- }
+ private final String expectedText;
+ private final Tensor tensorToReturn;
+ public MockEmbedder(String expectedText, Tensor tensorToReturn) {
+ this.expectedText = expectedText;
+ this.tensorToReturn = tensorToReturn;
+ }
+
+ @Override
+ public List<Integer> embed(String text, Embedder.Context context) {
+ fail("Unexpected call");
+ return null;
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ assertEquals(expectedText, text);
+ assertEquals(tensorToReturn.type(), tensorType);
+ return tensorToReturn;
+ }
+
+ }
}
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml
index 3de0e219158..cb1260bad8e 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml
@@ -1,3 +1,5 @@
<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<query-profile id="profile1" type="type1">
+ <field name="mykey">key1</field>
+ <field name="ranking.features.query(tensor_4)">{{key:"%{mykey}"}:1.0}</field>
</query-profile>
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml
index 6a2065d05eb..f3017327638 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml
@@ -1,3 +1,4 @@
<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<query-profile id="profile2" type="type2">
+ <!-- <field name="ranking.features.query(tensor_3)">embed(%{mytext})</field> embed(...) at config time is not supported -->
</query-profile>
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml
index a928c5f4deb..71e8e0f9d71 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml
@@ -1,4 +1,5 @@
<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<query-profile-type id="type1">
<field name="ranking.features.query(tensor_1)" type="tensor&lt;float&gt;(x[1])" />
+ <field name="ranking.features.query(tensor_4)" type="tensor(key{})" />
</query-profile-type>