diff options
6 files changed, 46 insertions, 48 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index 25d0e184588..06c551692fb 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -79,17 +79,32 @@ public class ValidateNearestNeighborSearcher extends Searcher { return true; } - private static boolean isCompatible(TensorType lhs, TensorType rhs) { - return lhs.dimensions().equals(rhs.dimensions()); + private static boolean isCompatible(TensorType fieldTensorType, TensorType queryTensorType) { + // Precondition: isTensorTypeThatSupportsHnswIndex(fieldTensorType) + var queryDimensions = queryTensorType.dimensions(); + if (queryDimensions.size() == 1) { + var queryDimension = queryDimensions.get(0); + var fieldDimensions = fieldTensorType.dimensions(); + for (var fieldDimension : fieldDimensions) { + if (fieldDimension.isIndexed()) { + return fieldDimension.equals(queryDimension); + } + } + } + return false; } - private static boolean isDenseVector(TensorType tt) { + private static boolean isTensorTypeThatSupportsHnswIndex(TensorType tt) { List<TensorType.Dimension> dims = tt.dimensions(); - if (dims.size() != 1) return false; - for (var d : dims) { - if (d.type() != TensorType.Dimension.Type.indexedBound) return false; + if (dims.size() == 1) { + return dims.get(0).isIndexed(); } - return true; + if (dims.size() == 2) { + var dims0 = dims.get(0); + var dims1 = dims.get(1); + return ((dims0.isMapped() && dims1.isIndexed()) || (dims0.isIndexed() && dims1.isMapped())); + } + return false; } /** Returns an error message if this is invalid, or null if it is valid */ @@ -107,18 +122,18 @@ public class ValidateNearestNeighborSearcher extends Searcher { } List<TensorType> allTensorTypes = validAttributes.get(item.getIndexName()); for (TensorType fieldType : allTensorTypes) { - if (isDenseVector(fieldType) && isCompatible(fieldType, queryTensor.get().type())) { + if (isTensorTypeThatSupportsHnswIndex(fieldType) && isCompatible(fieldType, queryTensor.get().type())) { return null; } } for (TensorType fieldType : allTensorTypes) { - if (isDenseVector(fieldType) && ! isCompatible(fieldType, queryTensor.get().type())) { + if (isTensorTypeThatSupportsHnswIndex(fieldType) && ! isCompatible(fieldType, queryTensor.get().type())) { return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type(); } } for (TensorType fieldType : allTensorTypes) { - if (! isDenseVector(fieldType)) { - return item + " tensor type " + fieldType + " is not a dense vector"; + if (! isTensorTypeThatSupportsHnswIndex(fieldType)) { + return item + " field type " + fieldType + " is not supported by nearest neighbor searcher"; } } return item + " field is not a tensor"; diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index b8be0b3dd43..9a7e3915d19 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -33,7 +33,7 @@ public class ValidateNearestNeighborTestCase { searcher = new ValidateNearestNeighborSearcher( ConfigGetter.getConfig(AttributesConfig.class, "raw:" + - "attribute[5]\n" + + "attribute[9]\n" + "attribute[0].name simple\n" + "attribute[0].datatype INT32\n" + "attribute[1].name dvector\n" + @@ -56,7 +56,10 @@ public class ValidateNearestNeighborTestCase { "attribute[6].tensortype tensor(x[3])\n" + "attribute[7].name threetypes\n" + "attribute[7].datatype TENSOR\n" + - "attribute[7].tensortype tensor(x{})\n" + "attribute[7].tensortype tensor(x{})\n" + + "attribute[8].name mixeddvector\n" + + "attribute[8].datatype TENSOR\n" + + "attribute[8].tensortype tensor(a{},x[3])\n" )); } @@ -134,6 +137,14 @@ public class ValidateNearestNeighborTestCase { assertNull(r.hits().getError()); } + @Test + void testvalidQueryMixedFieldTensor() { + String q = makeQuery("mixeddvector", "qvector"); + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertNull(r.hits().getError()); + } + private static void assertErrMsg(String message, Result r) { assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError()); } @@ -210,7 +221,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("sparse", "qvector"); Tensor t = makeTensor(tt_sparse_vector_x); Result r = doSearch(searcher, q, t); - assertErrMsg(desc("sparse", "qvector", 1, "tensor type tensor(x{}) is not a dense vector"), r); + assertErrMsg(desc("sparse", "qvector", 1, "field type tensor(x{}) is not supported by nearest neighbor searcher"), r); } @Test @@ -218,7 +229,7 @@ public class ValidateNearestNeighborTestCase { String q = makeQuery("matrix", "qvector"); Tensor t = makeMatrix(tt_dense_matrix_xy); Result r = doSearch(searcher, q, t); - assertErrMsg(desc("matrix", "qvector", 1, "tensor type tensor(x[3],y[1]) is not a dense vector"), r); + assertErrMsg(desc("matrix", "qvector", 1, "field type tensor(x[3],y[1]) is not supported by nearest neighbor searcher"), r); } private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Tensor qTensor) { diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp index b3f8195676d..92c9a21db83 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp @@ -13,17 +13,6 @@ using vespalib::eval::CellType; namespace search::queryeval { -namespace { - -bool -is_compatible(const vespalib::eval::ValueType& lhs, - const vespalib::eval::ValueType& rhs) -{ - return (lhs.dimensions() == rhs.dimensions()); -} - -} - /** * Search iterator for K nearest neighbor matching. * Uses unpack() as feedback mechanism to track which matches actually became hits. @@ -39,8 +28,6 @@ public: : NearestNeighborIterator(params_in), _lastScore(0.0) { - assert(is_compatible(params().distance_calc.attribute_tensor().getTensorType(), - params().distance_calc.query_tensor().type())); } ~NearestNeighborImpl(); diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp index ca09c0e58d9..157cbd41a95 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp @@ -1,15 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "dense_tensor_attribute.h" -#include "nearest_neighbor_index.h" -#include "tensor_attribute_loader.h" -#include "tensor_attribute_saver.h" -#include <vespa/eval/eval/value.h> #include <vespa/searchcommon/attribute/config.h> -#include <vespa/vespalib/data/slime/inserter.h> - -using vespalib::eval::Value; -using vespalib::slime::ObjectInserter; namespace search::tensor { @@ -37,17 +29,6 @@ DenseTensorAttribute::extract_cells_ref(DocId docId) const return _denseTensorStore.get_typed_cells(ref); } -void -DenseTensorAttribute::get_state(const vespalib::slime::Inserter& inserter) const -{ - auto& object = inserter.insertObject(); - populate_state(object); - if (_index) { - ObjectInserter index_inserter(object, "nearest_neighbor_index"); - _index->get_state(index_inserter); - } -} - vespalib::eval::TypedCells DenseTensorAttribute::get_vector(uint32_t docid, uint32_t subspace) const { diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h index 89f0fd1bd06..8bda4bdacd7 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h @@ -24,7 +24,6 @@ public: // Implements AttributeVector and ITensorAttribute vespalib::eval::TypedCells extract_cells_ref(DocId docId) const override; bool supports_extract_cells_ref() const override { return true; } - void get_state(const vespalib::slime::Inserter& inserter) const override; // Implements DocVectorAccess vespalib::eval::TypedCells get_vector(uint32_t docid, uint32_t subspace) const override; diff --git a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp index 5c50b2d83a2..b63034acfc4 100644 --- a/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/tensor_attribute.cpp @@ -30,6 +30,7 @@ using vespalib::eval::FastValueBuilderFactory; using vespalib::eval::TensorSpec; using vespalib::eval::Value; using vespalib::eval::ValueType; +using vespalib::slime::ObjectInserter; namespace search::tensor { @@ -224,6 +225,10 @@ TensorAttribute::populate_state(vespalib::slime::Cursor& object) const object.setObject("ref_vector").setObject("memory_usage")); StateExplorerUtils::memory_usage_to_slime(_tensorStore.getMemoryUsage(), object.setObject("tensor_store").setObject("memory_usage")); + if (_index) { + ObjectInserter index_inserter(object, "nearest_neighbor_index"); + _index->get_state(index_inserter); + } } void |