aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp')
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp22
1 files changed, 17 insertions, 5 deletions
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index e1bd47af358..cbdb2c9bd22 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -13,6 +13,7 @@
#include <vespa/searchlib/tensor/dense_tensor_attribute.h>
#include <vespa/searchlib/tensor/direct_tensor_attribute.h>
#include <vespa/searchlib/tensor/doc_vector_access.h>
+#include <vespa/searchlib/tensor/distance_functions.h>
#include <vespa/searchlib/tensor/hnsw_index.h>
#include <vespa/searchlib/tensor/nearest_neighbor_index.h>
#include <vespa/searchlib/tensor/nearest_neighbor_index_factory.h>
@@ -206,24 +207,32 @@ public:
_index_value = (reinterpret_cast<const int*>(buf.buffer()))[0];
return true;
}
- std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k) const override {
+ std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k,
+ double distance_threshold) const override
+ {
(void) k;
(void) vector;
(void) explore_k;
+ (void) distance_threshold;
return std::vector<Neighbor>();
}
std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector,
- const search::BitVector& filter, uint32_t explore_k) const override
+ const search::BitVector& filter, uint32_t explore_k,
+ double distance_threshold) const override
{
(void) k;
(void) vector;
(void) explore_k;
(void) filter;
+ (void) distance_threshold;
return std::vector<Neighbor>();
}
- const search::tensor::DistanceFunction *distance_function() const override { return nullptr; }
+ const search::tensor::DistanceFunction *distance_function() const override {
+ static search::tensor::SquaredEuclideanDistance<double> my_dist_fun;
+ return &my_dist_fun;
+ }
};
class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory {
@@ -563,7 +572,7 @@ void
Fixture::testCompaction()
{
if ((_traits.use_dense_tensor_attribute && _denseTensors) ||
- _traits.use_direct_tensor_attribute)
+ ! _traits.use_dense_tensor_attribute)
{
LOG(info, "Skipping compaction test for tensor '%s' which is using free-lists", _cfg.tensorType().to_spec().c_str());
return;
@@ -914,9 +923,12 @@ public:
field,
as_dense_tensor(),
createDenseTensor(vec_2d(17, 42)),
- 3, true, 5, brute_force_limit);
+ 3, true, 5,
+ 100100.25,
+ brute_force_limit);
EXPECT_EQUAL(11u, bp->getState().estimate().estHits);
EXPECT_TRUE(bp->may_approximate());
+ EXPECT_EQUAL(100100.25 * 100100.25, bp->get_distance_threshold());
return bp;
}
};