diff options
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java | 74 |
1 files changed, 2 insertions, 72 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index cc6b18af820..e0dea744075 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -1,18 +1,14 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.query.profile.types; -import com.yahoo.language.process.Embedder; import com.yahoo.processing.request.Properties; +import com.yahoo.search.config.internal.TensorConverter; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.search.query.profile.SubstituteString; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; -import java.util.ArrayList; -import java.util.List; import java.util.Map; -import java.util.regex.Matcher; -import java.util.regex.Pattern; /** * A tensor field type in a query profile @@ -21,8 +17,6 @@ import java.util.regex.Pattern; */ public class TensorFieldType extends FieldType { - private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])"); - private final TensorType type; /** Creates a tensor field type with information about the kind of tensor this will hold */ @@ -54,71 +48,7 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type); - Tensor tensor = toTensor(o, context); - if (tensor == null) return null; - if (! tensor.type().isAssignableTo(type)) - throw new IllegalArgumentException("Require a tensor of type " + type); - return tensor; - } - - private Tensor toTensor(Object o, ConversionContext context) { - if (o instanceof Tensor) return (Tensor)o; - if (o instanceof String && isEmbed((String)o)) return embed((String)o, type, context); - if (o instanceof String) return Tensor.from(type, (String)o); - return null; - } - - static boolean isEmbed(String value) { - return value.startsWith("embed("); - } - - static Tensor embed(String s, TensorType type, ConversionContext context) { - if ( ! s.endsWith(")")) - throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'"); - String argument = s.substring("embed(".length(), s.length() - 1); - Embedder embedder; - - // Check if arguments specifies an embedder with the format embed(embedder, "text to encode") - Matcher matcher = embedderArgumentRegexp.matcher(argument); - if (matcher.matches()) { - String embedderId = matcher.group(1); - argument = matcher.group(2); - if (!context.embedders().containsKey(embedderId)) { - throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " + - "Valid embedders are " + validEmbedders(context.embedders())); - } - embedder = context.embedders().get(embedderId); - } else if (context.embedders().size() == 0) { - throw new IllegalStateException("No embedders provided"); // should never happen - } else if (context.embedders().size() > 1) { - throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " + - "Valid embedders are " + validEmbedders(context.embedders())); - } else { - embedder = context.embedders().entrySet().stream().findFirst().get().getValue(); - } - - return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type); - } - - private static String removeQuotes(String s) { - if (s.startsWith("'") && s.endsWith("'")) { - return s.substring(1, s.length() - 1); - } - if (s.startsWith("\"") && s.endsWith("\"")) { - return s.substring(1, s.length() - 1); - } - return s; - } - - private static String validEmbedders(Map<String, Embedder> embedders) { - List<String> embedderIds = new ArrayList<>(); - embedders.forEach((key, value) -> embedderIds.add(key)); - embedderIds.sort(null); - return String.join(",", embedderIds); - } - - private static Embedder.Context toEmbedderContext(ConversionContext context) { - return new Embedder.Context(context.destination()).setLanguage(context.language()); + return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language()); } public static TensorFieldType fromTypeString(String s) { |