diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search')
5 files changed, 37 insertions, 47 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java index 70f6e405a92..bef766e7ef9 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/ConversionContext.java @@ -15,7 +15,6 @@ public class ConversionContext { private final String destination; private final CompiledQueryProfileRegistry registry; private final Map<String, Embedder> embedders; - private final Map<String, String> contextValues; private final Language language; public ConversionContext(String destination, CompiledQueryProfileRegistry registry, Embedder embedder, @@ -31,7 +30,6 @@ public class ConversionContext { this.embedders = embedders; this.language = context.containsKey("language") ? Language.fromLanguageTag(context.get("language")) : Language.UNKNOWN; - this.contextValues = context; } /** Returns the local name of the field which will receive the converted value (or null when this is empty) */ @@ -46,9 +44,6 @@ public class ConversionContext { /** Returns the language, which is never null but may be UNKNOWN */ Language language() { return language; } - /** Returns a read-only map of context key-values which can be looked up during conversion. */ - Map<String,String> contextValues() { return contextValues; } - /** Returns an empty context */ public static ConversionContext empty() { return new ConversionContext(null, null, Embedder.throwsOnUse.asMap(), Map.of()); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index e16f8e7b0cd..cfadd79de8f 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -48,8 +48,7 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); - return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, - context.language(), context.contextValues()); + return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language()); } public static TensorFieldType fromTypeString(String s) { diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index 25a5c277dce..c9f935e5f52 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -44,8 +44,7 @@ public class RankProfileInputProperties extends Properties { value = tensorConverter.convertTo(expectedType, name.last(), value, - query.getModel().getLanguage(), - context); + query.getModel().getLanguage()); } } catch (IllegalArgumentException e) { diff --git a/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java b/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java index eb81d0555b3..0d86e1409c3 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/Normalizer.java @@ -3,14 +3,29 @@ package com.yahoo.search.ranking; abstract class Normalizer { - protected final double[] data; + protected double[] data; protected int size = 0; - Normalizer(int maxSize) { - this.data = new double[maxSize]; + private static int initialCapacity(int hint) { + for (int capacity = 64; capacity < 4096; capacity *= 2) { + if (hint <= capacity) { + return capacity; + } + } + return 4096; + } + + Normalizer(int sizeHint) { + this.data = new double[initialCapacity(sizeHint)]; } int addInput(double value) { + if (size == data.length) { + int newSize = size * 2; + var tmp = new double[newSize]; + System.arraycopy(data, 0, tmp, 0, size); + this.data = tmp; + } data[size] = value; return size++; } diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java index 94f92c7fd48..6da53ae699c 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java +++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java @@ -19,8 +19,7 @@ import java.util.regex.Pattern; */ public class TensorConverter { - private static final Pattern embedderArgumentAndQuotedTextRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*([\"'].*[\"'])"); - private static final Pattern embedderArgumentAndReferenceRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*(@.*)"); + private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); private final Map<String, Embedder> embedders; @@ -28,9 +27,8 @@ public class TensorConverter { this.embedders = embedders; } - public Tensor convertTo(TensorType type, String key, Object value, Language language, - Map<String, String> contextValues) { - var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues); + public Tensor convertTo(TensorType type, String key, Object value, Language language) { + var context = new Embedder.Context(key).setLanguage(language); Tensor tensor = toTensor(type, value, context); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) @@ -57,16 +55,16 @@ public class TensorConverter { String embedderId; // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher; - if (( matcher = embedderArgumentAndQuotedTextRegexp.matcher(argument)).matches()) { + Matcher matcher = embedderArgumentRegexp.matcher(argument); + if (matcher.matches()) { embedderId = matcher.group(1); - embedder = requireEmbedder(embedderId); argument = matcher.group(2); - } else if (( matcher = embedderArgumentAndReferenceRegexp.matcher(argument)).matches()) { - embedderId = matcher.group(1); - embedder = requireEmbedder(embedderId); - argument = matcher.group(2); - } else if (embedders.isEmpty()) { + if ( ! embedders.containsKey(embedderId)) { + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + } + embedder = embedders.get(embedderId); + } else if (embedders.size() == 0) { throw new IllegalStateException("No embedders provided"); // should never happen } else if (embedders.size() > 1) { throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + @@ -76,35 +74,19 @@ public class TensorConverter { embedderId = entry.getKey(); embedder = entry.getValue(); } - return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type); + return embedder.embed(removeQuotes(argument), embedderContext.copy().setEmbedderId(embedderId), type); } - private Embedder requireEmbedder(String embedderId) { - if ( ! embedders.containsKey(embedderId)) - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); - return embedders.get(embedderId); - } - - private static String resolve(String s, Embedder.Context embedderContext) { - if (s.startsWith("'") && s.endsWith("'")) + private static String removeQuotes(String s) { + if (s.startsWith("'") && s.endsWith("'")) { return s.substring(1, s.length() - 1); - if (s.startsWith("\"") && s.endsWith("\"")) + } + if (s.startsWith("\"") && s.endsWith("\"")) { return s.substring(1, s.length() - 1); - if (s.startsWith("@")) - return resolveReference(s, embedderContext); + } return s; } - private static String resolveReference(String s, Embedder.Context embedderContext) { - String referenceKey = s.substring(1); - String referencedValue = embedderContext.getContextValues().get(referenceKey); - if (referencedValue == null) - throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey + - "' used in an embed() argument"); - return referencedValue; - } - private static String validEmbedders(Map<String, Embedder> embedders) { List<String> embedderIds = new ArrayList<>(); embedders.forEach((key, value) -> embedderIds.add(key)); |