summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2019-11-29 11:45:47 +0100
committerGitHub <noreply@github.com>2019-11-29 11:45:47 +0100
commit69001649f409038c48827faf1c51e02d200b54b4 (patch)
treecdcf138d207c257f94017afdfe0592444d743944
parent907e4dfc9b711cc112fe251098ea84a08ffaff98 (diff)
parentf71b23be6b9cf741d878fe70baa1d645c7f3c40e (diff)
Merge pull request #11450 from vespa-engine/arnej/template-nns-fully
template cell types also
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp4
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp2
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp154
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h44
4 files changed, 126 insertions, 78 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 6a96b7720b1..25ff459c005 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -100,7 +100,7 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) {
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIteratorFactory::createIterator(strict, tfmd, qtv, attr, dh);
+ auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh);
if (strict) {
return SimpleResult().searchStrict(*search, attr.getNumDocs());
} else {
@@ -137,7 +137,7 @@ std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) {
auto &tfmd = *(md->resolveTermField(0));
auto &attr = *(env._tensorAttr);
NearestNeighborDistanceHeap dh(2);
- auto search = NearestNeighborIteratorFactory::createIterator(strict, tfmd, qtv, attr, dh);
+ auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh);
uint32_t limit = attr.getNumDocs();
uint32_t docid = 1;
search->initRange(docid, limit);
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
index 572707323c9..6a844a6bec0 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp
@@ -30,7 +30,7 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData
fef::TermFieldMatchData &tfmd = *tfmda[0]; // always search in only one field
const vespalib::tensor::DenseTensorView &qT = *_query_tensor;
- return NearestNeighborIteratorFactory::createIterator(strict, tfmd, qT, _attr_tensor, _distance_heap);
+ return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, _distance_heap);
}
void
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
index d17ed024fce..4617bb0e374 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
@@ -2,29 +2,13 @@
#include "nearest_neighbor_iterator.h"
+using search::tensor::DenseTensorAttribute;
using vespalib::ConstArrayRef;
+using vespalib::tensor::DenseTensorView;
+using vespalib::tensor::MutableDenseTensorView;
using vespalib::tensor::TypedCells;
-namespace {
-
-struct SumSquaredDiff
-{
- template <typename LCT, typename RCT>
- static double
- call(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs)
- {
- double sum = 0.0;
- size_t sz = lhs.size();
- assert(sz == rhs.size());
- for (size_t i = 0; i < sz; ++i) {
- double diff = lhs[i] - rhs[i];
- sum += diff*diff;
- }
- return sum;
- }
-};
-
-}
+using CellType = vespalib::eval::ValueType::CellType;
namespace search::queryeval {
@@ -34,32 +18,24 @@ namespace search::queryeval {
* Keeps a heap of the K best hit distances.
* Currently always does brute-force scanning, which is very expensive.
**/
-template <bool strict>
-class NearestNeighborIterator : public SearchIterator
+template <bool strict, typename LCT, typename RCT>
+class NearestNeighborImpl : public NearestNeighborIterator
{
public:
- using DenseTensorView = vespalib::tensor::DenseTensorView;
- using DenseTensorAttribute = search::tensor::DenseTensorAttribute;
- using MutableDenseTensorView = vespalib::tensor::MutableDenseTensorView;
-
- NearestNeighborIterator(fef::TermFieldMatchData &tfmd,
- const DenseTensorView &queryTensor,
- const DenseTensorAttribute &tensorAttribute,
- NearestNeighborDistanceHeap &distanceHeap)
- : _tfmd(tfmd),
- _queryTensor(queryTensor),
- _tensorAttribute(tensorAttribute),
- _fieldTensor(_tensorAttribute.getTensorType()),
- _distanceHeap(distanceHeap),
+
+ NearestNeighborImpl(Params params_in)
+ : NearestNeighborIterator(params_in),
+ _lhs(params().queryTensor.cellsRef().typify<LCT>()),
+ _fieldTensor(params().tensorAttribute.getTensorType()),
_lastScore(0.0)
{
- assert(_fieldTensor.fast_type() == _queryTensor.fast_type());
+ assert(_fieldTensor.fast_type() == params().queryTensor.fast_type());
}
- ~NearestNeighborIterator();
+ ~NearestNeighborImpl();
void doSeek(uint32_t docId) override {
- double distanceLimit = _distanceHeap.distanceLimit();
+ double distanceLimit = params().distanceHeap.distanceLimit();
while (__builtin_expect((docId < getEndId()), true)) {
double d = computeDistance(docId);
if (d <= distanceLimit) {
@@ -77,53 +53,97 @@ public:
}
void doUnpack(uint32_t docId) override {
- _tfmd.setRawScore(docId, sqrt(_lastScore));
- _distanceHeap.used(_lastScore);
+ params().tfmd.setRawScore(docId, sqrt(_lastScore));
+ params().distanceHeap.used(_lastScore);
}
Trinary is_strict() const override { return strict ? Trinary::True : Trinary::False ; }
private:
- double computeDistance(uint32_t docId);
-
- fef::TermFieldMatchData &_tfmd;
- const DenseTensorView &_queryTensor;
- const DenseTensorAttribute &_tensorAttribute;
- MutableDenseTensorView _fieldTensor;
- NearestNeighborDistanceHeap &_distanceHeap;
- double _lastScore;
+ static double computeSum(ConstArrayRef<LCT> lhs, ConstArrayRef<RCT> rhs) {
+ double sum = 0.0;
+ size_t sz = lhs.size();
+ assert(sz == rhs.size());
+ for (size_t i = 0; i < sz; ++i) {
+ double diff = lhs[i] - rhs[i];
+ sum += diff*diff;
+ }
+ return sum;
+ }
+
+ double computeDistance(uint32_t docId) {
+ params().tensorAttribute.getTensor(docId, _fieldTensor);
+ return computeSum(_lhs, _fieldTensor.cellsRef().typify<RCT>());
+ }
+
+ ConstArrayRef<LCT> _lhs;
+ MutableDenseTensorView _fieldTensor;
+ double _lastScore;
};
-template <bool strict>
-NearestNeighborIterator<strict>::~NearestNeighborIterator() = default;
+template <bool strict, typename LCT, typename RCT>
+NearestNeighborImpl<strict, LCT, RCT>::~NearestNeighborImpl() = default;
-template <bool strict>
-double
-NearestNeighborIterator<strict>::computeDistance(uint32_t docId)
+namespace {
+
+template<bool strict, typename LCT, typename RCT>
+std::unique_ptr<NearestNeighborIterator>
+create_impl(const NearestNeighborIterator::Params &params)
{
- _tensorAttribute.getTensor(docId, _fieldTensor);
- TypedCells lhsCells = _queryTensor.cellsRef();
- TypedCells rhsCells = _fieldTensor.cellsRef();
- return vespalib::tensor::dispatch_2<SumSquaredDiff>(lhsCells, rhsCells);
+ using NNI = NearestNeighborImpl<strict, LCT, RCT>;
+ return std::make_unique<NNI>(params);
}
+template<bool strict, typename LCT>
+std::unique_ptr<NearestNeighborIterator>
+resolve_RCT(const NearestNeighborIterator::Params &params)
+{
+ CellType ct = params.tensorAttribute.getTensorType().cell_type();
+ if (ct == CellType::FLOAT) {
+ return create_impl<strict, LCT, float>(params);
+ }
+ if (ct == CellType::DOUBLE) {
+ return create_impl<strict, LCT, double>(params);
+ }
+ abort();
+}
-std::unique_ptr<SearchIterator>
-NearestNeighborIteratorFactory::createIterator(
+template<bool strict>
+std::unique_ptr<NearestNeighborIterator>
+resolve_LCT_RCT(const NearestNeighborIterator::Params &params)
+{
+ CellType ct = params.queryTensor.fast_type().cell_type();
+ if (ct == CellType::FLOAT) {
+ return resolve_RCT<strict, float>(params);
+ }
+ if (ct == CellType::DOUBLE) {
+ return resolve_RCT<strict, double>(params);
+ }
+ abort();
+}
+
+std::unique_ptr<NearestNeighborIterator>
+resolve_strict_LCT_RCT(bool strict, const NearestNeighborIterator::Params &params)
+{
+ if (strict) {
+ return resolve_LCT_RCT<true>(params);
+ } else {
+ return resolve_LCT_RCT<false>(params);
+ }
+}
+
+} // namespace <unnamed>
+
+std::unique_ptr<NearestNeighborIterator>
+NearestNeighborIterator::create(
bool strict,
fef::TermFieldMatchData &tfmd,
const vespalib::tensor::DenseTensorView &queryTensor,
const search::tensor::DenseTensorAttribute &tensorAttribute,
NearestNeighborDistanceHeap &distanceHeap)
{
- using StrictNN = NearestNeighborIterator<true>;
- using UnStrict = NearestNeighborIterator<false>;
-
- if (strict) {
- return std::make_unique<StrictNN>(tfmd, queryTensor, tensorAttribute, distanceHeap);
- } else {
- return std::make_unique<UnStrict>(tfmd, queryTensor, tensorAttribute, distanceHeap);
- }
+ Params params(tfmd, queryTensor, tensorAttribute, distanceHeap);
+ return resolve_strict_LCT_RCT(strict, params);
}
} // namespace
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
index 140e92ad37d..34eb547fe39 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h
@@ -13,15 +13,43 @@
namespace search::queryeval {
-class NearestNeighborIteratorFactory
+class NearestNeighborIterator : public SearchIterator
{
public:
- static std::unique_ptr<SearchIterator> createIterator(
- bool strict,
- fef::TermFieldMatchData &tfmd,
- const vespalib::tensor::DenseTensorView &queryTensor,
- const search::tensor::DenseTensorAttribute &tensorAttribute,
- NearestNeighborDistanceHeap &distanceHeap);
+ using DenseTensorAttribute = search::tensor::DenseTensorAttribute;
+ using DenseTensorView = vespalib::tensor::DenseTensorView;
+
+ struct Params {
+ fef::TermFieldMatchData &tfmd;
+ const DenseTensorView &queryTensor;
+ const DenseTensorAttribute &tensorAttribute;
+ NearestNeighborDistanceHeap &distanceHeap;
+
+ Params(fef::TermFieldMatchData &tfmd_in,
+ const DenseTensorView &queryTensor_in,
+ const DenseTensorAttribute &tensorAttribute_in,
+ NearestNeighborDistanceHeap &distanceHeap_in)
+ : tfmd(tfmd_in),
+ queryTensor(queryTensor_in),
+ tensorAttribute(tensorAttribute_in),
+ distanceHeap(distanceHeap_in)
+ {}
+ };
+
+ NearestNeighborIterator(Params params_in)
+ : _params(params_in)
+ {}
+
+ static std::unique_ptr<NearestNeighborIterator> create(
+ bool strict,
+ fef::TermFieldMatchData &tfmd,
+ const vespalib::tensor::DenseTensorView &queryTensor,
+ const search::tensor::DenseTensorAttribute &tensorAttribute,
+ NearestNeighborDistanceHeap &distanceHeap);
+
+ const Params& params() const { return _params; }
+private:
+ Params _params;
};
-}
+} // namespace