diff options
35 files changed, 264 insertions, 182 deletions
diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/AttributeSync.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/AttributeSync.java index cba3912ce49..16eb1a9b509 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/AttributeSync.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/AttributeSync.java @@ -1,5 +1,4 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.hosted.node.admin.task.util.file; import com.yahoo.vespa.hosted.node.admin.component.TaskContext; @@ -77,21 +76,21 @@ public class AttributeSync { context, "owner", owner, - () -> currentAttributes.get().owner(), + () -> currentAttributes.getOrThrow().owner(), path::setOwner); systemModified |= updateAttribute( context, "group", group, - () -> currentAttributes.get().group(), + () -> currentAttributes.getOrThrow().group(), path::setGroup); systemModified |= updateAttribute( context, "permissions", permissions, - () -> currentAttributes.get().permissions(), + () -> currentAttributes.getOrThrow().permissions(), path::setPermissions); return systemModified; diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCache.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCache.java index 12a9609f89c..a35bc844b8f 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCache.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCache.java @@ -1,5 +1,4 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.hosted.node.admin.task.util.file; import java.util.Optional; @@ -11,34 +10,23 @@ public class FileAttributesCache { private Optional<FileAttributes> attributes = Optional.empty(); public FileAttributesCache(UnixPath path) { - this.path = path; + this.path = path; } - public FileAttributes get() { - if (!attributes.isPresent()) { - attributes = Optional.of(path.getAttributes()); + public Optional<FileAttributes> get() { + if (attributes.isEmpty()) { + attributes = path.getAttributesIfExists(); } - return attributes.get(); + return attributes; } - public FileAttributes forceGet() { - attributes = Optional.empty(); - return get(); + public FileAttributes getOrThrow() { + return get().orElseThrow(); } - public boolean exists() { - if (attributes.isPresent()) { - return true; - } - - Optional<FileAttributes> attributes = path.getAttributesIfExists(); - if (attributes.isPresent()) { - // Might as well update this.attributes - this.attributes = attributes; - return true; - } else { - return false; - } + public Optional<FileAttributes> forceGet() { + attributes = Optional.empty(); + return get(); } } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCache.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCache.java index 43c9f7729ef..974ab68dc1c 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCache.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileContentCache.java @@ -1,5 +1,4 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.hosted.node.admin.task.util.file; import java.time.Instant; @@ -21,7 +20,7 @@ class FileContentCache { } byte[] get(Instant lastModifiedTime) { - if (!value.isPresent() || lastModifiedTime.compareTo(modifiedTime.get()) > 0) { + if (modifiedTime.isEmpty() || lastModifiedTime.isAfter(modifiedTime.get())) { value = Optional.of(path.readBytes()); modifiedTime = Optional.of(lastModifiedTime); } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSync.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSync.java index dc13ea1c9ab..7a11ecdf6ab 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSync.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileSync.java @@ -1,18 +1,14 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.hosted.node.admin.task.util.file; import com.yahoo.vespa.hosted.node.admin.component.TaskContext; -import java.nio.file.FileSystems; -import java.nio.file.Files; import java.nio.file.Path; +import java.time.Instant; import java.util.Arrays; import java.util.Optional; import java.util.logging.Logger; -import static com.yahoo.yolean.Exceptions.uncheck; - /** * Class to minimize resource usage with repetitive and mostly identical, idempotent, and * mutating file operations, e.g. setting file content, setting owner, etc. @@ -27,10 +23,12 @@ public class FileSync { private final UnixPath path; private final FileContentCache contentCache; + private final FileAttributesCache attributesCache; public FileSync(Path path) { this.path = new UnixPath(path); this.contentCache = new FileContentCache(this.path); + this.attributesCache = new FileAttributesCache(this.path); } public boolean convergeTo(TaskContext taskContext, PartialFileData partialFileData) { @@ -46,38 +44,41 @@ public class FileSync { * system is only modified if necessary (different). */ public boolean convergeTo(TaskContext taskContext, PartialFileData partialFileData, boolean atomicWrite) { - FileAttributesCache currentAttributes = new FileAttributesCache(path); + boolean modifiedSystem = false; - boolean modifiedSystem = maybeUpdateContent(taskContext, partialFileData.getContent(), currentAttributes, atomicWrite); + if (partialFileData.getContent().isPresent()) + modifiedSystem |= convergeTo(taskContext, partialFileData.getContent().get(), atomicWrite); AttributeSync attributeSync = new AttributeSync(path.toPath()).with(partialFileData); - modifiedSystem |= attributeSync.converge(taskContext, currentAttributes); + modifiedSystem |= attributeSync.converge(taskContext, this.attributesCache); return modifiedSystem; } - private boolean maybeUpdateContent(TaskContext taskContext, - Optional<byte[]> content, - FileAttributesCache currentAttributes, - boolean atomicWrite) { - if (!content.isPresent()) { - return false; - } + /** + * CPU, I/O, and memory usage is optimized for repeated calls with the same argument. + * + * @param atomicWrite Whether to write updates to a temporary file in the same directory, and atomically move it + * to path. Ensures the file cannot be read while in the middle of writing it. + * @return true if the content was written. Only modified if necessary (different). + */ + public boolean convergeTo(TaskContext taskContext, byte[] content, boolean atomicWrite) { + Optional<Instant> lastModifiedTime = attributesCache.forceGet().map(FileAttributes::lastModifiedTime); - if (!currentAttributes.exists()) { + if (lastModifiedTime.isEmpty()) { taskContext.recordSystemModification(logger, "Creating file " + path); path.createParents(); - writeBytes(content.get(), atomicWrite); - contentCache.updateWith(content.get(), currentAttributes.forceGet().lastModifiedTime()); + writeBytes(content, atomicWrite); + contentCache.updateWith(content, attributesCache.forceGet().orElseThrow().lastModifiedTime()); return true; } - if (Arrays.equals(content.get(), contentCache.get(currentAttributes.get().lastModifiedTime()))) { + if (Arrays.equals(content, contentCache.get(attributesCache.getOrThrow().lastModifiedTime()))) { return false; } else { taskContext.recordSystemModification(logger, "Patching file " + path); - writeBytes(content.get(), atomicWrite); - contentCache.updateWith(content.get(), currentAttributes.forceGet().lastModifiedTime()); + writeBytes(content, atomicWrite); + contentCache.updateWith(content, attributesCache.forceGet().orElseThrow().lastModifiedTime()); return true; } } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/MakeDirectory.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/MakeDirectory.java index 5f72ed7e9b8..e2366470f61 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/MakeDirectory.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/MakeDirectory.java @@ -19,12 +19,14 @@ public class MakeDirectory { private final UnixPath path; private final AttributeSync attributeSync; + private final FileAttributesCache attributesCache; private boolean createParents = false; public MakeDirectory(Path path) { this.path = new UnixPath(path); this.attributeSync = new AttributeSync(path); + this.attributesCache = new FileAttributesCache(this.path); } /** @@ -42,8 +44,8 @@ public class MakeDirectory { public boolean converge(TaskContext context) { boolean systemModified = false; - FileAttributesCache attributes = new FileAttributesCache(path); - if (attributes.exists()) { + Optional<FileAttributes> attributes = attributesCache.forceGet(); + if (attributes.isPresent()) { if (!attributes.get().isDirectory()) { throw new UncheckedIOException(new NotDirectoryException(path.toString())); } @@ -65,7 +67,7 @@ public class MakeDirectory { } } - systemModified |= attributeSync.converge(context, attributes); + systemModified |= attributeSync.converge(context, attributesCache); return systemModified; } diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/StoredInteger.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/StoredInteger.java index 79283983303..2d52622db0a 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/StoredInteger.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/StoredInteger.java @@ -3,10 +3,6 @@ package com.yahoo.vespa.hosted.node.admin.task.util.file; import com.yahoo.vespa.hosted.node.admin.component.TaskContext; -import java.io.IOException; -import java.io.UncheckedIOException; -import java.nio.file.Files; -import java.nio.file.NoSuchFileException; import java.nio.file.Path; import java.time.Instant; import java.util.Optional; @@ -23,50 +19,24 @@ public class StoredInteger implements Supplier<OptionalInt> { private static final Logger logger = Logger.getLogger(StoredInteger.class.getName()); - private final Path path; - private OptionalInt value; - private boolean hasBeenRead = false; - private Optional<Instant> lastModifiedTime; + private final UnixPath path; public StoredInteger(Path path) { - this.path = path; + this.path = new UnixPath(path); } @Override public OptionalInt get() { - if (!hasBeenRead) readValue(); - return value; + return path.readUtf8FileIfExists().stream().mapToInt(Integer::parseInt).findAny(); } public void write(TaskContext taskContext, int value) { - try { - Files.write(path, Integer.toString(value).getBytes()); - this.value = OptionalInt.of(value); - this.hasBeenRead = true; - this.lastModifiedTime = Optional.of(Instant.now()); - taskContext.log(logger, "Stored new integer in %s: %d", path, value); - } catch (IOException e) { - throw new UncheckedIOException("Failed to store integer in " + path, e); - } + path.writeUtf8File(Integer.toString(value)); + taskContext.log(logger, "Stored new integer in %s: %d", path, value); } public Optional<Instant> getLastModifiedTime() { - if (!hasBeenRead) readValue(); - return lastModifiedTime; - } - - private void readValue() { - try { - String value = new String(Files.readAllBytes(path)); - this.value = OptionalInt.of(Integer.parseInt(value)); - this.lastModifiedTime = Optional.of(Files.getLastModifiedTime(path).toInstant()); - } catch (NoSuchFileException e) { - this.value = OptionalInt.empty(); - this.lastModifiedTime = Optional.empty(); - } catch (IOException e) { - throw new UncheckedIOException("Failed to read integer in " + path, e); - } - hasBeenRead = true; + return path.getAttributesIfExists().map(FileAttributes::lastModifiedTime); } } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCacheTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCacheTest.java index 06192c9f308..88b33a7f9a8 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCacheTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/task/util/file/FileAttributesCacheTest.java @@ -1,5 +1,4 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - package com.yahoo.vespa.hosted.node.admin.task.util.file; import org.junit.Test; @@ -22,17 +21,17 @@ public class FileAttributesCacheTest { FileAttributesCache cache = new FileAttributesCache(unixPath); when(unixPath.getAttributesIfExists()).thenReturn(Optional.empty()); - assertFalse(cache.exists()); + assertFalse(cache.get().isPresent()); verify(unixPath, times(1)).getAttributesIfExists(); verifyNoMoreInteractions(unixPath); FileAttributes attributes = mock(FileAttributes.class); when(unixPath.getAttributesIfExists()).thenReturn(Optional.of(attributes)); - assertTrue(cache.exists()); + assertTrue(cache.get().isPresent()); verify(unixPath, times(1 + 1)).getAttributesIfExists(); verifyNoMoreInteractions(unixPath); - assertEquals(attributes, cache.get()); + assertEquals(attributes, cache.getOrThrow()); verifyNoMoreInteractions(unixPath); } }
\ No newline at end of file diff --git a/searchcore/src/tests/proton/matching/query_test.cpp b/searchcore/src/tests/proton/matching/query_test.cpp index 03f20876adc..a80b08badf8 100644 --- a/searchcore/src/tests/proton/matching/query_test.cpp +++ b/searchcore/src/tests/proton/matching/query_test.cpp @@ -272,6 +272,7 @@ public: void visit(ProtonWandTerm &) override {} void visit(ProtonPredicateQuery &) override {} void visit(ProtonRegExpTerm &) override {} + void visit(ProtonNearestNeighborTerm &) override {} }; void Test::requireThatTermsAreLookedUp() { @@ -423,6 +424,7 @@ public: void visit(ProtonWandTerm &) override {} void visit(ProtonPredicateQuery &) override {} void visit(ProtonRegExpTerm &) override {} + void visit(ProtonNearestNeighborTerm &) override {} }; void Test::requireThatTermDataIsFilledIn() { diff --git a/searchcore/src/tests/proton/matching/unpacking_iterators_optimizer/unpacking_iterators_optimizer_test.cpp b/searchcore/src/tests/proton/matching/unpacking_iterators_optimizer/unpacking_iterators_optimizer_test.cpp index c0418d82359..9ecbd532389 100644 --- a/searchcore/src/tests/proton/matching/unpacking_iterators_optimizer/unpacking_iterators_optimizer_test.cpp +++ b/searchcore/src/tests/proton/matching/unpacking_iterators_optimizer/unpacking_iterators_optimizer_test.cpp @@ -65,6 +65,7 @@ struct DumpQuery : QueryVisitor { void visit(WandTerm &) override {} void visit(PredicateQuery &) override {} void visit(RegExpTerm &) override {} + void visit(NearestNeighborTerm &) override {} }; std::string dump_query(Node &root) { diff --git a/searchcore/src/vespa/searchcore/proton/matching/blueprintbuilder.cpp b/searchcore/src/vespa/searchcore/proton/matching/blueprintbuilder.cpp index 7e55c8f778c..c8c5a3a427b 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/blueprintbuilder.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/blueprintbuilder.cpp @@ -159,6 +159,7 @@ protected: void visit(ProtonSuffixTerm &n) override { buildTerm(n); } void visit(ProtonPredicateQuery &n) override { buildTerm(n); } void visit(ProtonRegExpTerm &n) override { buildTerm(n); } + void visit(ProtonNearestNeighborTerm &n) override { buildTerm(n); } public: BlueprintBuilderVisitor(const IRequestContext & requestContext, ISearchContext &context) : diff --git a/searchcore/src/vespa/searchcore/proton/matching/querynodes.h b/searchcore/src/vespa/searchcore/proton/matching/querynodes.h index 6454845b247..d7ec24edb8f 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/querynodes.h +++ b/searchcore/src/vespa/searchcore/proton/matching/querynodes.h @@ -137,6 +137,7 @@ typedef ProtonTerm<search::query::DotProduct> ProtonDotProduct; typedef ProtonTerm<search::query::WandTerm> ProtonWandTerm; typedef ProtonTerm<search::query::PredicateQuery> ProtonPredicateQuery; typedef ProtonTerm<search::query::RegExpTerm> ProtonRegExpTerm; +typedef ProtonTerm<search::query::NearestNeighborTerm> ProtonNearestNeighborTerm; struct ProtonNodeTypes { typedef ProtonAnd And; @@ -161,6 +162,7 @@ struct ProtonNodeTypes { typedef ProtonWandTerm WandTerm; typedef ProtonPredicateQuery PredicateQuery; typedef ProtonRegExpTerm RegExpTerm; + typedef ProtonNearestNeighborTerm NearestNeighborTerm; }; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/same_element_builder.cpp b/searchcore/src/vespa/searchcore/proton/matching/same_element_builder.cpp index 241ab53874f..1e5df97a659 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/same_element_builder.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/same_element_builder.cpp @@ -67,6 +67,7 @@ public: void visit(ProtonSuffixTerm &n) override { visitTerm(n); } void visit(ProtonPredicateQuery &) override {} void visit(ProtonRegExpTerm &n) override { visitTerm(n); } + void visit(ProtonNearestNeighborTerm &) override {} }; } // namespace proton::matching::<unnamed> diff --git a/searchcore/src/vespa/searchcore/proton/matching/termdatafromnode.cpp b/searchcore/src/vespa/searchcore/proton/matching/termdatafromnode.cpp index 3fd4000bf9f..3184b5cc061 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/termdatafromnode.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/termdatafromnode.cpp @@ -42,6 +42,7 @@ struct TermDataFromTermVisitor void visit(ProtonSuffixTerm &n) override { visitTerm(n); } void visit(ProtonPredicateQuery &) override { } void visit(ProtonRegExpTerm &n) override { visitTerm(n); } + void visit(ProtonNearestNeighborTerm &n) override { visitTerm(n); } }; } // namespace diff --git a/searchcore/src/vespa/searchcore/proton/matching/unpacking_iterators_optimizer.cpp b/searchcore/src/vespa/searchcore/proton/matching/unpacking_iterators_optimizer.cpp index af355452c73..eada88010dd 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/unpacking_iterators_optimizer.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/unpacking_iterators_optimizer.cpp @@ -56,6 +56,7 @@ struct TermExpander : QueryVisitor { void visit(WandTerm &) override {} void visit(PredicateQuery &) override {} void visit(RegExpTerm &) override {} + void visit(NearestNeighborTerm &) override {} void flush(Intermediate &parent) { for (Node::UP &term: terms) { parent.append(std::move(term)); diff --git a/searchcorespi/src/vespa/searchcorespi/index/indexcollection.cpp b/searchcorespi/src/vespa/searchcorespi/index/indexcollection.cpp index b71927c714f..12128a3df18 100644 --- a/searchcorespi/src/vespa/searchcorespi/index/indexcollection.cpp +++ b/searchcorespi/src/vespa/searchcorespi/index/indexcollection.cpp @@ -189,7 +189,6 @@ private: void visit(ONear &) override { } void visit(SameElement &) override { } - void visit(WeightedSetTerm &n) override { visitTerm(n); } void visit(DotProduct &n) override { visitTerm(n); } void visit(WandTerm &n) override { visitTerm(n); } @@ -203,6 +202,7 @@ private: void visit(SuffixTerm &n) override { visitTerm(n); } void visit(PredicateQuery &n) override { visitTerm(n); } void visit(RegExpTerm &n) override { visitTerm(n); } + void visit(NearestNeighborTerm &n) override { visitTerm(n); } public: CreateBlueprintVisitor(const IIndexCollection &indexes, diff --git a/searchlib/src/tests/query/customtypevisitor_test.cpp b/searchlib/src/tests/query/customtypevisitor_test.cpp index c5eeac8543d..3f7d57b7aa4 100644 --- a/searchlib/src/tests/query/customtypevisitor_test.cpp +++ b/searchlib/src/tests/query/customtypevisitor_test.cpp @@ -54,6 +54,7 @@ struct MyDotProduct : DotProduct { MyDotProduct() : DotProduct("view", 0, Weight struct MyWandTerm : WandTerm { MyWandTerm() : WandTerm("view", 0, Weight(42), 57, 67, 77.7) {} }; struct MyPredicateQuery : InitTerm<PredicateQuery> {}; struct MyRegExpTerm : InitTerm<RegExpTerm> {}; +struct MyNearestNeighborTerm : NearestNeighborTerm {}; struct MyQueryNodeTypes { typedef MyAnd And; @@ -78,6 +79,7 @@ struct MyQueryNodeTypes { typedef MyWandTerm WandTerm; typedef MyPredicateQuery PredicateQuery; typedef MyRegExpTerm RegExpTerm; + typedef MyNearestNeighborTerm NearestNeighborTerm; }; class MyCustomVisitor : public CustomTypeVisitor<MyQueryNodeTypes> @@ -113,6 +115,7 @@ public: void visit(MyWandTerm &) override { setVisited<MyWandTerm>(); } void visit(MyPredicateQuery &) override { setVisited<MyPredicateQuery>(); } void visit(MyRegExpTerm &) override { setVisited<MyRegExpTerm>(); } + void visit(MyNearestNeighborTerm &) override { setVisited<MyNearestNeighborTerm>(); } }; template <class T> diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp index f8922c54a4e..39e381c0942 100644 --- a/searchlib/src/tests/query/query_visitor_test.cpp +++ b/searchlib/src/tests/query/query_visitor_test.cpp @@ -65,6 +65,7 @@ public: void visit(WandTerm &) override { isVisited<WandTerm>() = true; } void visit(PredicateQuery &) override { isVisited<PredicateQuery>() = true; } void visit(RegExpTerm &) override { isVisited<RegExpTerm>() = true; } + void visit(NearestNeighborTerm &) override { isVisited<NearestNeighborTerm>() = true; } }; template <class T> @@ -98,6 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() { checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0))); checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0))); checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0))); + checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123)); } } // namespace diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp index 6673a107d44..d47fb071d81 100644 --- a/searchlib/src/tests/query/querybuilder_test.cpp +++ b/searchlib/src/tests/query/querybuilder_test.cpp @@ -48,7 +48,7 @@ PredicateQueryTerm::UP getPredicateQueryTerm() { template <class NodeTypes> Node::UP createQueryTree() { QueryBuilder<NodeTypes> builder; - builder.addAnd(10); + builder.addAnd(11); { builder.addRank(2); { @@ -111,6 +111,7 @@ Node::UP createQueryTree() { builder.addStringTerm(str[5], view[5], id[5], weight[6]); builder.addStringTerm(str[6], view[6], id[6], weight[7]); } + builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7); } Node::UP node = builder.build(); ASSERT_TRUE(node.get()); @@ -140,6 +141,15 @@ bool checkTerm(const Term *term, const typename Term::Type &t, const string &f, EXPECT_EQUAL(use_position_data, term->usePositionData())); } +template <class NodeType> +NodeType* +as_node(Node* node) +{ + auto* result = dynamic_cast<NodeType*>(node); + ASSERT_TRUE(result != nullptr); + return result; +} + template <class NodeTypes> void checkQueryTreeTypes(Node *node) { typedef typename NodeTypes::And And; @@ -166,126 +176,114 @@ void checkQueryTreeTypes(Node *node) { typedef typename NodeTypes::RegExpTerm RegExpTerm; ASSERT_TRUE(node); - And *and_node = dynamic_cast<And *>(node); - ASSERT_TRUE(and_node); - EXPECT_EQUAL(10u, and_node->getChildren().size()); + auto* and_node = as_node<And>(node); + EXPECT_EQUAL(11u, and_node->getChildren().size()); - - Rank *rank = dynamic_cast<Rank *>(and_node->getChildren()[0]); - ASSERT_TRUE(rank); + auto* rank = as_node<Rank>(and_node->getChildren()[0]); EXPECT_EQUAL(2u, rank->getChildren().size()); - Near *near = dynamic_cast<Near *>(rank->getChildren()[0]); - ASSERT_TRUE(near); + auto* near = as_node<Near>(rank->getChildren()[0]); EXPECT_EQUAL(2u, near->getChildren().size()); EXPECT_EQUAL(distance, near->getDistance()); - StringTerm *string_term = dynamic_cast<StringTerm *>(near->getChildren()[0]); + auto* string_term = as_node<StringTerm>(near->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[0], view[0], id[0], weight[0])); - SubstringTerm *substring_term = dynamic_cast<SubstringTerm *>(near->getChildren()[1]); + auto* substring_term = as_node<SubstringTerm>(near->getChildren()[1]); EXPECT_TRUE(checkTerm(substring_term, str[1], view[1], id[1], weight[1])); - ONear *onear = dynamic_cast<ONear *>(rank->getChildren()[1]); - ASSERT_TRUE(onear); + auto* onear = as_node<ONear>(rank->getChildren()[1]); EXPECT_EQUAL(2u, onear->getChildren().size()); EXPECT_EQUAL(distance, onear->getDistance()); - SuffixTerm *suffix_term = dynamic_cast<SuffixTerm *>(onear->getChildren()[0]); + auto* suffix_term = as_node<SuffixTerm>(onear->getChildren()[0]); EXPECT_TRUE(checkTerm(suffix_term, str[2], view[2], id[2], weight[2])); - PrefixTerm *prefix_term = dynamic_cast<PrefixTerm *>(onear->getChildren()[1]); + auto* prefix_term = as_node<PrefixTerm>(onear->getChildren()[1]); EXPECT_TRUE(checkTerm(prefix_term, str[3], view[3], id[3], weight[3])); - - Or *or_node = dynamic_cast<Or *>(and_node->getChildren()[1]); - ASSERT_TRUE(or_node); + auto* or_node = as_node<Or>(and_node->getChildren()[1]); EXPECT_EQUAL(3u, or_node->getChildren().size()); - Phrase *phrase = dynamic_cast<Phrase *>(or_node->getChildren()[0]); - ASSERT_TRUE(phrase); + auto* phrase = as_node<Phrase>(or_node->getChildren()[0]); EXPECT_TRUE(phrase->isRanked()); EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent()); EXPECT_EQUAL(3u, phrase->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[0]); + string_term = as_node<StringTerm>(phrase->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[1]); + string_term = as_node<StringTerm>(phrase->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[2]); + string_term = as_node<StringTerm>(phrase->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[4])); - phrase = dynamic_cast<Phrase *>(or_node->getChildren()[1]); - ASSERT_TRUE(phrase); + phrase = as_node<Phrase>(or_node->getChildren()[1]); EXPECT_TRUE(!phrase->isRanked()); EXPECT_EQUAL(weight[4].percent(), phrase->getWeight().percent()); EXPECT_EQUAL(2u, phrase->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[0]); + string_term = as_node<StringTerm>(phrase->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(phrase->getChildren()[1]); + string_term = as_node<StringTerm>(phrase->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[4])); - AndNot *and_not = dynamic_cast<AndNot *>(or_node->getChildren()[2]); - ASSERT_TRUE(and_not); + auto* and_not = as_node<AndNot>(or_node->getChildren()[2]); EXPECT_EQUAL(2u, and_not->getChildren().size()); - NumberTerm *integer_term = dynamic_cast<NumberTerm *>(and_not->getChildren()[0]); + auto* integer_term = as_node<NumberTerm>(and_not->getChildren()[0]); EXPECT_TRUE(checkTerm(integer_term, int1, view[7], id[7], weight[7])); - NumberTerm *float_term = dynamic_cast<NumberTerm *>(and_not->getChildren()[1]); + auto* float_term = as_node<NumberTerm>(and_not->getChildren()[1]); EXPECT_TRUE(checkTerm(float_term, float1, view[8], id[8], weight[8], false)); - - RangeTerm *range_term = dynamic_cast<RangeTerm *>(and_node->getChildren()[2]); - ASSERT_TRUE(range_term); + auto* range_term = as_node<RangeTerm>(and_node->getChildren()[2]); EXPECT_TRUE(checkTerm(range_term, range, view[9], id[9], weight[9])); - LocationTerm *loc_term = dynamic_cast<LocationTerm *>(and_node->getChildren()[3]); - ASSERT_TRUE(loc_term); + auto* loc_term = as_node<LocationTerm>(and_node->getChildren()[3]); EXPECT_TRUE(checkTerm(loc_term, location, view[10], id[10], weight[10])); - WeakAnd *wand = dynamic_cast<WeakAnd *>(and_node->getChildren()[4]); - ASSERT_TRUE(wand != 0); + auto* wand = as_node<WeakAnd>(and_node->getChildren()[4]); EXPECT_EQUAL(123u, wand->getMinHits()); EXPECT_EQUAL(2u, wand->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(wand->getChildren()[0]); + string_term = as_node<StringTerm>(wand->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(wand->getChildren()[1]); + string_term = as_node<StringTerm>(wand->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5])); - PredicateQuery *predicateQuery = dynamic_cast<PredicateQuery *>(and_node->getChildren()[5]); - ASSERT_TRUE(predicateQuery); + auto* predicateQuery = as_node<PredicateQuery>(and_node->getChildren()[5]); PredicateQueryTerm::UP pqt(new PredicateQueryTerm); EXPECT_TRUE(checkTerm(predicateQuery, getPredicateQueryTerm(), view[3], id[3], weight[3])); - DotProduct *dotProduct = dynamic_cast<DotProduct *>(and_node->getChildren()[6]); - ASSERT_TRUE(dotProduct); + auto* dotProduct = as_node<DotProduct>(and_node->getChildren()[6]); EXPECT_EQUAL(3u, dotProduct->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[0]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[3], view[3], id[3], weight[3])); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[1]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[4])); - string_term = dynamic_cast<StringTerm *>(dotProduct->getChildren()[2]); + string_term = as_node<StringTerm>(dotProduct->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[5])); - WandTerm *wandTerm = dynamic_cast<WandTerm *>(and_node->getChildren()[7]); - ASSERT_TRUE(wandTerm); + auto* wandTerm = as_node<WandTerm>(and_node->getChildren()[7]); EXPECT_EQUAL(57u, wandTerm->getTargetNumHits()); EXPECT_EQUAL(67, wandTerm->getScoreThreshold()); EXPECT_EQUAL(77.7, wandTerm->getThresholdBoostFactor()); EXPECT_EQUAL(2u, wandTerm->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(wandTerm->getChildren()[0]); + string_term = as_node<StringTerm>(wandTerm->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[1], view[1], id[1], weight[1])); - string_term = dynamic_cast<StringTerm *>(wandTerm->getChildren()[1]); + string_term = as_node<StringTerm>(wandTerm->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[2], view[2], id[2], weight[2])); - RegExpTerm *regexp_term = dynamic_cast<RegExpTerm *>(and_node->getChildren()[8]); + auto* regexp_term = as_node<RegExpTerm>(and_node->getChildren()[8]); EXPECT_TRUE(checkTerm(regexp_term, str[5], view[5], id[5], weight[5])); - SameElement *same = dynamic_cast<SameElement *>(and_node->getChildren()[9]); - ASSERT_TRUE(same != nullptr); + auto* same = as_node<SameElement>(and_node->getChildren()[9]); EXPECT_EQUAL(view[4], same->getView()); EXPECT_EQUAL(3u, same->getChildren().size()); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[0]); + string_term = as_node<StringTerm>(same->getChildren()[0]); EXPECT_TRUE(checkTerm(string_term, str[4], view[4], id[4], weight[5])); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[1]); + string_term = as_node<StringTerm>(same->getChildren()[1]); EXPECT_TRUE(checkTerm(string_term, str[5], view[5], id[5], weight[6])); - string_term = dynamic_cast<StringTerm *>(same->getChildren()[2]); + string_term = as_node<StringTerm>(same->getChildren()[2]); EXPECT_TRUE(checkTerm(string_term, str[6], view[6], id[6], weight[7])); + auto* nearest_neighbor = as_node<NearestNeighborTerm>(and_node->getChildren()[10]); + EXPECT_EQUAL("query_tensor", nearest_neighbor->get_query_tensor_name()); + EXPECT_EQUAL("doc_tensor", nearest_neighbor->getView()); + EXPECT_EQUAL(id[3], nearest_neighbor->getId()); + EXPECT_EQUAL(weight[5].percent(), nearest_neighbor->getWeight().percent()); + EXPECT_EQUAL(7u, nearest_neighbor->get_target_num_hits()); } struct AbstractTypes { @@ -395,6 +393,12 @@ struct MyRegExpTerm : RegExpTerm { : RegExpTerm(t, f, i, w) { } }; +struct MyNearestNeighborTerm : NearestNeighborTerm { + MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t i, Weight w, uint32_t target_num_hits) + : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits) + {} +}; struct MyQueryNodeTypes { typedef MyAnd And; @@ -419,6 +423,7 @@ struct MyQueryNodeTypes { typedef MyWandTerm WandTerm; typedef MyPredicateQuery PredicateQuery; typedef MyRegExpTerm RegExpTerm; + typedef MyNearestNeighborTerm NearestNeighborTerm; }; TEST("require that Custom Query Trees Can Be Built") { diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 5261f568673..9c132527abc 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -592,6 +592,11 @@ public: createShallowWeightedSet(bp, n, _field, _attr.isIntegerType()); } } + void visit(query::NearestNeighborTerm &n) override { + (void) n; + // TODO (geirst): implement + setResult(std::make_unique<queryeval::EmptyBlueprint>(_field)); + } }; } // namespace diff --git a/searchlib/src/vespa/searchlib/diskindex/diskindex.cpp b/searchlib/src/vespa/searchlib/diskindex/diskindex.cpp index 7d8bcf032ba..ddcee50c219 100644 --- a/searchlib/src/vespa/searchlib/diskindex/diskindex.cpp +++ b/searchlib/src/vespa/searchlib/diskindex/diskindex.cpp @@ -398,6 +398,8 @@ public: handleNumberTermAsText(n); } + void not_supported(Node &) {} + void visit(LocationTerm &n) override { visitTerm(n); } void visit(PrefixTerm &n) override { visitTerm(n); } void visit(RangeTerm &n) override { visitTerm(n); } @@ -405,7 +407,8 @@ public: void visit(SubstringTerm &n) override { visitTerm(n); } void visit(SuffixTerm &n) override { visitTerm(n); } void visit(RegExpTerm &n) override { visitTerm(n); } - void visit(PredicateQuery &) override { } + void visit(PredicateQuery &n) override { not_supported(n); } + void visit(NearestNeighborTerm &n) override { not_supported(n); } }; Blueprint::UP diff --git a/searchlib/src/vespa/searchlib/memoryindex/memory_index.cpp b/searchlib/src/vespa/searchlib/memoryindex/memory_index.cpp index d3d3004100c..d8e48e84fb7 100644 --- a/searchlib/src/vespa/searchlib/memoryindex/memory_index.cpp +++ b/searchlib/src/vespa/searchlib/memoryindex/memory_index.cpp @@ -28,6 +28,7 @@ using index::IndexBuilder; using index::Schema; using index::SchemaUtil; using query::LocationTerm; +using query::NearestNeighborTerm; using query::Node; using query::NumberTerm; using query::PredicateQuery; @@ -163,6 +164,8 @@ public: setResult(fieldIndex->make_term_blueprint(termStr, _field, _fieldId)); } + void not_supported(Node &) {} + void visit(LocationTerm &n) override { visitTerm(n); } void visit(PrefixTerm &n) override { visitTerm(n); } void visit(RangeTerm &n) override { visitTerm(n); } @@ -170,7 +173,8 @@ public: void visit(SubstringTerm &n) override { visitTerm(n); } void visit(SuffixTerm &n) override { visitTerm(n); } void visit(RegExpTerm &n) override { visitTerm(n); } - void visit(PredicateQuery &) override { } + void visit(PredicateQuery &n) override { not_supported(n); } + void visit(NearestNeighborTerm &n) override { not_supported(n); } void visit(NumberTerm &n) override { handleNumberTermAsText(n); diff --git a/searchlib/src/vespa/searchlib/parsequery/parse.h b/searchlib/src/vespa/searchlib/parsequery/parse.h index 9c0e76d2441..83352b571c8 100644 --- a/searchlib/src/vespa/searchlib/parsequery/parse.h +++ b/searchlib/src/vespa/searchlib/parsequery/parse.h @@ -60,7 +60,8 @@ public: ITEM_PREDICATE_QUERY = 23, ITEM_REGEXP = 24, ITEM_WORD_ALTERNATIVES = 25, - ITEM_MAX = 26, // Indicates how long tables must be. + ITEM_NEAREST_NEIGHBOR = 26, + ITEM_MAX = 27, // Indicates how long tables must be. ITEM_UNDEF = 31, }; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp index a70fe07cf81..70a3097ae05 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp @@ -270,6 +270,17 @@ SimpleQueryStackDumpIterator::next() } break; + case ParseItem::ITEM_NEAREST_NEIGHBOR: + try { + _curr_index_name = read_stringref(p); + _curr_term = read_stringref(p); // query_tensor_name + _currArg1 = readCompressedPositiveInt(p); // target_num_hits; + _currArity = 0; + } catch (...) { + return false; + } + break; + default: // Unknown item, so report that no more are available return false; diff --git a/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h b/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h index cdeebcaf9e5..3882bc41b2b 100644 --- a/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/customtypevisitor.h @@ -49,6 +49,7 @@ public: virtual void visit(typename NodeTypes::WandTerm &) = 0; virtual void visit(typename NodeTypes::PredicateQuery &) = 0; virtual void visit(typename NodeTypes::RegExpTerm &) = 0; + virtual void visit(typename NodeTypes::NearestNeighborTerm &) = 0; private: // Route QueryVisit requests to the correct custom type. @@ -75,6 +76,7 @@ private: typedef typename NodeTypes::WandTerm TWandTerm; typedef typename NodeTypes::PredicateQuery TPredicateQuery; typedef typename NodeTypes::RegExpTerm TRegExpTerm; + typedef typename NodeTypes::NearestNeighborTerm TNearestNeighborTerm; void visit(And &n) override { visit(static_cast<TAnd&>(n)); } void visit(AndNot &n) override { visit(static_cast<TAndNot&>(n)); } @@ -98,6 +100,7 @@ private: void visit(WandTerm &n) override { visit(static_cast<TWandTerm&>(n)); } void visit(PredicateQuery &n) override { visit(static_cast<TPredicateQuery&>(n)); } void visit(RegExpTerm &n) override { visit(static_cast<TRegExpTerm&>(n)); } + void visit(NearestNeighborTerm &n) override { visit(static_cast<TNearestNeighborTerm&>(n)); } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h index a2ad8eae84b..797defc39f5 100644 --- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h +++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h @@ -203,6 +203,13 @@ createRegExpTerm(vespalib::stringref term, vespalib::stringref view, int32_t id, } template <class NodeTypes> +typename NodeTypes::NearestNeighborTerm * +create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) { + return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits); +} + +template <class NodeTypes> class QueryBuilder : public QueryBuilderBase { template <class T> T &addIntermediate(T *node, int child_count) { @@ -309,6 +316,11 @@ public: adjustWeight(weight); return addTerm(createRegExpTerm<NodeTypes>(term, view, id, weight)); } + typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) { + adjustWeight(weight); + return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits)); + } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h index e7c3fd8c73b..d2249a53f18 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h @@ -163,6 +163,11 @@ private: node.getTerm(), node.getView(), node.getId(), node.getWeight())); } + + void visit(NearestNeighborTerm &node) override { + replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(), + node.getId(), node.getWeight(), node.get_target_num_hits())); + } }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h b/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h index 0cb56f9127a..533e240e088 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryvisitor.h @@ -26,6 +26,7 @@ class WandTerm; class PredicateQuery; class RegExpTerm; class SameElement; +class NearestNeighborTerm; struct QueryVisitor { virtual ~QueryVisitor() {} @@ -52,6 +53,7 @@ struct QueryVisitor { virtual void visit(WandTerm &) = 0; virtual void visit(PredicateQuery &) = 0; virtual void visit(RegExpTerm &) = 0; + virtual void visit(NearestNeighborTerm &) = 0; }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h index 707ed2aa0db..8663bede4d6 100644 --- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h +++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h @@ -103,31 +103,38 @@ struct SimpleRegExpTerm : RegExpTerm { : RegExpTerm(term, view, id, weight) { } }; +struct SimpleNearestNeighborTerm : NearestNeighborTerm { + SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) + : NearestNeighborTerm(query_tensor_name, field_name, id, weight, target_num_hits) + {} +}; struct SimpleQueryNodeTypes { - typedef SimpleAnd And; - typedef SimpleAndNot AndNot; - typedef SimpleEquiv Equiv; - typedef SimpleNumberTerm NumberTerm; - typedef SimpleLocationTerm LocationTerm; - typedef SimpleNear Near; - typedef SimpleONear ONear; - typedef SimpleOr Or; - typedef SimplePhrase Phrase; - typedef SimpleSameElement SameElement; - typedef SimplePrefixTerm PrefixTerm; - typedef SimpleRangeTerm RangeTerm; - typedef SimpleRank Rank; - typedef SimpleStringTerm StringTerm; - typedef SimpleSubstringTerm SubstringTerm; - typedef SimpleSuffixTerm SuffixTerm; - typedef SimpleWeakAnd WeakAnd; - typedef SimpleWeightedSetTerm WeightedSetTerm; - typedef SimpleDotProduct DotProduct; - typedef SimpleWandTerm WandTerm; - typedef SimplePredicateQuery PredicateQuery; - typedef SimpleRegExpTerm RegExpTerm; + using And = SimpleAnd; + using AndNot = SimpleAndNot; + using Equiv = SimpleEquiv; + using NumberTerm = SimpleNumberTerm; + using LocationTerm = SimpleLocationTerm; + using Near = SimpleNear; + using ONear = SimpleONear; + using Or = SimpleOr; + using Phrase = SimplePhrase; + using SameElement = SimpleSameElement; + using PrefixTerm = SimplePrefixTerm; + using RangeTerm = SimpleRangeTerm; + using Rank = SimpleRank; + using StringTerm = SimpleStringTerm; + using SubstringTerm = SimpleSubstringTerm; + using SuffixTerm = SimpleSuffixTerm; + using WeakAnd = SimpleWeakAnd; + using WeightedSetTerm = SimpleWeightedSetTerm; + using DotProduct = SimpleDotProduct; + using WandTerm = SimpleWandTerm; + using PredicateQuery = SimplePredicateQuery; + using RegExpTerm = SimpleRegExpTerm; + using NearestNeighborTerm = SimpleNearestNeighborTerm; }; } diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp index 645750b8576..63acf532144 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp @@ -196,8 +196,7 @@ class QueryNodeConverter : public QueryVisitor { template <typename T> void appendTerm(const TermBase<T> &node); - template <class Term> - void createTerm(const Term &node, size_t type) { + void createTermNode(const TermNode &node, size_t type) { uint8_t typefield = type | ParseItem::IF_WEIGHT | ParseItem::IF_UNIQUEID; uint8_t flags = 0; if (!node.isRanked()) { @@ -216,6 +215,11 @@ class QueryNodeConverter : public QueryVisitor { appendByte(flags); } appendString(node.getView()); + } + + template <class Term> + void createTerm(const Term &node, size_t type) { + createTermNode(node, type); appendTerm(node); } @@ -255,6 +259,12 @@ class QueryNodeConverter : public QueryVisitor { createTerm(node, ParseItem::ITEM_REGEXP); } + void visit(NearestNeighborTerm &node) override { + createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR); + appendString(node.get_query_tensor_name()); + appendCompressedPositiveNumber(node.get_target_num_hits()); + } + public: QueryNodeConverter() : _buf(4096) diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index dfb0c75a695..a5f25d81400 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -109,6 +109,13 @@ private: pureTermView = vespalib::stringref(); } else if (type == ParseItem::ITEM_NOT) { builder.addAndNot(arity); + } else if (type == ParseItem::ITEM_NEAREST_NEIGHBOR) { + vespalib::stringref query_tensor_name = queryStack.getTerm(); + vespalib::stringref field_name = queryStack.getIndexName(); + uint32_t target_num_hits = queryStack.getArg1(); + int32_t id = queryStack.getUniqueId(); + Weight weight = queryStack.GetWeight(); + builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, target_num_hits); } else { vespalib::stringref term = queryStack.getTerm(); vespalib::stringref view = queryStack.getIndexName(); diff --git a/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h b/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h index 0cdaca82572..d1abc816838 100644 --- a/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h +++ b/searchlib/src/vespa/searchlib/query/tree/templatetermvisitor.h @@ -31,6 +31,7 @@ class TemplateTermVisitor : public CustomTypeTermVisitor<NodeTypes> { void visit(typename NodeTypes::SuffixTerm &n) override { myVisit(n); } void visit(typename NodeTypes::PredicateQuery &n) override { myVisit(n); } void visit(typename NodeTypes::RegExpTerm &n) override { myVisit(n); } + void visit(typename NodeTypes::NearestNeighborTerm &n) override { myVisit(n); } // Phrases are terms with children. This visitor will not visit // the phrase's children, unless this member function is diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h index 35c23dde985..a82b1e14d76 100644 --- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h +++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h @@ -113,5 +113,33 @@ public: virtual ~RegExpTerm() = 0; }; +/** + * Term matching the K nearest neighbors in a multi-dimensional vector space. + * + * The query point is specified as a dense tensor of order 1. + * This is found in fef::IQueryEnvironment using the query tensor name as key. + * The field name is the name of a dense document tensor of order 1. + * Both tensors are validated to have the same tensor type before the query is sent to the backend. + * + * Target num hits (K) is a hint to how many neighbors to return. + * The actual returned number might be higher (or lower if the query returns fewer hits). + */ +class NearestNeighborTerm : public QueryNodeMixin<NearestNeighborTerm, TermNode> { +private: + vespalib::string _query_tensor_name; + uint32_t _target_num_hits; + +public: + NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, + int32_t id, Weight weight, uint32_t target_num_hits) + : QueryNodeMixinType(field_name, id, weight), + _query_tensor_name(query_tensor_name), + _target_num_hits(target_num_hits) + {} + virtual ~NearestNeighborTerm() {} + const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; } + uint32_t get_target_num_hits() const { return _target_num_hits; } +}; + } diff --git a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h index 84830111fde..4fd8f64cc99 100644 --- a/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h +++ b/searchlib/src/vespa/searchlib/queryeval/create_blueprint_visitor_helper.h @@ -41,6 +41,7 @@ public: void visitWeightedSetTerm(query::WeightedSetTerm &n); void visitDotProduct(query::DotProduct &n); void visitWandTerm(query::WandTerm &n); + void visitNearestNeighborTerm(query::NearestNeighborTerm &n); void handleNumberTermAsText(query::NumberTerm &n); @@ -71,6 +72,7 @@ public: void visit(query::SubstringTerm &n) override = 0; void visit(query::SuffixTerm &n) override = 0; void visit(query::RegExpTerm &n) override = 0; + void visit(query::NearestNeighborTerm &n) override = 0; }; } diff --git a/searchlib/src/vespa/searchlib/queryeval/fake_searchable.cpp b/searchlib/src/vespa/searchlib/queryeval/fake_searchable.cpp index 4c678a9902f..fc3a6399e00 100644 --- a/searchlib/src/vespa/searchlib/queryeval/fake_searchable.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/fake_searchable.cpp @@ -6,9 +6,10 @@ #include "create_blueprint_visitor_helper.h" #include <vespa/vespalib/objects/visit.h> -using search::query::NumberTerm; using search::query::LocationTerm; +using search::query::NearestNeighborTerm; using search::query::Node; +using search::query::NumberTerm; using search::query::PredicateQuery; using search::query::PrefixTerm; using search::query::RangeTerm; @@ -64,6 +65,7 @@ public: void visit(SuffixTerm &n) override { visitTerm(n); } void visit(PredicateQuery &n) override { visitTerm(n); } void visit(RegExpTerm &n) override { visitTerm(n); } + void visit(NearestNeighborTerm &n) override { visitTerm(n); } }; template <class Map> diff --git a/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp b/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp index 3829ea45e2b..7a97110713d 100644 --- a/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/termasstring.cpp @@ -14,28 +14,29 @@ LOG_SETUP(".termasstring"); using search::query::And; using search::query::AndNot; +using search::query::DotProduct; using search::query::Equiv; -using search::query::NumberTerm; using search::query::LocationTerm; using search::query::Near; +using search::query::NearestNeighborTerm; using search::query::Node; +using search::query::NumberTerm; using search::query::ONear; using search::query::Or; using search::query::Phrase; -using search::query::SameElement; using search::query::PredicateQuery; using search::query::PrefixTerm; using search::query::QueryVisitor; using search::query::RangeTerm; using search::query::Rank; using search::query::RegExpTerm; +using search::query::SameElement; using search::query::StringTerm; using search::query::SubstringTerm; using search::query::SuffixTerm; +using search::query::WandTerm; using search::query::WeakAnd; using search::query::WeightedSetTerm; -using search::query::DotProduct; -using search::query::WandTerm; using vespalib::string; namespace search::queryeval { @@ -101,6 +102,7 @@ struct TermAsStringVisitor : public QueryVisitor { void visit(SuffixTerm &n) override {visitTerm(n); } void visit(RegExpTerm &n) override {visitTerm(n); } void visit(PredicateQuery &) override {illegalVisit(); } + void visit(NearestNeighborTerm &) override { illegalVisit(); } }; } // namespace |