From adf85e21711e8f1b184309986911b86ed92be89e Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Sat, 20 Jan 2024 12:59:20 +0100 Subject: Revert "Merge pull request #29905 from vespa-engine/revert-29884-bratseth/param-refs-in-embed" This reverts commit c6b547c0c2898a324983356aa677ea3082533f7d, reversing changes made to 8c7f8c17ad5e1de5adcc71ee34f2a3c1cd36d6bd. --- .../query/profile/types/ConversionContext.java | 5 +++ .../query/profile/types/TensorFieldType.java | 3 +- .../properties/RankProfileInputProperties.java | 3 +- .../search/schema/internal/TensorConverter.java | 52 +++++++++++++++------- .../yahoo/search/query/RankProfileInputTest.java | 27 +++++++++-- 5 files changed, 68 insertions(+), 22 deletions(-) (limited to 'container-search') diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index bef766e7ef9..70f6e405a92 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -15,6 +15,7 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; private final Map embedders; + private final Map contextValues; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, @@ -30,6 +31,7 @@ public class ConversionContext { this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; + this.contextValues = context; } /** Returns the local name of the field which will receive the converted value (or null when this is empty) */ @@ -44,6 +46,9 @@ public class ConversionContext { /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } + /** Returns a read-only map of context key-values which can be looked up during conversion. */ + Map contextValues() { return contextValues; } + /** Returns an empty context */ public static ConversionContext empty() { return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index cfadd79de8f..e16f8e7b0cd 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -48,7 +48,8 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); - return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language()); + return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, + context.language(), context.contextValues()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index c9f935e5f52..25a5c277dce 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -44,7 +44,8 @@ public class RankProfileInputProperties extends Properties { value = tensorConverter.convertTo(expectedType, name.last(), value, - query.getModel().getLanguage()); + query.getModel().getLanguage(), + context); } } catch (IllegalArgumentException e) { diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java index 6da53ae699c..94f92c7fd48 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java +++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java @@ -19,7 +19,8 @@ import java.util.regex.Pattern; */ public class TensorConverter { - private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndQuotedTextRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*([\"'].*[\"'])"); + private static final Pattern embedderArgumentAndReferenceRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*(@.*)"); private final Map embedders; @@ -27,8 +28,9 @@ public class TensorConverter { this.embedders = embedders; } - public Tensor convertTo(TensorType type, String key, Object value, Language language) { - var context = new Embedder.Context(key).setLanguage(language); + public Tensor convertTo(TensorType type, String key, Object value, Language language, + Map contextValues) { + var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues); Tensor tensor = toTensor(type, value, context); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) @@ -55,16 +57,16 @@ public class TensorConverter { String embedderId; // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher = embedderArgumentRegexp.matcher(argument); - if (matcher.matches()) { + Matcher matcher; + if (( matcher = embedderArgumentAndQuotedTextRegexp.matcher(argument)).matches()) { embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); argument = matcher.group(2); - if ( ! embedders.containsKey(embedderId)) { - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); - } - embedder = embedders.get(embedderId); - } else if (embedders.size() == 0) { + } else if (( matcher = embedderArgumentAndReferenceRegexp.matcher(argument)).matches()) { + embedderId = matcher.group(1); + embedder = requireEmbedder(embedderId); + argument = matcher.group(2); + } else if (embedders.isEmpty()) { throw new IllegalStateException("No embedders provided"); // should never happen } else if (embedders.size() > 1) { throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + @@ -74,19 +76,35 @@ public class TensorConverter { embedderId = entry.getKey(); embedder = entry.getValue(); } - return embedder.embed(removeQuotes(argument), embedderContext.copy().setEmbedderId(embedderId), type); + return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type); } - private static String removeQuotes(String s) { - if (s.startsWith("'") && s.endsWith("'")) { + private Embedder requireEmbedder(String embedderId) { + if ( ! embedders.containsKey(embedderId)) + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + return embedders.get(embedderId); + } + + private static String resolve(String s, Embedder.Context embedderContext) { + if (s.startsWith("'") && s.endsWith("'")) return s.substring(1, s.length() - 1); - } - if (s.startsWith("\"") && s.endsWith("\"")) { + if (s.startsWith("\"") && s.endsWith("\"")) return s.substring(1, s.length() - 1); - } + if (s.startsWith("@")) + return resolveReference(s, embedderContext); return s; } + private static String resolveReference(String s, Embedder.Context embedderContext) { + String referenceKey = s.substring(1); + String referencedValue = embedderContext.getContextValues().get(referenceKey); + if (referencedValue == null) + throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey + + "' used in an embed() argument"); + return referencedValue; + } + private static String validEmbedders(Map embedders) { List embedderIds = new ArrayList<>(); embedders.forEach((key, value) -> embedderIds.add(key)); diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java index 90e21e5f3b0..429b8d1c6cb 100644 --- a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -185,6 +185,21 @@ public class RankProfileInputTest { assertEmbedQuery("embed(emb2, '" + text + "')", embedding2, embedders, Language.UNKNOWN.languageCode()); } + @Test + void testUnembeddedTensorRankFeatureInRequestReferencedFromAParameter() { + String text = "text to embed into a tensor"; + Tensor embedding1 = Tensor.from("tensor(x[5]):[3,7,4,0,0]]"); + + Map embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(@param1)", embedding1, embedders, null, text); + assertEmbedQuery("embed(emb1, @param1)", embedding1, embedders, null, text); + assertEmbedQueryFails("embed(emb1, @noSuchParam)", embedding1, embedders, + "Could not resolve query parameter reference 'noSuchParam' " + + "used in an embed() argument"); + } + private Query createTensor1Query(String tensorString, String profile, String additionalParams) { return new Query.Builder() .setSchemaInfo(createSchemaInfo()) @@ -202,18 +217,24 @@ public class RankProfileInputTest { } private void assertEmbedQuery(String embed, Tensor expected, Map embedders) { - assertEmbedQuery(embed, expected, embedders, null); + assertEmbedQuery(embed, expected, embedders, null, null); } private void assertEmbedQuery(String embed, Tensor expected, Map embedders, String language) { + assertEmbedQuery(embed, expected, embedders, language, null); + } + private void assertEmbedQuery(String embed, Tensor expected, Map embedders, String language, String param1Value) { String languageParam = language == null ? "" : "&language=" + language; + String param1 = param1Value == null ? "" : "¶m1=" + urlEncode(param1Value); + String destination = "query(myTensor4)"; Query query = new Query.Builder().setRequest(HttpRequest.createTestRequest( "?" + urlEncode("ranking.features." + destination) + "=" + urlEncode(embed) + "&ranking=commonProfile" + - languageParam, + languageParam + + param1, com.yahoo.jdisc.http.HttpRequest.Method.GET)) .setSchemaInfo(createSchemaInfo()) .setQueryProfile(createQueryProfile()) @@ -230,7 +251,7 @@ public class RankProfileInputTest { if (t.getMessage().equals(errMsg)) return; t = t.getCause(); } - fail("Error '" + errMsg + "' not thrown"); + fail("Exception with message '" + errMsg + "' not thrown"); } private CompiledQueryProfile createQueryProfile() { -- cgit v1.2.3