diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-03-30 17:22:08 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-03-30 17:22:08 +0200 |
commit | c0cd66e75bd5a8ff5cd94a72529cb951a1046746 (patch) | |
tree | 739856a9eb2ba3069f6427b9aa562ff6ad8376df /container-search/src/test/java/com/yahoo/search/query | |
parent | 211fdb61ecf7379c8113e0da413a8cc16f72494d (diff) |
Support substitutions in tensors
Diffstat (limited to 'container-search/src/test/java/com/yahoo/search/query')
4 files changed, 62 insertions, 8 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java index c47e1e5b23c..ad85b5e37a1 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java @@ -3,7 +3,12 @@ package com.yahoo.search.query.profile.config.test; import com.yahoo.jdisc.http.HttpRequest.Method; import com.yahoo.container.jdisc.HttpRequest; +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; import com.yahoo.processing.request.CompoundName; +import com.yahoo.search.query.profile.types.test.QueryProfileTypeTestCase; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.yolean.Exceptions; import com.yahoo.search.Query; import com.yahoo.search.query.Properties; @@ -17,6 +22,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType; import org.junit.Test; import java.util.HashMap; +import java.util.List; import java.util.Map; import static org.junit.Assert.assertEquals; @@ -474,20 +480,64 @@ public class XmlReadingTestCase { CompiledQueryProfileRegistry registry = new QueryProfileXMLReader().read("src/test/java/com/yahoo/search/query/profile/config/test/tensortypes").compile(); QueryProfileType type1 = registry.getTypeRegistry().getComponent("type1"); - assertEquals("tensor<float>(x[1])", type1.getFieldType(new CompoundName("ranking.features.query(tensor_1)")).stringValue()); + assertEquals(TensorType.fromSpec("tensor<float>(x[1])"), + type1.getFieldType(new CompoundName("ranking.features.query(tensor_1)")).asTensorType()); assertNull(type1.getFieldType(new CompoundName("ranking.features.query(tensor_2)"))); assertNull(type1.getFieldType(new CompoundName("ranking.features.query(tensor_3)"))); + assertEquals(TensorType.fromSpec("tensor(key{})"), + type1.getFieldType(new CompoundName("ranking.features.query(tensor_4)")).asTensorType()); QueryProfileType type2 = registry.getTypeRegistry().getComponent("type2"); assertNull(type2.getFieldType(new CompoundName("ranking.features.query(tensor_1)"))); - assertEquals("tensor<float>(x[2])", type2.getFieldType(new CompoundName("ranking.features.query(tensor_2)")).stringValue()); - assertEquals("tensor<float>(x[3])", type2.getFieldType(new CompoundName("ranking.features.query(tensor_3)")).stringValue()); + assertEquals(TensorType.fromSpec("tensor<float>(x[2])"), + type2.getFieldType(new CompoundName("ranking.features.query(tensor_2)")).asTensorType()); + assertEquals(TensorType.fromSpec("tensor<float>(x[3])"), + type2.getFieldType(new CompoundName("ranking.features.query(tensor_3)")).asTensorType()); + + Query queryProfile1 = new Query.Builder().setQueryProfile(registry.getComponent("profile1")) + .setRequest("?query=test&ranking.features.query(tensor_1)=[1.200]") + .build(); + assertEquals("tensor_1 received as a tensor tensor", + Tensor.from("tensor<float>(x[1]):[1.2]"), + queryProfile1.properties().get("ranking.features.query(tensor_1)")); + assertEquals("tensor_4 contained in the profile is a tensor", + Tensor.from("tensor(key{}):{key1:1.0}"), + queryProfile1.properties().get("ranking.features.query(tensor_4)")); + + Query queryProfile2 = new Query.Builder().setQueryProfile(registry.getComponent("profile2")) + .setEmbedder(new MockEmbedder("text-to-embed", + Tensor.from("tensor(x[3]):[1, 2, 3]"))) + .setRequest("?query=test&ranking.features.query(tensor_1)=[1.200]") + .build(); + assertEquals("tensor_1 received as a string as it is not in type2", + "[1.200]", + queryProfile2.properties().get("ranking.features.query(tensor_1)")); + //assertEquals(Tensor.from("tensor(x[3]):[1, 2, 3]"), + // queryProfile2.properties().get("ranking.features.query(tensor_3)")); + } - Query queryProfile1 = new Query("?query=test&ranking.features.query(tensor_1)=[1.200]", registry.getComponent("profile1")); - assertEquals("Is received as a tensor tensor", "tensor<float>(x[1]):[1.2]", queryProfile1.properties().get("ranking.features.query(tensor_1)").toString()); + private static final class MockEmbedder implements Embedder { - Query queryProfile2 = new Query("?query=test&ranking.features.query(tensor_1)=[1.200]", registry.getComponent("profile2")); - assertEquals("Is received as a string", "[1.200]", queryProfile2.properties().get("ranking.features.query(tensor_1)").toString()); - } + private final String expectedText; + private final Tensor tensorToReturn; + public MockEmbedder(String expectedText, Tensor tensorToReturn) { + this.expectedText = expectedText; + this.tensorToReturn = tensorToReturn; + } + + @Override + public List<Integer> embed(String text, Embedder.Context context) { + fail("Unexpected call"); + return null; + } + + @Override + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { + assertEquals(expectedText, text); + assertEquals(tensorToReturn.type(), tensorType); + return tensorToReturn; + } + + } } diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml index 3de0e219158..cb1260bad8e 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml +++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml @@ -1,3 +1,5 @@ <!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <query-profile id="profile1" type="type1"> + <field name="mykey">key1</field> + <field name="ranking.features.query(tensor_4)">{{key:"%{mykey}"}:1.0}</field> </query-profile> diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml index 6a2065d05eb..f3017327638 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml +++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml @@ -1,3 +1,4 @@ <!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <query-profile id="profile2" type="type2"> + <!-- <field name="ranking.features.query(tensor_3)">embed(%{mytext})</field> embed(...) at config time is not supported --> </query-profile> diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml index a928c5f4deb..71e8e0f9d71 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml +++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml @@ -1,4 +1,5 @@ <!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <query-profile-type id="type1"> <field name="ranking.features.query(tensor_1)" type="tensor<float>(x[1])" /> + <field name="ranking.features.query(tensor_4)" type="tensor(key{})" /> </query-profile-type> |