From f093b271f1f6aafa37079a889ae5d621db275dcb Mon Sep 17 00:00:00 2001 From: Geir Storli Date: Thu, 5 Dec 2019 10:04:52 +0000 Subject: Allow nearest neighbor operator where attribute tensor and query tensor have different cell types (float vs double). --- .../searchers/ValidateNearestNeighborTestCase.java | 53 ++++++++++++++++------ 1 file changed, 39 insertions(+), 14 deletions(-) (limited to 'container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java') diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index 1add8c09075..871d9285071 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -63,15 +63,20 @@ public class ValidateNearestNeighborTestCase { } private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])"); + private static TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])"); private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor(x[3])"); private static TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])"); private static TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})"); private Tensor makeTensor(TensorType tensorType) { + return makeTensor(tensorType, 3); + } + + private Tensor makeTensor(TensorType tensorType, int numCells) { Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); double dv = 1.0; String tensorDimension = "x"; - for (long label = 0; label < 3; label++) { + for (long label = 0; label < numCells; label++) { tensorBuilder.cell() .label(tensorDimension, label) .value(dv); @@ -94,22 +99,42 @@ public class ValidateNearestNeighborTestCase { return tensorBuilder.build(); } + private String makeQuery(String attributeTensor, String queryTensor) { + return "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(" + attributeTensor + ", " + queryTensor + ");"; + } + @Test - public void testValidQueryDV() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);"; + public void testValidQueryDoubleVectors() { + String q = makeQuery("dvector", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); assertNull(r.hits().getError()); } @Test - public void testValidQueryFV() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);"; + public void testValidQueryFloatVectors() { + String q = makeQuery("fvector", "qvector"); + Tensor t = makeTensor(tt_dense_fvector_3); + Result r = doSearch(searcher, q, t); + assertNull(r.hits().getError()); + } + + @Test + public void testValidQueryDoubleVectorAgainstFloatVector() { + String q = makeQuery("dvector", "qvector"); Tensor t = makeTensor(tt_dense_fvector_3); Result r = doSearch(searcher, q, t); assertNull(r.hits().getError()); } + @Test + public void testValidQueryFloatVectorAgainstDoubleVector() { + String q = makeQuery("fvector", "qvector"); + Tensor t = makeTensor(tt_dense_dvector_3); + Result r = doSearch(searcher, q, t); + assertNull(r.hits().getError()); + } + private static void assertErrMsg(String message, Result r) { assertEquals(ErrorMessage.createIllegalQuery(message), r.hits().getError()); } @@ -124,7 +149,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testMissingQueryTensor() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,foo);"; + String q = makeQuery("dvector", "foo"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=foo,targetNumHits=1} query tensor not found", r); @@ -132,7 +157,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testQueryTensorWrongType() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(dvector,qvector);"; + String q = makeQuery("dvector", "qvector"); Result r = doSearch(searcher, q, "tensor string"); assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} query tensor should be a tensor, was: class java.lang.String", r); r = doSearch(searcher, q, null); @@ -141,15 +166,15 @@ public class ValidateNearestNeighborTestCase { @Test public void testWrongTensorType() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(fvector,qvector);"; - Tensor t = makeTensor(tt_dense_dvector_3); + String q = makeQuery("dvector", "qvector"); + Tensor t = makeTensor(tt_dense_dvector_2, 2); Result r = doSearch(searcher, q, t); - assertErrMsg("NEAREST_NEIGHBOR {field=fvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[3])", r); + assertErrMsg("NEAREST_NEIGHBOR {field=dvector,queryTensorName=qvector,targetNumHits=1} field type tensor(x[3]) does not match query tensor type tensor(x[2])", r); } @Test public void testNotAttribute() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(foo,qvector);"; + String q = makeQuery("foo", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=foo,queryTensorName=qvector,targetNumHits=1} field is not an attribute", r); @@ -157,7 +182,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testWrongAttributeType() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(simple,qvector);"; + String q = makeQuery("simple", "qvector"); Tensor t = makeTensor(tt_dense_dvector_3); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=simple,queryTensorName=qvector,targetNumHits=1} field is not a tensor", r); @@ -165,7 +190,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testSparseTensor() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(sparse,qvector);"; + String q = makeQuery("sparse", "qvector"); Tensor t = makeTensor(tt_sparse_vector_x); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r); @@ -173,7 +198,7 @@ public class ValidateNearestNeighborTestCase { @Test public void testMatrix() { - String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(matrix,qvector);"; + String q = makeQuery("matrix", "qvector"); Tensor t = makeMatrix(tt_dense_matrix_xy); Result r = doSearch(searcher, q, t); assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r); -- cgit v1.2.3