diff options
author | Jon Bratseth <bratseth@vespa.ai> | 2024-02-15 16:39:32 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@vespa.ai> | 2024-02-15 16:39:32 +0100 |
commit | 4a3f24577249443b6fdaeea36b7b8110da107f06 (patch) | |
tree | a3eda6efcc2371c3a13d32b5112524388a175299 /container-search | |
parent | 55d29fa838e69cf5013deb4ca15f4b131d35876c (diff) |
Resolve embed refs from query profile
Diffstat (limited to 'container-search')
8 files changed, 69 insertions, 22 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java index 879f49c455f..3227b047984 100644 --- a/container-search/src/main/java/com/yahoo/search/Query.java +++ b/container-search/src/main/java/com/yahoo/search/Query.java @@ -441,7 +441,6 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { private void setFrom(String prefix, Properties originalProperties, QueryProfileType arguments, Map<String, String> context) { prefix = append(prefix, getPrefix(arguments).toString()); for (FieldDescription field : arguments.fields().values()) { - if (field.getType() == FieldType.genericQueryProfileType) { // Generic map String fullName = append(prefix, field.getCompoundName().toString()); for (Map.Entry<String, Object> entry : originalProperties.listProperties(CompoundName.from(fullName), context).entrySet()) { @@ -463,7 +462,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable { } } - /** Calls properties.set on all entries in requestMap */ + /** Calls properties#set on all entries in requestMap */ private void setPropertiesFromRequestMap(Map<String, String> requestMap, Properties properties, boolean ignoreSelect) { var entrySet = requestMap.entrySet(); for (var entry : entrySet) { diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java index 4cc727ea23a..8778dcc7348 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java @@ -192,7 +192,7 @@ public class QueryProfileProperties extends Properties { if (fieldDescription != null) { if (i == name.size() - 1) { // at the end of the path, check the assignment type - var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context); + var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context, this); var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext); if (convertedValue == null && fieldDescription.getType() instanceof QueryProfileFieldType diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index 70f6e405a92..9bb770cf527 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -3,7 +3,9 @@ package com.yahoo.search.query.profile.types; import com.yahoo.language.Language; import com.yahoo.language.process.Embedder; +import com.yahoo.search.query.Properties; import com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry; +import com.yahoo.search.query.properties.PropertyMap; import java.util.Map; @@ -16,22 +18,25 @@ public class ConversionContext { private final CompiledQueryProfileRegistry registry; private final Map<String, Embedder> embedders; private final Map<String, String> contextValues; + private final Properties properties; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, - Map<String, String> context) { - this(destination, registry, Map.of(Embedder.defaultEmbedderId, embedder), context); + Map<String, String> context, Properties properties) { + this(destination, registry, Map.of(Embedder.defaultEmbedderId, embedder), context, properties); } public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Map<String, Embedder> embedders, - Map<String, String> context) { + Map<String, String> context, + Properties properties) { this.destination = destination; this.registry = registry; this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; this.contextValues = context; + this.properties = properties; } /** Returns the local name of the field which will receive the converted value (or null when this is empty) */ @@ -47,11 +52,17 @@ public class ConversionContext { Language language() { return language; } /** Returns a read-only map of context key-values which can be looked up during conversion. */ - Map<String,String> contextValues() { return contextValues; } + Map<String, String> contextValues() { return contextValues; } + + /** + * Returns properties that can supply values referenced during conversion. + * This contains the context values as well, but may also contain additional values, e.g. from query profiles. + */ + Properties properties() { return properties; } /** Returns an empty context */ public static ConversionContext empty() { - return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); + return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of(), new PropertyMap()); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index e16f8e7b0cd..abcd9641d46 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -49,7 +49,7 @@ public class TensorFieldType extends FieldType { public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, - context.language(), context.contextValues()); + context.language(), context.contextValues(), context.properties()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index 505759d8967..8806854b9ce 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -250,7 +250,7 @@ public class QueryProperties extends Properties { if (type == null) return value; // no type info -> keep as string FieldDescription field = type.getField(key); if (field == null) return value; // ditto - return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context)); + return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context, this)); } private void throwIllegalParameter(String key, String namespace) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index 25a5c277dce..82e1409c4eb 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -45,7 +45,8 @@ public class RankProfileInputProperties extends Properties { name.last(), value, query.getModel().getLanguage(), - context); + context, + this); } } catch (IllegalArgumentException e) { diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java index 94f92c7fd48..b1d6dbd5859 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java +++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java @@ -3,6 +3,7 @@ package com.yahoo.search.schema.internal; import com.yahoo.language.Language; import com.yahoo.language.process.Embedder; +import com.yahoo.processing.request.Properties; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -29,18 +30,18 @@ public class TensorConverter { } public Tensor convertTo(TensorType type, String key, Object value, Language language, - Map<String, String> contextValues) { + Map<String, String> contextValues, Properties properties) { var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues); - Tensor tensor = toTensor(type, value, context); + Tensor tensor = toTensor(type, value, context, properties); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) throw new IllegalArgumentException("Require a tensor of type " + type); return tensor; } - private Tensor toTensor(TensorType type, Object value, Embedder.Context context) { + private Tensor toTensor(TensorType type, Object value, Embedder.Context context, Properties properties) { if (value instanceof Tensor) return (Tensor)value; - if (value instanceof String && isEmbed((String)value)) return embed((String)value, type, context); + if (value instanceof String && isEmbed((String)value)) return embed((String)value, type, context, properties); if (value instanceof String) return Tensor.from(type, (String)value); return null; } @@ -49,7 +50,7 @@ public class TensorConverter { return value.startsWith("embed("); } - private Tensor embed(String s, TensorType type, Embedder.Context embedderContext) { + private Tensor embed(String s, TensorType type, Embedder.Context embedderContext, Properties properties) { if ( ! s.endsWith(")")) throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); String argument = s.substring("embed(".length(), s.length() - 1); @@ -76,7 +77,7 @@ public class TensorConverter { embedderId = entry.getKey(); embedder = entry.getValue(); } - return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type); + return embedder.embed(resolve(argument, properties), embedderContext.copy().setEmbedderId(embedderId), type); } private Embedder requireEmbedder(String embedderId) { @@ -86,19 +87,19 @@ public class TensorConverter { return embedders.get(embedderId); } - private static String resolve(String s, Embedder.Context embedderContext) { + private static String resolve(String s, Properties properties) { if (s.startsWith("'") && s.endsWith("'")) return s.substring(1, s.length() - 1); if (s.startsWith("\"") && s.endsWith("\"")) return s.substring(1, s.length() - 1); if (s.startsWith("@")) - return resolveReference(s, embedderContext); + return resolveReference(s, properties); return s; } - private static String resolveReference(String s, Embedder.Context embedderContext) { + private static String resolveReference(String s, Properties properties) { String referenceKey = s.substring(1); - String referencedValue = embedderContext.getContextValues().get(referenceKey); + String referencedValue = properties.getString(referenceKey); if (referencedValue == null) throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey + "' used in an embed() argument"); diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java index 429b8d1c6cb..3f9beacffe8 100644 --- a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -200,6 +200,28 @@ public class RankProfileInputTest { "used in an embed() argument"); } + @Test + void testUnembeddedTensorRankFeatureInRequestReferencedFromAParameterSuppliedByQueryProfile() { + String text = "text to embed into a tensor"; + + var registry = new QueryProfileRegistry(); + var profile = new QueryProfile("test"); + profile.set("param1", text, registry); + registry.register(profile); + var cProfile = registry.compile().findQueryProfile("test"); + + Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]"); + + Map<String, Embedder> embedders = Map.of( + "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1) + ); + assertEmbedQuery("embed(@param1)", embedding1, embedders, null, null, cProfile); + assertEmbedQuery("embed(emb1, @param1)", embedding1, embedders, null, null, cProfile); + assertEmbedQueryFails("embed(emb1, @noSuchParam)", embedding1, embedders, + "Could not resolve query parameter reference 'noSuchParam' " + + "used in an embed() argument"); + } + private Query createTensor1Query(String tensorString, String profile, String additionalParams) { return new Query.Builder() .setSchemaInfo(createSchemaInfo()) @@ -223,7 +245,19 @@ public class RankProfileInputTest { private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) { assertEmbedQuery(embed, expected, embedders, language, null); } - private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language, String param1Value) { + private void assertEmbedQuery(String embed, + Tensor expected, + Map<String, Embedder> embedders, + String language, + String param1Value) { + assertEmbedQuery(embed, expected, embedders, language, param1Value, null); + } + private void assertEmbedQuery(String embed, + Tensor expected, + Map<String, Embedder> embedders, + String language, + String param1Value, + CompiledQueryProfile queryProfile) { String languageParam = language == null ? "" : "&language=" + language; String param1 = param1Value == null ? "" : "¶m1=" + urlEncode(param1Value); @@ -239,6 +273,7 @@ public class RankProfileInputTest { .setSchemaInfo(createSchemaInfo()) .setQueryProfile(createQueryProfile()) .setEmbedders(embedders) + .setQueryProfile(queryProfile) .build(); assertEquals(0, query.errors().size()); assertEquals(expected, query.properties().get("ranking.features." + destination)); |