From 67b7cdb8fefde4dc5843d646ce9fbd33acb37f7f Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Mon, 14 Jun 2021 16:51:31 +0000 Subject: allow multiple tensor types for same name * with several document types you can have fields with the same name but different tensor types that all should be allowed as target for nearestNeighbor operator. --- .../searchers/ValidateNearestNeighborSearcher.java | 44 ++++++++++++++-------- .../searchers/ValidateNearestNeighborTestCase.java | 26 ++++++++++++- 2 files changed, 53 insertions(+), 17 deletions(-) (limited to 'container-search') diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java index ca9d17cb656..d22dd2e6af6 100644 --- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java +++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java @@ -19,6 +19,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.AttributesConfig; import com.yahoo.yolean.chain.Before; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,15 +33,17 @@ import java.util.Optional; @Before(GroupingExecutor.COMPONENT_NAME) // Must happen before query.prepare() public class ValidateNearestNeighborSearcher extends Searcher { - private final Map validAttributes = new HashMap<>(); + private final Map> validAttributes = new HashMap<>(); public ValidateNearestNeighborSearcher(AttributesConfig attributesConfig) { for (AttributesConfig.Attribute a : attributesConfig.attribute()) { - TensorType tt = null; + if (! validAttributes.containsKey(a.name())) { + validAttributes.put(a.name(), new ArrayList()); + } if (a.datatype() == AttributesConfig.Attribute.Datatype.TENSOR) { - tt = TensorType.fromSpec(a.tensortype()); + TensorType tt = TensorType.fromSpec(a.tensortype()); + validAttributes.get(a.name()).add(tt); } - validAttributes.put(a.name(), tt); } } @@ -60,10 +63,10 @@ public class ValidateNearestNeighborSearcher extends Searcher { public Optional errorMessage = Optional.empty(); - private final Map validAttributes; + private final Map> validAttributes; private final Query query; - public NNVisitor(RankProperties rankProperties, Map validAttributes, Query query) { + public NNVisitor(RankProperties rankProperties, Map> validAttributes, Query query) { this.validAttributes = validAttributes; this.query = query; } @@ -101,17 +104,26 @@ public class ValidateNearestNeighborSearcher extends Searcher { if (queryTensor.isEmpty()) return item + " requires a tensor rank feature " + queryFeatureName + " but this is not present"; - if ( ! validAttributes.containsKey(item.getIndexName())) + if ( ! validAttributes.containsKey(item.getIndexName())) { return item + " field is not an attribute"; - TensorType fieldType = validAttributes.get(item.getIndexName()); - if (fieldType == null) return item + " field is not a tensor"; - if ( ! isDenseVector(fieldType)) - return item + " tensor type " + fieldType + " is not a dense vector"; - - if ( ! isCompatible(fieldType, queryTensor.get().type())) - return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type(); - - return null; + } + List allTensorTypes = validAttributes.get(item.getIndexName()); + for (TensorType fieldType : allTensorTypes) { + if (isDenseVector(fieldType) && isCompatible(fieldType, queryTensor.get().type())) { + return null; + } + } + for (TensorType fieldType : allTensorTypes) { + if (isDenseVector(fieldType) && ! isCompatible(fieldType, queryTensor.get().type())) { + return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type(); + } + } + for (TensorType fieldType : allTensorTypes) { + if (! isDenseVector(fieldType)) { + return item + " tensor type " + fieldType + " is not a dense vector"; + } + } + return item + " field is not a tensor"; } @Override diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index 72956b5b6eb..e5ed6f89fd4 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -51,10 +51,20 @@ public class ValidateNearestNeighborTestCase { "attribute[3].tensortype tensor(x{})\n" + "attribute[4].name matrix\n" + "attribute[4].datatype TENSOR\n" + - "attribute[4].tensortype tensor(x[3],y[1])\n" + "attribute[4].tensortype tensor(x[3],y[1])\n" + + "attribute[5].name threetypes\n" + + "attribute[5].datatype TENSOR\n" + + "attribute[5].tensortype tensor(x[42])\n" + + "attribute[6].name threetypes\n" + + "attribute[6].datatype TENSOR\n" + + "attribute[6].tensortype tensor(x[3])\n" + + "attribute[7].name threetypes\n" + + "attribute[7].datatype TENSOR\n" + + "attribute[7].tensortype tensor(x{})\n" ))); } + private static TensorType tt_dense_dvector_42 = TensorType.fromSpec("tensor(x[42])"); private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])"); private static TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])"); private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor(x[3])"); @@ -185,6 +195,20 @@ public class ValidateNearestNeighborTestCase { assertErrMsg(desc("simple", "qvector", 1, "field is not a tensor"), r); } + @Test + public void testSeveralAttributesWithSameName() { + String q = makeQuery("threetypes", "qvector"); + Tensor t1 = makeTensor(tt_dense_fvector_3); + Result r1 = doSearch(searcher, q, t1); + assertNull(r1.hits().getError()); + Tensor t2 = makeTensor(tt_dense_dvector_42, 42); + Result r2 = doSearch(searcher, q, t2); + assertNull(r2.hits().getError()); + Tensor t3 = makeTensor(tt_dense_dvector_2, 2); + Result r3 = doSearch(searcher, q, t3); + assertErrMsg(desc("threetypes", "qvector", 1, "field type tensor(x[42]) does not match query type tensor(x[2])"), r3); + } + @Test public void testSparseTensor() { String q = makeQuery("sparse", "qvector"); -- cgit v1.2.3