summaryrefslogtreecommitdiffstats
path: root/searchlib/src/tests/tensor
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-12 11:22:25 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-12 11:22:25 +0000
commit05391413632841596bd0cd8e40389a185461a0af (patch)
tree4ed11de18d5a2ace1f4822e6edfb039e554ab4eb /searchlib/src/tests/tensor
parent791c4b163669d5ef8ea671be1efacb89655d3935 (diff)
fix NNS distance for new cell types
This reverts commit f167fe4362c5e4e20a7605b99205cfbee77c569a.
Diffstat (limited to 'searchlib/src/tests/tensor')
-rw-r--r--searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp6
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp3
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp3
3 files changed, 8 insertions, 4 deletions
diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
index ee0a2aff80e..6f8c9b4c885 100644
--- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
+++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp
@@ -10,6 +10,7 @@
LOG_SETUP("distance_function_test");
using namespace search::tensor;
+using vespalib::eval::Int8Float;
using vespalib::eval::TypedCells;
using search::attribute::DistanceMetric;
@@ -212,6 +213,11 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score)
EXPECT_DOUBLE_EQ(threshold, 0.5);
threshold = hamming->convert_threshold(1.0);
EXPECT_DOUBLE_EQ(threshold, 1.0);
+
+ std::vector<Int8Float> bytes_a = { 0, 1, 2, 4, 8, 16, 32, 64, -128, 0, 1, 2, 4, 8, 16, 32, 64, -128, 0, 1, 2 };
+ std::vector<Int8Float> bytes_b = { 1, 2, 2, 4, 8, 16, 32, 65, -128, 0, 1, 0, 4, 8, 16, 32, 64, -128, 0, 1, -1 };
+ // expect diff: 1 2 1 1 7
+ EXPECT_EQ(hamming->calc(TypedCells(bytes_a), TypedCells(bytes_b)), 12.0);
}
TEST(GeoDegreesTest, gives_expected_score)
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 20dc55df329..6ffe118aa65 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -51,7 +51,6 @@ struct LevelGenerator : public RandomLevelGenerator {
};
using FloatVectors = MyDocVectorAccess<float>;
-using FloatSqEuclideanDistance = SquaredEuclideanDistance<float>;
using HnswIndexUP = std::unique_ptr<HnswIndex>;
class HnswIndexTest : public ::testing::Test {
@@ -79,7 +78,7 @@ public:
void init(bool heuristic_select_neighbors) {
auto generator = std::make_unique<LevelGenerator>();
level_generator = generator.get();
- index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(),
+ index = std::make_unique<HnswIndex>(vectors, std::make_unique<SquaredEuclideanDistance>(),
std::move(generator),
HnswIndex::Config(5, 2, 10, 0, heuristic_select_neighbors));
}
diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
index 8e3bb95a776..7acdb4df983 100644
--- a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
@@ -117,7 +117,6 @@ public:
}
};
-using FloatSqEuclideanDistance = SquaredEuclideanDistance<float>;
using HnswIndexUP = std::unique_ptr<HnswIndex>;
class Stressor : public ::testing::Test {
@@ -232,7 +231,7 @@ public:
void init() {
uint32_t m = 16;
- index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(),
+ index = std::make_unique<HnswIndex>(vectors, std::make_unique<SquaredEuclideanDistance>(),
std::make_unique<InvLogLevelGenerator>(m),
HnswIndex::Config(2*m, m, 200, 10, true));
}