diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/query')
4 files changed, 74 insertions, 14 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java index f58395fd5bb..6e778a0fac6 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java @@ -30,7 +30,7 @@ import java.util.Map; public class QueryProfileProperties extends Properties { private final CompiledQueryProfile profile; - private final Embedder embedder; + private final Map<String, Embedder> embedders; // Note: The priority order is: values has precedence over references @@ -45,14 +45,18 @@ public class QueryProfileProperties extends Properties { private List<Pair<CompoundName, CompiledQueryProfile>> references = null; public QueryProfileProperties(CompiledQueryProfile profile) { - this(profile, Embedder.throwsOnUse); + this(profile, Embedder.throwsOnUse.asMap()); } - /** Creates an instance from a profile, throws an exception if the given profile is null */ public QueryProfileProperties(CompiledQueryProfile profile, Embedder embedder) { + this(profile, Map.of(Embedder.defaultEmbedderName, embedder)); + } + + /** Creates an instance from a profile, throws an exception if the given profile is null */ + public QueryProfileProperties(CompiledQueryProfile profile, Map<String, Embedder> embedders) { Validator.ensureNotNull("The profile wrapped by this cannot be null", profile); this.profile = profile; - this.embedder = embedder; + this.embedders = embedders; } /** Returns the query profile backing this, or null if none */ @@ -147,7 +151,7 @@ public class QueryProfileProperties extends Properties { if (fieldDescription != null) { if (i == name.size() - 1) { // at the end of the path, check the assignment type - var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedder, context); + var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context); var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext); if (convertedValue == null && fieldDescription.getType() instanceof QueryProfileFieldType diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index 8dfb67a9d5f..1fc405051ac 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -14,14 +14,20 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; - private final Embedder embedder; + private final Map<String, Embedder> embedders; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, Map<String, String> context) { + this(destination, registry, Map.of(Embedder.defaultEmbedderName, embedder), context); + } + + public ConversionContext(String destination, CompiledQueryProfileRegistry registry, + Map<String, Embedder> embedders, + Map<String, String> context) { this.destination = destination; this.registry = registry; - this.embedder = embedder; + this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; } @@ -33,14 +39,14 @@ public class ConversionContext { CompiledQueryProfileRegistry registry() {return registry;} /** Returns the configured embedder, never null */ - Embedder embedder() { return embedder; } + Map<String, Embedder> embedders() { return embedders; } /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } /** Returns an empty context */ public static ConversionContext empty() { - return new ConversionContext(null, null, Embedder.throwsOnUse, Map.of()); + return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index d6676db3774..6f1cfccc16b 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -6,6 +6,12 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + /** * A tensor field type in a query profile * @@ -13,6 +19,8 @@ import com.yahoo.tensor.TensorType; */ public class TensorFieldType extends FieldType { + private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); + private final TensorType type; /** Creates a tensor field type with information about the kind of tensor this will hold */ @@ -52,8 +60,46 @@ public class TensorFieldType extends FieldType { private Tensor encode(String s, ConversionContext context) { if ( ! s.endsWith(")")) throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); - String text = s.substring("embed(".length(), s.length() - 1); - return context.embedder().embed(text, toEmbedderContext(context), type); + String argument = s.substring("embed(".length(), s.length() - 1); + Embedder embedder; + + // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") + Matcher matcher = embedderArgumentRegexp.matcher(argument); + if (matcher.matches()) { + String embedderId = matcher.group(1); + argument = matcher.group(2); + if (!context.embedders().containsKey(embedderId)) { + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(context.embedders())); + } + embedder = context.embedders().get(embedderId); + } else if (context.embedders().size() == 0) { + throw new IllegalStateException("No embedders provided"); // should never happen + } else if (context.embedders().size() > 1) { + throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + + "Valid embedders are " + validEmbedders(context.embedders())); + } else { + embedder = context.embedders().entrySet().stream().findFirst().get().getValue(); + } + + return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type); + } + + private static String removeQuotes(String s) { + if (s.startsWith("'") && s.endsWith("'")) { + return s.substring(1, s.length() - 1); + } + if (s.startsWith("\"") && s.endsWith("\"")) { + return s.substring(1, s.length() - 1); + } + return s; + } + + private static String validEmbedders(Map<String, Embedder> embedders) { + List<String> embedderIds = new ArrayList<>(); + embedders.forEach((key, value) -> embedderIds.add(key)); + embedderIds.sort(null); + return String.join(",", embedderIds); } private Embedder.Context toEmbedderContext(ConversionContext context) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index 243915662d2..dc901589cde 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -34,12 +34,16 @@ public class QueryProperties extends Properties { private Query query; private final CompiledQueryProfileRegistry profileRegistry; - private final Embedder embedder; + private final Map<String, Embedder> embedders; public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Embedder embedder) { + this(query, profileRegistry, Map.of(Embedder.defaultEmbedderName, embedder)); + } + + public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Map<String, Embedder> embedders) { this.query = query; this.profileRegistry = profileRegistry; - this.embedder = embedder; + this.embedders = embedders; } public void setParentQuery(Query query) { @@ -394,7 +398,7 @@ public class QueryProperties extends Properties { if (type == null) return value; // no type info -> keep as string FieldDescription field = type.getField(key); if (field == null) return value; // ditto - return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedder, context)); + return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context)); } private void throwIllegalParameter(String key,String namespace) { |