diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-06-14 16:51:31 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-06-14 16:53:30 +0000 |
commit | 67b7cdb8fefde4dc5843d646ce9fbd33acb37f7f (patch) | |
tree | 2d6dc4b6c56863973a72c977a46d38941f181e74 /container-search/src/main/java/com/yahoo | |
parent | 05617553a3f8d563c1955b82e81c54f4cb2136b5 (diff) |
allow multiple tensor types for same name
* with several document types you can have fields with the
same name but different tensor types that all should be
allowed as target for nearestNeighbor operator.
Diffstat (limited to 'container-search/src/main/java/com/yahoo')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java | 44 |
1 files changed, 28 insertions, 16 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index ca9d17cb656..d22dd2e6af6 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -19,6 +19,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.yolean.chain.Before; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,15 +33,17 @@ import java.util.Optional; @Before(GroupingExecutor.COMPONENT_NAME) // Must happen before query.prepare() public class ValidateNearestNeighborSearcher extends Searcher { - private final Map<String, TensorType> validAttributes = new HashMap<>(); + private final Map<String, List<TensorType>> validAttributes = new HashMap<>(); public ValidateNearestNeighborSearcher(AttributesConfig attributesConfig) { for (AttributesConfig.Attribute a : attributesConfig.attribute()) { - TensorType tt = null; + if (! validAttributes.containsKey(a.name())) { + validAttributes.put(a.name(), new ArrayList<TensorType>()); + } if (a.datatype() == AttributesConfig.Attribute.Datatype.TENSOR) { - tt = TensorType.fromSpec(a.tensortype()); + TensorType tt = TensorType.fromSpec(a.tensortype()); + validAttributes.get(a.name()).add(tt); } - validAttributes.put(a.name(), tt); } } @@ -60,10 +63,10 @@ public class ValidateNearestNeighborSearcher extends Searcher { public Optional<ErrorMessage> errorMessage = Optional.empty(); - private final Map<String, TensorType> validAttributes; + private final Map<String, List<TensorType>> validAttributes; private final Query query; - public NNVisitor(RankProperties rankProperties, Map<String, TensorType> validAttributes, Query query) { + public NNVisitor(RankProperties rankProperties, Map<String, List<TensorType>> validAttributes, Query query) { this.validAttributes = validAttributes; this.query = query; } @@ -101,17 +104,26 @@ public class ValidateNearestNeighborSearcher extends Searcher { if (queryTensor.isEmpty()) return item + " requires a tensor rank feature " + queryFeatureName + " but this is not present"; - if ( ! validAttributes.containsKey(item.getIndexName())) + if ( ! validAttributes.containsKey(item.getIndexName())) { return item + " field is not an attribute"; - TensorType fieldType = validAttributes.get(item.getIndexName()); - if (fieldType == null) return item + " field is not a tensor"; - if ( ! isDenseVector(fieldType)) - return item + " tensor type " + fieldType + " is not a dense vector"; - - if ( ! isCompatible(fieldType, queryTensor.get().type())) - return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type(); - - return null; + } + List<TensorType> allTensorTypes = validAttributes.get(item.getIndexName()); + for (TensorType fieldType : allTensorTypes) { + if (isDenseVector(fieldType) && isCompatible(fieldType, queryTensor.get().type())) { + return null; + } + } + for (TensorType fieldType : allTensorTypes) { + if (isDenseVector(fieldType) && ! isCompatible(fieldType, queryTensor.get().type())) { + return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type(); + } + } + for (TensorType fieldType : allTensorTypes) { + if (! isDenseVector(fieldType)) { + return item + " tensor type " + fieldType + " is not a dense vector"; + } + } + return item + " field is not a tensor"; } @Override |