From 39471a1bf1652a2736ffa1c6d2287c571a4d89f2 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Wed, 20 Nov 2019 21:02:56 +0000 Subject: update unit test after review --- .../searchers/ValidateNearestNeighborTestCase.java | 37 +++++++++++++++++++--- 1 file changed, 33 insertions(+), 4 deletions(-) (limited to 'container-search') diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index cd1849a3586..1add8c09075 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -1,3 +1,4 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.prelude.searcher; @@ -43,7 +44,7 @@ public class ValidateNearestNeighborTestCase { searcher = new ValidateNearestNeighborSearcher( ConfigGetter.getConfig(AttributesConfig.class, "raw:", - new RawSource("attribute[4]\n" + + new RawSource("attribute[5]\n" + "attribute[0].name simple\n" + "attribute[0].datatype INT32\n" + "attribute[1].name dvector\n" + @@ -54,20 +55,40 @@ public class ValidateNearestNeighborTestCase { "attribute[2].tensortype tensor(x[3])\n" + "attribute[3].name sparse\n" + "attribute[3].datatype TENSOR\n" + - "attribute[3].tensortype tensor(x{})" + "attribute[3].tensortype tensor(x{})\n" + + "attribute[4].name matrix\n" + + "attribute[4].datatype TENSOR\n" + + "attribute[4].tensortype tensor(x[3],y[1])\n" ))); } private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])"); private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor(x[3])"); + private static TensorType tt_dense_matrix_xy = TensorType.fromSpec("tensor(x[3],y[1])"); private static TensorType tt_sparse_vector_x = TensorType.fromSpec("tensor(x{})"); private Tensor makeTensor(TensorType tensorType) { Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); double dv = 1.0; String tensorDimension = "x"; - for (int label = 0; label < 3; label++) { - tensorBuilder.cell().label(tensorDimension, Integer.toString(label)).value(dv); + for (long label = 0; label < 3; label++) { + tensorBuilder.cell() + .label(tensorDimension, label) + .value(dv); + dv += 1.0; + } + return tensorBuilder.build(); + } + + private Tensor makeMatrix(TensorType tensorType) { + Tensor.Builder tensorBuilder = Tensor.Builder.of(tensorType); + double dv = 1.0; + String tensorDimension = "x"; + for (long label = 0; label < 3; label++) { + tensorBuilder.cell() + .label("y", 0L) + .label(tensorDimension, label) + .value(dv); dv += 1.0; } return tensorBuilder.build(); @@ -150,6 +171,14 @@ public class ValidateNearestNeighborTestCase { assertErrMsg("NEAREST_NEIGHBOR {field=sparse,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x{}) is not a dense vector", r); } + @Test + public void testMatrix() { + String q = "select * from sources * where [{\"targetNumHits\":1}]nearestNeighbor(matrix,qvector);"; + Tensor t = makeMatrix(tt_dense_matrix_xy); + Result r = doSearch(searcher, q, t); + assertErrMsg("NEAREST_NEIGHBOR {field=matrix,queryTensorName=qvector,targetNumHits=1} tensor type tensor(x[3],y[1]) is not a dense vector", r); + } + private static Result doSearch(ValidateNearestNeighborSearcher searcher, String yqlQuery, Object qTensor) { QueryTree queryTree = new YqlParser(new ParserEnvironment()).parse(new Parsable().setQuery(yqlQuery)); Query query = new Query(); -- cgit v1.2.3