aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2022-07-06 09:55:18 +0000
committerGeir Storli <geirst@yahooinc.com>2022-07-06 09:55:18 +0000
commita7ab4ec6149b691f4a3cf2241f70aa66f0b72861 (patch)
tree8a4ac016651e2e8cad0fbe3fe1941baa044bef71 /searchlib
parent8dae227258dde84db5116922fbc616dc1d70d3a7 (diff)
Refactor validation code for setting up a distance calculator for re-use in rank features.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp13
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp49
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp15
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h5
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp47
-rw-r--r--searchlib/src/vespa/searchlib/tensor/distance_calculator.h11
6 files changed, 87 insertions, 53 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index b6953ec5dca..b93398e16a1 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -46,8 +46,8 @@ using search::queryeval::NearestNeighborBlueprint;
using search::tensor::DefaultNearestNeighborIndexFactory;
using search::tensor::DenseTensorAttribute;
using search::tensor::DirectTensorAttribute;
+using search::tensor::DistanceCalculator;
using search::tensor::DocVectorAccess;
-using search::tensor::SerializedFastValueAttribute;
using search::tensor::HnswIndex;
using search::tensor::HnswNode;
using search::tensor::NearestNeighborIndex;
@@ -55,13 +55,14 @@ using search::tensor::NearestNeighborIndexFactory;
using search::tensor::NearestNeighborIndexLoader;
using search::tensor::NearestNeighborIndexSaver;
using search::tensor::PrepareResult;
+using search::tensor::SerializedFastValueAttribute;
using search::tensor::TensorAttribute;
using vespalib::datastore::CompactionStrategy;
-using vespalib::eval::TensorSpec;
using vespalib::eval::CellType;
-using vespalib::eval::ValueType;
-using vespalib::eval::Value;
using vespalib::eval::SimpleValue;
+using vespalib::eval::TensorSpec;
+using vespalib::eval::Value;
+using vespalib::eval::ValueType;
using DoubleVector = std::vector<double>;
using generation_t = vespalib::GenerationHandler::generation_t;
@@ -1072,8 +1073,8 @@ public:
search::queryeval::FieldSpec field("foo", 0, 0);
auto bp = std::make_unique<NearestNeighborBlueprint>(
field,
- this->as_dense_tensor(),
- create_query_tensor(vec_2d(17, 42)),
+ std::make_unique<DistanceCalculator>(this->as_dense_tensor(),
+ create_query_tensor(vec_2d(17, 42))),
3, approximate, 5,
100100.25,
global_filter_lower_limit, 1.0);
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
index d9db50ae816..7af2186ed1e 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
@@ -599,13 +599,6 @@ bool check_valid_diversity_attr(const IAttributeVector *attr) {
return (attr->hasEnum() || attr->isIntegerType() || attr->isFloatingPointType());
}
-bool
-is_compatible_for_nearest_neighbor(const vespalib::eval::ValueType& lhs,
- const vespalib::eval::ValueType& rhs)
-{
- return (lhs.dimensions() == rhs.dimensions());
-}
-
//-----------------------------------------------------------------------------
@@ -760,40 +753,24 @@ public:
setResult(std::make_unique<queryeval::EmptyBlueprint>(_field));
}
void visit(query::NearestNeighborTerm &n) override {
- const ITensorAttribute *tensor_attr = _attr.asTensorAttribute();
- if (tensor_attr == nullptr) {
- return fail_nearest_neighbor_term(n, "Attribute is not a tensor");
- }
- const auto & ta_type = tensor_attr->getTensorType();
- if ((! ta_type.is_dense()) || (ta_type.dimensions().size() != 1)) {
- return fail_nearest_neighbor_term(n, make_string("Attribute tensor type (%s) is not a dense tensor of order 1",
- ta_type.to_spec().c_str()));
- }
const auto* query_tensor = getRequestContext().get_query_tensor(n.get_query_tensor_name());
if (query_tensor == nullptr) {
return fail_nearest_neighbor_term(n, "Query tensor was not found in request context");
}
- const auto & qt_type = query_tensor->type();
- if (! qt_type.is_dense()) {
- return fail_nearest_neighbor_term(n, make_string("Query tensor is not a dense tensor (type=%s)",
- qt_type.to_spec().c_str()));
- }
- if (!is_compatible_for_nearest_neighbor(ta_type, qt_type)) {
- return fail_nearest_neighbor_term(n, make_string("Attribute tensor type (%s) and query tensor type (%s) are not compatible",
- ta_type.to_spec().c_str(), qt_type.to_spec().c_str()));
- }
- if (tensor_attr->supports_extract_cells_ref() == false) {
- return fail_nearest_neighbor_term(n, make_string("Attribute does not support access to tensor data (type=%s)",
- ta_type.to_spec().c_str()));
+ try {
+ auto calc = tensor::DistanceCalculator::make_with_validation(_attr, *query_tensor);
+ setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field,
+ std::move(calc),
+ n.get_target_num_hits(),
+ n.get_allow_approximate(),
+ n.get_explore_additional_hits(),
+ n.get_distance_threshold(),
+ getRequestContext().get_attribute_blueprint_params().global_filter_lower_limit,
+ getRequestContext().get_attribute_blueprint_params().global_filter_upper_limit));
+ } catch (const vespalib::IllegalArgumentException& ex) {
+ return fail_nearest_neighbor_term(n, ex.getMessage());
+
}
- setResult(std::make_unique<queryeval::NearestNeighborBlueprint>(_field, *tensor_attr,
- *query_tensor,
- n.get_target_num_hits(),
- n.get_allow_approximate(),
- n.get_explore_additional_hits(),
- n.get_distance_threshold(),
- getRequestContext().get_attribute_blueprint_params().global_filter_lower_limit,
- getRequestContext().get_attribute_blueprint_params().global_filter_upper_limit));
}
void visit(query::FuzzyTerm &n) override { visitTerm(n); }
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index a36a0006c76..6a891341afd 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -34,8 +34,7 @@ to_string(NearestNeighborBlueprint::Algorithm algorithm)
} // namespace <unnamed>
NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field,
- const tensor::ITensorAttribute& attr_tensor,
- const Value& query_tensor,
+ std::unique_ptr<search::tensor::DistanceCalculator> distance_calc,
uint32_t target_hits,
bool approximate,
uint32_t explore_additional_hits,
@@ -43,9 +42,9 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
double global_filter_lower_limit,
double global_filter_upper_limit)
: ComplexLeafBlueprint(field),
- _attr_tensor(attr_tensor),
- _distance_calc(_attr_tensor, query_tensor),
- _query_tensor(_distance_calc.query_tensor()),
+ _distance_calc(std::move(distance_calc)),
+ _attr_tensor(_distance_calc->attribute_tensor()),
+ _query_tensor(_distance_calc->query_tensor()),
_target_hits(target_hits),
_adjusted_target_hits(target_hits),
_approximate(approximate),
@@ -62,7 +61,7 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_global_filter_hit_ratio()
{
if (distance_threshold < std::numeric_limits<double>::max()) {
- _distance_threshold = _distance_calc.function().convert_threshold(distance_threshold);
+ _distance_threshold = _distance_calc->function().convert_threshold(distance_threshold);
_distance_heap.set_distance_threshold(_distance_threshold);
}
uint32_t est_hits = _attr_tensor.get_num_docs();
@@ -127,11 +126,11 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData
switch (_algorithm) {
case Algorithm::INDEX_TOP_K_WITH_FILTER:
case Algorithm::INDEX_TOP_K:
- return NnsIndexIterator::create(tfmd, _found_hits, _distance_calc.function());
+ return NnsIndexIterator::create(tfmd, _found_hits, _distance_calc->function());
default:
;
}
- return NearestNeighborIterator::create(strict, tfmd, _distance_calc,
+ return NearestNeighborIterator::create(strict, tfmd, *_distance_calc,
_distance_heap, _global_filter->filter());
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index 9948cce1407..3dd03291b97 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -28,8 +28,8 @@ public:
INDEX_TOP_K_WITH_FILTER
};
private:
+ std::unique_ptr<search::tensor::DistanceCalculator> _distance_calc;
const tensor::ITensorAttribute& _attr_tensor;
- search::tensor::DistanceCalculator _distance_calc;
const vespalib::eval::Value& _query_tensor;
uint32_t _target_hits;
uint32_t _adjusted_target_hits;
@@ -49,8 +49,7 @@ private:
void perform_top_k(const search::tensor::NearestNeighborIndex* nns_index);
public:
NearestNeighborBlueprint(const queryeval::FieldSpec& field,
- const tensor::ITensorAttribute& attr_tensor,
- const vespalib::eval::Value& query_tensor,
+ std::unique_ptr<search::tensor::DistanceCalculator> distance_calc,
uint32_t target_hits, bool approximate, uint32_t explore_additional_hits,
double distance_threshold,
double global_filter_lower_limit,
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
index adfa5b7ee4a..d6d5433ff15 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.cpp
@@ -4,12 +4,17 @@
#include "distance_function_factory.h"
#include "nearest_neighbor_index.h"
#include <vespa/eval/eval/fast_value.h>
+#include <vespa/searchcommon/attribute/iattributevector.h>
+#include <vespa/vespalib/util/exceptions.h>
+#include <vespa/vespalib/util/stringfmt.h>
+using vespalib::IllegalArgumentException;
using vespalib::eval::CellType;
using vespalib::eval::FastValueBuilderFactory;
using vespalib::eval::TypedCells;
using vespalib::eval::Value;
using vespalib::eval::ValueType;
+using vespalib::make_string;
namespace {
@@ -42,6 +47,13 @@ struct ConvertCellsSelector
}
};
+bool
+is_compatible(const vespalib::eval::ValueType& lhs,
+ const vespalib::eval::ValueType& rhs)
+{
+ return (lhs.dimensions() == rhs.dimensions());
+}
+
}
namespace search::tensor {
@@ -86,5 +98,40 @@ DistanceCalculator::DistanceCalculator(const tensor::ITensorAttribute& attr_tens
DistanceCalculator::~DistanceCalculator() = default;
+namespace {
+
+
+
+}
+
+std::unique_ptr<DistanceCalculator>
+DistanceCalculator::make_with_validation(const search::attribute::IAttributeVector& attr,
+ const vespalib::eval::Value& query_tensor_in)
+{
+ const ITensorAttribute* attr_tensor = attr.asTensorAttribute();
+ if (attr_tensor == nullptr) {
+ throw IllegalArgumentException("Attribute is not a tensor");
+ }
+ const auto& at_type = attr_tensor->getTensorType();
+ if ((!at_type.is_dense()) || (at_type.dimensions().size() != 1)) {
+ throw IllegalArgumentException(make_string("Attribute tensor type (%s) is not a dense tensor of order 1",
+ at_type.to_spec().c_str()));
+ }
+ const auto& qt_type = query_tensor_in.type();
+ if (!qt_type.is_dense()) {
+ throw IllegalArgumentException(make_string("Query tensor type (%s) is not a dense tensor",
+ qt_type.to_spec().c_str()));
+ }
+ if (!is_compatible(at_type, qt_type)) {
+ throw IllegalArgumentException(make_string("Attribute tensor type (%s) and query tensor type (%s) are not compatible",
+ at_type.to_spec().c_str(), qt_type.to_spec().c_str()));
+ }
+ if (!attr_tensor->supports_extract_cells_ref()) {
+ throw IllegalArgumentException(make_string("Attribute tensor does not support access to tensor data (type=%s)",
+ at_type.to_spec().c_str()));
+ }
+ return std::make_unique<DistanceCalculator>(*attr_tensor, query_tensor_in);
+}
+
}
diff --git a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
index f1cc7feb9df..1bd9586a2bb 100644
--- a/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
+++ b/searchlib/src/vespa/searchlib/tensor/distance_calculator.h
@@ -6,6 +6,8 @@
namespace vespalib::eval { struct Value; }
+namespace search::attribute { class IAttributeVector; }
+
namespace search::tensor {
/**
@@ -43,6 +45,15 @@ public:
double calc_with_limit(uint32_t docid, double limit) const {
return _dist_fun->calc_with_limit(_query_tensor_cells, _attr_tensor.extract_cells_ref(docid), limit);
}
+
+ /**
+ * Create a calculator for the given attribute tensor and query tensor, if possible.
+ *
+ * Throws vespalib::IllegalArgumentException if the inputs are not supported or incompatible.
+ */
+ static std::unique_ptr<DistanceCalculator> make_with_validation(const search::attribute::IAttributeVector& attr,
+ const vespalib::eval::Value& query_tensor_in);
+
};
}