aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-03-23 14:13:42 +0000
committerArne Juul <arnej@verizonmedia.com>2020-03-23 14:30:36 +0000
commit1d3fdf021dab6ab47736feea0cd439e79360acbf (patch)
tree779d9ea9c3b8a121cbee2b55b7fb7507b680bf5b /searchlib
parent63368ac60b322be2830868bd76570184c39938ed (diff)
use distance function from index if available
* convert query tensor to same cell type as attribute * use DistanceFunction to calculate abstract distances for NNS * use DistanceFunction to convert abstract distances to rawscore * if no index is available, use a fallback DistanceFunction
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp55
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h3
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp29
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h11
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp19
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h4
6 files changed, 87 insertions, 34 deletions
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index d3b2925e075..4035da5f435 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -6,10 +6,47 @@
#include "nns_index_iterator.h"
#include <vespa/searchlib/fef/termfieldmatchdataarray.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
+#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
+#include <vespa/searchlib/tensor/distance_function_factory.h>
+
+using vespalib::tensor::DenseTensorView;
+using vespalib::tensor::DenseTensor;
namespace search::queryeval {
+namespace {
+
+template<typename LCT, typename RCT>
+void
+convert_cells(std::unique_ptr<DenseTensorView> &original, vespalib::eval::ValueType want_type)
+{
+ auto old_cells = original->cellsRef().typify<LCT>();
+ std::vector<RCT> new_cells;
+ new_cells.reserve(old_cells.size());
+ for (LCT value : old_cells) {
+ RCT conv = value;
+ new_cells.push_back(conv);
+ }
+ original = std::make_unique<DenseTensor<RCT>>(want_type, std::move(new_cells));
+}
+
+template<>
+void
+convert_cells<float,float>(std::unique_ptr<DenseTensorView> &, vespalib::eval::ValueType) {}
+
+template<>
+void
+convert_cells<double,double>(std::unique_ptr<DenseTensorView> &, vespalib::eval::ValueType) {}
+
+struct ConvertCellsSelector
+{
+ template <typename LCT, typename RCT>
+ static auto get_fun() { return convert_cells<LCT, RCT>; }
+};
+
+} // namespace <unnamed>
+
NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field,
const tensor::DenseTensorAttribute& attr_tensor,
std::unique_ptr<vespalib::tensor::DenseTensorView> query_tensor,
@@ -20,11 +57,23 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f
_target_num_hits(target_num_hits),
_approximate(approximate),
_explore_additional_hits(explore_additional_hits),
+ _fallback_dist_fun(),
_distance_heap(target_num_hits),
_found_hits()
{
+ auto lct = _query_tensor->cellsRef().type;
+ auto rct = _attr_tensor.getTensorType().cell_type();
+ auto fixup_fun = vespalib::tensor::select_2<ConvertCellsSelector>(lct, rct);
+ fixup_fun(_query_tensor, _attr_tensor.getTensorType());
+ auto def_dm = search::attribute::DistanceMetric::Euclidean;
+ _fallback_dist_fun = search::tensor::make_distance_function(def_dm, rct);
+ _dist_fun = _fallback_dist_fun.get();
+ auto nns_index = _attr_tensor.nearest_neighbor_index();
+ if (nns_index) {
+ _dist_fun = nns_index->distance_function();
+ }
uint32_t est_hits = _attr_tensor.getNumDocs();
- if (_attr_tensor.nearest_neighbor_index()) {
+ if (_approximate && nns_index) {
est_hits = std::min(target_num_hits, est_hits);
}
setEstimate(HitEstimate(est_hits, false));
@@ -61,10 +110,10 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData
assert(tfmda.size() == 1);
fef::TermFieldMatchData &tfmd = *tfmda[0]; // always search in only one field
if (strict && ! _found_hits.empty()) {
- return NnsIndexIterator::create(tfmd, _found_hits);
+ return NnsIndexIterator::create(tfmd, _found_hits, _dist_fun);
}
const vespalib::tensor::DenseTensorView &qT = *_query_tensor;
- return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, _distance_heap);
+ return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, _distance_heap, _dist_fun);
}
void
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
index 39165b066be..c6b6f3f7fae 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h
@@ -3,6 +3,7 @@
#include "blueprint.h"
#include "nearest_neighbor_distance_heap.h"
+#include <vespa/searchlib/tensor/distance_function.h>
#include <vespa/searchlib/tensor/nearest_neighbor_index.h>
namespace vespalib::tensor { class DenseTensorView; }
@@ -23,6 +24,8 @@ private:
uint32_t _target_num_hits;
bool _approximate;
uint32_t _explore_additional_hits;
+ search::tensor::DistanceFunction::UP _fallback_dist_fun;
+ search::tensor::DistanceFunction *_dist_fun;
mutable NearestNeighborDistanceHeap _distance_heap;
std::vector<search::tensor::NearestNeighborIndex::Neighbor> _found_hits;
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
index c20d539d4cd..cd8fd76e988 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
@@ -18,7 +18,7 @@ bool
is_compatible(const vespalib::eval::ValueType& lhs,
const vespalib::eval::ValueType& rhs)
{
- return (lhs.dimensions() == rhs.dimensions());
+ return (lhs == rhs);
}
}
@@ -36,7 +36,7 @@ public:
NearestNeighborImpl(Params params_in)
: NearestNeighborIterator(params_in),
- _lhs(params().queryTensor.cellsRef().template typify<LCT>()),
+ _lhs(params().queryTensor.cellsRef()),
_fieldTensor(params().tensorAttribute.getTensorType()),
_lastScore(0.0)
{
@@ -64,8 +64,7 @@ public:
}
void doUnpack(uint32_t docId) override {
- double d = sqrt(_lastScore);
- double score = 1.0 / (1.0 + d);
+ double score = params().distanceFunction->to_rawscore(_lastScore);
params().tfmd.setRawScore(docId, score);
params().distanceHeap.used(_lastScore);
}
@@ -73,23 +72,13 @@ public:
Trinary is_strict() const override { return strict ? Trinary::True : Trinary::False ; }
private:
- static double computeSum(ConstArrayRef<LCT> lhs, ConstArrayRef<RCT> rhs, double limit) {
- double sum = 0.0;
- size_t sz = lhs.size();
- assert(sz == rhs.size());
- for (size_t i = 0; i < sz && sum <= limit; ++i) {
- double diff = lhs[i] - rhs[i];
- sum += diff*diff;
- }
- return sum;
- }
-
double computeDistance(uint32_t docId, double limit) {
params().tensorAttribute.getTensor(docId, _fieldTensor);
- return computeSum(_lhs, _fieldTensor.cellsRef().template typify<RCT>(), limit);
+ auto rhs = _fieldTensor.cellsRef();
+ return params().distanceFunction->calc_with_limit(_lhs, rhs, limit);
}
- ConstArrayRef<LCT> _lhs;
+ TypedCells _lhs;
MutableDenseTensorView _fieldTensor;
double _lastScore;
};
@@ -141,9 +130,11 @@ NearestNeighborIterator::create(
fef::TermFieldMatchData &tfmd,
const vespalib::tensor::DenseTensorView &queryTensor,
const search::tensor::DenseTensorAttribute &tensorAttribute,
- NearestNeighborDistanceHeap &distanceHeap)
+ NearestNeighborDistanceHeap &distanceHeap,
+ search::tensor::DistanceFunction *dist_fun)
+
{
- Params params(tfmd, queryTensor, tensorAttribute, distanceHeap);
+ Params params(tfmd, queryTensor, tensorAttribute, distanceHeap, dist_fun);
return resolve_strict_LCT_RCT(strict, params);
}
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
index 34eb547fe39..fae1d4e752f 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
@@ -8,6 +8,7 @@
#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h>
#include <vespa/searchlib/fef/termfieldmatchdata.h>
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
+#include <vespa/searchlib/tensor/distance_function.h>
#include <vespa/vespalib/util/priority_queue.h>
#include <cmath>
@@ -24,15 +25,18 @@ public:
const DenseTensorView &queryTensor;
const DenseTensorAttribute &tensorAttribute;
NearestNeighborDistanceHeap &distanceHeap;
+ search::tensor::DistanceFunction *distanceFunction;
Params(fef::TermFieldMatchData &tfmd_in,
const DenseTensorView &queryTensor_in,
const DenseTensorAttribute &tensorAttribute_in,
- NearestNeighborDistanceHeap &distanceHeap_in)
+ NearestNeighborDistanceHeap &distanceHeap_in,
+ search::tensor::DistanceFunction *distanceFunction_in)
: tfmd(tfmd_in),
queryTensor(queryTensor_in),
tensorAttribute(tensorAttribute_in),
- distanceHeap(distanceHeap_in)
+ distanceHeap(distanceHeap_in),
+ distanceFunction(distanceFunction_in)
{}
};
@@ -45,7 +49,8 @@ public:
fef::TermFieldMatchData &tfmd,
const vespalib::tensor::DenseTensorView &queryTensor,
const search::tensor::DenseTensorAttribute &tensorAttribute,
- NearestNeighborDistanceHeap &distanceHeap);
+ NearestNeighborDistanceHeap &distanceHeap,
+ search::tensor::DistanceFunction *dist_fun);
const Params& params() const { return _params; }
private:
diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
index 48df28e7f3e..b047f5b68b7 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.cpp
@@ -18,15 +18,18 @@ class NeighborVectorIterator : public NnsIndexIterator
private:
fef::TermFieldMatchData &_tfmd;
const std::vector<Neighbor> &_hits;
+ search::tensor::DistanceFunction * const _dist_fun;
uint32_t _idx;
- double _last_sq_dist;
+ double _last_abstract_dist;
public:
NeighborVectorIterator(fef::TermFieldMatchData &tfmd,
- const std::vector<Neighbor> &hits)
+ const std::vector<Neighbor> &hits,
+ search::tensor::DistanceFunction *dist_fun)
: _tfmd(tfmd),
_hits(hits),
+ _dist_fun(dist_fun),
_idx(0),
- _last_sq_dist(0.0)
+ _last_abstract_dist(0.0)
{}
void initRange(uint32_t begin_id, uint32_t end_id) override {
@@ -41,7 +44,7 @@ public:
++_idx;
} else if (hit_id < getEndId()) {
setDocId(hit_id);
- _last_sq_dist = _hits[_idx].distance;
+ _last_abstract_dist = _hits[_idx].distance;
return;
} else {
_idx = _hits.size();
@@ -51,8 +54,7 @@ public:
}
void doUnpack(uint32_t docId) override {
- double d = sqrt(_last_sq_dist);
- double score = 1.0 / (1.0 + d);
+ double score = _dist_fun->to_rawscore(_last_abstract_dist);
_tfmd.setRawScore(docId, score);
}
@@ -62,9 +64,10 @@ public:
std::unique_ptr<NnsIndexIterator>
NnsIndexIterator::create(
fef::TermFieldMatchData &tfmd,
- const std::vector<Neighbor> &hits)
+ const std::vector<Neighbor> &hits,
+ search::tensor::DistanceFunction *dist_fun)
{
- return std::make_unique<NeighborVectorIterator>(tfmd, hits);
+ return std::make_unique<NeighborVectorIterator>(tfmd, hits, dist_fun);
}
} // namespace
diff --git a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
index 9ffd0df94eb..9b173f86229 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nns_index_iterator.h
@@ -4,6 +4,7 @@
#include "searchiterator.h"
#include <vespa/searchlib/fef/termfieldmatchdata.h>
+#include <vespa/searchlib/tensor/distance_function.h>
#include <vespa/searchlib/tensor/nearest_neighbor_index.h>
namespace search::queryeval {
@@ -14,7 +15,8 @@ public:
using Hit = search::tensor::NearestNeighborIndex::Neighbor;
static std::unique_ptr<NnsIndexIterator> create(
fef::TermFieldMatchData &tfmd,
- const std::vector<Hit> &hits);
+ const std::vector<Hit> &hits,
+ search::tensor::DistanceFunction *dist_fun);
};
} // namespace