aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java')
-rw-r--r--container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java74
1 files changed, 2 insertions, 72 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
index cc6b18af820..e0dea744075 100644
--- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
+++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java
@@ -1,18 +1,14 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.query.profile.types;
-import com.yahoo.language.process.Embedder;
import com.yahoo.processing.request.Properties;
+import com.yahoo.search.config.internal.TensorConverter;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.search.query.profile.SubstituteString;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.ArrayList;
-import java.util.List;
import java.util.Map;
-import java.util.regex.Matcher;
-import java.util.regex.Pattern;
/**
* A tensor field type in a query profile
@@ -21,8 +17,6 @@ import java.util.regex.Pattern;
*/
public class TensorFieldType extends FieldType {
- private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])");
-
private final TensorType type;
/** Creates a tensor field type with information about the kind of tensor this will hold */
@@ -54,71 +48,7 @@ public class TensorFieldType extends FieldType {
@Override
public Object convertFrom(Object o, ConversionContext context) {
if (o instanceof SubstituteString) return new SubstituteStringTensor((SubstituteString) o, type);
- Tensor tensor = toTensor(o, context);
- if (tensor == null) return null;
- if (! tensor.type().isAssignableTo(type))
- throw new IllegalArgumentException("Require a tensor of type " + type);
- return tensor;
- }
-
- private Tensor toTensor(Object o, ConversionContext context) {
- if (o instanceof Tensor) return (Tensor)o;
- if (o instanceof String && isEmbed((String)o)) return embed((String)o, type, context);
- if (o instanceof String) return Tensor.from(type, (String)o);
- return null;
- }
-
- static boolean isEmbed(String value) {
- return value.startsWith("embed(");
- }
-
- static Tensor embed(String s, TensorType type, ConversionContext context) {
- if ( ! s.endsWith(")"))
- throw new IllegalArgumentException("Expected any string enclosed in embed(), but the argument does not end by ')'");
- String argument = s.substring("embed(".length(), s.length() - 1);
- Embedder embedder;
-
- // Check if arguments specifies an embedder with the format embed(embedder, "text to encode")
- Matcher matcher = embedderArgumentRegexp.matcher(argument);
- if (matcher.matches()) {
- String embedderId = matcher.group(1);
- argument = matcher.group(2);
- if (!context.embedders().containsKey(embedderId)) {
- throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " +
- "Valid embedders are " + validEmbedders(context.embedders()));
- }
- embedder = context.embedders().get(embedderId);
- } else if (context.embedders().size() == 0) {
- throw new IllegalStateException("No embedders provided"); // should never happen
- } else if (context.embedders().size() > 1) {
- throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " +
- "Valid embedders are " + validEmbedders(context.embedders()));
- } else {
- embedder = context.embedders().entrySet().stream().findFirst().get().getValue();
- }
-
- return embedder.embed(removeQuotes(argument), toEmbedderContext(context), type);
- }
-
- private static String removeQuotes(String s) {
- if (s.startsWith("'") && s.endsWith("'")) {
- return s.substring(1, s.length() - 1);
- }
- if (s.startsWith("\"") && s.endsWith("\"")) {
- return s.substring(1, s.length() - 1);
- }
- return s;
- }
-
- private static String validEmbedders(Map<String, Embedder> embedders) {
- List<String> embedderIds = new ArrayList<>();
- embedders.forEach((key, value) -> embedderIds.add(key));
- embedderIds.sort(null);
- return String.join(",", embedderIds);
- }
-
- private static Embedder.Context toEmbedderContext(ConversionContext context) {
- return new Embedder.Context(context.destination()).setLanguage(context.language());
+ return new TensorConverter(context.embedders()).convertTo(type, context.destination(), o, context.language());
}
public static TensorFieldType fromTypeString(String s) {