diff options
Diffstat (limited to 'container-search/src/test/java')
-rw-r--r-- | container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java | 25 |
1 files changed, 18 insertions, 7 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java index e22263070e0..f11e5614635 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java @@ -447,24 +447,31 @@ public class QueryProfileTypeTestCase { CompiledQueryProfileRegistry cRegistry = registry.compile(); String textToEmbed = "text to embed into a tensor"; + String destinationFeature = "query(myTensor4)"; Tensor expectedTensor = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); - Query query1 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features.query(myTensor4)") + + Query query1 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) + "=" + urlEncode("embed(" + textToEmbed + ")"), com.yahoo.jdisc.http.HttpRequest.Method.GET)) .setQueryProfile(cRegistry.getComponent("test")) - .setEmbedder(new MockEmbedder(textToEmbed, Language.UNKNOWN, expectedTensor)) + .setEmbedder(new MockEmbedder(textToEmbed, + Language.UNKNOWN, + destinationFeature, + expectedTensor)) .build(); assertEquals(0, query1.errors().size()); assertEquals(expectedTensor, query1.properties().get("ranking.features.query(myTensor4)")); assertEquals(expectedTensor, query1.getRanking().getFeatures().getTensor("query(myTensor4)").get()); // Explicit language - Query query2 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features.query(myTensor4)") + + Query query2 = new Query.Builder().setRequest(HttpRequest.createTestRequest("?" + urlEncode("ranking.features." + destinationFeature) + "=" + urlEncode("embed(" + textToEmbed + ")") + "&language=en", com.yahoo.jdisc.http.HttpRequest.Method.GET)) .setQueryProfile(cRegistry.getComponent("test")) - .setEmbedder(new MockEmbedder(textToEmbed, Language.ENGLISH, expectedTensor)) + .setEmbedder(new MockEmbedder(textToEmbed, + Language.ENGLISH, + destinationFeature, + expectedTensor)) .build(); assertEquals(0, query2.errors().size()); assertEquals(expectedTensor, query2.properties().get("ranking.features.query(myTensor4)")); @@ -726,26 +733,30 @@ public class QueryProfileTypeTestCase { private final String expectedText; private final Language expectedLanguage; + private final String expectedDestination; private final Tensor tensorToReturn; public MockEmbedder(String expectedText, Language expectedLanguage, + String expectedDestination, Tensor tensorToReturn) { this.expectedText = expectedText; this.expectedLanguage = expectedLanguage; + this.expectedDestination = expectedDestination; this.tensorToReturn = tensorToReturn; } @Override - public List<Integer> embed(String text, Language language) { + public List<Integer> embed(String text, Embedder.Context context) { fail("Unexpected call"); return null; } @Override - public Tensor embed(String text, Language language, TensorType tensorType) { + public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedText, text); - assertEquals(expectedLanguage, language); + assertEquals(expectedLanguage, context.getLanguage()); + assertEquals(expectedDestination, context.getDestination()); assertEquals(tensorToReturn.type(), tensorType); return tensorToReturn; } |