summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-06-14 16:51:31 +0000
committerArne Juul <arnej@verizonmedia.com>2021-06-14 16:53:30 +0000
commit67b7cdb8fefde4dc5843d646ce9fbd33acb37f7f (patch)
tree2d6dc4b6c56863973a72c977a46d38941f181e74 /container-search
parent05617553a3f8d563c1955b82e81c54f4cb2136b5 (diff)
allow multiple tensor types for same name
* with several document types you can have fields with the same name but different tensor types that all should be allowed as target for nearestNeighbor operator.
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java44
-rw-r--r--container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java26
2 files changed, 53 insertions, 17 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
index ca9d17cb656..d22dd2e6af6 100644
--- a/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
+++ b/container-search/src/main/java/com/yahoo/search/searchers/ValidateNearestNeighborSearcher.java
@@ -19,6 +19,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.AttributesConfig;
import com.yahoo.yolean.chain.Before;
+import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@@ -32,15 +33,17 @@ import java.util.Optional;
@Before(GroupingExecutor.COMPONENT_NAME) // Must happen before query.prepare()
public class ValidateNearestNeighborSearcher extends Searcher {
- private final Map<String, TensorType> validAttributes = new HashMap<>();
+ private final Map<String, List<TensorType>> validAttributes = new HashMap<>();
public ValidateNearestNeighborSearcher(AttributesConfig attributesConfig) {
for (AttributesConfig.Attribute a : attributesConfig.attribute()) {
- TensorType tt = null;
+ if (! validAttributes.containsKey(a.name())) {
+ validAttributes.put(a.name(), new ArrayList<TensorType>());
+ }
if (a.datatype() == AttributesConfig.Attribute.Datatype.TENSOR) {
- tt = TensorType.fromSpec(a.tensortype());
+ TensorType tt = TensorType.fromSpec(a.tensortype());
+ validAttributes.get(a.name()).add(tt);
}
- validAttributes.put(a.name(), tt);
}
}
@@ -60,10 +63,10 @@ public class ValidateNearestNeighborSearcher extends Searcher {
public Optional<ErrorMessage> errorMessage = Optional.empty();
- private final Map<String, TensorType> validAttributes;
+ private final Map<String, List<TensorType>> validAttributes;
private final Query query;
- public NNVisitor(RankProperties rankProperties, Map<String, TensorType> validAttributes, Query query) {
+ public NNVisitor(RankProperties rankProperties, Map<String, List<TensorType>> validAttributes, Query query) {
this.validAttributes = validAttributes;
this.query = query;
}
@@ -101,17 +104,26 @@ public class ValidateNearestNeighborSearcher extends Searcher {
if (queryTensor.isEmpty())
return item + " requires a tensor rank feature " + queryFeatureName + " but this is not present";
- if ( ! validAttributes.containsKey(item.getIndexName()))
+ if ( ! validAttributes.containsKey(item.getIndexName())) {
return item + " field is not an attribute";
- TensorType fieldType = validAttributes.get(item.getIndexName());
- if (fieldType == null) return item + " field is not a tensor";
- if ( ! isDenseVector(fieldType))
- return item + " tensor type " + fieldType + " is not a dense vector";
-
- if ( ! isCompatible(fieldType, queryTensor.get().type()))
- return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type();
-
- return null;
+ }
+ List<TensorType> allTensorTypes = validAttributes.get(item.getIndexName());
+ for (TensorType fieldType : allTensorTypes) {
+ if (isDenseVector(fieldType) && isCompatible(fieldType, queryTensor.get().type())) {
+ return null;
+ }
+ }
+ for (TensorType fieldType : allTensorTypes) {
+ if (isDenseVector(fieldType) && ! isCompatible(fieldType, queryTensor.get().type())) {
+ return item + " field type " + fieldType + " does not match query type " + queryTensor.get().type();
+ }
+ }
+ for (TensorType fieldType : allTensorTypes) {
+ if (! isDenseVector(fieldType)) {
+ return item + " tensor type " + fieldType + " is not a dense vector";
+ }
+ }
+ return item + " field is not a tensor";
}
@Override
diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
index 72956b5b6eb..e5ed6f89fd4 100644
--- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
+++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java
@@ -51,10 +51,20 @@ public class ValidateNearestNeighborTestCase {
"attribute[3].tensortype tensor(x{})\n" +
"attribute[4].name matrix\n" +
"attribute[4].datatype TENSOR\n" +
- "attribute[4].tensortype tensor(x[3],y[1])\n"
+ "attribute[4].tensortype tensor(x[3],y[1])\n" +
+ "attribute[5].name threetypes\n" +
+ "attribute[5].datatype TENSOR\n" +
+ "attribute[5].tensortype tensor(x[42])\n" +
+ "attribute[6].name threetypes\n" +
+ "attribute[6].datatype TENSOR\n" +
+ "attribute[6].tensortype tensor(x[3])\n" +
+ "attribute[7].name threetypes\n" +
+ "attribute[7].datatype TENSOR\n" +
+ "attribute[7].tensortype tensor(x{})\n"
)));
}
+ private static TensorType tt_dense_dvector_42 = TensorType.fromSpec("tensor(x[42])");
private static TensorType tt_dense_dvector_3 = TensorType.fromSpec("tensor(x[3])");
private static TensorType tt_dense_dvector_2 = TensorType.fromSpec("tensor(x[2])");
private static TensorType tt_dense_fvector_3 = TensorType.fromSpec("tensor<float>(x[3])");
@@ -186,6 +196,20 @@ public class ValidateNearestNeighborTestCase {
}
@Test
+ public void testSeveralAttributesWithSameName() {
+ String q = makeQuery("threetypes", "qvector");
+ Tensor t1 = makeTensor(tt_dense_fvector_3);
+ Result r1 = doSearch(searcher, q, t1);
+ assertNull(r1.hits().getError());
+ Tensor t2 = makeTensor(tt_dense_dvector_42, 42);
+ Result r2 = doSearch(searcher, q, t2);
+ assertNull(r2.hits().getError());
+ Tensor t3 = makeTensor(tt_dense_dvector_2, 2);
+ Result r3 = doSearch(searcher, q, t3);
+ assertErrMsg(desc("threetypes", "qvector", 1, "field type tensor(x[42]) does not match query type tensor(x[2])"), r3);
+ }
+
+ @Test
public void testSparseTensor() {
String q = makeQuery("sparse", "qvector");
Tensor t = makeTensor(tt_sparse_vector_x);