aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-12-05 10:04:52 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-12-05 13:30:46 +0000
commitf093b271f1f6aafa37079a889ae5d621db275dcb (patch)
tree840482cbb42ecffb42bdef1bea0d7647cf25984f
parentc6484309cfe178a5d2610405460cfb0d4a89db4c (diff)
Allow nearest neighbor operator where attribute tensor and query tensor have different cell types (float vs double).
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java6
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java53
-rw-r--r--searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp25
-rw-r--r--searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp43
-rw-r--r--searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp13
6 files changed, 115 insertions, 37 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
index 25c65783821..8cae081cada 100644
--- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
@@ -85,6 +85,10 @@ public class ValidateNearestNeighborSearcher extends Searcher {
errorMessage = Optional.of(ErrorMessage.createIllegalQuery(description));
}
+ private static boolean isCompatible(TensorType lhs, TensorType rhs) {
+ return lhs.dimensions().equals(rhs.dimensions());
+ }
+
private static boolean isDenseVector(TensorType tt) {
List<TensorType.Dimension> dims = tt.dimensions();
if (dims.size() != 1) return false;
@@ -126,7 +130,7 @@ public class ValidateNearestNeighborSearcher extends Searcher {
setError(item.toString() + " field is not a tensor");
return;
}
- if (! fTensorType.equals(qTensorType)) {
+ if (! isCompatible(fTensorType, qTensorType)) {
setError(item.toString() + " field type "+fTensorType+" does not match query tensor type "+qTensorType);
return;
}
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index 1add8c09075..871d9285071 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -63,15 +63,20 @@ public class ValidateNearestNeighborTestCase {
}
private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])");
+ private static TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])");
private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor<float>(x[3])");
private static TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])");
private static TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})");
private Tensor makeTensor(TensorType tensorType) {
+ return makeTensor(tensorType, 3);
+ }
+
+ private Tensor makeTensor(TensorType tensorType, int numCells) {
Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
double dv = 1.0;
String tensorDimension = "x";
- for (long label = 0; label < 3; label++) {
+ for (long label = 0; label < numCells; label++) {
tensorBuilder.cell()
.label(tensorDimension, label)
.value(dv);
@@ -94,22 +99,42 @@ public class ValidateNearestNeighborTestCase {
return tensorBuilder.build();
}
+ private String makeQuery(String attributeTensor, String queryTensor) {
+ return "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ");";
+ }
+
@Test
- public void testValidQueryDV() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);";
+ public void testValidQueryDoubleVectors() {
+ String q = makeQuery("dvector", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
assertNull(r.hits().getError());
}
@Test
- public void testValidQueryFV() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);";
+ public void testValidQueryFloatVectors() {
+ String q = makeQuery("fvector", "qvector");
+ Tensor t = makeTensor(tt_dense_fvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertNull(r.hits().getError());
+ }
+
+ @Test
+ public void testValidQueryDoubleVectorAgainstFloatVector() {
+ String q = makeQuery("dvector", "qvector");
Tensor t = makeTensor(tt_dense_fvector_3);
Result r = doSearch(searcher, q, t);
assertNull(r.hits().getError());
}
+ @Test
+ public void testValidQueryFloatVectorAgainstDoubleVector() {
+ String q = makeQuery("fvector", "qvector");
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertNull(r.hits().getError());
+ }
+
private static void assertErrMsg(String message, Result r) {
assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError());
}
@@ -124,7 +149,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testMissingQueryTensor() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,foo);";
+ String q = makeQuery("dvector", "foo");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r);
@@ -132,7 +157,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testQueryTensorWrongType() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);";
+ String q = makeQuery("dvector", "qvector");
Result r = doSearch(searcher, q, "tensor string");
assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r);
r = doSearch(searcher, q, null);
@@ -141,15 +166,15 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testWrongTensorType() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);";
- Tensor t = makeTensor(tt_dense_dvector_3);
+ String q = makeQuery("dvector", "qvector");
+ Tensor t = makeTensor(tt_dense_dvector_2, 2);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=fvector,queryTensorName=qvector,targetNumHits=1} field type tensor<float>(x[3]) does not match query tensor type tensor(x[3])", r);
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[2])", r);
}
@Test
public void testNotAttribute() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(foo,qvector);";
+ String q = makeQuery("foo", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r);
@@ -157,7 +182,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testWrongAttributeType() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(simple,qvector);";
+ String q = makeQuery("simple", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r);
@@ -165,7 +190,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testSparseTensor() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(sparse,qvector);";
+ String q = makeQuery("sparse", "qvector");
Tensor t = makeTensor(tt_sparse_vector_x);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r);
@@ -173,7 +198,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testMatrix() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(matrix,qvector);";
+ String q = makeQuery("matrix", "qvector");
Tensor t = makeMatrix(tt_dense_matrix_xy);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r);
diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
index 1418d1f0c97..8a63f20822f 100644
--- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
+++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp
@@ -295,19 +295,30 @@ public:
}
};
-TEST(AttributeBlueprintTest, nearest_neighbor_blueprint_is_created_by_attribute_blueprint_factory)
+void
+expect_nearest_neighbor_blueprint(const vespalib::string& attribute_tensor_type_spec, const TensorSpec& query_tensor)
{
- NearestNeighborFixture f(make_tensor_attribute(field, "tensor(x[2])"));
- TensorSpec dense_x_2 = TensorSpec("tensor(x[2])").add({{"x", 0}}, 3).add({{"x", 1}}, 5);
- f.set_query_tensor(dense_x_2);
+ NearestNeighborFixture f(make_tensor_attribute(field, attribute_tensor_type_spec));
+ f.set_query_tensor(query_tensor);
auto result = f.create_blueprint();
const auto& nearest = as_type<NearestNeighborBlueprint>(*result);
- EXPECT_EQ("tensor(x[2])", nearest.get_attribute_tensor().getTensorType().to_spec());
- EXPECT_EQ(dense_x_2, DefaultTensorEngine::ref().to_spec(nearest.get_query_tensor()));
+ EXPECT_EQ(attribute_tensor_type_spec, nearest.get_attribute_tensor().getTensorType().to_spec());
+ EXPECT_EQ(query_tensor, DefaultTensorEngine::ref().to_spec(nearest.get_query_tensor()));
EXPECT_EQ(7u, nearest.get_target_num_hits());
}
+TEST(AttributeBlueprintTest, nearest_neighbor_blueprint_is_created_by_attribute_blueprint_factory)
+{
+ TensorSpec x_2_double = TensorSpec("tensor(x[2])").add({{"x", 0}}, 3).add({{"x", 1}}, 5);
+ TensorSpec x_2_float = TensorSpec("tensor<float>(x[2])").add({{"x", 0}}, 3).add({{"x", 1}}, 5);
+
+ expect_nearest_neighbor_blueprint("tensor(x[2])", x_2_double);
+ expect_nearest_neighbor_blueprint("tensor<float>(x[2])", x_2_float);
+ expect_nearest_neighbor_blueprint("tensor(x[2])", x_2_float);
+ expect_nearest_neighbor_blueprint("tensor<float>(x[2])", x_2_double);
+}
+
void
expect_empty_blueprint(AttributeVector::SP attr, const TensorSpec& query_tensor, bool insert_query_tensor = true)
{
@@ -335,7 +346,7 @@ TEST(AttributeBlueprintTest, empty_blueprint_is_created_when_nearest_neighbor_te
expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2],y[2])")); // tensor type is not of order 1
expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])")); // query tensor not found
expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])"), sparse_x); // query tensor is not dense
- expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])"), dense_y_2); // tensor types are not equal
+ expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])"), dense_y_2); // tensor types are not compatible
expect_empty_blueprint(make_tensor_attribute(field, "tensor(x[2])"), dense_x_3); // tensor types are not same size
}
diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
index 25ff459c005..7bc582ab442 100644
--- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
+++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp
@@ -28,7 +28,8 @@ using vespalib::tensor::DefaultTensorEngine;
using namespace search::fef;
using namespace search::queryeval;
-vespalib::string denseSpec("tensor(x[2])");
+vespalib::string denseSpecDouble("tensor(x[2])");
+vespalib::string denseSpecFloat("tensor<float>(x[2])");
std::unique_ptr<DenseTensorView> createTensor(const TensorSpec &spec) {
auto value = DefaultTensorEngine::ref().from_spec(spec);
@@ -38,8 +39,8 @@ std::unique_ptr<DenseTensorView> createTensor(const TensorSpec &spec) {
return std::unique_ptr<DenseTensorView>(tensor);
}
-std::unique_ptr<DenseTensorView> createTensor(double v1, double v2) {
- return createTensor(TensorSpec(denseSpec).add({{"x", 0}}, v1)
+std::unique_ptr<DenseTensorView> createTensor(const vespalib::string& type_spec, double v1, double v2) {
+ return createTensor(TensorSpec(type_spec).add({{"x", 0}}, v1)
.add({{"x", 1}}, v2));
}
@@ -89,7 +90,7 @@ struct Fixture
}
void setTensor(uint32_t docId, double v1, double v2) {
- auto t = createTensor(v1, v2);
+ auto t = createTensor(_typeSpec, v1, v2);
setTensor(docId, *t);
}
};
@@ -108,8 +109,11 @@ SimpleResult find_matches(Fixture &env, const DenseTensorView &qtv) {
}
}
-TEST("require that NearestNeighborIterator returns expected results") {
- Fixture fixture(denseSpec);
+void
+verify_iterator_returns_expected_results(const vespalib::string& attribute_tensor_type_spec,
+ const vespalib::string& query_tensor_type_spec)
+{
+ Fixture fixture(attribute_tensor_type_spec);
fixture.ensureSpace(6);
fixture.setTensor(1, 3.0, 4.0);
fixture.setTensor(2, 6.0, 8.0);
@@ -117,13 +121,13 @@ TEST("require that NearestNeighborIterator returns expected results") {
fixture.setTensor(4, 4.0, 3.0);
fixture.setTensor(5, 8.0, 6.0);
fixture.setTensor(6, 4.0, 3.0);
- auto nullTensor = createTensor(0.0, 0.0);
+ auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0);
SimpleResult result = find_matches<true>(fixture, *nullTensor);
SimpleResult nullExpect({1,2,4,6});
EXPECT_EQUAL(result, nullExpect);
result = find_matches<false>(fixture, *nullTensor);
EXPECT_EQUAL(result, nullExpect);
- auto farTensor = createTensor(9.0, 9.0);
+ auto farTensor = createTensor(query_tensor_type_spec, 9.0, 9.0);
SimpleResult farExpect({1,2,3,5});
result = find_matches<true>(fixture, *farTensor);
EXPECT_EQUAL(result, farExpect);
@@ -131,6 +135,13 @@ TEST("require that NearestNeighborIterator returns expected results") {
EXPECT_EQUAL(result, farExpect);
}
+TEST("require that NearestNeighborIterator returns expected results") {
+ TEST_DO(verify_iterator_returns_expected_results(denseSpecDouble, denseSpecDouble));
+ TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecFloat));
+ TEST_DO(verify_iterator_returns_expected_results(denseSpecDouble, denseSpecFloat));
+ TEST_DO(verify_iterator_returns_expected_results(denseSpecFloat, denseSpecDouble));
+}
+
template <bool strict>
std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) {
auto md = MatchData::makeTestInstance(2, 2);
@@ -152,8 +163,11 @@ std::vector<feature_t> get_rawscores(Fixture &env, const DenseTensorView &qtv) {
return rv;
}
-TEST("require that NearestNeighborIterator sets expected rawscore") {
- Fixture fixture(denseSpec);
+void
+verify_iterator_sets_expected_rawscore(const vespalib::string& attribute_tensor_type_spec,
+ const vespalib::string& query_tensor_type_spec)
+{
+ Fixture fixture(attribute_tensor_type_spec);
fixture.ensureSpace(6);
fixture.setTensor(1, 3.0, 4.0);
fixture.setTensor(2, 5.0, 12.0);
@@ -161,7 +175,7 @@ TEST("require that NearestNeighborIterator sets expected rawscore") {
fixture.setTensor(4, 5.0, 12.0);
fixture.setTensor(5, 8.0, 6.0);
fixture.setTensor(6, 4.0, 3.0);
- auto nullTensor = createTensor(0.0, 0.0);
+ auto nullTensor = createTensor(query_tensor_type_spec, 0.0, 0.0);
std::vector<feature_t> got = get_rawscores<true>(fixture, *nullTensor);
std::vector<feature_t> expected{5.0, 13.0, 10.0, 10.0, 5.0};
EXPECT_EQUAL(got, expected);
@@ -169,4 +183,11 @@ TEST("require that NearestNeighborIterator sets expected rawscore") {
EXPECT_EQUAL(got, expected);
}
+TEST("require that NearestNeighborIterator sets expected rawscore") {
+ TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecDouble, denseSpecDouble));
+ TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecFloat, denseSpecFloat));
+ TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecDouble, denseSpecFloat));
+ TEST_DO(verify_iterator_sets_expected_rawscore(denseSpecFloat, denseSpecDouble));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
index 596e6894482..29298a05d6e 100644
--- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
+++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp
@@ -438,6 +438,13 @@ bool check_valid_diversity_attr(const IAttributeVector *attr) {
return (attr->hasEnum() || attr->isIntegerType() || attr->isFloatingPointType());
}
+bool
+is_compatible_for_nearest_neighbor(const vespalib::eval::ValueType& lhs,
+ const vespalib::eval::ValueType& rhs)
+{
+ return (lhs.dimensions() == rhs.dimensions());
+}
+
//-----------------------------------------------------------------------------
@@ -630,9 +637,8 @@ public:
return fail_nearest_neighbor_term(n, make_string("Query tensor is not a dense tensor (type=%s)",
query_tensor->type().to_spec().c_str()));
}
- if (dense_attr_tensor->getTensorType() != dense_query_tensor->type()) {
- // TODO: consider allowing different data types (float vs double).
- return fail_nearest_neighbor_term(n, make_string("Attribute tensor type (%s) and query tensor type (%s) are not equal",
+ if (!is_compatible_for_nearest_neighbor(dense_attr_tensor->getTensorType(), dense_query_tensor->type())) {
+ return fail_nearest_neighbor_term(n, make_string("Attribute tensor type (%s) and query tensor type (%s) are not compatible",
dense_attr_tensor->getTensorType().to_spec().c_str(), dense_query_tensor->type().to_spec().c_str()));
}
std::unique_ptr<DenseTensorView> dense_query_tensor_up(dense_query_tensor);
diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
index dcf599daa6a..212985a4b0a 100644
--- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
+++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_iterator.cpp
@@ -12,6 +12,17 @@ using CellType = vespalib::eval::ValueType::CellType;
namespace search::queryeval {
+namespace {
+
+bool
+is_compatible(const vespalib::eval::ValueType& lhs,
+ const vespalib::eval::ValueType& rhs)
+{
+ return (lhs.dimensions() == rhs.dimensions());
+}
+
+}
+
/**
* Search iterator for K nearest neighbor matching.
* Uses unpack() as feedback mechanism to track which matches actually became hits.
@@ -29,7 +40,7 @@ public:
_fieldTensor(params().tensorAttribute.getTensorType()),
_lastScore(0.0)
{
- assert(_fieldTensor.fast_type() == params().queryTensor.fast_type());
+ assert(is_compatible(_fieldTensor.fast_type(), params().queryTensor.fast_type()));
}
~NearestNeighborImpl();