diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-05-14 06:29:58 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-05-14 06:29:58 +0000 |
commit | 409693b97750d60e3be4b9a216bbc5c9d9def4ab (patch) | |
tree | b96b3ae5c2571f4c048ba0f3fcb8dea81f166e3a /searchlib | |
parent | 52bbe79b9d91288a83876c79b870f1e90f022fc1 (diff) |
refactor tests
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp | 151 |
1 files changed, 79 insertions, 72 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index c125cb265ef..92ba11d2aa0 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -33,6 +33,8 @@ using search::AttributeGuard; using search::AttributeVector; using search::attribute::DistanceMetric; using search::attribute::HnswIndexParams; +using search::queryeval::NearestNeighborBlueprint; +using search::queryeval::GlobalFilter; using search::tensor::DefaultNearestNeighborIndexFactory; using search::tensor::DenseTensorAttribute; using search::tensor::DocVectorAccess; @@ -47,6 +49,7 @@ using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; using vespalib::tensor::DefaultTensorEngine; using vespalib::tensor::DenseTensor; +using vespalib::tensor::DenseTensorView; using vespalib::tensor::Tensor; using DoubleVector = std::vector<double>; @@ -76,14 +79,6 @@ Tensor::UP createTensor(const TensorSpec &spec) { return Tensor::UP(tensor); } -std::unique_ptr<DenseTensor<double>> createDenseTensor(const TensorSpec &spec) { - auto value = DefaultTensorEngine::ref().from_spec(spec); - DenseTensor<double> *tensor = dynamic_cast<DenseTensor<double> *>(value.get()); - ASSERT_TRUE(tensor != nullptr); - value.release(); - return std::unique_ptr<DenseTensor<double>>(tensor); -} - TensorSpec vec_2d(double x0, double x1) { @@ -667,69 +662,6 @@ public: DenseTensorAttributeMockIndex() : Fixture(vec_2d_spec, true, true, true) {} }; -TEST_F("blueprint takes global filter into account", DenseTensorAttributeMockIndex) -{ - using vespalib::tensor::DenseTensorView; - using search::queryeval::NearestNeighborBlueprint; - using search::queryeval::GlobalFilter; - - f.set_tensor(1, vec_2d(1, 1)); - f.set_tensor(2, vec_2d(2, 2)); - f.set_tensor(3, vec_2d(3, 3)); - f.set_tensor(4, vec_2d(4, 4)); - f.set_tensor(5, vec_2d(5, 5)); - f.set_tensor(6, vec_2d(6, 6)); - f.set_tensor(7, vec_2d(7, 7)); - f.set_tensor(8, vec_2d(8, 8)); - f.set_tensor(9, vec_2d(9, 9)); - f.set_tensor(10, vec_2d(0, 0)); - - search::queryeval::FieldSpec field("foo", 0, 0); - auto bp = std::make_unique<NearestNeighborBlueprint>(field, - f.as_dense_tensor(), - createDenseTensor(vec_2d(17, 42)), - 3, true, 5); - EXPECT_EQUAL(11u, bp->getState().estimate().estHits); - EXPECT_TRUE(bp->may_approximate()); - auto empty_filter = GlobalFilter::create(); - bp->set_global_filter(*empty_filter); - EXPECT_EQUAL(3u, bp->getState().estimate().estHits); - EXPECT_TRUE(bp->may_approximate()); - - bp = std::make_unique<NearestNeighborBlueprint>(field, - f.as_dense_tensor(), - createDenseTensor(vec_2d(17, 42)), - 3, true, 5); - EXPECT_EQUAL(11u, bp->getState().estimate().estHits); - EXPECT_TRUE(bp->may_approximate()); - auto filter = search::BitVector::create(11); - filter->setBit(3); - filter->invalidateCachedCount(); - auto strong_filter = GlobalFilter::create(std::move(filter)); - bp->set_global_filter(*strong_filter); - EXPECT_EQUAL(11u, bp->getState().estimate().estHits); - EXPECT_FALSE(bp->may_approximate()); - - bp = std::make_unique<NearestNeighborBlueprint>(field, - f.as_dense_tensor(), - createDenseTensor(vec_2d(17, 42)), - 3, true, 5); - EXPECT_EQUAL(11u, bp->getState().estimate().estHits); - EXPECT_TRUE(bp->may_approximate()); - filter = search::BitVector::create(11); - filter->setBit(1); - filter->setBit(3); - filter->setBit(5); - filter->setBit(7); - filter->setBit(9); - filter->setBit(11); - filter->invalidateCachedCount(); - auto weak_filter = GlobalFilter::create(std::move(filter)); - bp->set_global_filter(*weak_filter); - EXPECT_EQUAL(3u, bp->getState().estimate().estHits); - EXPECT_TRUE(bp->may_approximate()); -} - TEST_F("setTensor() updates nearest neighbor index", DenseTensorAttributeMockIndex) { auto& index = f.mock_index(); @@ -859,5 +791,80 @@ TEST_F("Nearest neighbor index type is added to attribute file header", DenseTen EXPECT_EQUAL("hnsw", header.getTag("nearest_neighbor_index").asString()); } -TEST_MAIN() { TEST_RUN_ALL(); } +class NearestNeighborBlueprintFixture : public DenseTensorAttributeMockIndex { +public: + using QueryTensor = DenseTensor<double>; + + NearestNeighborBlueprintFixture() { + set_tensor(1, vec_2d(1, 1)); + set_tensor(2, vec_2d(2, 2)); + set_tensor(3, vec_2d(3, 3)); + set_tensor(4, vec_2d(4, 4)); + set_tensor(5, vec_2d(5, 5)); + set_tensor(6, vec_2d(6, 6)); + set_tensor(7, vec_2d(7, 7)); + set_tensor(8, vec_2d(8, 8)); + set_tensor(9, vec_2d(9, 9)); + set_tensor(10, vec_2d(0, 0)); + } + + std::unique_ptr<QueryTensor> createDenseTensor(const TensorSpec &spec) { + auto value = DefaultTensorEngine::ref().from_spec(spec); + QueryTensor *tensor = dynamic_cast<QueryTensor *>(value.get()); + ASSERT_TRUE(tensor != nullptr); + value.release(); + return std::unique_ptr<QueryTensor>(tensor); + } + + std::unique_ptr<NearestNeighborBlueprint> make_blueprint() { + search::queryeval::FieldSpec field("foo", 0, 0); + auto bp = std::make_unique<NearestNeighborBlueprint>( + field, + as_dense_tensor(), + createDenseTensor(vec_2d(17, 42)), + 3, true, 5); + EXPECT_EQUAL(11u, bp->getState().estimate().estHits); + EXPECT_TRUE(bp->may_approximate()); + return bp; + } +}; +TEST_F("NN blueprint handles empty filter", NearestNeighborBlueprintFixture) +{ + auto bp = f.make_blueprint(); + auto empty_filter = GlobalFilter::create(); + bp->set_global_filter(*empty_filter); + EXPECT_EQUAL(3u, bp->getState().estimate().estHits); + EXPECT_TRUE(bp->may_approximate()); +} + +TEST_F("NN blueprint handles strong filter", NearestNeighborBlueprintFixture) +{ + auto bp = f.make_blueprint(); + auto filter = search::BitVector::create(11); + filter->setBit(3); + filter->invalidateCachedCount(); + auto strong_filter = GlobalFilter::create(std::move(filter)); + bp->set_global_filter(*strong_filter); + EXPECT_EQUAL(11u, bp->getState().estimate().estHits); + EXPECT_FALSE(bp->may_approximate()); +} + +TEST_F("NN blueprint handles weak filter", NearestNeighborBlueprintFixture) +{ + auto bp = f.make_blueprint(); + auto filter = search::BitVector::create(11); + filter->setBit(1); + filter->setBit(3); + filter->setBit(5); + filter->setBit(7); + filter->setBit(9); + filter->setBit(11); + filter->invalidateCachedCount(); + auto weak_filter = GlobalFilter::create(std::move(filter)); + bp->set_global_filter(*weak_filter); + EXPECT_EQUAL(3u, bp->getState().estimate().estHits); + EXPECT_TRUE(bp->may_approximate()); +} + +TEST_MAIN() { TEST_RUN_ALL(); } |