summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2019-11-26 08:45:40 +0000
committerArne Juul <arnej@verizonmedia.com>2019-11-26 08:53:26 +0000
commitdc9e53f2916813591198912b36ede01470a948d9 (patch)
tree19d63446ef512280649f7e058e7b2ce177adb58e /searchlib
parente4053e53aa0f7a481fbe85095ed5f58449317dcc (diff)
refactor and check with non-null tensor
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp41
1 files changed, 23 insertions, 18 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 37f1b3af75f..8146bdfeaa5 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -38,6 +38,11 @@ std::unique_ptr<DenseTensorView> createTensor(const TensorSpec &spec) {
return std::unique_ptr<DenseTensorView>(tensor);
}
+std::unique_ptr<DenseTensorView> createTensor(double v1, double v2) {
+ return createTensor(TensorSpec(denseSpec).add({{"x", 0}}, v1)
+ .add({{"x", 1}}, v2));
+}
+
struct Fixture
{
using BasicType = search::attribute::BasicType;
@@ -84,22 +89,16 @@ struct Fixture
}
void setTensor(uint32_t docId, double v1, double v2) {
- auto t = createTensor(TensorSpec(denseSpec)
- .add({{"x", 0}}, v1)
- .add({{"x", 1}}, v2));
+ auto t = createTensor(v1, v2);
setTensor(docId, *t);
}
};
template <bool strict>
-SimpleResult find_matches(Fixture &env) {
+SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
- auto qt = createTensor(TensorSpec(denseSpec));
-
auto &tfmd = *(md->resolveTermField(0));
- const DenseTensorView &qtv = *qt;
auto &attr = *(env._tensorAttr);
-
NearestNeighborDistanceHeap dh(2);
auto search = NearestNeighborIteratorFactory::createIterator(strict, tfmd, qtv, attr, dh);
if (strict) {
@@ -118,19 +117,24 @@ TEST("require that NearestNeighborIterator returns expected results") {
fixture.setTensor(4, 4.0, 3.0);
fixture.setTensor(5, 8.0, 6.0);
fixture.setTensor(6, 4.0, 3.0);
- SimpleResult result = find_matches<true>(fixture);
- SimpleResult expect({1,2,4,6});
- EXPECT_EQUAL(result, expect);
- result = find_matches<false>(fixture);
- EXPECT_EQUAL(result, expect);
+ auto nullTensor = createTensor(0.0, 0.0);
+ SimpleResult result = find_matches<true>(fixture, *nullTensor);
+ SimpleResult nullExpect({1,2,4,6});
+ EXPECT_EQUAL(result, nullExpect);
+ result = find_matches<false>(fixture, *nullTensor);
+ EXPECT_EQUAL(result, nullExpect);
+ auto farTensor = createTensor(9.0, 9.0);
+ SimpleResult farExpect({1,2,3,5});
+ result = find_matches<true>(fixture, *farTensor);
+ EXPECT_EQUAL(result, farExpect);
+ result = find_matches<false>(fixture, *farTensor);
+ EXPECT_EQUAL(result, farExpect);
}
template <bool strict>
-std::vector<feature_t> get_rawscores(Fixture &env) {
+std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
- auto qt = createTensor(TensorSpec(denseSpec));
auto &tfmd = *(md->resolveTermField(0));
- const DenseTensorView &qtv = *qt;
auto &attr = *(env._tensorAttr);
NearestNeighborDistanceHeap dh(2);
auto search = NearestNeighborIteratorFactory::createIterator(strict, tfmd, qtv, attr, dh);
@@ -165,10 +169,11 @@ TEST("require that NearestNeighborIterator sets expected rawscore") {
fixture.setTensor(4, 5.0, 12.0);
fixture.setTensor(5, 8.0, 6.0);
fixture.setTensor(6, 4.0, 3.0);
- std::vector<feature_t> got = get_rawscores<true>(fixture);
+ auto nullTensor = createTensor(0.0, 0.0);
+ std::vector<feature_t> got = get_rawscores<true>(fixture, *nullTensor);
std::vector<feature_t> expected{5.0, 13.0, 10.0, 10.0, 5.0};
EXPECT_EQUAL(got, expected);
- got = get_rawscores<false>(fixture);
+ got = get_rawscores<false>(fixture, *nullTensor);
EXPECT_EQUAL(got, expected);
}