aboutsummaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@vespa.ai>2024-02-15 16:39:32 +0100
committerJon Bratseth <bratseth@vespa.ai>2024-02-15 16:39:32 +0100
commit4a3f24577249443b6fdaeea36b7b8110da107f06 (patch)
treea3eda6efcc2371c3a13d32b5112524388a175299 /container-search
parent55d29fa838e69cf5013deb4ca15f4b131d35876c (diff)
Resolve embed refs from query profile
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/Query.java3
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java21
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java3
-rw-r--r--container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java21
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java37
8 files changed, 69 insertions, 22 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java
index 879f49c455f..3227b047984 100644
--- a/container-search/src/main/java/com/yahoo/search/Query.java
+++ b/container-search/src/main/java/com/yahoo/search/Query.java
@@ -441,7 +441,6 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
private void setFrom(String prefix, Properties originalProperties, QueryProfileType arguments, Map<String, String> context) {
prefix = append(prefix, getPrefix(arguments).toString());
for (FieldDescription field : arguments.fields().values()) {
-
if (field.getType() == FieldType.genericQueryProfileType) { // Generic map
String fullName = append(prefix, field.getCompoundName().toString());
for (Map.Entry<String, Object> entry : originalProperties.listProperties(CompoundName.from(fullName), context).entrySet()) {
@@ -463,7 +462,7 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
}
}
- /** Calls properties.set on all entries in requestMap */
+ /** Calls properties#set on all entries in requestMap */
private void setPropertiesFromRequestMap(Map<String, String> requestMap, Properties properties, boolean ignoreSelect) {
var entrySet = requestMap.entrySet();
for (var entry : entrySet) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
index 4cc727ea23a..8778dcc7348 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java
@@ -192,7 +192,7 @@ public class QueryProfileProperties extends Properties {
if (fieldDescription != null) {
if (i == name.size() - 1) { // at the end of the path, check the assignment type
- var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context);
+ var conversionContext = new ConversionContext(localName, profile.getRegistry(), embedders, context, this);
var convertedValue = fieldDescription.getType().convertFrom(value, conversionContext);
if (convertedValue == null
&& fieldDescription.getType() instanceof QueryProfileFieldType
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
index 70f6e405a92..9bb770cf527 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java
@@ -3,7 +3,9 @@ package com.yahoo.search.query.profile.types;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
+import com.yahoo.search.query.Properties;
import com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry;
+import com.yahoo.search.query.properties.PropertyMap;
import java.util.Map;
@@ -16,22 +18,25 @@ public class ConversionContext {
private final CompiledQueryProfileRegistry registry;
private final Map<String, Embedder> embedders;
private final Map<String, String> contextValues;
+ private final Properties properties;
private final Language language;
public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder,
- Map<String, String> context) {
- this(destination, registry, Map.of(Embedder.defaultEmbedderId, embedder), context);
+ Map<String, String> context, Properties properties) {
+ this(destination, registry, Map.of(Embedder.defaultEmbedderId, embedder), context, properties);
}
public ConversionContext(String destination, CompiledQueryProfileRegistry registry,
Map<String, Embedder> embedders,
- Map<String, String> context) {
+ Map<String, String> context,
+ Properties properties) {
this.destination = destination;
this.registry = registry;
this.embedders = embedders;
this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language"))
: Language.UNKNOWN;
this.contextValues = context;
+ this.properties = properties;
}
/** Returns the local name of the field which will receive the converted value (or null when this is empty) */
@@ -47,11 +52,17 @@ public class ConversionContext {
Language language() { return language; }
/** Returns a read-only map of context key-values which can be looked up during conversion. */
- Map<String,String> contextValues() { return contextValues; }
+ Map<String, String> contextValues() { return contextValues; }
+
+ /**
+ * Returns properties that can supply values referenced during conversion.
+ * This contains the context values as well, but may also contain additional values, e.g. from query profiles.
+ */
+ Properties properties() { return properties; }
/** Returns an empty context */
public static ConversionContext empty() {
- return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of());
+ return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of(), new PropertyMap());
}
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index e16f8e7b0cd..abcd9641d46 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -49,7 +49,7 @@ public class TensorFieldType extends FieldType {
public Object convertFrom(Object o, ConversionContext context) {
if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type);
return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o,
- context.language(), context.contextValues());
+ context.language(), context.contextValues(), context.properties());
}
public static TensorFieldType fromTypeString(String s) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
index 505759d8967..8806854b9ce 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java
@@ -250,7 +250,7 @@ public class QueryProperties extends Properties {
if (type == null) return value; // no type info -> keep as string
FieldDescription field = type.getField(key);
if (field == null) return value; // ditto
- return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context));
+ return field.getType().convertFrom(value, new ConversionContext(key, profileRegistry, embedders, context, this));
}
private void throwIllegalParameter(String key, String namespace) {
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java
index 25a5c277dce..82e1409c4eb 100644
--- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java
+++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java
@@ -45,7 +45,8 @@ public class RankProfileInputProperties extends Properties {
name.last(),
value,
query.getModel().getLanguage(),
- context);
+ context,
+ this);
}
}
catch (IllegalArgumentException e) {
diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java
index 94f92c7fd48..b1d6dbd5859 100644
--- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java
+++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java
@@ -3,6 +3,7 @@ package com.yahoo.search.schema.internal;
import com.yahoo.language.Language;
import com.yahoo.language.process.Embedder;
+import com.yahoo.processing.request.Properties;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -29,18 +30,18 @@ public class TensorConverter {
}
public Tensor convertTo(TensorType type, String key, Object value, Language language,
- Map<String, String> contextValues) {
+ Map<String, String> contextValues, Properties properties) {
var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues);
- Tensor tensor = toTensor(type, value, context);
+ Tensor tensor = toTensor(type, value, context, properties);
if (tensor == null) return null;
if (! tensor.type().isAssignableTo(type))
throw new IllegalArgumentException("Require a tensor of type " + type);
return tensor;
}
- private Tensor toTensor(TensorType type, Object value, Embedder.Context context) {
+ private Tensor toTensor(TensorType type, Object value, Embedder.Context context, Properties properties) {
if (value instanceof Tensor) return (Tensor)value;
- if (value instanceof String && isEmbed((String)value)) return embed((String)value, type, context);
+ if (value instanceof String && isEmbed((String)value)) return embed((String)value, type, context, properties);
if (value instanceof String) return Tensor.from(type, (String)value);
return null;
}
@@ -49,7 +50,7 @@ public class TensorConverter {
return value.startsWith("embed(");
}
- private Tensor embed(String s, TensorType type, Embedder.Context embedderContext) {
+ private Tensor embed(String s, TensorType type, Embedder.Context embedderContext, Properties properties) {
if ( ! s.endsWith(")"))
throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
String argument = s.substring("embed(".length(), s.length() - 1);
@@ -76,7 +77,7 @@ public class TensorConverter {
embedderId = entry.getKey();
embedder = entry.getValue();
}
- return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type);
+ return embedder.embed(resolve(argument, properties), embedderContext.copy().setEmbedderId(embedderId), type);
}
private Embedder requireEmbedder(String embedderId) {
@@ -86,19 +87,19 @@ public class TensorConverter {
return embedders.get(embedderId);
}
- private static String resolve(String s, Embedder.Context embedderContext) {
+ private static String resolve(String s, Properties properties) {
if (s.startsWith("'") && s.endsWith("'"))
return s.substring(1, s.length() - 1);
if (s.startsWith("\"") && s.endsWith("\""))
return s.substring(1, s.length() - 1);
if (s.startsWith("@"))
- return resolveReference(s, embedderContext);
+ return resolveReference(s, properties);
return s;
}
- private static String resolveReference(String s, Embedder.Context embedderContext) {
+ private static String resolveReference(String s, Properties properties) {
String referenceKey = s.substring(1);
- String referencedValue = embedderContext.getContextValues().get(referenceKey);
+ String referencedValue = properties.getString(referenceKey);
if (referencedValue == null)
throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey +
"' used in an embed() argument");
diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java
index 429b8d1c6cb..3f9beacffe8 100644
--- a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java
+++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java
@@ -200,6 +200,28 @@ public class RankProfileInputTest {
"used in an embed() argument");
}
+ @Test
+ void testUnembeddedTensorRankFeatureInRequestReferencedFromAParameterSuppliedByQueryProfile() {
+ String text = "text to embed into a tensor";
+
+ var registry = new QueryProfileRegistry();
+ var profile = new QueryProfile("test");
+ profile.set("param1", text, registry);
+ registry.register(profile);
+ var cProfile = registry.compile().findQueryProfile("test");
+
+ Tensor embedding1 = Tensor.from("tensor<float>(x[5]):[3,7,4,0,0]]");
+
+ Map<String, Embedder> embedders = Map.of(
+ "emb1", new MockEmbedder(text, Language.UNKNOWN, embedding1)
+ );
+ assertEmbedQuery("embed(@param1)", embedding1, embedders, null, null, cProfile);
+ assertEmbedQuery("embed(emb1, @param1)", embedding1, embedders, null, null, cProfile);
+ assertEmbedQueryFails("embed(emb1, @noSuchParam)", embedding1, embedders,
+ "Could not resolve query parameter reference 'noSuchParam' " +
+ "used in an embed() argument");
+ }
+
private Query createTensor1Query(String tensorString, String profile, String additionalParams) {
return new Query.Builder()
.setSchemaInfo(createSchemaInfo())
@@ -223,7 +245,19 @@ public class RankProfileInputTest {
private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language) {
assertEmbedQuery(embed, expected, embedders, language, null);
}
- private void assertEmbedQuery(String embed, Tensor expected, Map<String, Embedder> embedders, String language, String param1Value) {
+ private void assertEmbedQuery(String embed,
+ Tensor expected,
+ Map<String, Embedder> embedders,
+ String language,
+ String param1Value) {
+ assertEmbedQuery(embed, expected, embedders, language, param1Value, null);
+ }
+ private void assertEmbedQuery(String embed,
+ Tensor expected,
+ Map<String, Embedder> embedders,
+ String language,
+ String param1Value,
+ CompiledQueryProfile queryProfile) {
String languageParam = language == null ? "" : "&language=" + language;
String param1 = param1Value == null ? "" : "&param1=" + urlEncode(param1Value);
@@ -239,6 +273,7 @@ public class RankProfileInputTest {
.setSchemaInfo(createSchemaInfo())
.setQueryProfile(createQueryProfile())
.setEmbedders(embedders)
+ .setQueryProfile(queryProfile)
.build();
assertEquals(0, query.errors().size());
assertEquals(expected, query.properties().get("ranking.features." + destination));