diff options
author | Geir Storli <geirst@yahooinc.com> | 2022-07-06 09:55:18 +0000 |
---|---|---|
committer | Geir Storli <geirst@yahooinc.com> | 2022-07-06 09:55:18 +0000 |
commit | a7ab4ec6149b691f4a3cf2241f70aa66f0b72861 (patch) | |
tree | 8a4ac016651e2e8cad0fbe3fe1941baa044bef71 /searchlib/src | |
parent | 8dae227258dde84db5116922fbc616dc1d70d3a7 (diff) |
Refactor validation code for setting up a distance calculator for re-use in rank features.
Diffstat (limited to 'searchlib/src')
6 files changed, 87 insertions, 53 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index b6953ec5dca..b93398e16a1 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -46,8 +46,8 @@ using search::queryeval::NearestNeighborBlueprint; using search::tensor::DefaultNearestNeighborIndexFactory; using search::tensor::DenseTensorAttribute; using search::tensor::DirectTensorAttribute; +using search::tensor::DistanceCalculator; using search::tensor::DocVectorAccess; -using search::tensor::SerializedFastValueAttribute; using search::tensor::HnswIndex; using search::tensor::HnswNode; using search::tensor::NearestNeighborIndex; @@ -55,13 +55,14 @@ using search::tensor::NearestNeighborIndexFactory; using search::tensor::NearestNeighborIndexLoader; using search::tensor::NearestNeighborIndexSaver; using search::tensor::PrepareResult; +using search::tensor::SerializedFastValueAttribute; using search::tensor::TensorAttribute; using vespalib::datastore::CompactionStrategy; -using vespalib::eval::TensorSpec; using vespalib::eval::CellType; -using vespalib::eval::ValueType; -using vespalib::eval::Value; using vespalib::eval::SimpleValue; +using vespalib::eval::TensorSpec; +using vespalib::eval::Value; +using vespalib::eval::ValueType; using DoubleVector = std::vector<double>; using generation_t = vespalib::GenerationHandler::generation_t; @@ -1072,8 +1073,8 @@ public: search::queryeval::FieldSpec field("foo", 0, 0); auto bp = std::make_unique<NearestNeighborBlueprint>( field, - this->as_dense_tensor(), - create_query_tensor(vec_2d(17, 42)), + std::make_unique<DistanceCalculator>(this->as_dense_tensor(), + create_query_tensor(vec_2d(17, 42))), 3, approximate, 5, 100100.25, global_filter_lower_limit, 1.0); diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index d9db50ae816..7af2186ed1e 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -599,13 +599,6 @@ bool check_valid_diversity_attr(const IAttributeVector *attr) { return (attr->hasEnum() || attr->isIntegerType() || attr->isFloatingPointType()); } -bool -is_compatible_for_nearest_neighbor(const vespalib::eval::ValueType& lhs, - const vespalib::eval::ValueType& rhs) -{ - return (lhs.dimensions() == rhs.dimensions()); -} - //----------------------------------------------------------------------------- @@ -760,40 +753,24 @@ public: setResult(std::make_unique<queryeval::EmptyBlueprint>(_field)); } void visit(query::NearestNeighborTerm &n) override { - const ITensorAttribute *tensor_attr = _attr.asTensorAttribute(); - if (tensor_attr == nullptr) { - return fail_nearest_neighbor_term(n, "Attribute is not a tensor"); - } - const auto & ta_type = tensor_attr->getTensorType(); - if ((! ta_type.is_dense()) || (ta_type.dimensions().size() != 1)) { - return fail_nearest_neighbor_term(n, make_string("Attribute tensor type (%s) is not a dense tensor of order 1", - ta_type.to_spec().c_str())); - } const auto* query_tensor = getRequestContext().get_query_tensor(n.get_query_tensor_name()); if (query_tensor == nullptr) { return fail_nearest_neighbor_term(n, "Query tensor was not found in request context"); } - const auto & qt_type = query_tensor->type(); - if (! qt_type.is_dense()) { - return fail_nearest_neighbor_term(n, make_string("Query tensor is not a dense tensor (type=%s)", - qt_type.to_spec().c_str())); - } - if (!is_compatible_for_nearest_neighbor(ta_type, qt_type)) { - return fail_nearest_neighbor_term(n, make_string("Attribute tensor type (%s) and query tensor type (%s) are not compatible", - ta_type.to_spec().c_str(), qt_type.to_spec().c_str())); - } - if (tensor_attr->supports_extract_cells_ref() == false) { - return fail_nearest_neighbor_term(n, make_string("Attribute does not support access to tensor data (type=%s)", - ta_type.to_spec().c_str())); + try { + auto calc = tensor::DistanceCalculator::make_with_validation(_attr, *query_tensor); + setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field, + std::move(calc), + n.get_target_num_hits(), + n.get_allow_approximate(), + n.get_explore_additional_hits(), + n.get_distance_threshold(), + getRequestContext().get_attribute_blueprint_params().global_filter_lower_limit, + getRequestContext().get_attribute_blueprint_params().global_filter_upper_limit)); + } catch (const vespalib::IllegalArgumentException& ex) { + return fail_nearest_neighbor_term(n, ex.getMessage()); + } - setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field, *tensor_attr, - *query_tensor, - n.get_target_num_hits(), - n.get_allow_approximate(), - n.get_explore_additional_hits(), - n.get_distance_threshold(), - getRequestContext().get_attribute_blueprint_params().global_filter_lower_limit, - getRequestContext().get_attribute_blueprint_params().global_filter_upper_limit)); } void visit(query::FuzzyTerm &n) override { visitTerm(n); } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index a36a0006c76..6a891341afd 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -34,8 +34,7 @@ to_string(NearestNeighborBlueprint::Algorithm algorithm) } // namespace <unnamed> NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field, - const tensor::ITensorAttribute& attr_tensor, - const Value& query_tensor, + std::unique_ptr<search::tensor::DistanceCalculator> distance_calc, uint32_t target_hits, bool approximate, uint32_t explore_additional_hits, @@ -43,9 +42,9 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f double global_filter_lower_limit, double global_filter_upper_limit) : ComplexLeafBlueprint(field), - _attr_tensor(attr_tensor), - _distance_calc(_attr_tensor, query_tensor), - _query_tensor(_distance_calc.query_tensor()), + _distance_calc(std::move(distance_calc)), + _attr_tensor(_distance_calc->attribute_tensor()), + _query_tensor(_distance_calc->query_tensor()), _target_hits(target_hits), _adjusted_target_hits(target_hits), _approximate(approximate), @@ -62,7 +61,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f _global_filter_hit_ratio() { if (distance_threshold < std::numeric_limits<double>::max()) { - _distance_threshold = _distance_calc.function().convert_threshold(distance_threshold); + _distance_threshold = _distance_calc->function().convert_threshold(distance_threshold); _distance_heap.set_distance_threshold(_distance_threshold); } uint32_t est_hits = _attr_tensor.get_num_docs(); @@ -127,11 +126,11 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData switch (_algorithm) { case Algorithm::INDEX_TOP_K_WITH_FILTER: case Algorithm::INDEX_TOP_K: - return NnsIndexIterator::create(tfmd, _found_hits, _distance_calc.function()); + return NnsIndexIterator::create(tfmd, _found_hits, _distance_calc->function()); default: ; } - return NearestNeighborIterator::create(strict, tfmd, _distance_calc, + return NearestNeighborIterator::create(strict, tfmd, *_distance_calc, _distance_heap, _global_filter->filter()); } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index 9948cce1407..3dd03291b97 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -28,8 +28,8 @@ public: INDEX_TOP_K_WITH_FILTER }; private: + std::unique_ptr<search::tensor::DistanceCalculator> _distance_calc; const tensor::ITensorAttribute& _attr_tensor; - search::tensor::DistanceCalculator _distance_calc; const vespalib::eval::Value& _query_tensor; uint32_t _target_hits; uint32_t _adjusted_target_hits; @@ -49,8 +49,7 @@ private: void perform_top_k(const search::tensor::NearestNeighborIndex* nns_index); public: NearestNeighborBlueprint(const queryeval::FieldSpec& field, - const tensor::ITensorAttribute& attr_tensor, - const vespalib::eval::Value& query_tensor, + std::unique_ptr<search::tensor::DistanceCalculator> distance_calc, uint32_t target_hits, bool approximate, uint32_t explore_additional_hits, double distance_threshold, double global_filter_lower_limit, diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp index adfa5b7ee4a..d6d5433ff15 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp @@ -4,12 +4,17 @@ #include "distance_function_factory.h" #include "nearest_neighbor_index.h" #include <vespa/eval/eval/fast_value.h> +#include <vespa/searchcommon/attribute/iattributevector.h> +#include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/util/stringfmt.h> +using vespalib::IllegalArgumentException; using vespalib::eval::CellType; using vespalib::eval::FastValueBuilderFactory; using vespalib::eval::TypedCells; using vespalib::eval::Value; using vespalib::eval::ValueType; +using vespalib::make_string; namespace { @@ -42,6 +47,13 @@ struct ConvertCellsSelector } }; +bool +is_compatible(const vespalib::eval::ValueType& lhs, + const vespalib::eval::ValueType& rhs) +{ + return (lhs.dimensions() == rhs.dimensions()); +} + } namespace search::tensor { @@ -86,5 +98,40 @@ DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tens DistanceCalculator::~DistanceCalculator() = default; +namespace { + + + +} + +std::unique_ptr<DistanceCalculator> +DistanceCalculator::make_with_validation(const search::attribute::IAttributeVector& attr, + const vespalib::eval::Value& query_tensor_in) +{ + const ITensorAttribute* attr_tensor = attr.asTensorAttribute(); + if (attr_tensor == nullptr) { + throw IllegalArgumentException("Attribute is not a tensor"); + } + const auto& at_type = attr_tensor->getTensorType(); + if ((!at_type.is_dense()) || (at_type.dimensions().size() != 1)) { + throw IllegalArgumentException(make_string("Attribute tensor type (%s) is not a dense tensor of order 1", + at_type.to_spec().c_str())); + } + const auto& qt_type = query_tensor_in.type(); + if (!qt_type.is_dense()) { + throw IllegalArgumentException(make_string("Query tensor type (%s) is not a dense tensor", + qt_type.to_spec().c_str())); + } + if (!is_compatible(at_type, qt_type)) { + throw IllegalArgumentException(make_string("Attribute tensor type (%s) and query tensor type (%s) are not compatible", + at_type.to_spec().c_str(), qt_type.to_spec().c_str())); + } + if (!attr_tensor->supports_extract_cells_ref()) { + throw IllegalArgumentException(make_string("Attribute tensor does not support access to tensor data (type=%s)", + at_type.to_spec().c_str())); + } + return std::make_unique<DistanceCalculator>(*attr_tensor, query_tensor_in); +} + } diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h index f1cc7feb9df..1bd9586a2bb 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h @@ -6,6 +6,8 @@ namespace vespalib::eval { struct Value; } +namespace search::attribute { class IAttributeVector; } + namespace search::tensor { /** @@ -43,6 +45,15 @@ public: double calc_with_limit(uint32_t docid, double limit) const { return _dist_fun->calc_with_limit(_query_tensor_cells, _attr_tensor.extract_cells_ref(docid), limit); } + + /** + * Create a calculator for the given attribute tensor and query tensor, if possible. + * + * Throws vespalib::IllegalArgumentException if the inputs are not supported or incompatible. + */ + static std::unique_ptr<DistanceCalculator> make_with_validation(const search::attribute::IAttributeVector& attr, + const vespalib::eval::Value& query_tensor_in); + }; } |