aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
diff options
context:
space:
mode:
authorGeir Storli <geirst@verizonmedia.com>2019-12-05 10:04:52 +0000
committerGeir Storli <geirst@verizonmedia.com>2019-12-05 13:30:46 +0000
commitf093b271f1f6aafa37079a889ae5d621db275dcb (patch)
tree840482cbb42ecffb42bdef1bea0d7647cf25984f /container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
parentc6484309cfe178a5d2610405460cfb0d4a89db4c (diff)
Allow nearest neighbor operator where attribute tensor and query tensor have different cell types (float vs double).
Diffstat (limited to 'container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java')
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java53
1 files changed, 39 insertions, 14 deletions
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index 1add8c09075..871d9285071 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -63,15 +63,20 @@ public class ValidateNearestNeighborTestCase {
}
private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])");
+ private static TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])");
private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor<float>(x[3])");
private static TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])");
private static TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})");
private Tensor makeTensor(TensorType tensorType) {
+ return makeTensor(tensorType, 3);
+ }
+
+ private Tensor makeTensor(TensorType tensorType, int numCells) {
Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType);
double dv = 1.0;
String tensorDimension = "x";
- for (long label = 0; label < 3; label++) {
+ for (long label = 0; label < numCells; label++) {
tensorBuilder.cell()
.label(tensorDimension, label)
.value(dv);
@@ -94,22 +99,42 @@ public class ValidateNearestNeighborTestCase {
return tensorBuilder.build();
}
+ private String makeQuery(String attributeTensor, String queryTensor) {
+ return "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ");";
+ }
+
@Test
- public void testValidQueryDV() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);";
+ public void testValidQueryDoubleVectors() {
+ String q = makeQuery("dvector", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
assertNull(r.hits().getError());
}
@Test
- public void testValidQueryFV() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);";
+ public void testValidQueryFloatVectors() {
+ String q = makeQuery("fvector", "qvector");
+ Tensor t = makeTensor(tt_dense_fvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertNull(r.hits().getError());
+ }
+
+ @Test
+ public void testValidQueryDoubleVectorAgainstFloatVector() {
+ String q = makeQuery("dvector", "qvector");
Tensor t = makeTensor(tt_dense_fvector_3);
Result r = doSearch(searcher, q, t);
assertNull(r.hits().getError());
}
+ @Test
+ public void testValidQueryFloatVectorAgainstDoubleVector() {
+ String q = makeQuery("fvector", "qvector");
+ Tensor t = makeTensor(tt_dense_dvector_3);
+ Result r = doSearch(searcher, q, t);
+ assertNull(r.hits().getError());
+ }
+
private static void assertErrMsg(String message, Result r) {
assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError());
}
@@ -124,7 +149,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testMissingQueryTensor() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,foo);";
+ String q = makeQuery("dvector", "foo");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r);
@@ -132,7 +157,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testQueryTensorWrongType() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);";
+ String q = makeQuery("dvector", "qvector");
Result r = doSearch(searcher, q, "tensor string");
assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r);
r = doSearch(searcher, q, null);
@@ -141,15 +166,15 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testWrongTensorType() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);";
- Tensor t = makeTensor(tt_dense_dvector_3);
+ String q = makeQuery("dvector", "qvector");
+ Tensor t = makeTensor(tt_dense_dvector_2, 2);
Result r = doSearch(searcher, q, t);
- assertErrMsg("NEAREST_NEIGHBOR {field=fvector,queryTensorName=qvector,targetNumHits=1} field type tensor<float>(x[3]) does not match query tensor type tensor(x[3])", r);
+ assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[2])", r);
}
@Test
public void testNotAttribute() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(foo,qvector);";
+ String q = makeQuery("foo", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r);
@@ -157,7 +182,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testWrongAttributeType() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(simple,qvector);";
+ String q = makeQuery("simple", "qvector");
Tensor t = makeTensor(tt_dense_dvector_3);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r);
@@ -165,7 +190,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testSparseTensor() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(sparse,qvector);";
+ String q = makeQuery("sparse", "qvector");
Tensor t = makeTensor(tt_sparse_vector_x);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r);
@@ -173,7 +198,7 @@ public class ValidateNearestNeighborTestCase {
@Test
public void testMatrix() {
- String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(matrix,qvector);";
+ String q = makeQuery("matrix", "qvector");
Tensor t = makeMatrix(tt_dense_matrix_xy);
Result r = doSearch(searcher, q, t);
assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r);