diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-03-30 17:22:08 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-03-30 17:22:08 +0200 |
commit | c0cd66e75bd5a8ff5cd94a72529cb951a1046746 (patch) | |
tree | 739856a9eb2ba3069f6427b9aa562ff6ad8376df /container-search/src/main/java/com/yahoo/search/query/profile | |
parent | 211fdb61ecf7379c8113e0da413a8cc16f72494d (diff) |
Support substitutions in tensors
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/query/profile')
7 files changed, 71 insertions, 47 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java index c30a78da57d..989f12172b3 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java @@ -477,22 +477,6 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable } } - /** Returns this value, or its corresponding substitution string if it contains substitutions */ - protected Object convertToSubstitutionString(Object value) { - if (value == null) return value; - if (value.getClass() != String.class) return value; - SubstituteString substituteString = SubstituteString.create((String)value); - if (substituteString == null) return value; - return substituteString; - } - - /** Returns the field description of this field, or null if it is not typed */ - protected FieldDescription getFieldDescription(CompoundName name, DimensionBinding binding) { - FieldDescriptionQueryProfileVisitor visitor = new FieldDescriptionQueryProfileVisitor(name.asList()); - accept(visitor, binding, null); - return visitor.result(); - } - /** * Returns true if this value is definitely overridable in this (set and not unoverridable), * false if it is declared unoverridable (in instance or type), and null if this profile has no @@ -620,6 +604,7 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable if (parentType != null && type == null && ! isFrozen()) type = parentType; + value = convertToSubstitutionString(value); value = checkAndConvertAssignment(localName, value, registry); localPut(localName, value, dimensionBinding); return this; @@ -841,7 +826,6 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable localName = type.unalias(localName); validateName(localName); - value = convertToSubstitutionString(value); if (dimensionBinding.isNull()) { Object combinedValue = value instanceof QueryProfile @@ -857,6 +841,15 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable } } + /** Returns this value, or its corresponding substitution string if it contains substitutions */ + static Object convertToSubstitutionString(Object value) { + if (value == null) return value; + if (value.getClass() != String.class) return value; + SubstituteString substituteString = SubstituteString.create((String)value); + if (substituteString == null) return value; + return substituteString; + } + private static final Pattern namePattern = Pattern.compile("[$a-zA-Z_/][-$a-zA-Z0-9_/()]*"); /** diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java index 5be0fc9ea10..845c2cfd384 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileVariants.java @@ -455,7 +455,7 @@ public class QueryProfileVariants implements Freezable, Cloneable { public Object getValue() { return value; } /** Sets the value to use for this set of dimension values */ - public void setValue(Object value) { this.value=value; } + public void setValue(Object value) { this.value = value; } public boolean matches(DimensionValues givenDimensionValues) { return dimensionValues.matches(givenDimensionValues); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java b/container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java index 2a3feb084db..5035a5ccd49 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/SubstituteString.java @@ -68,7 +68,7 @@ public class SubstituteString { * @param context the content which is used to resolve profile variants when looking up substitution values * @param substitution the properties in which values to be substituted are looked up */ - public String substitute(Map<String, String> context, Properties substitution) { + public Object substitute(Map<String, String> context, Properties substitution) { StringBuilder b = new StringBuilder(); for (Component component : components) b.append(component.getValue(context, substitution)); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java index a600389f4d5..ad9d3f4c1a5 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/CompiledQueryProfile.java @@ -191,7 +191,7 @@ public class CompiledQueryProfile extends AbstractComponent implements Cloneable private Object substitute(Object value, Map<String, String> context, Properties substitution) { if (value == null) return value; if (substitution == null) return value; - if (value.getClass() != SubstituteString.class) return value; + if ( ! (value instanceof SubstituteString)) return value; return ((SubstituteString)value).substitute(context, substitution); } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java index c94987ce88e..70a08b9ab8a 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java @@ -47,7 +47,7 @@ public class QueryProfileXMLReader { if ( ! file.getName().endsWith(".xml")) continue; queryProfileReaders.add(new NamedReader(file.getName(), new FileReader(file))); } - File typeDir=new File(dir,"types"); + File typeDir = new File(dir,"types"); if (typeDir.isDirectory()) { for (File file : sortFiles(typeDir)) { if ( ! file.getName().endsWith(".xml")) continue; @@ -106,7 +106,7 @@ public class QueryProfileXMLReader { } String idString = root.getAttribute("id"); - if (idString == null || idString.equals("")) + if (idString.isEmpty()) throw new IllegalArgumentException("'" + reader.getName() + "' has no 'id' attribute in the root element"); ComponentId id = new ComponentId(idString); validateFileNameToId(reader.getName(), id,"query profile type"); @@ -129,7 +129,7 @@ public class QueryProfileXMLReader { } String idString = root.getAttribute("id"); - if (idString == null || idString.equals("")) + if (idString.isEmpty()) throw new IllegalArgumentException("Query profile '" + reader.getName() + "' has no 'id' attribute in the root element"); ComponentId id = new ComponentId(idString); @@ -137,7 +137,7 @@ public class QueryProfileXMLReader { QueryProfile queryProfile = new QueryProfile(id, reader.getName(), registry); String typeId = root.getAttribute("type"); - if (typeId != null && ! typeId.equals("")) { + if (! typeId.isEmpty()) { QueryProfileType type = registry.getType(typeId); if (type == null) throw new IllegalArgumentException("Query profile '" + reader.getName() + @@ -197,7 +197,7 @@ public class QueryProfileXMLReader { private void readInheritedTypes(Element element,QueryProfileType type, QueryProfileTypeRegistry registry) { String inheritedString = element.getAttribute("inherits"); - if (inheritedString == null || inheritedString.equals("")) return; + if (inheritedString.equals("")) return; for (String inheritedId : inheritedString.split(" ")) { inheritedId = inheritedId.trim(); if (inheritedId.equals("")) continue; @@ -211,10 +211,10 @@ public class QueryProfileXMLReader { private void readFieldDefinitions(Element element, QueryProfileType type, QueryProfileTypeRegistry registry) { for (Element field : XML.getChildren(element,"field")) { String name = field.getAttribute("name"); - if (name == null || name.equals("")) throw new IllegalArgumentException("A field has no 'name' attribute"); + if (name.isEmpty()) throw new IllegalArgumentException("A field has no 'name' attribute"); try { String fieldTypeName = field.getAttribute("type"); - if (fieldTypeName == null) throw new IllegalArgumentException("Field '" + field + "' has no 'type' attribute"); + if (fieldTypeName.isEmpty()) throw new IllegalArgumentException("Field '" + field + "' has no 'type' attribute"); FieldType fieldType = FieldType.fromString(fieldTypeName, registry); type.addField(new FieldDescription(name, fieldType, @@ -247,7 +247,7 @@ public class QueryProfileXMLReader { private void readInherited(Element element, QueryProfile profile, QueryProfileRegistry registry, DimensionValues dimensionValues, String sourceDescription) { String inheritedString = element.getAttribute("inherits"); - if (inheritedString == null || inheritedString.equals("")) return; + if (inheritedString.isEmpty()) return; for (String inheritedId : inheritedString.split(" ")) { inheritedId = inheritedId.trim(); if (inheritedId.equals("")) continue; @@ -265,7 +265,7 @@ public class QueryProfileXMLReader { List<KeyValue> properties = new ArrayList<>(); for (Element field : XML.getChildren(element,"field")) { String name = field.getAttribute("name"); - if (name == null || name.equals("")) + if (name.isEmpty()) throw new IllegalArgumentException("A field in " + sourceDescription + " has no 'name' attribute"); try { Boolean overridable = getBooleanAttribute("overridable", null, field); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java index 564f90be422..2da72772e64 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/PrimitiveFieldType.java @@ -1,7 +1,9 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.query.profile.types; +import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.SubstituteString; import static com.yahoo.text.Lowercase.toLowerCase; @@ -13,10 +15,10 @@ import static com.yahoo.text.Lowercase.toLowerCase; @SuppressWarnings("rawtypes") public class PrimitiveFieldType extends FieldType { - private Class primitiveClass; + private final Class primitiveClass; PrimitiveFieldType(Class primitiveClass) { - this.primitiveClass=primitiveClass; + this.primitiveClass = primitiveClass; } @Override @@ -43,20 +45,20 @@ public class PrimitiveFieldType extends FieldType { @Override public Object convertFrom(Object object, QueryProfileRegistry registry) { if (primitiveClass == object.getClass()) return object; + if (primitiveClass == String.class && object.getClass() == SubstituteString.class) return object; if (object.getClass() == String.class) return convertFromString((String)object); if (object instanceof Number) return convertFromNumber((Number)object); - return null; } private Object convertFromString(String string) { try { - if (primitiveClass==Integer.class) return Integer.valueOf(string); - if (primitiveClass==Double.class) return Double.valueOf(string); - if (primitiveClass==Float.class) return Float.valueOf(string); - if (primitiveClass==Long.class) return Long.valueOf(string); - if (primitiveClass==Boolean.class) return Boolean.valueOf(string); + if (primitiveClass == Integer.class) return Integer.valueOf(string); + if (primitiveClass == Double.class) return Double.valueOf(string); + if (primitiveClass == Float.class) return Float.valueOf(string); + if (primitiveClass == Long.class) return Long.valueOf(string); + if (primitiveClass == Boolean.class) return Boolean.valueOf(string); } catch (NumberFormatException e) { return null; // Handled in caller @@ -65,11 +67,11 @@ public class PrimitiveFieldType extends FieldType { } private Object convertFromNumber(Number number) { - if (primitiveClass==Integer.class) return number.intValue(); - if (primitiveClass==Double.class) return number.doubleValue(); - if (primitiveClass==Float.class) return number.floatValue(); - if (primitiveClass==Long.class) return number.longValue(); - if (primitiveClass==String.class) return String.valueOf(number); + if (primitiveClass == Integer.class) return number.intValue(); + if (primitiveClass == Double.class) return number.doubleValue(); + if (primitiveClass == Float.class) return number.floatValue(); + if (primitiveClass == Long.class) return number.longValue(); + if (primitiveClass == String.class) return String.valueOf(number); throw new RuntimeException("Programming error: Input type is " + number.getClass() + " primitiveClass is " + primitiveClass); } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index db6a58a4dd3..cc6b18af820 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -2,7 +2,9 @@ package com.yahoo.search.query.profile.types; import com.yahoo.language.process.Embedder; +import com.yahoo.processing.request.Properties; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.SubstituteString; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -51,6 +53,7 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { + if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); Tensor tensor = toTensor(o, context); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) @@ -60,12 +63,16 @@ public class TensorFieldType extends FieldType { private Tensor toTensor(Object o, ConversionContext context) { if (o instanceof Tensor) return (Tensor)o; - if (o instanceof String && ((String)o).startsWith("embed(")) return encode((String)o, context); + if (o instanceof String && isEmbed((String)o)) return embed((String)o, type, context); if (o instanceof String) return Tensor.from(type, (String)o); return null; } - private Tensor encode(String s, ConversionContext context) { + static boolean isEmbed(String value) { + return value.startsWith("embed("); + } + + static Tensor embed(String s, TensorType type, ConversionContext context) { if ( ! s.endsWith(")")) throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); String argument = s.substring("embed(".length(), s.length() - 1); @@ -78,14 +85,14 @@ public class TensorFieldType extends FieldType { argument = matcher.group(2); if (!context.embedders().containsKey(embedderId)) { throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(context.embedders())); + "Valid embedders are " + validEmbedders(context.embedders())); } embedder = context.embedders().get(embedderId); } else if (context.embedders().size() == 0) { throw new IllegalStateException("No embedders provided"); // should never happen } else if (context.embedders().size() > 1) { throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + - "Valid embedders are " + validEmbedders(context.embedders())); + "Valid embedders are " + validEmbedders(context.embedders())); } else { embedder = context.embedders().entrySet().stream().findFirst().get().getValue(); } @@ -110,7 +117,7 @@ public class TensorFieldType extends FieldType { return String.join(",", embedderIds); } - private Embedder.Context toEmbedderContext(ConversionContext context) { + private static Embedder.Context toEmbedderContext(ConversionContext context) { return new Embedder.Context(context.destination()).setLanguage(context.language()); } @@ -118,4 +125,26 @@ public class TensorFieldType extends FieldType { return new TensorFieldType(TensorType.fromSpec(s)); } + /** + * A substitute string that should become a tensor once the substitution is performed at lookup time. + * This is to support substitution strings in tensor values by parsing (only) such tensors at + * lookup time rather than at construction time. + */ + private static class SubstituteStringTensor extends SubstituteString { + + private final TensorType type; + + SubstituteStringTensor(SubstituteString string, TensorType type) { + super(string.components(), string.stringValue()); + this.type = type; + } + + @Override + public Object substitute(Map<String, String> context, Properties substitution) { + String substituted = super.substitute(context, substitution).toString(); + return Tensor.from(type, substituted); + } + + } + } |