aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/ann/verify-top-k.h
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/ann/verify-top-k.h')
-rw-r--r--eval/src/tests/ann/verify-top-k.h27
1 files changed, 27 insertions, 0 deletions
diff --git a/eval/src/tests/ann/verify-top-k.h b/eval/src/tests/ann/verify-top-k.h
new file mode 100644
index 00000000000..220c273d017
--- /dev/null
+++ b/eval/src/tests/ann/verify-top-k.h
@@ -0,0 +1,27 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+int verify_top_k(const TopK &perfect, const TopK &result, uint32_t sk, uint32_t qid) {
+ int recall = perfect.recall(result);
+ EXPECT_TRUE(recall > 40);
+ double sum_error = 0.0;
+ double c_factor = 1.0;
+ for (size_t i = 0; i < result.K; ++i) {
+ double factor = (result.hits[i].distance / perfect.hits[i].distance);
+ if (factor < 0.99 || factor > 25) {
+ fprintf(stderr, "hit[%zu] got distance %.3f, expected %.3f\n",
+ i, result.hits[i].distance, perfect.hits[i].distance);
+ }
+ sum_error += factor;
+ c_factor = std::max(c_factor, factor);
+ }
+ EXPECT_TRUE(c_factor < 1.5);
+ fprintf(stderr, "quality sk=%u: query %u: recall %d c2-factor %.3f avg c2: %.3f\n",
+ sk, qid, recall, c_factor, sum_error / result.K);
+ return recall;
+}
+
+int verify_nns_quality(uint32_t sk, NNS_API &nns, uint32_t qid) {
+ TopK perfect = bruteforceResults[qid];
+ TopK result = find_with_nns(sk, nns, qid);
+ return verify_top_k(perfect, result, sk, qid);
+}