diff options
author | Geir Storli <geirst@verizonmedia.com> | 2019-12-05 16:00:59 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-05 16:00:59 +0100 |
commit | 7e0f43569a6c257fec2ece0eca8e638963c34f38 (patch) | |
tree | 8a49cfc9ca76a1e9afcd477c58ca1018f91650f0 | |
parent | ecc0bc3a84f470d4262f0be7e60b531bec966333 (diff) | |
parent | f093b271f1f6aafa37079a889ae5d621db275dcb (diff) |
Merge pull request #11514 from vespa-engine/geirst/nearest-neighbor-with-different-cell-types
Allow nearest neighbor operator where attribute tensor and query tens…
6 files changed, 115 insertions, 37 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index 25c65783821..8cae081cada 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -85,6 +85,10 @@ public class ValidateNearestNeighborSearcher extends Searcher { errorMessage = Optional.of(ErrorMessage.createIllegalQuery(description)); } + private static boolean isCompatible(TensorType lhs, TensorType rhs) { + return lhs.dimensions().equals(rhs.dimensions()); + } + private static boolean isDenseVector(TensorType tt) { List<TensorType.Dimension> dims = tt.dimensions(); if (dims.size() != 1) return false; @@ -126,7 +130,7 @@ public class ValidateNearestNeighborSearcher extends Searcher { setError(item.toString() + " field is not a tensor"); return; } - if (! fTensorType.equals(qTensorType)) { + if (! isCompatible(fTensorType, qTensorType)) { setError(item.toString() + " field type "+fTensorType+" does not match query tensor type "+qTensorType); return; } diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index 1add8c09075..871d9285071 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -63,15 +63,20 @@ public class ValidateNearestNeighborTestCase { } private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])"); + private static TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])"); private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor<float>(x[3])"); private static TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])"); private static TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})"); private Tensor makeTensor(TensorType tensorType) { + return makeTensor(tensorType, 3); + } + + private Tensor makeTensor(TensorType tensorType, int numCells) { Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); double dv = 1.0; String tensorDimension = "x"; - for (long label = 0; label < 3; label++) { + for (long label = 0; label < numCells; label++) { tensorBuilder.cell() .label(tensorDimension, label) .value(dv); @@ -94,22 +99,42 @@ public class ValidateNearestNeighborTestCase { return tensorBuilder.build(); } + private String makeQuery(String attributeTensor, String queryTensor) { + return "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ");"; + } + @Test - public void testValidQueryDV() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);"; + public void testValidQueryDoubleVectors() { + String q = makeQuery("dvector", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); assertNull(r.hits().getError()); } @Test - public void testValidQueryFV() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);"; + public void testValidQueryFloatVectors() { + String q = makeQuery("fvector", "qvector"); + Tensor t = makeTensor(tt_dense_fvector_3); + Result r = doSearch(searcher, q, t); + assertNull(r.hits().getError()); + } + + @Test + public void testValidQueryDoubleVectorAgainstFloatVector() { + String q = makeQuery("dvector", "qvector"); Tensor t = makeTensor(tt_dense_fvector_3); Result r = doSearch(searcher, q, t); assertNull(r.hits().getError()); } + @Test + public void testValidQueryFloatVectorAgainstDoubleVector() { + String q = makeQuery("fvector", "qvector"); + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertNull(r.hits().getError()); + } + private static void assertErrMsg(String message, Result r) { assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError()); } @@ -124,7 +149,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testMissingQueryTensor() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,foo);"; + String q = makeQuery("dvector", "foo"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r); @@ -132,7 +157,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testQueryTensorWrongType() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);"; + String q = makeQuery("dvector", "qvector"); Result r = doSearch(searcher, q, "tensor string"); assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r); r = doSearch(searcher, q, null); @@ -141,15 +166,15 @@ public class ValidateNearestNeighborTestCase { @Test public void testWrongTensorType() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);"; - Tensor t = makeTensor(tt_dense_dvector_3); + String q = makeQuery("dvector", "qvector"); + Tensor t = makeTensor(tt_dense_dvector_2, 2); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=fvector,queryTensorName=qvector,targetNumHits=1} field type tensor<float>(x[3]) does not match query tensor type tensor(x[3])", r); + assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[2])", r); } @Test public void testNotAttribute() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(foo,qvector);"; + String q = makeQuery("foo", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r); @@ -157,7 +182,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testWrongAttributeType() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(simple,qvector);"; + String q = makeQuery("simple", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r); @@ -165,7 +190,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testSparseTensor() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(sparse,qvector);"; + String q = makeQuery("sparse", "qvector"); Tensor t = makeTensor(tt_sparse_vector_x); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r); @@ -173,7 +198,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testMatrix() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(matrix,qvector);"; + String q = makeQuery("matrix", "qvector"); Tensor t = makeMatrix(tt_dense_matrix_xy); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r); diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp index 1418d1f0c97..8a63f20822f 100644 --- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp @@ -295,19 +295,30 @@ public: } }; -TEST(AttributeBlueprintTest, nearest_neighbor_blueprint_is_created_by_attribute_blueprint_factory) +void +expect_nearest_neighbor_blueprint(const vespalib::string& attribute_tensor_type_spec, const TensorSpec& query_tensor) { - NearestNeighborFixture f(make_tensor_attribute(field, "tensor(x[2])")); - TensorSpec dense_x_2 = TensorSpec("tensor(x[2])").add({{"x", 0}}, 3).add({{"x", 1}}, 5); - f.set_query_tensor(dense_x_2); + NearestNeighborFixture f(make_tensor_attribute(field, attribute_tensor_type_spec)); + f.set_query_tensor(query_tensor); auto result = f.create_blueprint(); const auto& nearest = as_type<NearestNeighborBlueprint>(*result); - EXPECT_EQ("tensor(x[2])", nearest.get_attribute_tensor().getTensorType().to_spec()); - EXPECT_EQ(dense_x_2, DefaultTensorEngine::ref().to_spec(nearest.get_query_tensor())); + EXPECT_EQ(attribute_tensor_type_spec, nearest.get_attribute_tensor().getTensorType().to_spec()); + EXPECT_EQ(query_tensor, DefaultTensorEngine::ref().to_spec(nearest.get_query_tensor())); EXPECT_EQ(7u, nearest.get_target_num_hits()); } +TEST(AttributeBlueprintTest, nearest_neighbor_blueprint_is_created_by_attribute_blueprint_factory) +{ + TensorSpec x_2_double = TensorSpec("tensor(x[2])").add({{"x", 0}}, 3).add({{"x", 1}}, 5); + TensorSpec x_2_float = TensorSpec("tensor<float>(x[2])").add({{"x", 0}}, 3).add({{"x", 1}}, 5); + + expect_nearest_neighbor_blueprint("tensor(x[2])", x_2_double); + expect_nearest_neighbor_blueprint("tensor<float>(x[2])", x_2_float); + expect_nearest_neighbor_blueprint("tensor(x[2])", x_2_float); + expect_nearest_neighbor_blueprint("tensor<float>(x[2])", x_2_double); +} + void expect_empty_blueprint(AttributeVector::SP attr, const TensorSpec& query_tensor, bool insert_query_tensor = true) { @@ -335,7 +346,7 @@ TEST(AttributeBlueprintTest, empty_blueprint_is_created_when_nearest_neighbor_te expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2],y[2])")); // tensor type is not of order 1 expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])")); // query tensor not found expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])"), sparse_x); // query tensor is not dense - expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])"), dense_y_2); // tensor types are not equal + expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])"), dense_y_2); // tensor types are not compatible expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])"), dense_x_3); // tensor types are not same size } diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index 25ff459c005..7bc582ab442 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -28,7 +28,8 @@ using vespalib::tensor::DefaultTensorEngine; using namespace search::fef; using namespace search::queryeval; -vespalib::string denseSpec("tensor(x[2])"); +vespalib::string denseSpecDouble("tensor(x[2])"); +vespalib::string denseSpecFloat("tensor<float>(x[2])"); std::unique_ptr<DenseTensorView> createTensor(const TensorSpec &spec) { auto value = DefaultTensorEngine::ref().from_spec(spec); @@ -38,8 +39,8 @@ std::unique_ptr<DenseTensorView> createTensor(const TensorSpec &spec) { return std::unique_ptr<DenseTensorView>(tensor); } -std::unique_ptr<DenseTensorView> createTensor(double v1, double v2) { - return createTensor(TensorSpec(denseSpec).add({{"x", 0}}, v1) +std::unique_ptr<DenseTensorView> createTensor(const vespalib::string& type_spec, double v1, double v2) { + return createTensor(TensorSpec(type_spec).add({{"x", 0}}, v1) .add({{"x", 1}}, v2)); } @@ -89,7 +90,7 @@ struct Fixture } void setTensor(uint32_t docId, double v1, double v2) { - auto t = createTensor(v1, v2); + auto t = createTensor(_typeSpec, v1, v2); setTensor(docId, *t); } }; @@ -108,8 +109,11 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) { } } -TEST("require that NearestNeighborIterator returns expected results") { - Fixture fixture(denseSpec); +void +verify_iterator_returns_expected_results(const vespalib::string& attribute_tensor_type_spec, + const vespalib::string& query_tensor_type_spec) +{ + Fixture fixture(attribute_tensor_type_spec); fixture.ensureSpace(6); fixture.setTensor(1, 3.0, 4.0); fixture.setTensor(2, 6.0, 8.0); @@ -117,13 +121,13 @@ TEST("require that NearestNeighborIterator returns expected results") { fixture.setTensor(4, 4.0, 3.0); fixture.setTensor(5, 8.0, 6.0); fixture.setTensor(6, 4.0, 3.0); - auto nullTensor = createTensor(0.0, 0.0); + auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0); SimpleResult result = find_matches<true>(fixture, *nullTensor); SimpleResult nullExpect({1,2,4,6}); EXPECT_EQUAL(result, nullExpect); result = find_matches<false>(fixture, *nullTensor); EXPECT_EQUAL(result, nullExpect); - auto farTensor = createTensor(9.0, 9.0); + auto farTensor = createTensor(query_tensor_type_spec, 9.0, 9.0); SimpleResult farExpect({1,2,3,5}); result = find_matches<true>(fixture, *farTensor); EXPECT_EQUAL(result, farExpect); @@ -131,6 +135,13 @@ TEST("require that NearestNeighborIterator returns expected results") { EXPECT_EQUAL(result, farExpect); } +TEST("require that NearestNeighborIterator returns expected results") { + TEST_DO(verify_iterator_returns_expected_results(denseSpecDouble, denseSpecDouble)); + TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecFloat)); + TEST_DO(verify_iterator_returns_expected_results(denseSpecDouble, denseSpecFloat)); + TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecDouble)); +} + template <bool strict> std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) { auto md = MatchData::makeTestInstance(2, 2); @@ -152,8 +163,11 @@ std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) { return rv; } -TEST("require that NearestNeighborIterator sets expected rawscore") { - Fixture fixture(denseSpec); +void +verify_iterator_sets_expected_rawscore(const vespalib::string& attribute_tensor_type_spec, + const vespalib::string& query_tensor_type_spec) +{ + Fixture fixture(attribute_tensor_type_spec); fixture.ensureSpace(6); fixture.setTensor(1, 3.0, 4.0); fixture.setTensor(2, 5.0, 12.0); @@ -161,7 +175,7 @@ TEST("require that NearestNeighborIterator sets expected rawscore") { fixture.setTensor(4, 5.0, 12.0); fixture.setTensor(5, 8.0, 6.0); fixture.setTensor(6, 4.0, 3.0); - auto nullTensor = createTensor(0.0, 0.0); + auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0); std::vector<feature_t> got = get_rawscores<true>(fixture, *nullTensor); std::vector<feature_t> expected{5.0, 13.0, 10.0, 10.0, 5.0}; EXPECT_EQUAL(got, expected); @@ -169,4 +183,11 @@ TEST("require that NearestNeighborIterator sets expected rawscore") { EXPECT_EQUAL(got, expected); } +TEST("require that NearestNeighborIterator sets expected rawscore") { + TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecDouble, denseSpecDouble)); + TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecFloat, denseSpecFloat)); + TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecDouble, denseSpecFloat)); + TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecFloat, denseSpecDouble)); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 596e6894482..29298a05d6e 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -438,6 +438,13 @@ bool check_valid_diversity_attr(const IAttributeVector *attr) { return (attr->hasEnum() || attr->isIntegerType() || attr->isFloatingPointType()); } +bool +is_compatible_for_nearest_neighbor(const vespalib::eval::ValueType& lhs, + const vespalib::eval::ValueType& rhs) +{ + return (lhs.dimensions() == rhs.dimensions()); +} + //----------------------------------------------------------------------------- @@ -630,9 +637,8 @@ public: return fail_nearest_neighbor_term(n, make_string("Query tensor is not a dense tensor (type=%s)", query_tensor->type().to_spec().c_str())); } - if (dense_attr_tensor->getTensorType() != dense_query_tensor->type()) { - // TODO: consider allowing different data types (float vs double). - return fail_nearest_neighbor_term(n, make_string("Attribute tensor type (%s) and query tensor type (%s) are not equal", + if (!is_compatible_for_nearest_neighbor(dense_attr_tensor->getTensorType(), dense_query_tensor->type())) { + return fail_nearest_neighbor_term(n, make_string("Attribute tensor type (%s) and query tensor type (%s) are not compatible", dense_attr_tensor->getTensorType().to_spec().c_str(), dense_query_tensor->type().to_spec().c_str())); } std::unique_ptr<DenseTensorView> dense_query_tensor_up(dense_query_tensor); diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp index dcf599daa6a..212985a4b0a 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp @@ -12,6 +12,17 @@ using CellType = vespalib::eval::ValueType::CellType; namespace search::queryeval { +namespace { + +bool +is_compatible(const vespalib::eval::ValueType& lhs, + const vespalib::eval::ValueType& rhs) +{ + return (lhs.dimensions() == rhs.dimensions()); +} + +} + /** * Search iterator for K nearest neighbor matching. * Uses unpack() as feedback mechanism to track which matches actually became hits. @@ -29,7 +40,7 @@ public: _fieldTensor(params().tensorAttribute.getTensorType()), _lastScore(0.0) { - assert(_fieldTensor.fast_type() == params().queryTensor.fast_type()); + assert(is_compatible(_fieldTensor.fast_type(), params().queryTensor.fast_type())); } ~NearestNeighborImpl(); |