summaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java')
-rw-r--r--container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java52
1 files changed, 17 insertions, 35 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java
index 94f92c7fd48..6da53ae699c 100644
--- a/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java
+++ b/container-search/src/main/java/com/yahoo/search/schema/internal/TensorConverter.java
@@ -19,8 +19,7 @@ import java.util.regex.Pattern;
*/
public class TensorConverter {
- private static final Pattern embedderArgumentAndQuotedTextRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*([\"'].*[\"'])");
- private static final Pattern embedderArgumentAndReferenceRegexp = Pattern.compile("^([A-Za-z0-9_@\\-.]+),\\s*(@.*)");
+ private static final Pattern embedderArgumentRegexp = Pattern.compile("^([A-Za-z0-9_\\-.]+),\\s*([\"'].*[\"'])");
private final Map<String, Embedder> embedders;
@@ -28,9 +27,8 @@ public class TensorConverter {
this.embedders = embedders;
}
- public Tensor convertTo(TensorType type, String key, Object value, Language language,
- Map<String, String> contextValues) {
- var context = new Embedder.Context(key).setLanguage(language).setContextValues(contextValues);
+ public Tensor convertTo(TensorType type, String key, Object value, Language language) {
+ var context = new Embedder.Context(key).setLanguage(language);
Tensor tensor = toTensor(type, value, context);
if (tensor == null) return null;
if (! tensor.type().isAssignableTo(type))
@@ -57,16 +55,16 @@ public class TensorConverter {
String embedderId;
// Check if arguments specifies an embedder with the format embed(embedder, "text to encode")
- Matcher matcher;
- if (( matcher = embedderArgumentAndQuotedTextRegexp.matcher(argument)).matches()) {
+ Matcher matcher = embedderArgumentRegexp.matcher(argument);
+ if (matcher.matches()) {
embedderId = matcher.group(1);
- embedder = requireEmbedder(embedderId);
argument = matcher.group(2);
- } else if (( matcher = embedderArgumentAndReferenceRegexp.matcher(argument)).matches()) {
- embedderId = matcher.group(1);
- embedder = requireEmbedder(embedderId);
- argument = matcher.group(2);
- } else if (embedders.isEmpty()) {
+ if ( ! embedders.containsKey(embedderId)) {
+ throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " +
+ "Valid embedders are " + validEmbedders(embedders));
+ }
+ embedder = embedders.get(embedderId);
+ } else if (embedders.size() == 0) {
throw new IllegalStateException("No embedders provided"); // should never happen
} else if (embedders.size() > 1) {
throw new IllegalArgumentException("Multiple embedders are provided but no embedder id is given. " +
@@ -76,35 +74,19 @@ public class TensorConverter {
embedderId = entry.getKey();
embedder = entry.getValue();
}
- return embedder.embed(resolve(argument, embedderContext), embedderContext.copy().setEmbedderId(embedderId), type);
+ return embedder.embed(removeQuotes(argument), embedderContext.copy().setEmbedderId(embedderId), type);
}
- private Embedder requireEmbedder(String embedderId) {
- if ( ! embedders.containsKey(embedderId))
- throw new IllegalArgumentException("Can't find embedder '" + embedderId + "'. " +
- "Valid embedders are " + validEmbedders(embedders));
- return embedders.get(embedderId);
- }
-
- private static String resolve(String s, Embedder.Context embedderContext) {
- if (s.startsWith("'") && s.endsWith("'"))
+ private static String removeQuotes(String s) {
+ if (s.startsWith("'") && s.endsWith("'")) {
return s.substring(1, s.length() - 1);
- if (s.startsWith("\"") && s.endsWith("\""))
+ }
+ if (s.startsWith("\"") && s.endsWith("\"")) {
return s.substring(1, s.length() - 1);
- if (s.startsWith("@"))
- return resolveReference(s, embedderContext);
+ }
return s;
}
- private static String resolveReference(String s, Embedder.Context embedderContext) {
- String referenceKey = s.substring(1);
- String referencedValue = embedderContext.getContextValues().get(referenceKey);
- if (referencedValue == null)
- throw new IllegalArgumentException("Could not resolve query parameter reference '" + referenceKey +
- "' used in an embed() argument");
- return referencedValue;
- }
-
private static String validEmbedders(Map<String, Embedder> embedders) {
List<String> embedderIds = new ArrayList<>();
embedders.forEach((key, value) -> embedderIds.add(key));