aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-05-14 06:29:58 +0000
committerArne Juul <arnej@verizonmedia.com>2020-05-14 06:29:58 +0000
commit409693b97750d60e3be4b9a216bbc5c9d9def4ab (patch)
treeb96b3ae5c2571f4c048ba0f3fcb8dea81f166e3a /searchlib
parent52bbe79b9d91288a83876c79b870f1e90f022fc1 (diff)
refactor tests
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp151
1 files changed, 79 insertions, 72 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index c125cb265ef..92ba11d2aa0 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -33,6 +33,8 @@ using search::AttributeGuard;
using search::AttributeVector;
using search::attribute::DistanceMetric;
using search::attribute::HnswIndexParams;
+using search::queryeval::NearestNeighborBlueprint;
+using search::queryeval::GlobalFilter;
using search::tensor::DefaultNearestNeighborIndexFactory;
using search::tensor::DenseTensorAttribute;
using search::tensor::DocVectorAccess;
@@ -47,6 +49,7 @@ using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
using vespalib::tensor::DefaultTensorEngine;
using vespalib::tensor::DenseTensor;
+using vespalib::tensor::DenseTensorView;
using vespalib::tensor::Tensor;
using DoubleVector = std::vector<double>;
@@ -76,14 +79,6 @@ Tensor::UP createTensor(const TensorSpec &spec) {
return Tensor::UP(tensor);
}
-std::unique_ptr<DenseTensor<double>> createDenseTensor(const TensorSpec &spec) {
- auto value = DefaultTensorEngine::ref().from_spec(spec);
- DenseTensor<double> *tensor = dynamic_cast<DenseTensor<double> *>(value.get());
- ASSERT_TRUE(tensor != nullptr);
- value.release();
- return std::unique_ptr<DenseTensor<double>>(tensor);
-}
-
TensorSpec
vec_2d(double x0, double x1)
{
@@ -667,69 +662,6 @@ public:
DenseTensorAttributeMockIndex() : Fixture(vec_2d_spec, true, true, true) {}
};
-TEST_F("blueprint takes global filter into account", DenseTensorAttributeMockIndex)
-{
- using vespalib::tensor::DenseTensorView;
- using search::queryeval::NearestNeighborBlueprint;
- using search::queryeval::GlobalFilter;
-
- f.set_tensor(1, vec_2d(1, 1));
- f.set_tensor(2, vec_2d(2, 2));
- f.set_tensor(3, vec_2d(3, 3));
- f.set_tensor(4, vec_2d(4, 4));
- f.set_tensor(5, vec_2d(5, 5));
- f.set_tensor(6, vec_2d(6, 6));
- f.set_tensor(7, vec_2d(7, 7));
- f.set_tensor(8, vec_2d(8, 8));
- f.set_tensor(9, vec_2d(9, 9));
- f.set_tensor(10, vec_2d(0, 0));
-
- search::queryeval::FieldSpec field("foo", 0, 0);
- auto bp = std::make_unique<NearestNeighborBlueprint>(field,
- f.as_dense_tensor(),
- createDenseTensor(vec_2d(17, 42)),
- 3, true, 5);
- EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
- EXPECT_TRUE(bp->may_approximate());
- auto empty_filter = GlobalFilter::create();
- bp->set_global_filter(*empty_filter);
- EXPECT_EQUAL(3u, bp->getState().estimate().estHits);
- EXPECT_TRUE(bp->may_approximate());
-
- bp = std::make_unique<NearestNeighborBlueprint>(field,
- f.as_dense_tensor(),
- createDenseTensor(vec_2d(17, 42)),
- 3, true, 5);
- EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
- EXPECT_TRUE(bp->may_approximate());
- auto filter = search::BitVector::create(11);
- filter->setBit(3);
- filter->invalidateCachedCount();
- auto strong_filter = GlobalFilter::create(std::move(filter));
- bp->set_global_filter(*strong_filter);
- EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
- EXPECT_FALSE(bp->may_approximate());
-
- bp = std::make_unique<NearestNeighborBlueprint>(field,
- f.as_dense_tensor(),
- createDenseTensor(vec_2d(17, 42)),
- 3, true, 5);
- EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
- EXPECT_TRUE(bp->may_approximate());
- filter = search::BitVector::create(11);
- filter->setBit(1);
- filter->setBit(3);
- filter->setBit(5);
- filter->setBit(7);
- filter->setBit(9);
- filter->setBit(11);
- filter->invalidateCachedCount();
- auto weak_filter = GlobalFilter::create(std::move(filter));
- bp->set_global_filter(*weak_filter);
- EXPECT_EQUAL(3u, bp->getState().estimate().estHits);
- EXPECT_TRUE(bp->may_approximate());
-}
-
TEST_F("setTensor() updates nearest neighbor index", DenseTensorAttributeMockIndex)
{
auto& index = f.mock_index();
@@ -859,5 +791,80 @@ TEST_F("Nearest neighbor index type is added to attribute file header", DenseTen
EXPECT_EQUAL("hnsw", header.getTag("nearest_neighbor_index").asString());
}
-TEST_MAIN() { TEST_RUN_ALL(); }
+class NearestNeighborBlueprintFixture : public DenseTensorAttributeMockIndex {
+public:
+ using QueryTensor = DenseTensor<double>;
+
+ NearestNeighborBlueprintFixture() {
+ set_tensor(1, vec_2d(1, 1));
+ set_tensor(2, vec_2d(2, 2));
+ set_tensor(3, vec_2d(3, 3));
+ set_tensor(4, vec_2d(4, 4));
+ set_tensor(5, vec_2d(5, 5));
+ set_tensor(6, vec_2d(6, 6));
+ set_tensor(7, vec_2d(7, 7));
+ set_tensor(8, vec_2d(8, 8));
+ set_tensor(9, vec_2d(9, 9));
+ set_tensor(10, vec_2d(0, 0));
+ }
+
+ std::unique_ptr<QueryTensor> createDenseTensor(const TensorSpec &spec) {
+ auto value = DefaultTensorEngine::ref().from_spec(spec);
+ QueryTensor *tensor = dynamic_cast<QueryTensor *>(value.get());
+ ASSERT_TRUE(tensor != nullptr);
+ value.release();
+ return std::unique_ptr<QueryTensor>(tensor);
+ }
+
+ std::unique_ptr<NearestNeighborBlueprint> make_blueprint() {
+ search::queryeval::FieldSpec field("foo", 0, 0);
+ auto bp = std::make_unique<NearestNeighborBlueprint>(
+ field,
+ as_dense_tensor(),
+ createDenseTensor(vec_2d(17, 42)),
+ 3, true, 5);
+ EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
+ EXPECT_TRUE(bp->may_approximate());
+ return bp;
+ }
+};
+TEST_F("NN blueprint handles empty filter", NearestNeighborBlueprintFixture)
+{
+ auto bp = f.make_blueprint();
+ auto empty_filter = GlobalFilter::create();
+ bp->set_global_filter(*empty_filter);
+ EXPECT_EQUAL(3u, bp->getState().estimate().estHits);
+ EXPECT_TRUE(bp->may_approximate());
+}
+
+TEST_F("NN blueprint handles strong filter", NearestNeighborBlueprintFixture)
+{
+ auto bp = f.make_blueprint();
+ auto filter = search::BitVector::create(11);
+ filter->setBit(3);
+ filter->invalidateCachedCount();
+ auto strong_filter = GlobalFilter::create(std::move(filter));
+ bp->set_global_filter(*strong_filter);
+ EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
+ EXPECT_FALSE(bp->may_approximate());
+}
+
+TEST_F("NN blueprint handles weak filter", NearestNeighborBlueprintFixture)
+{
+ auto bp = f.make_blueprint();
+ auto filter = search::BitVector::create(11);
+ filter->setBit(1);
+ filter->setBit(3);
+ filter->setBit(5);
+ filter->setBit(7);
+ filter->setBit(9);
+ filter->setBit(11);
+ filter->invalidateCachedCount();
+ auto weak_filter = GlobalFilter::create(std::move(filter));
+ bp->set_global_filter(*weak_filter);
+ EXPECT_EQUAL(3u, bp->getState().estimate().estHits);
+ EXPECT_TRUE(bp->may_approximate());
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }