aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/query
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/query')
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java14
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java50
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java10
4 files changed, 74 insertions, 14 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
index f58395fd5bb..6e778a0fac6 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
@@ -30,7 +30,7 @@ import java.util.Map;
public class QueryProfileProperties extends Properties {
private final CompiledQueryProfile profile;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
// Note: The priority order is: values has precedence over references
@@ -45,14 +45,18 @@ public class QueryProfileProperties extends Properties {
private List<Pair<CompoundName, CompiledQueryProfile>> references = null;
public QueryProfileProperties(CompiledQueryProfile profile) {
- this(profile, Embedder.throwsOnUse);
+ this(profile, Embedder.throwsOnUse.asMap());
}
- /** Creates an instance from a profile, throws an exception if the given profile is null */
public QueryProfileProperties(CompiledQueryProfile profile, Embedder embedder) {
+ this(profile, Map.of(Embedder.defaultEmbedderName, embedder));
+ }
+
+ /** Creates an instance from a profile, throws an exception if the given profile is null */
+ public QueryProfileProperties(CompiledQueryProfile profile, Map<String, Embedder> embedders) {
Validator.ensureNotNull("The profile wrapped by this cannot be null", profile);
this.profile = profile;
- this.embedder = embedder;
+ this.embedders = embedders;
}
/** Returns the query profile backing this, or null if none */
@@ -147,7 +151,7 @@ public class QueryProfileProperties extends Properties {
if (fieldDescription != null) {
if (i == name.size() - 1) { // at the end of the path, check the assignment type
- var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedder, context);
+ var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context);
var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext);
if (convertedValue == null
&& fieldDescription.getType() instanceof QueryProfileFieldType
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
index 8dfb67a9d5f..1fc405051ac 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
@@ -14,14 +14,20 @@ public class ConversionContext {
private final String destination;
private final CompiledQueryProfileRegistry registry;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
private final Language language;
public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder,
Map<String, String> context) {
+ this(destination, registry, Map.of(Embedder.defaultEmbedderName, embedder), context);
+ }
+
+ public ConversionContext(String destination, CompiledQueryProfileRegistry registry,
+ Map<String, Embedder> embedders,
+ Map<String, String> context) {
this.destination = destination;
this.registry = registry;
- this.embedder = embedder;
+ this.embedders = embedders;
this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language"))
: Language.UNKNOWN;
}
@@ -33,14 +39,14 @@ public class ConversionContext {
CompiledQueryProfileRegistry registry() {return registry;}
/** Returns the configured embedder, never null */
- Embedder embedder() { return embedder; }
+ Map<String, Embedder> embedders() { return embedders; }
/** Returns the language, which is never null but may be UNKNOWN */
Language language() { return language; }
/** Returns an empty context */
public static ConversionContext empty() {
- return new ConversionContext(null, null, Embedder.throwsOnUse, Map.of());
+ return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of());
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index d6676db3774..6f1cfccc16b 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -6,6 +6,12 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Matcher;
+import java.util.regex.Pattern;
+
/**
* A tensor field type in a query profile
*
@@ -13,6 +19,8 @@ import com.yahoo.tensor.TensorType;
*/
public class TensorFieldType extends FieldType {
+ private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])");
+
private final TensorType type;
/** Creates a tensor field type with information about the kind of tensor this will hold */
@@ -52,8 +60,46 @@ public class TensorFieldType extends FieldType {
private Tensor encode(String s, ConversionContext context) {
if ( ! s.endsWith(")"))
throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
- String text = s.substring("embed(".length(), s.length() - 1);
- return context.embedder().embed(text, toEmbedderContext(context), type);
+ String argument = s.substring("embed(".length(), s.length() - 1);
+ Embedder embedder;
+
+ // Check if arguments specifies an embedder with the format embed(embedder, "text to encode")
+ Matcher matcher = embedderArgumentRegexp.matcher(argument);
+ if (matcher.matches()) {
+ String embedderId = matcher.group(1);
+ argument = matcher.group(2);
+ if (!context.embedders().containsKey(embedderId)) {
+ throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " +
+ "Valid embedders are " + validEmbedders(context.embedders()));
+ }
+ embedder = context.embedders().get(embedderId);
+ } else if (context.embedders().size() == 0) {
+ throw new IllegalStateException("No embedders provided"); // should never happen
+ } else if (context.embedders().size() > 1) {
+ throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " +
+ "Valid embedders are " + validEmbedders(context.embedders()));
+ } else {
+ embedder = context.embedders().entrySet().stream().findFirst().get().getValue();
+ }
+
+ return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type);
+ }
+
+ private static String removeQuotes(String s) {
+ if (s.startsWith("'") && s.endsWith("'")) {
+ return s.substring(1, s.length() - 1);
+ }
+ if (s.startsWith("\"") && s.endsWith("\"")) {
+ return s.substring(1, s.length() - 1);
+ }
+ return s;
+ }
+
+ private static String validEmbedders(Map<String, Embedder> embedders) {
+ List<String> embedderIds = new ArrayList<>();
+ embedders.forEach((key, value) -> embedderIds.add(key));
+ embedderIds.sort(null);
+ return String.join(",", embedderIds);
}
private Embedder.Context toEmbedderContext(ConversionContext context) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
index 243915662d2..dc901589cde 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
@@ -34,12 +34,16 @@ public class QueryProperties extends Properties {
private Query query;
private final CompiledQueryProfileRegistry profileRegistry;
- private final Embedder embedder;
+ private final Map<String, Embedder> embedders;
public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Embedder embedder) {
+ this(query, profileRegistry, Map.of(Embedder.defaultEmbedderName, embedder));
+ }
+
+ public QueryProperties(Query query, CompiledQueryProfileRegistry profileRegistry, Map<String, Embedder> embedders) {
this.query = query;
this.profileRegistry = profileRegistry;
- this.embedder = embedder;
+ this.embedders = embedders;
}
public void setParentQuery(Query query) {
@@ -394,7 +398,7 @@ public class QueryProperties extends Properties {
if (type == null) return value; // no type info -> keep as string
FieldDescription field = type.getField(key);
if (field == null) return value; // ditto
- return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedder, context));
+ return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context));
}
private void throwIllegalParameter(String key,String namespace) {