diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2019-11-29 11:45:47 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-29 11:45:47 +0100 |
commit | 69001649f409038c48827faf1c51e02d200b54b4 (patch) | |
tree | cdcf138d207c257f94017afdfe0592444d743944 | |
parent | 907e4dfc9b711cc112fe251098ea84a08ffaff98 (diff) | |
parent | f71b23be6b9cf741d878fe70baa1d645c7f3c40e (diff) |
Merge pull request #11450 from vespa-engine/arnej/template-nns-fully
template cell types also
4 files changed, 126 insertions, 78 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 6a96b7720b1..25ff459c005 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -100,7 +100,7 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) { auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIteratorFactory::createIterator(strict, tfmd, qtv, attr, dh); + auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh); if (strict) { return SimpleResult().searchStrict(*search, attr.getNumDocs()); } else { @@ -137,7 +137,7 @@ std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) { auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); - auto search = NearestNeighborIteratorFactory::createIterator(strict, tfmd, qtv, attr, dh); + auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh); uint32_t limit = attr.getNumDocs(); uint32_t docid = 1; search->initRange(docid, limit); diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index 572707323c9..6a844a6bec0 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -30,7 +30,7 @@ NearestNeighborBlueprint::createLeafSearch(const search::fef::TermFieldMatchData fef::TermFieldMatchData &tfmd = *tfmda[0]; // always search in only one field const vespalib::tensor::DenseTensorView &qT = *_query_tensor; - return NearestNeighborIteratorFactory::createIterator(strict, tfmd, qT, _attr_tensor, _distance_heap); + return NearestNeighborIterator::create(strict, tfmd, qT, _attr_tensor, _distance_heap); } void diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp index d17ed024fce..4617bb0e374 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp @@ -2,29 +2,13 @@ #include "nearest_neighbor_iterator.h" +using search::tensor::DenseTensorAttribute; using vespalib::ConstArrayRef; +using vespalib::tensor::DenseTensorView; +using vespalib::tensor::MutableDenseTensorView; using vespalib::tensor::TypedCells; -namespace { - -struct SumSquaredDiff -{ - template <typename LCT, typename RCT> - static double - call(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs) - { - double sum = 0.0; - size_t sz = lhs.size(); - assert(sz == rhs.size()); - for (size_t i = 0; i < sz; ++i) { - double diff = lhs[i] - rhs[i]; - sum += diff*diff; - } - return sum; - } -}; - -} +using CellType = vespalib::eval::ValueType::CellType; namespace search::queryeval { @@ -34,32 +18,24 @@ namespace search::queryeval { * Keeps a heap of the K best hit distances. * Currently always does brute-force scanning, which is very expensive. **/ -template <bool strict> -class NearestNeighborIterator : public SearchIterator +template <bool strict, typename LCT, typename RCT> +class NearestNeighborImpl : public NearestNeighborIterator { public: - using DenseTensorView = vespalib::tensor::DenseTensorView; - using DenseTensorAttribute = search::tensor::DenseTensorAttribute; - using MutableDenseTensorView = vespalib::tensor::MutableDenseTensorView; - - NearestNeighborIterator(fef::TermFieldMatchData &tfmd, - const DenseTensorView &queryTensor, - const DenseTensorAttribute &tensorAttribute, - NearestNeighborDistanceHeap &distanceHeap) - : _tfmd(tfmd), - _queryTensor(queryTensor), - _tensorAttribute(tensorAttribute), - _fieldTensor(_tensorAttribute.getTensorType()), - _distanceHeap(distanceHeap), + + NearestNeighborImpl(Params params_in) + : NearestNeighborIterator(params_in), + _lhs(params().queryTensor.cellsRef().typify<LCT>()), + _fieldTensor(params().tensorAttribute.getTensorType()), _lastScore(0.0) { - assert(_fieldTensor.fast_type() == _queryTensor.fast_type()); + assert(_fieldTensor.fast_type() == params().queryTensor.fast_type()); } - ~NearestNeighborIterator(); + ~NearestNeighborImpl(); void doSeek(uint32_t docId) override { - double distanceLimit = _distanceHeap.distanceLimit(); + double distanceLimit = params().distanceHeap.distanceLimit(); while (__builtin_expect((docId < getEndId()), true)) { double d = computeDistance(docId); if (d <= distanceLimit) { @@ -77,53 +53,97 @@ public: } void doUnpack(uint32_t docId) override { - _tfmd.setRawScore(docId, sqrt(_lastScore)); - _distanceHeap.used(_lastScore); + params().tfmd.setRawScore(docId, sqrt(_lastScore)); + params().distanceHeap.used(_lastScore); } Trinary is_strict() const override { return strict ? Trinary::True : Trinary::False ; } private: - double computeDistance(uint32_t docId); - - fef::TermFieldMatchData &_tfmd; - const DenseTensorView &_queryTensor; - const DenseTensorAttribute &_tensorAttribute; - MutableDenseTensorView _fieldTensor; - NearestNeighborDistanceHeap &_distanceHeap; - double _lastScore; + static double computeSum(ConstArrayRef<LCT> lhs, ConstArrayRef<RCT> rhs) { + double sum = 0.0; + size_t sz = lhs.size(); + assert(sz == rhs.size()); + for (size_t i = 0; i < sz; ++i) { + double diff = lhs[i] - rhs[i]; + sum += diff*diff; + } + return sum; + } + + double computeDistance(uint32_t docId) { + params().tensorAttribute.getTensor(docId, _fieldTensor); + return computeSum(_lhs, _fieldTensor.cellsRef().typify<RCT>()); + } + + ConstArrayRef<LCT> _lhs; + MutableDenseTensorView _fieldTensor; + double _lastScore; }; -template <bool strict> -NearestNeighborIterator<strict>::~NearestNeighborIterator() = default; +template <bool strict, typename LCT, typename RCT> +NearestNeighborImpl<strict, LCT, RCT>::~NearestNeighborImpl() = default; -template <bool strict> -double -NearestNeighborIterator<strict>::computeDistance(uint32_t docId) +namespace { + +template<bool strict, typename LCT, typename RCT> +std::unique_ptr<NearestNeighborIterator> +create_impl(const NearestNeighborIterator::Params ¶ms) { - _tensorAttribute.getTensor(docId, _fieldTensor); - TypedCells lhsCells = _queryTensor.cellsRef(); - TypedCells rhsCells = _fieldTensor.cellsRef(); - return vespalib::tensor::dispatch_2<SumSquaredDiff>(lhsCells, rhsCells); + using NNI = NearestNeighborImpl<strict, LCT, RCT>; + return std::make_unique<NNI>(params); } +template<bool strict, typename LCT> +std::unique_ptr<NearestNeighborIterator> +resolve_RCT(const NearestNeighborIterator::Params ¶ms) +{ + CellType ct = params.tensorAttribute.getTensorType().cell_type(); + if (ct == CellType::FLOAT) { + return create_impl<strict, LCT, float>(params); + } + if (ct == CellType::DOUBLE) { + return create_impl<strict, LCT, double>(params); + } + abort(); +} -std::unique_ptr<SearchIterator> -NearestNeighborIteratorFactory::createIterator( +template<bool strict> +std::unique_ptr<NearestNeighborIterator> +resolve_LCT_RCT(const NearestNeighborIterator::Params ¶ms) +{ + CellType ct = params.queryTensor.fast_type().cell_type(); + if (ct == CellType::FLOAT) { + return resolve_RCT<strict, float>(params); + } + if (ct == CellType::DOUBLE) { + return resolve_RCT<strict, double>(params); + } + abort(); +} + +std::unique_ptr<NearestNeighborIterator> +resolve_strict_LCT_RCT(bool strict, const NearestNeighborIterator::Params ¶ms) +{ + if (strict) { + return resolve_LCT_RCT<true>(params); + } else { + return resolve_LCT_RCT<false>(params); + } +} + +} // namespace <unnamed> + +std::unique_ptr<NearestNeighborIterator> +NearestNeighborIterator::create( bool strict, fef::TermFieldMatchData &tfmd, const vespalib::tensor::DenseTensorView &queryTensor, const search::tensor::DenseTensorAttribute &tensorAttribute, NearestNeighborDistanceHeap &distanceHeap) { - using StrictNN = NearestNeighborIterator<true>; - using UnStrict = NearestNeighborIterator<false>; - - if (strict) { - return std::make_unique<StrictNN>(tfmd, queryTensor, tensorAttribute, distanceHeap); - } else { - return std::make_unique<UnStrict>(tfmd, queryTensor, tensorAttribute, distanceHeap); - } + Params params(tfmd, queryTensor, tensorAttribute, distanceHeap); + return resolve_strict_LCT_RCT(strict, params); } } // namespace diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h index 140e92ad37d..34eb547fe39 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.h @@ -13,15 +13,43 @@ namespace search::queryeval { -class NearestNeighborIteratorFactory +class NearestNeighborIterator : public SearchIterator { public: - static std::unique_ptr<SearchIterator> createIterator( - bool strict, - fef::TermFieldMatchData &tfmd, - const vespalib::tensor::DenseTensorView &queryTensor, - const search::tensor::DenseTensorAttribute &tensorAttribute, - NearestNeighborDistanceHeap &distanceHeap); + using DenseTensorAttribute = search::tensor::DenseTensorAttribute; + using DenseTensorView = vespalib::tensor::DenseTensorView; + + struct Params { + fef::TermFieldMatchData &tfmd; + const DenseTensorView &queryTensor; + const DenseTensorAttribute &tensorAttribute; + NearestNeighborDistanceHeap &distanceHeap; + + Params(fef::TermFieldMatchData &tfmd_in, + const DenseTensorView &queryTensor_in, + const DenseTensorAttribute &tensorAttribute_in, + NearestNeighborDistanceHeap &distanceHeap_in) + : tfmd(tfmd_in), + queryTensor(queryTensor_in), + tensorAttribute(tensorAttribute_in), + distanceHeap(distanceHeap_in) + {} + }; + + NearestNeighborIterator(Params params_in) + : _params(params_in) + {} + + static std::unique_ptr<NearestNeighborIterator> create( + bool strict, + fef::TermFieldMatchData &tfmd, + const vespalib::tensor::DenseTensorView &queryTensor, + const search::tensor::DenseTensorAttribute &tensorAttribute, + NearestNeighborDistanceHeap &distanceHeap); + + const Params& params() const { return _params; } +private: + Params _params; }; -} +} // namespace |