diff options
author | Tor Egge <Tor.Egge@online.no> | 2021-05-21 17:27:35 +0200 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2021-05-21 18:56:04 +0200 |
commit | 6589ff74ba9b337ddc0bd04a9703da17baf4cb05 (patch) | |
tree | da83b186e39194f951f779ed206be85dad6ece73 /ann_benchmark | |
parent | cdef79c90ca43d9a2e620fda0178caca832055cb (diff) |
Add extra constructor argument to HNSW fixture to normalize vectors.
This allows inner dot product distance metric to be used instead of
angular distance metric.
Diffstat (limited to 'ann_benchmark')
3 files changed, 74 insertions, 8 deletions
diff --git a/ann_benchmark/src/tests/ann_benchmark/test_angular.py b/ann_benchmark/src/tests/ann_benchmark/test_angular.py new file mode 100644 index 00000000000..3e48a6bd970 --- /dev/null +++ b/ann_benchmark/src/tests/ann_benchmark/test_angular.py @@ -0,0 +1,41 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import pytest +import sys +import os +import math +sys.path.insert(0, os.path.abspath("../../vespa/ann_benchmark")) +from vespa_ann_benchmark import DistanceMetric, HnswIndexParams, HnswIndex + +class Fixture: + def __init__(self, normalize): + metric = DistanceMetric.InnerProduct if normalize else DistanceMetric.Angular + self.index = HnswIndex(2, HnswIndexParams(16, 200, metric, False), normalize) + self.index.set_vector(0, [1, 0]) + self.index.set_vector(1, [10, 10]) + + def find(self, k, value): + return self.index.find_top_k(k, value, k + 200) + + def run_test(self): + top = self.find(10, [1, 1]) + assert [top[0][0], top[1][0]] == [0, 1] + # Allow some rounding errors + epsilon = 5e-8 + assert abs((1 - top[0][1]) - math.sqrt(0.5)) < epsilon + assert abs((1 - top[1][1]) - 1) < epsilon + top2 = self.find(10, [0, 2]) + # Result is not sorted by distance + assert [top2[0][0], top2[1][0]] == [0, 1] + assert abs((1 - top2[0][1]) - 0) < epsilon + assert abs((1 - top2[1][1]) - math.sqrt(0.5)) < epsilon + assert 1 == self.find(1, [1, 1])[0][0] + assert 0 == self.find(1, [1, -1])[0][0] + +def test_find_angular(): + f = Fixture(False) + f.run_test() + +def test_find_angular_normalized(): + f = Fixture(True) + f.run_test() diff --git a/ann_benchmark/src/tests/ann_benchmark/test_euclidean.py b/ann_benchmark/src/tests/ann_benchmark/test_euclidean.py index 5bda19dbc0e..b6fb06cb029 100644 --- a/ann_benchmark/src/tests/ann_benchmark/test_euclidean.py +++ b/ann_benchmark/src/tests/ann_benchmark/test_euclidean.py @@ -9,7 +9,7 @@ from vespa_ann_benchmark import DistanceMetric, HnswIndexParams, HnswIndex class Fixture: def __init__(self): - self.index = HnswIndex(2, HnswIndexParams(16, 200, DistanceMetric.Euclidean, False)) + self.index = HnswIndex(2, HnswIndexParams(16, 200, DistanceMetric.Euclidean, False), False) def set(self, lid, value): self.index.set_vector(lid, value) diff --git a/ann_benchmark/src/vespa/ann_benchmark/vespa_ann_benchmark.cpp b/ann_benchmark/src/vespa/ann_benchmark/vespa_ann_benchmark.cpp index 470dd1939f7..3304e598862 100644 --- a/ann_benchmark/src/vespa/ann_benchmark/vespa_ann_benchmark.cpp +++ b/ann_benchmark/src/vespa/ann_benchmark/vespa_ann_benchmark.cpp @@ -65,11 +65,13 @@ class HnswIndex TensorAttribute* _tensor_attribute; const NearestNeighborIndex* _nearest_neighbor_index; size_t _dim_size; + bool _normalize_vectors; bool check_lid(uint32_t lid); bool check_value(const char *op, const std::vector<float>& value); + TypedCells get_typed_cells(const std::vector<float>& value, std::vector<float>& normalized_value); public: - HnswIndex(uint32_t dim_size, const HnswIndexParams &hnsw_index_params); + HnswIndex(uint32_t dim_size, const HnswIndexParams &hnsw_index_params, bool normalize_vectors); virtual ~HnswIndex(); void set_vector(uint32_t lid, const std::vector<float>& value); std::vector<float> get_vector(uint32_t lid); @@ -77,13 +79,14 @@ public: TopKResult find_top_k(uint32_t k, const std::vector<float>& value, uint32_t explore_k); }; -HnswIndex::HnswIndex(uint32_t dim_size, const HnswIndexParams &hnsw_index_params) +HnswIndex::HnswIndex(uint32_t dim_size, const HnswIndexParams &hnsw_index_params, bool normalize_vectors) : _tensor_type(ValueType::error_type()), _hnsw_index_params(hnsw_index_params), _attribute(), _tensor_attribute(nullptr), _nearest_neighbor_index(nullptr), - _dim_size(0u) + _dim_size(0u), + _normalize_vectors(normalize_vectors) { Config cfg(BasicType::TENSOR, CollectionType::SINGLE); _tensor_type = ValueType::from_spec(make_tensor_spec(dim_size)); @@ -122,6 +125,25 @@ HnswIndex::check_value(const char *op, const std::vector<float>& value) return true; } +TypedCells +HnswIndex::get_typed_cells(const std::vector<float>& value, std::vector<float>& normalized_value) +{ + if (!_normalize_vectors) { + return {&value[0], CellType::FLOAT, value.size()}; + } + float sum_of_squared = 0.0f; + for (auto elem : value) { + sum_of_squared += elem * elem; + } + float factor = 1.0f / (sqrtf(sum_of_squared) + 1e-40f); + normalized_value.reserve(value.size()); + normalized_value.clear(); + for (auto elem : value) { + normalized_value.emplace_back(elem * factor); + } + return {&normalized_value[0], CellType::FLOAT, normalized_value.size()}; +} + void HnswIndex::set_vector(uint32_t lid, const std::vector<float>& value) { @@ -134,7 +156,8 @@ HnswIndex::set_vector(uint32_t lid, const std::vector<float>& value) /* * Not thread safe against concurrent set_vector(). */ - TypedCells typed_cells(&value[0], CellType::FLOAT, value.size()); + std::vector<float> normalized_value; + auto typed_cells = get_typed_cells(value, normalized_value); DenseValueView tensor_view(_tensor_type, typed_cells); while (size_t(lid + lid_bias) >= _attribute->getNumDocs()) { uint32_t new_lid = 0; @@ -180,7 +203,8 @@ HnswIndex::find_top_k(uint32_t k, const std::vector<float>& value, uint32_t expl * read guard is not taken here. */ TopKResult result; - TypedCells typed_cells(&value[0], CellType::FLOAT, value.size()); + std::vector<float> normalized_value; + auto typed_cells = get_typed_cells(value, normalized_value); auto raw_result = _nearest_neighbor_index->find_top_k(k, typed_cells, explore_k, std::numeric_limits<double>::max()); result.reserve(raw_result.size()); switch (_hnsw_index_params.distance_metric()) { @@ -207,13 +231,14 @@ PYBIND11_MODULE(vespa_ann_benchmark, m) { py::enum_<DistanceMetric>(m, "DistanceMetric") .value("Euclidean", DistanceMetric::Euclidean) - .value("Angular", DistanceMetric::Angular); + .value("Angular", DistanceMetric::Angular) + .value("InnerProduct", DistanceMetric::InnerProduct); py::class_<HnswIndexParams>(m, "HnswIndexParams") .def(py::init<uint32_t, uint32_t, DistanceMetric, bool>()); py::class_<HnswIndex>(m, "HnswIndex") - .def(py::init<uint32_t, const HnswIndexParams&>()) + .def(py::init<uint32_t, const HnswIndexParams&, bool>()) .def("set_vector", &HnswIndex::set_vector) .def("get_vector", &HnswIndex::get_vector) .def("clear_vector", &HnswIndex::clear_vector) |