diff options
author | Arne Juul <arnej@verizonmedia.com> | 2019-11-26 08:45:40 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2019-11-26 08:53:26 +0000 |
commit | dc9e53f2916813591198912b36ede01470a948d9 (patch) | |
tree | 19d63446ef512280649f7e058e7b2ce177adb58e /searchlib | |
parent | e4053e53aa0f7a481fbe85095ed5f58449317dcc (diff) |
refactor and check with non-null tensor
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp | 41 |
1 files changed, 23 insertions, 18 deletions
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 37f1b3af75f..8146bdfeaa5 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -38,6 +38,11 @@ std::unique_ptr<DenseTensorView> createTensor(const TensorSpec &spec) { return std::unique_ptr<DenseTensorView>(tensor); } +std::unique_ptr<DenseTensorView> createTensor(double v1, double v2) { + return createTensor(TensorSpec(denseSpec).add({{"x", 0}}, v1) + .add({{"x", 1}}, v2)); +} + struct Fixture { using BasicType = search::attribute::BasicType; @@ -84,22 +89,16 @@ struct Fixture } void setTensor(uint32_t docId, double v1, double v2) { - auto t = createTensor(TensorSpec(denseSpec) - .add({{"x", 0}}, v1) - .add({{"x", 1}}, v2)); + auto t = createTensor(v1, v2); setTensor(docId, *t); } }; template <bool strict> -SimpleResult find_matches(Fixture &env) { +SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) { auto md = MatchData::makeTestInstance(2, 2); - auto qt = createTensor(TensorSpec(denseSpec)); - auto &tfmd = *(md->resolveTermField(0)); - const DenseTensorView &qtv = *qt; auto &attr = *(env._tensorAttr); - NearestNeighborDistanceHeap dh(2); auto search = NearestNeighborIteratorFactory::createIterator(strict, tfmd, qtv, attr, dh); if (strict) { @@ -118,19 +117,24 @@ TEST("require that NearestNeighborIterator returns expected results") { fixture.setTensor(4, 4.0, 3.0); fixture.setTensor(5, 8.0, 6.0); fixture.setTensor(6, 4.0, 3.0); - SimpleResult result = find_matches<true>(fixture); - SimpleResult expect({1,2,4,6}); - EXPECT_EQUAL(result, expect); - result = find_matches<false>(fixture); - EXPECT_EQUAL(result, expect); + auto nullTensor = createTensor(0.0, 0.0); + SimpleResult result = find_matches<true>(fixture, *nullTensor); + SimpleResult nullExpect({1,2,4,6}); + EXPECT_EQUAL(result, nullExpect); + result = find_matches<false>(fixture, *nullTensor); + EXPECT_EQUAL(result, nullExpect); + auto farTensor = createTensor(9.0, 9.0); + SimpleResult farExpect({1,2,3,5}); + result = find_matches<true>(fixture, *farTensor); + EXPECT_EQUAL(result, farExpect); + result = find_matches<false>(fixture, *farTensor); + EXPECT_EQUAL(result, farExpect); } template <bool strict> -std::vector<feature_t> get_rawscores(Fixture &env) { +std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) { auto md = MatchData::makeTestInstance(2, 2); - auto qt = createTensor(TensorSpec(denseSpec)); auto &tfmd = *(md->resolveTermField(0)); - const DenseTensorView &qtv = *qt; auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); auto search = NearestNeighborIteratorFactory::createIterator(strict, tfmd, qtv, attr, dh); @@ -165,10 +169,11 @@ TEST("require that NearestNeighborIterator sets expected rawscore") { fixture.setTensor(4, 5.0, 12.0); fixture.setTensor(5, 8.0, 6.0); fixture.setTensor(6, 4.0, 3.0); - std::vector<feature_t> got = get_rawscores<true>(fixture); + auto nullTensor = createTensor(0.0, 0.0); + std::vector<feature_t> got = get_rawscores<true>(fixture, *nullTensor); std::vector<feature_t> expected{5.0, 13.0, 10.0, 10.0, 5.0}; EXPECT_EQUAL(got, expected); - got = get_rawscores<false>(fixture); + got = get_rawscores<false>(fixture, *nullTensor); EXPECT_EQUAL(got, expected); } |