diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java | 52 |
1 files changed, 17 insertions, 35 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java index 94f92c7fd48..6da53ae699c 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java +++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java @@ -19,8 +19,7 @@ import java.util.regex.Pattern; */ public class TensorConverter { - private static final Pattern embedderArgumentAndQuotedTextRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*([\"'].*[\"'])"); - private static final Pattern embedderArgumentAndReferenceRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*(@.*)"); + private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); private final Map<String, Embedder> embedders; @@ -28,9 +27,8 @@ public class TensorConverter { this.embedders = embedders; } - public Tensor convertTo(TensorType type, String key, Object value, Language language, - Map<String, String> contextValues) { - var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues); + public Tensor convertTo(TensorType type, String key, Object value, Language language) { + var context = new Embedder.Context(key).setLanguage(language); Tensor tensor = toTensor(type, value, context); if (tensor == null) return null; if (! tensor.type().isAssignableTo(type)) @@ -57,16 +55,16 @@ public class TensorConverter { String embedderId; // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher; - if (( matcher = embedderArgumentAndQuotedTextRegexp.matcher(argument)).matches()) { + Matcher matcher = embedderArgumentRegexp.matcher(argument); + if (matcher.matches()) { embedderId = matcher.group(1); - embedder = requireEmbedder(embedderId); argument = matcher.group(2); - } else if (( matcher = embedderArgumentAndReferenceRegexp.matcher(argument)).matches()) { - embedderId = matcher.group(1); - embedder = requireEmbedder(embedderId); - argument = matcher.group(2); - } else if (embedders.isEmpty()) { + if ( ! embedders.containsKey(embedderId)) { + throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + + "Valid embedders are " + validEmbedders(embedders)); + } + embedder = embedders.get(embedderId); + } else if (embedders.size() == 0) { throw new IllegalStateException("No embedders provided"); // should never happen } else if (embedders.size() > 1) { throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + @@ -76,35 +74,19 @@ public class TensorConverter { embedderId = entry.getKey(); embedder = entry.getValue(); } - return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type); + return embedder.embed(removeQuotes(argument), embedderContext.copy().setEmbedderId(embedderId), type); } - private Embedder requireEmbedder(String embedderId) { - if ( ! embedders.containsKey(embedderId)) - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(embedders)); - return embedders.get(embedderId); - } - - private static String resolve(String s, Embedder.Context embedderContext) { - if (s.startsWith("'") && s.endsWith("'")) + private static String removeQuotes(String s) { + if (s.startsWith("'") && s.endsWith("'")) { return s.substring(1, s.length() - 1); - if (s.startsWith("\"") && s.endsWith("\"")) + } + if (s.startsWith("\"") && s.endsWith("\"")) { return s.substring(1, s.length() - 1); - if (s.startsWith("@")) - return resolveReference(s, embedderContext); + } return s; } - private static String resolveReference(String s, Embedder.Context embedderContext) { - String referenceKey = s.substring(1); - String referencedValue = embedderContext.getContextValues().get(referenceKey); - if (referencedValue == null) - throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey + - "' used in an embed() argument"); - return referencedValue; - } - private static String validEmbedders(Map<String, Embedder> embedders) { List<String> embedderIds = new ArrayList<>(); embedders.forEach((key, value) -> embedderIds.add(key)); |