summaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
diff options
context:
space:
mode:
Diffstat (limited to 'container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java')
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java37
1 files changed, 26 insertions, 11 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
index 25d0e184588..06c551692fb 100644
--- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
@@ -79,17 +79,32 @@ public class ValidateNearestNeighborSearcher extends Searcher {
return true;
}
- private static boolean isCompatible(TensorType lhs, TensorType rhs) {
- return lhs.dimensions().equals(rhs.dimensions());
+ private static boolean isCompatible(TensorType fieldTensorType, TensorType queryTensorType) {
+ // Precondition: isTensorTypeThatSupportsHnswIndex(fieldTensorType)
+ var queryDimensions = queryTensorType.dimensions();
+ if (queryDimensions.size() == 1) {
+ var queryDimension = queryDimensions.get(0);
+ var fieldDimensions = fieldTensorType.dimensions();
+ for (var fieldDimension : fieldDimensions) {
+ if (fieldDimension.isIndexed()) {
+ return fieldDimension.equals(queryDimension);
+ }
+ }
+ }
+ return false;
}
- private static boolean isDenseVector(TensorType tt) {
+ private static boolean isTensorTypeThatSupportsHnswIndex(TensorType tt) {
List<TensorType.Dimension> dims = tt.dimensions();
- if (dims.size() != 1) return false;
- for (var d : dims) {
- if (d.type() != TensorType.Dimension.Type.indexedBound) return false;
+ if (dims.size() == 1) {
+ return dims.get(0).isIndexed();
}
- return true;
+ if (dims.size() == 2) {
+ var dims0 = dims.get(0);
+ var dims1 = dims.get(1);
+ return ((dims0.isMapped() && dims1.isIndexed()) || (dims0.isIndexed() && dims1.isMapped()));
+ }
+ return false;
}
/** Returns an error message if this is invalid, or null if it is valid */
@@ -107,18 +122,18 @@ public class ValidateNearestNeighborSearcher extends Searcher {
}
List<TensorType> allTensorTypes = validAttributes.get(item.getIndexName());
for (TensorType fieldType : allTensorTypes) {
- if (isDenseVector(fieldType) && isCompatible(fieldType, queryTensor.get().type())) {
+ if (isTensorTypeThatSupportsHnswIndex(fieldType) && isCompatible(fieldType, queryTensor.get().type())) {
return null;
}
}
for (TensorType fieldType : allTensorTypes) {
- if (isDenseVector(fieldType) && ! isCompatible(fieldType, queryTensor.get().type())) {
+ if (isTensorTypeThatSupportsHnswIndex(fieldType) && ! isCompatible(fieldType, queryTensor.get().type())) {
return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type();
}
}
for (TensorType fieldType : allTensorTypes) {
- if (! isDenseVector(fieldType)) {
- return item + " tensor type " + fieldType + " is not a dense vector";
+ if (! isTensorTypeThatSupportsHnswIndex(fieldType)) {
+ return item + " field type " + fieldType + " is not supported by nearest neighbor searcher";
}
}
return item + " field is not a tensor";