aboutsummaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-03-30 17:22:08 +0200
committerJon Bratseth <bratseth@gmail.com>2022-03-30 17:22:08 +0200
commitc0cd66e75bd5a8ff5cd94a72529cb951a1046746 (patch)
tree739856a9eb2ba3069f6427b9aa562ff6ad8376df /container-search
parent211fdb61ecf7379c8113e0da413a8cc16f72494d (diff)
Support substitutions in tensors
Diffstat (limited to 'container-search')
-rw-r--r--container-search/abi-spec.json4
-rw-r--r--container-search/src/main/java/com/yahoo/search/Query.java5
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java27
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java2
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java18
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java28
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java39
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java66
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml2
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml1
-rw-r--r--container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml1
13 files changed, 138 insertions, 59 deletions
diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json
index 7cc7168c79a..d0c305cea51 100644
--- a/container-search/abi-spec.json
+++ b/container-search/abi-spec.json
@@ -5983,8 +5983,6 @@
"public com.yahoo.search.query.profile.QueryProfile clone()",
"public static void validateName(java.lang.String)",
"protected void set(com.yahoo.processing.request.CompoundName, java.lang.Object, com.yahoo.search.query.profile.DimensionBinding, com.yahoo.search.query.profile.QueryProfileRegistry)",
- "protected java.lang.Object convertToSubstitutionString(java.lang.Object)",
- "protected com.yahoo.search.query.profile.types.FieldDescription getFieldDescription(com.yahoo.processing.request.CompoundName, com.yahoo.search.query.profile.DimensionBinding)",
"protected java.lang.Boolean isLocalInstanceOverridable(java.lang.String)",
"protected java.lang.Object lookup(com.yahoo.processing.request.CompoundName, boolean, com.yahoo.search.query.profile.DimensionBinding)",
"protected final void accept(com.yahoo.search.query.profile.QueryProfileVisitor, com.yahoo.search.query.profile.DimensionBinding, com.yahoo.search.query.profile.QueryProfile)",
@@ -6236,7 +6234,7 @@
"public static com.yahoo.search.query.profile.SubstituteString create(java.lang.String)",
"public void <init>(java.util.List, java.lang.String)",
"public boolean hasRelative()",
- "public java.lang.String substitute(java.util.Map, com.yahoo.processing.request.Properties)",
+ "public java.lang.Object substitute(java.util.Map, com.yahoo.processing.request.Properties)",
"public java.util.List components()",
"public java.lang.String stringValue()",
"public int hashCode()",
diff --git a/container-search/src/main/java/com/yahoo/search/Query.java b/container-search/src/main/java/com/yahoo/search/Query.java
index 9d2c7f6eec9..3d411adeb75 100644
--- a/container-search/src/main/java/com/yahoo/search/Query.java
+++ b/container-search/src/main/java/com/yahoo/search/Query.java
@@ -351,7 +351,10 @@ public class Query extends com.yahoo.processing.Request implements Cloneable {
builder.getZoneInfo());
}
- private Query(HttpRequest request, Map<String, String> requestMap, CompiledQueryProfile queryProfile, Map<String, Embedder> embedders,
+ private Query(HttpRequest request,
+ Map<String, String> requestMap,
+ CompiledQueryProfile queryProfile,
+ Map<String, Embedder> embedders,
ZoneInfo zoneInfo) {
super(new QueryPropertyAliases(propertyAliases));
this.httpRequest = request;
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java
index c30a78da57d..989f12172b3 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java
@@ -477,22 +477,6 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
}
}
- /** Returns this value, or its corresponding substitution string if it contains substitutions */
- protected Object convertToSubstitutionString(Object value) {
- if (value == null) return value;
- if (value.getClass() != String.class) return value;
- SubstituteString substituteString = SubstituteString.create((String)value);
- if (substituteString == null) return value;
- return substituteString;
- }
-
- /** Returns the field description of this field, or null if it is not typed */
- protected FieldDescription getFieldDescription(CompoundName name, DimensionBinding binding) {
- FieldDescriptionQueryProfileVisitor visitor = new FieldDescriptionQueryProfileVisitor(name.asList());
- accept(visitor, binding, null);
- return visitor.result();
- }
-
/**
* Returns true if this value is definitely overridable in this (set and not unoverridable),
* false if it is declared unoverridable (in instance or type), and null if this profile has no
@@ -620,6 +604,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
if (parentType != null && type == null && ! isFrozen())
type = parentType;
+ value = convertToSubstitutionString(value);
value = checkAndConvertAssignment(localName, value, registry);
localPut(localName, value, dimensionBinding);
return this;
@@ -841,7 +826,6 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
localName = type.unalias(localName);
validateName(localName);
- value = convertToSubstitutionString(value);
if (dimensionBinding.isNull()) {
Object combinedValue = value instanceof QueryProfile
@@ -857,6 +841,15 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable
}
}
+ /** Returns this value, or its corresponding substitution string if it contains substitutions */
+ static Object convertToSubstitutionString(Object value) {
+ if (value == null) return value;
+ if (value.getClass() != String.class) return value;
+ SubstituteString substituteString = SubstituteString.create((String)value);
+ if (substituteString == null) return value;
+ return substituteString;
+ }
+
private static final Pattern namePattern = Pattern.compile("[$a-zA-Z_/][-$a-zA-Z0-9_/()]*");
/**
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java
index 5be0fc9ea10..845c2cfd384 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java
@@ -455,7 +455,7 @@ public class QueryProfileVariants implements Freezable, Cloneable {
public Object getValue() { return value; }
/** Sets the value to use for this set of dimension values */
- public void setValue(Object value) { this.value=value; }
+ public void setValue(Object value) { this.value = value; }
public boolean matches(DimensionValues givenDimensionValues) {
return dimensionValues.matches(givenDimensionValues);
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java b/container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java
index 2a3feb084db..5035a5ccd49 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java
@@ -68,7 +68,7 @@ public class SubstituteString {
* @param context the content which is used to resolve profile variants when looking up substitution values
* @param substitution the properties in which values to be substituted are looked up
*/
- public String substitute(Map<String, String> context, Properties substitution) {
+ public Object substitute(Map<String, String> context, Properties substitution) {
StringBuilder b = new StringBuilder();
for (Component component : components)
b.append(component.getValue(context, substitution));
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java
index a600389f4d5..ad9d3f4c1a5 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java
@@ -191,7 +191,7 @@ public class CompiledQueryProfile extends AbstractComponent implements Cloneable
private Object substitute(Object value, Map<String, String> context, Properties substitution) {
if (value == null) return value;
if (substitution == null) return value;
- if (value.getClass() != SubstituteString.class) return value;
+ if ( ! (value instanceof SubstituteString)) return value;
return ((SubstituteString)value).substitute(context, substitution);
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java
index c94987ce88e..70a08b9ab8a 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java
@@ -47,7 +47,7 @@ public class QueryProfileXMLReader {
if ( ! file.getName().endsWith(".xml")) continue;
queryProfileReaders.add(new NamedReader(file.getName(), new FileReader(file)));
}
- File typeDir=new File(dir,"types");
+ File typeDir = new File(dir,"types");
if (typeDir.isDirectory()) {
for (File file : sortFiles(typeDir)) {
if ( ! file.getName().endsWith(".xml")) continue;
@@ -106,7 +106,7 @@ public class QueryProfileXMLReader {
}
String idString = root.getAttribute("id");
- if (idString == null || idString.equals(""))
+ if (idString.isEmpty())
throw new IllegalArgumentException("'" + reader.getName() + "' has no 'id' attribute in the root element");
ComponentId id = new ComponentId(idString);
validateFileNameToId(reader.getName(), id,"query profile type");
@@ -129,7 +129,7 @@ public class QueryProfileXMLReader {
}
String idString = root.getAttribute("id");
- if (idString == null || idString.equals(""))
+ if (idString.isEmpty())
throw new IllegalArgumentException("Query profile '" + reader.getName() +
"' has no 'id' attribute in the root element");
ComponentId id = new ComponentId(idString);
@@ -137,7 +137,7 @@ public class QueryProfileXMLReader {
QueryProfile queryProfile = new QueryProfile(id, reader.getName(), registry);
String typeId = root.getAttribute("type");
- if (typeId != null && ! typeId.equals("")) {
+ if (! typeId.isEmpty()) {
QueryProfileType type = registry.getType(typeId);
if (type == null)
throw new IllegalArgumentException("Query profile '" + reader.getName() +
@@ -197,7 +197,7 @@ public class QueryProfileXMLReader {
private void readInheritedTypes(Element element,QueryProfileType type, QueryProfileTypeRegistry registry) {
String inheritedString = element.getAttribute("inherits");
- if (inheritedString == null || inheritedString.equals("")) return;
+ if (inheritedString.equals("")) return;
for (String inheritedId : inheritedString.split(" ")) {
inheritedId = inheritedId.trim();
if (inheritedId.equals("")) continue;
@@ -211,10 +211,10 @@ public class QueryProfileXMLReader {
private void readFieldDefinitions(Element element, QueryProfileType type, QueryProfileTypeRegistry registry) {
for (Element field : XML.getChildren(element,"field")) {
String name = field.getAttribute("name");
- if (name == null || name.equals("")) throw new IllegalArgumentException("A field has no 'name' attribute");
+ if (name.isEmpty()) throw new IllegalArgumentException("A field has no 'name' attribute");
try {
String fieldTypeName = field.getAttribute("type");
- if (fieldTypeName == null) throw new IllegalArgumentException("Field '" + field + "' has no 'type' attribute");
+ if (fieldTypeName.isEmpty()) throw new IllegalArgumentException("Field '" + field + "' has no 'type' attribute");
FieldType fieldType = FieldType.fromString(fieldTypeName, registry);
type.addField(new FieldDescription(name,
fieldType,
@@ -247,7 +247,7 @@ public class QueryProfileXMLReader {
private void readInherited(Element element, QueryProfile profile, QueryProfileRegistry registry,
DimensionValues dimensionValues, String sourceDescription) {
String inheritedString = element.getAttribute("inherits");
- if (inheritedString == null || inheritedString.equals("")) return;
+ if (inheritedString.isEmpty()) return;
for (String inheritedId : inheritedString.split(" ")) {
inheritedId = inheritedId.trim();
if (inheritedId.equals("")) continue;
@@ -265,7 +265,7 @@ public class QueryProfileXMLReader {
List<KeyValue> properties = new ArrayList<>();
for (Element field : XML.getChildren(element,"field")) {
String name = field.getAttribute("name");
- if (name == null || name.equals(""))
+ if (name.isEmpty())
throw new IllegalArgumentException("A field in " + sourceDescription + " has no 'name' attribute");
try {
Boolean overridable = getBooleanAttribute("overridable", null, field);
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java
index 564f90be422..2da72772e64 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java
@@ -1,7 +1,9 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.query.profile.types;
+import com.yahoo.search.query.profile.QueryProfile;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.SubstituteString;
import static com.yahoo.text.Lowercase.toLowerCase;
@@ -13,10 +15,10 @@ import static com.yahoo.text.Lowercase.toLowerCase;
@SuppressWarnings("rawtypes")
public class PrimitiveFieldType extends FieldType {
- private Class primitiveClass;
+ private final Class primitiveClass;
PrimitiveFieldType(Class primitiveClass) {
- this.primitiveClass=primitiveClass;
+ this.primitiveClass = primitiveClass;
}
@Override
@@ -43,20 +45,20 @@ public class PrimitiveFieldType extends FieldType {
@Override
public Object convertFrom(Object object, QueryProfileRegistry registry) {
if (primitiveClass == object.getClass()) return object;
+ if (primitiveClass == String.class && object.getClass() == SubstituteString.class) return object;
if (object.getClass() == String.class) return convertFromString((String)object);
if (object instanceof Number) return convertFromNumber((Number)object);
-
return null;
}
private Object convertFromString(String string) {
try {
- if (primitiveClass==Integer.class) return Integer.valueOf(string);
- if (primitiveClass==Double.class) return Double.valueOf(string);
- if (primitiveClass==Float.class) return Float.valueOf(string);
- if (primitiveClass==Long.class) return Long.valueOf(string);
- if (primitiveClass==Boolean.class) return Boolean.valueOf(string);
+ if (primitiveClass == Integer.class) return Integer.valueOf(string);
+ if (primitiveClass == Double.class) return Double.valueOf(string);
+ if (primitiveClass == Float.class) return Float.valueOf(string);
+ if (primitiveClass == Long.class) return Long.valueOf(string);
+ if (primitiveClass == Boolean.class) return Boolean.valueOf(string);
}
catch (NumberFormatException e) {
return null; // Handled in caller
@@ -65,11 +67,11 @@ public class PrimitiveFieldType extends FieldType {
}
private Object convertFromNumber(Number number) {
- if (primitiveClass==Integer.class) return number.intValue();
- if (primitiveClass==Double.class) return number.doubleValue();
- if (primitiveClass==Float.class) return number.floatValue();
- if (primitiveClass==Long.class) return number.longValue();
- if (primitiveClass==String.class) return String.valueOf(number);
+ if (primitiveClass == Integer.class) return number.intValue();
+ if (primitiveClass == Double.class) return number.doubleValue();
+ if (primitiveClass == Float.class) return number.floatValue();
+ if (primitiveClass == Long.class) return number.longValue();
+ if (primitiveClass == String.class) return String.valueOf(number);
throw new RuntimeException("Programming error: Input type is " + number.getClass() +
" primitiveClass is " + primitiveClass);
}
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index db6a58a4dd3..cc6b18af820 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -2,7 +2,9 @@
package com.yahoo.search.query.profile.types;
import com.yahoo.language.process.Embedder;
+import com.yahoo.processing.request.Properties;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.search.query.profile.SubstituteString;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -51,6 +53,7 @@ public class TensorFieldType extends FieldType {
@Override
public Object convertFrom(Object o, ConversionContext context) {
+ if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type);
Tensor tensor = toTensor(o, context);
if (tensor == null) return null;
if (! tensor.type().isAssignableTo(type))
@@ -60,12 +63,16 @@ public class TensorFieldType extends FieldType {
private Tensor toTensor(Object o, ConversionContext context) {
if (o instanceof Tensor) return (Tensor)o;
- if (o instanceof String && ((String)o).startsWith("embed(")) return encode((String)o, context);
+ if (o instanceof String && isEmbed((String)o)) return embed((String)o, type, context);
if (o instanceof String) return Tensor.from(type, (String)o);
return null;
}
- private Tensor encode(String s, ConversionContext context) {
+ static boolean isEmbed(String value) {
+ return value.startsWith("embed(");
+ }
+
+ static Tensor embed(String s, TensorType type, ConversionContext context) {
if ( ! s.endsWith(")"))
throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
String argument = s.substring("embed(".length(), s.length() - 1);
@@ -78,14 +85,14 @@ public class TensorFieldType extends FieldType {
argument = matcher.group(2);
if (!context.embedders().containsKey(embedderId)) {
throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " +
- "Valid embedders are " + validEmbedders(context.embedders()));
+ "Valid embedders are " + validEmbedders(context.embedders()));
}
embedder = context.embedders().get(embedderId);
} else if (context.embedders().size() == 0) {
throw new IllegalStateException("No embedders provided"); // should never happen
} else if (context.embedders().size() > 1) {
throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " +
- "Valid embedders are " + validEmbedders(context.embedders()));
+ "Valid embedders are " + validEmbedders(context.embedders()));
} else {
embedder = context.embedders().entrySet().stream().findFirst().get().getValue();
}
@@ -110,7 +117,7 @@ public class TensorFieldType extends FieldType {
return String.join(",", embedderIds);
}
- private Embedder.Context toEmbedderContext(ConversionContext context) {
+ private static Embedder.Context toEmbedderContext(ConversionContext context) {
return new Embedder.Context(context.destination()).setLanguage(context.language());
}
@@ -118,4 +125,26 @@ public class TensorFieldType extends FieldType {
return new TensorFieldType(TensorType.fromSpec(s));
}
+ /**
+ * A substitute string that should become a tensor once the substitution is performed at lookup time.
+ * This is to support substitution strings in tensor values by parsing (only) such tensors at
+ * lookup time rather than at construction time.
+ */
+ private static class SubstituteStringTensor extends SubstituteString {
+
+ private final TensorType type;
+
+ SubstituteStringTensor(SubstituteString string, TensorType type) {
+ super(string.components(), string.stringValue());
+ this.type = type;
+ }
+
+ @Override
+ public Object substitute(Map<String, String> context, Properties substitution) {
+ String substituted = super.substitute(context, substitution).toString();
+ return Tensor.from(type, substituted);
+ }
+
+ }
+
}
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java
index c47e1e5b23c..ad85b5e37a1 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/XmlReadingTestCase.java
@@ -3,7 +3,12 @@ package com.yahoo.search.query.profile.config.test;
import com.yahoo.jdisc.http.HttpRequest.Method;
import com.yahoo.container.jdisc.HttpRequest;
+import com.yahoo.language.Language;
+import com.yahoo.language.process.Embedder;
import com.yahoo.processing.request.CompoundName;
+import com.yahoo.search.query.profile.types.test.QueryProfileTypeTestCase;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.yolean.Exceptions;
import com.yahoo.search.Query;
import com.yahoo.search.query.Properties;
@@ -17,6 +22,7 @@ import com.yahoo.search.query.profile.types.QueryProfileType;
import org.junit.Test;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import static org.junit.Assert.assertEquals;
@@ -474,20 +480,64 @@ public class XmlReadingTestCase {
CompiledQueryProfileRegistry registry = new QueryProfileXMLReader().read("src/test/java/com/yahoo/search/query/profile/config/test/tensortypes").compile();
QueryProfileType type1 = registry.getTypeRegistry().getComponent("type1");
- assertEquals("tensor<float>(x[1])", type1.getFieldType(new CompoundName("ranking.features.query(tensor_1)")).stringValue());
+ assertEquals(TensorType.fromSpec("tensor<float>(x[1])"),
+ type1.getFieldType(new CompoundName("ranking.features.query(tensor_1)")).asTensorType());
assertNull(type1.getFieldType(new CompoundName("ranking.features.query(tensor_2)")));
assertNull(type1.getFieldType(new CompoundName("ranking.features.query(tensor_3)")));
+ assertEquals(TensorType.fromSpec("tensor(key{})"),
+ type1.getFieldType(new CompoundName("ranking.features.query(tensor_4)")).asTensorType());
QueryProfileType type2 = registry.getTypeRegistry().getComponent("type2");
assertNull(type2.getFieldType(new CompoundName("ranking.features.query(tensor_1)")));
- assertEquals("tensor<float>(x[2])", type2.getFieldType(new CompoundName("ranking.features.query(tensor_2)")).stringValue());
- assertEquals("tensor<float>(x[3])", type2.getFieldType(new CompoundName("ranking.features.query(tensor_3)")).stringValue());
+ assertEquals(TensorType.fromSpec("tensor<float>(x[2])"),
+ type2.getFieldType(new CompoundName("ranking.features.query(tensor_2)")).asTensorType());
+ assertEquals(TensorType.fromSpec("tensor<float>(x[3])"),
+ type2.getFieldType(new CompoundName("ranking.features.query(tensor_3)")).asTensorType());
+
+ Query queryProfile1 = new Query.Builder().setQueryProfile(registry.getComponent("profile1"))
+ .setRequest("?query=test&ranking.features.query(tensor_1)=[1.200]")
+ .build();
+ assertEquals("tensor_1 received as a tensor tensor",
+ Tensor.from("tensor<float>(x[1]):[1.2]"),
+ queryProfile1.properties().get("ranking.features.query(tensor_1)"));
+ assertEquals("tensor_4 contained in the profile is a tensor",
+ Tensor.from("tensor(key{}):{key1:1.0}"),
+ queryProfile1.properties().get("ranking.features.query(tensor_4)"));
+
+ Query queryProfile2 = new Query.Builder().setQueryProfile(registry.getComponent("profile2"))
+ .setEmbedder(new MockEmbedder("text-to-embed",
+ Tensor.from("tensor(x[3]):[1, 2, 3]")))
+ .setRequest("?query=test&ranking.features.query(tensor_1)=[1.200]")
+ .build();
+ assertEquals("tensor_1 received as a string as it is not in type2",
+ "[1.200]",
+ queryProfile2.properties().get("ranking.features.query(tensor_1)"));
+ //assertEquals(Tensor.from("tensor(x[3]):[1, 2, 3]"),
+ // queryProfile2.properties().get("ranking.features.query(tensor_3)"));
+ }
- Query queryProfile1 = new Query("?query=test&ranking.features.query(tensor_1)=[1.200]", registry.getComponent("profile1"));
- assertEquals("Is received as a tensor tensor", "tensor<float>(x[1]):[1.2]", queryProfile1.properties().get("ranking.features.query(tensor_1)").toString());
+ private static final class MockEmbedder implements Embedder {
- Query queryProfile2 = new Query("?query=test&ranking.features.query(tensor_1)=[1.200]", registry.getComponent("profile2"));
- assertEquals("Is received as a string", "[1.200]", queryProfile2.properties().get("ranking.features.query(tensor_1)").toString());
- }
+ private final String expectedText;
+ private final Tensor tensorToReturn;
+ public MockEmbedder(String expectedText, Tensor tensorToReturn) {
+ this.expectedText = expectedText;
+ this.tensorToReturn = tensorToReturn;
+ }
+
+ @Override
+ public List<Integer> embed(String text, Embedder.Context context) {
+ fail("Unexpected call");
+ return null;
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ assertEquals(expectedText, text);
+ assertEquals(tensorToReturn.type(), tensorType);
+ return tensorToReturn;
+ }
+
+ }
}
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml
index 3de0e219158..cb1260bad8e 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile1.xml
@@ -1,3 +1,5 @@
<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<query-profile id="profile1" type="type1">
+ <field name="mykey">key1</field>
+ <field name="ranking.features.query(tensor_4)">{{key:"%{mykey}"}:1.0}</field>
</query-profile>
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml
index 6a2065d05eb..f3017327638 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/profile2.xml
@@ -1,3 +1,4 @@
<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<query-profile id="profile2" type="type2">
+ <!-- <field name="ranking.features.query(tensor_3)">embed(%{mytext})</field> embed(...) at config time is not supported -->
</query-profile>
diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml
index a928c5f4deb..71e8e0f9d71 100644
--- a/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml
+++ b/container-search/src/test/java/com/yahoo/search/query/profile/config/test/tensortypes/types/type1.xml
@@ -1,4 +1,5 @@
<!-- Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -->
<query-profile-type id="type1">
<field name="ranking.features.query(tensor_1)" type="tensor&lt;float&gt;(x[1])" />
+ <field name="ranking.features.query(tensor_4)" type="tensor(key{})" />
</query-profile-type>