diff options
12 files changed, 306 insertions, 88 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java index 25d184d148a..1885ce94ba3 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomAdminV4Builder.java @@ -113,7 +113,7 @@ public class DomAdminV4Builder extends DomAdminBuilderBase { private void addLogHandler(ContainerCluster cluster) { Handler<?> logHandler = Handler.fromClassName("com.yahoo.container.handler.LogHandler"); - logHandler.addServerBindings("http://*/logs/", "https://*/logs/"); + logHandler.addServerBindings("http://*/logs", "https://*/logs"); cluster.addComponent(logHandler); } diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityRequestFilterChainTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityRequestFilterChainTest.java new file mode 100644 index 00000000000..2e072f29039 --- /dev/null +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityRequestFilterChainTest.java @@ -0,0 +1,145 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter; + +import com.yahoo.jdisc.AbstractResource; +import com.yahoo.jdisc.Response; +import com.yahoo.jdisc.handler.CompletionHandler; +import com.yahoo.jdisc.handler.ContentChannel; +import com.yahoo.jdisc.handler.ResponseDispatch; +import com.yahoo.jdisc.handler.ResponseHandler; +import com.yahoo.jdisc.http.HttpRequest; +import com.yahoo.jdisc.test.TestDriver; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; + +import static org.testng.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class SecurityRequestFilterChainTest { + + + private static HttpRequest newRequest(URI uri, HttpRequest.Method method, HttpRequest.Version version) { + InetSocketAddress address = new InetSocketAddress("java.corp.yahoo.com", 69); + TestDriver driver = TestDriver.newSimpleApplicationInstanceWithoutOsgi(); + driver.activateContainer(driver.newContainerBuilder()); + HttpRequest request = HttpRequest.newServerRequest(driver, uri, method, version, address); + request.release(); + Assert.assertTrue(driver.close()); + return request; + } + + @Test + public void testFilterChainConstruction() { + SecurityRequestFilterChain chain = (SecurityRequestFilterChain)SecurityRequestFilterChain.newInstance(); + assertEquals(chain.getFilters().size(),0); + + List<SecurityRequestFilter> requestFilters = new ArrayList<SecurityRequestFilter>(); + chain = (SecurityRequestFilterChain)SecurityRequestFilterChain.newInstance(); + + chain = (SecurityRequestFilterChain)SecurityRequestFilterChain.newInstance(new RequestHeaderFilter("abc", "xyz"), + new RequestHeaderFilter("pqr", "def")); + + assertEquals(chain instanceof SecurityRequestFilterChain, true); + } + + + @Test + public void testFilterChainRun() { + RequestFilter chain = SecurityRequestFilterChain.newInstance(new RequestHeaderFilter("abc", "xyz"), + new RequestHeaderFilter("pqr", "def")); + + assertEquals(chain instanceof SecurityRequestFilterChain, true); + ResponseHandler handler = newResponseHandler(); + HttpRequest request = newRequest(URI.create("http://test/test"), HttpRequest.Method.GET, HttpRequest.Version.HTTP_1_1); + chain.filter(request, handler); + Assert.assertTrue(request.headers().contains("abc", "xyz")); + Assert.assertTrue(request.headers().contains("pqr", "def")); + } + + @Test + public void testFilterChainResponds() { + RequestFilter chain = SecurityRequestFilterChain.newInstance( + new MyFilter(), + new RequestHeaderFilter("abc", "xyz"), + new RequestHeaderFilter("pqr", "def")); + + assertEquals(chain instanceof SecurityRequestFilterChain, true); + ResponseHandler handler = newResponseHandler(); + HttpRequest request = newRequest(URI.create("http://test/test"), HttpRequest.Method.GET, HttpRequest.Version.HTTP_1_1); + chain.filter(request, handler); + Response response = getResponse(handler); + Assert.assertNotNull(response); + Assert.assertTrue(!request.headers().contains("abc", "xyz")); + Assert.assertTrue(!request.headers().contains("pqr", "def")); + } + + private class RequestHeaderFilter extends AbstractResource implements SecurityRequestFilter { + + private final String key; + private final String val; + + public RequestHeaderFilter(String key, String val) { + this.key = key; + this.val = val; + } + + @Override + public void filter(DiscFilterRequest request, ResponseHandler handler) { + request.setHeaders(key, val); + } + } + + private class MyFilter extends AbstractResource implements SecurityRequestFilter { + + @Override + public void filter(DiscFilterRequest request, ResponseHandler handler) { + ResponseDispatch.newInstance(Response.Status.FORBIDDEN).dispatch(handler); + } + } + + private static ResponseHandler newResponseHandler() { + return new NonWorkingResponseHandler(); + } + + private static Response getResponse(ResponseHandler handler) { + return ((NonWorkingResponseHandler) handler).getResponse(); + } + + private static class NonWorkingResponseHandler implements ResponseHandler { + + private Response response = null; + + @Override + public ContentChannel handleResponse(Response response) { + this.response = response; + return new NonWorkingContentChannel(); + } + + public Response getResponse() { + return response; + } + } + + private static class NonWorkingContentChannel implements ContentChannel { + + @Override + public void close(CompletionHandler handler) { + + } + + @Override + public void write(ByteBuffer buf, CompletionHandler handler) { + + } + + } + +} diff --git a/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityResponseFilterChainTest.java b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityResponseFilterChainTest.java new file mode 100644 index 00000000000..b38ca240a78 --- /dev/null +++ b/jdisc_http_service/src/test/java/com/yahoo/jdisc/http/filter/SecurityResponseFilterChainTest.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.jdisc.http.filter; + +import com.yahoo.jdisc.AbstractResource; +import com.yahoo.jdisc.Response; +import com.yahoo.jdisc.http.HttpRequest; +import com.yahoo.jdisc.http.HttpResponse; +import com.yahoo.jdisc.test.TestDriver; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.net.InetSocketAddress; +import java.net.URI; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertTrue; + +/** + * @author bjorncs + */ +public class SecurityResponseFilterChainTest { + private static HttpRequest newRequest(URI uri, HttpRequest.Method method, HttpRequest.Version version) { + InetSocketAddress address = new InetSocketAddress("java.corp.yahoo.com", 69); + TestDriver driver = TestDriver.newSimpleApplicationInstanceWithoutOsgi(); + driver.activateContainer(driver.newContainerBuilder()); + HttpRequest request = HttpRequest.newServerRequest(driver, uri, method, version, address); + request.release(); + Assert.assertTrue(driver.close()); + return request; + } + + @Test + public void testFilterChainConstruction() { + SecurityResponseFilterChain chain = (SecurityResponseFilterChain)SecurityResponseFilterChain.newInstance(); + assertEquals(chain.getFilters().size(),0); + + chain = (SecurityResponseFilterChain)SecurityResponseFilterChain.newInstance(new ResponseHeaderFilter("abc", "xyz"), + new ResponseHeaderFilter("pqr", "def")); + + assertEquals(chain instanceof SecurityResponseFilterChain, true); + } + + @Test + public void testFilterChainRun() { + URI uri = URI.create("http://localhost:8080/echo"); + HttpRequest request = newRequest(uri, HttpRequest.Method.GET, HttpRequest.Version.HTTP_1_1); + Response response = HttpResponse.newInstance(Response.Status.OK); + + ResponseFilter chain = SecurityResponseFilterChain.newInstance(new ResponseHeaderFilter("abc", "xyz"), + new ResponseHeaderFilter("pqr", "def")); + chain.filter(response, null); + assertTrue(response.headers().contains("abc", "xyz")); + assertTrue(response.headers().contains("pqr", "def")); + } + + private class ResponseHeaderFilter extends AbstractResource implements SecurityResponseFilter { + + private final String key; + private final String val; + + public ResponseHeaderFilter(String key, String val) { + this.key = key; + this.val = val; + } + + @Override + public void filter(DiscFilterResponse response, RequestView request) { + response.setHeaders(key, val); + } + + } + + + +} diff --git a/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp b/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp index 93153a920cf..38309284e54 100644 --- a/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp +++ b/searchcore/src/vespa/searchcore/grouping/groupingcontext.cpp @@ -22,8 +22,6 @@ GroupingContext::deserialize(const char *groupSpec, uint32_t groupSpecLen) for (size_t i = 0; i < numGroupings; i++) { GroupingPtr grouping(new search::aggregation::Grouping); grouping->deserialize(nis); - aggregation::Attribute2AttributeKeyed attr2AttrKeyed; - grouping->select(attr2AttrKeyed, attr2AttrKeyed); grouping->setClock(&_clock); grouping->setTimeOfDoom(_timeOfDoom); _groupingList.push_back(grouping); diff --git a/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp b/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp index c92b5fc4808..ec1a86bca69 100644 --- a/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp +++ b/searchlib/src/tests/expression/attributenode/attribute_node_test.cpp @@ -10,7 +10,7 @@ #include <vespa/searchlib/attribute/integerbase.h> #include <vespa/searchlib/attribute/stringbase.h> #include <vespa/searchlib/expression/attributenode.h> -#include <vespa/searchlib/expression/attribute_keyed_node.h> +#include <vespa/searchlib/expression/attribute_map_lookup_node.h> #include <vespa/searchlib/expression/resultvector.h> #include <vespa/vespalib/test/insertion_operators.h> #include <vespa/vespalib/testkit/testapp.h> @@ -31,7 +31,7 @@ using search::attribute::Config; using search::attribute::IAttributeVector; using search::attribute::getUndefined; using search::expression::AttributeNode; -using search::expression::AttributeKeyedNode; +using search::expression::AttributeMapLookupNode; using search::expression::EnumResultNode; using search::expression::EnumResultNodeVector; using search::expression::FloatResultNode; @@ -60,6 +60,31 @@ vespalib::string stringValue(const ResultNode &result, const IAttributeVector &a return vespalib::string(sbuf.c_str(), sbuf.c_str() + sbuf.size()); } +vespalib::string indirectKeyMarker("attribute("); + +std::unique_ptr<AttributeNode> +makeAttributeMapLookupNode(const vespalib::string attributeName) +{ + vespalib::asciistream keyName; + vespalib::asciistream valueName; + auto leftBracePos = attributeName.find('{'); + auto baseName = attributeName.substr(0, leftBracePos); + auto rightBracePos = attributeName.rfind('}'); + keyName << baseName << ".key"; + valueName << baseName << ".value" << attributeName.substr(rightBracePos + 1); + if (rightBracePos != vespalib::string::npos && rightBracePos > leftBracePos) { + if (attributeName[leftBracePos + 1] == '"' && attributeName[rightBracePos - 1] == '"') { + vespalib::string key = attributeName.substr(leftBracePos + 2, rightBracePos - leftBracePos - 3); + return std::make_unique<AttributeMapLookupNode>(attributeName, keyName.str(), valueName.str(), key, ""); + } else if (attributeName.substr(leftBracePos + 1, indirectKeyMarker.size()) == indirectKeyMarker && attributeName[rightBracePos - 1] == ')') { + auto startPos = leftBracePos + 1 + indirectKeyMarker.size(); + vespalib::string keySourceAttributeName = attributeName.substr(startPos, rightBracePos - 1 - startPos); + return std::make_unique<AttributeMapLookupNode>(attributeName, keyName.str(), valueName.str(), "", keySourceAttributeName); + } + } + return std::unique_ptr<AttributeNode>(); +} + struct AttributeManagerFixture { AttributeManager mgr; @@ -220,7 +245,7 @@ Fixture::makeNode(const vespalib::string &attributeName, bool useEnumOptimizatio if (attributeName.find('{') == vespalib::string::npos) { node = std::make_unique<AttributeNode>(attributeName); } else { - node = std::make_unique<AttributeKeyedNode>(attributeName); + node = makeAttributeMapLookupNode(attributeName); } if (useEnumOptimization) { node->useEnumOptimization(); diff --git a/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp b/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp index fe71484c4e6..5ffad122c7d 100644 --- a/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp +++ b/searchlib/src/vespa/searchlib/aggregation/modifiers.cpp @@ -4,7 +4,7 @@ #include "grouping.h" #include <vespa/searchlib/expression/multiargfunctionnode.h> #include <vespa/searchlib/expression/attributenode.h> -#include <vespa/searchlib/expression/attribute_keyed_node.h> +#include <vespa/searchlib/expression/attribute_map_lookup_node.h> #include <vespa/searchlib/expression/documentfieldnode.h> using namespace search::expression; @@ -64,18 +64,6 @@ Attribute2DocumentAccessor::getReplacementNode(const AttributeNode &attributeNod return std::make_unique<DocumentFieldNode>(attributeNode.getAttributeName()); } -std::unique_ptr<ExpressionNode> -Attribute2AttributeKeyed::getReplacementNode(const AttributeNode &attributeNode) -{ - const vespalib::string &attributeName = attributeNode.getAttributeName(); - auto lBracePos = attributeName.find('{'); - if (attributeNode.isKeyed() || lBracePos == vespalib::string::npos) { - return std::unique_ptr<ExpressionNode>(); - } else { - return std::make_unique<AttributeKeyedNode>(attributeName); - } -} - } // this function was added by ../../forcelink.sh diff --git a/searchlib/src/vespa/searchlib/aggregation/modifiers.h b/searchlib/src/vespa/searchlib/aggregation/modifiers.h index 6ffda313904..0120cb4eac9 100644 --- a/searchlib/src/vespa/searchlib/aggregation/modifiers.h +++ b/searchlib/src/vespa/searchlib/aggregation/modifiers.h @@ -28,10 +28,4 @@ private: std::unique_ptr<search::expression::ExpressionNode> getReplacementNode(const search::expression::AttributeNode &attributeNode) override; }; -class Attribute2AttributeKeyed : public AttributeNodeReplacer -{ -private: - std::unique_ptr<search::expression::ExpressionNode> getReplacementNode(const search::expression::AttributeNode &attributeNode) override; -}; - } diff --git a/searchlib/src/vespa/searchlib/common/identifiable.h b/searchlib/src/vespa/searchlib/common/identifiable.h index 5a64e29ddf3..35e49b5cddf 100644 --- a/searchlib/src/vespa/searchlib/common/identifiable.h +++ b/searchlib/src/vespa/searchlib/common/identifiable.h @@ -148,6 +148,7 @@ #define CID_search_expression_AggregationRefNode SEARCHLIB_CID(142) #define CID_search_expression_NormalizeSubjectFunctionNode SEARCHLIB_CID(143) #define CID_search_expression_DebugWaitFunctionNode SEARCHLIB_CID(144) +#define CID_search_expression_AttributeMapLookupNode SEARCHLIB_CID(145) #define CID_search_QueryNode SEARCHLIB_CID(150) #define CID_search_Query SEARCHLIB_CID(151) diff --git a/searchlib/src/vespa/searchlib/expression/CMakeLists.txt b/searchlib/src/vespa/searchlib/expression/CMakeLists.txt index 944bc6f63df..652fa5a3b01 100644 --- a/searchlib/src/vespa/searchlib/expression/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/expression/CMakeLists.txt @@ -1,7 +1,7 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(searchlib_expression OBJECT SOURCES - attribute_keyed_node.cpp + attribute_map_lookup_node.cpp attributenode.cpp attributeresult.cpp enumattributeresult.cpp diff --git a/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.cpp b/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.cpp index da6ed363b17..8a851b043aa 100644 --- a/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.cpp +++ b/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.cpp @@ -1,6 +1,6 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include "attribute_keyed_node.h" +#include "attribute_map_lookup_node.h" #include <vespa/vespalib/stllike/asciistream.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/searchcommon/attribute/attributecontent.h> @@ -11,11 +11,15 @@ using search::attribute::AttributeContent; using search::attribute::IAttributeVector; using search::attribute::BasicType; using search::attribute::getUndefined; +using vespalib::Deserializer; +using vespalib::Serializer; using EnumHandle = IAttributeVector::EnumHandle; namespace search::expression { -class AttributeKeyedNode::KeyHandler +IMPLEMENT_EXPRESSIONNODE(AttributeMapLookupNode, AttributeNode); + +class AttributeMapLookupNode::KeyHandler { protected: const IAttributeVector &_attribute; @@ -34,7 +38,7 @@ namespace { vespalib::string indirectKeyMarker("attribute("); -class BadKeyHandler : public AttributeKeyedNode::KeyHandler +class BadKeyHandler : public AttributeMapLookupNode::KeyHandler { public: BadKeyHandler(const IAttributeVector &attribute) @@ -70,7 +74,7 @@ EnumHandle convertKey<EnumHandle>(const IAttributeVector &attribute, const vespa } template <typename T, typename KeyType = T> -class KeyHandlerT : public AttributeKeyedNode::KeyHandler +class KeyHandlerT : public AttributeMapLookupNode::KeyHandler { AttributeContent<T> _keys; KeyType _key; @@ -119,7 +123,7 @@ matchingKey<const char *>(const char *lhs, const char *rhs) } template <typename T> -class IndirectKeyHandlerT : public AttributeKeyedNode::KeyHandler +class IndirectKeyHandlerT : public AttributeMapLookupNode::KeyHandler { const IAttributeVector &_keySourceAttribute; AttributeContent<T> _keys; @@ -157,9 +161,9 @@ using IndirectStringKeyHandler = IndirectKeyHandlerT<const char *>; class ValueHandler : public AttributeNode::Handler { protected: - std::unique_ptr<AttributeKeyedNode::KeyHandler> _keyHandler; + std::unique_ptr<AttributeMapLookupNode::KeyHandler> _keyHandler; const IAttributeVector &_attribute; - ValueHandler(std::unique_ptr<AttributeKeyedNode::KeyHandler> keyHandler, const IAttributeVector &attribute) + ValueHandler(std::unique_ptr<AttributeMapLookupNode::KeyHandler> keyHandler, const IAttributeVector &attribute) : _keyHandler(std::move(keyHandler)), _attribute(attribute) { @@ -173,7 +177,7 @@ class ValueHandlerT : public ValueHandler ResultNodeType &_result; T _undefinedValue; public: - ValueHandlerT(std::unique_ptr<AttributeKeyedNode::KeyHandler> keyHandler, const IAttributeVector &attribute, ResultNodeType &result, T undefinedValue) + ValueHandlerT(std::unique_ptr<AttributeMapLookupNode::KeyHandler> keyHandler, const IAttributeVector &attribute, ResultNodeType &result, T undefinedValue) : ValueHandler(std::move(keyHandler), attribute), _values(), _result(result), @@ -183,7 +187,7 @@ public: void handle(const AttributeResult & r) override { uint32_t docId = r.getDocId(); uint32_t keyIdx = _keyHandler->handle(docId); - if (keyIdx != AttributeKeyedNode::KeyHandler::noKeyIdx()) { + if (keyIdx != AttributeMapLookupNode::KeyHandler::noKeyIdx()) { _values.fill(_attribute, docId); if (keyIdx < _values.size()) { _result = _values[keyIdx]; @@ -228,7 +232,7 @@ IAttributeVector::largeint_t getUndefinedValue(BasicType::Type basicType) } -AttributeKeyedNode::AttributeKeyedNode() +AttributeMapLookupNode::AttributeMapLookupNode() : AttributeNode(), _keyAttributeName(), _valueAttributeName(), @@ -239,58 +243,35 @@ AttributeKeyedNode::AttributeKeyedNode() { } -AttributeKeyedNode::AttributeKeyedNode(const AttributeKeyedNode &) = default; +AttributeMapLookupNode::AttributeMapLookupNode(const AttributeMapLookupNode &) = default; -AttributeKeyedNode::AttributeKeyedNode(vespalib::stringref name) +AttributeMapLookupNode::AttributeMapLookupNode(vespalib::stringref name, vespalib::stringref keyAttributeName, vespalib::stringref valueAttributeName, vespalib::stringref key, vespalib::stringref keySourceAttributeName) : AttributeNode(name), - _keyAttributeName(), - _valueAttributeName(), - _key(), - _keySourceAttributeName(), + _keyAttributeName(keyAttributeName), + _valueAttributeName(valueAttributeName), + _key(key), + _keySourceAttributeName(keySourceAttributeName), _keyAttribute(nullptr), _keySourceAttribute(nullptr) { - setupAttributeNames(); } -AttributeKeyedNode::~AttributeKeyedNode() = default; +AttributeMapLookupNode::~AttributeMapLookupNode() = default; -AttributeKeyedNode & -AttributeKeyedNode::operator=(const AttributeKeyedNode &rhs) = default; - -void -AttributeKeyedNode::setupAttributeNames() -{ - vespalib::asciistream keyName; - vespalib::asciistream valueName; - auto leftBracePos = _attributeName.find('{'); - auto baseName = _attributeName.substr(0, leftBracePos); - auto rightBracePos = _attributeName.rfind('}'); - keyName << baseName << ".key"; - valueName << baseName << ".value" << _attributeName.substr(rightBracePos + 1); - _keyAttributeName = keyName.str(); - _valueAttributeName = valueName.str(); - if (rightBracePos != vespalib::string::npos && rightBracePos > leftBracePos) { - if (_attributeName[leftBracePos + 1] == '"' && _attributeName[rightBracePos - 1] == '"') { - _key = _attributeName.substr(leftBracePos + 2, rightBracePos - leftBracePos - 3); - } else if (_attributeName.substr(leftBracePos + 1, indirectKeyMarker.size()) == indirectKeyMarker && _attributeName[rightBracePos - 1] == ')') { - auto startPos = leftBracePos + 1 + indirectKeyMarker.size(); - _keySourceAttributeName = _attributeName.substr(startPos, rightBracePos - 1 - startPos); - } - } -} +AttributeMapLookupNode & +AttributeMapLookupNode::operator=(const AttributeMapLookupNode &rhs) = default; template <typename ResultNodeType> void -AttributeKeyedNode::prepareIntValues(std::unique_ptr<KeyHandler> keyHandler, const IAttributeVector &attribute, IAttributeVector::largeint_t undefinedValue) +AttributeMapLookupNode::prepareIntValues(std::unique_ptr<KeyHandler> keyHandler, const IAttributeVector &attribute, IAttributeVector::largeint_t undefinedValue) { auto resultNode = std::make_unique<ResultNodeType>(); _handler = std::make_unique<IntegerValueHandler<ResultNodeType>>(std::move(keyHandler), attribute, *resultNode, undefinedValue); setResultType(std::move(resultNode)); } -std::unique_ptr<AttributeKeyedNode::KeyHandler> -AttributeKeyedNode::makeKeyHandlerHelper() +std::unique_ptr<AttributeMapLookupNode::KeyHandler> +AttributeMapLookupNode::makeKeyHandlerHelper() { const IAttributeVector &attribute = *_keyAttribute; if (_keySourceAttribute != nullptr) { @@ -318,8 +299,8 @@ AttributeKeyedNode::makeKeyHandlerHelper() } } -std::unique_ptr<AttributeKeyedNode::KeyHandler> -AttributeKeyedNode::makeKeyHandler() +std::unique_ptr<AttributeMapLookupNode::KeyHandler> +AttributeMapLookupNode::makeKeyHandler() { try { return makeKeyHandlerHelper(); @@ -329,7 +310,7 @@ AttributeKeyedNode::makeKeyHandler() } void -AttributeKeyedNode::onPrepare(bool preserveAccurateTypes) +AttributeMapLookupNode::onPrepare(bool preserveAccurateTypes) { auto keyHandler = makeKeyHandler(); const IAttributeVector * attribute = _scratchResult->getAttribute(); @@ -380,7 +361,7 @@ AttributeKeyedNode::onPrepare(bool preserveAccurateTypes) } void -AttributeKeyedNode::cleanup() +AttributeMapLookupNode::cleanup() { _keyAttribute = nullptr; _keySourceAttribute = nullptr; @@ -388,7 +369,7 @@ AttributeKeyedNode::cleanup() } void -AttributeKeyedNode::wireAttributes(const search::attribute::IAttributeContext &attrCtx) +AttributeMapLookupNode::wireAttributes(const search::attribute::IAttributeContext &attrCtx) { auto valueAttribute = findAttribute(attrCtx, _useEnumOptimization, _valueAttributeName); _hasMultiValue = false; @@ -399,8 +380,20 @@ AttributeKeyedNode::wireAttributes(const search::attribute::IAttributeContext &a } } +Serializer & AttributeMapLookupNode::onSerialize(Serializer & os) const +{ + AttributeNode::onSerialize(os); + return os << _keyAttributeName << _valueAttributeName << _key << _keySourceAttributeName; +} + +Deserializer & AttributeMapLookupNode::onDeserialize(Deserializer & is) +{ + AttributeNode::onDeserialize(is); + return is >> _keyAttributeName >> _valueAttributeName >> _key >> _keySourceAttributeName; +} + void -AttributeKeyedNode::visitMembers(vespalib::ObjectVisitor &visitor) const +AttributeMapLookupNode::visitMembers(vespalib::ObjectVisitor &visitor) const { AttributeNode::visitMembers(visitor); visit(visitor, "keyAttributeName", _keyAttributeName); diff --git a/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.h b/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.h index e2cf8943aae..2f9c6328969 100644 --- a/searchlib/src/vespa/searchlib/expression/attribute_keyed_node.h +++ b/searchlib/src/vespa/searchlib/expression/attribute_map_lookup_node.h @@ -9,7 +9,7 @@ namespace search::expression { * Extract map value from attribute for the map key specified in the * grouping expression. */ -class AttributeKeyedNode : public AttributeNode +class AttributeMapLookupNode : public AttributeNode { public: using IAttributeVector = search::attribute::IAttributeVector; @@ -22,7 +22,6 @@ private: const IAttributeVector *_keyAttribute; const IAttributeVector *_keySourceAttribute; - void setupAttributeNames(); template <typename ResultNodeType> void prepareIntValues(std::unique_ptr<KeyHandler> keyHandler, const IAttributeVector &attribute, IAttributeVector::largeint_t undefinedValue); std::unique_ptr<KeyHandler> makeKeyHandlerHelper(); @@ -31,15 +30,16 @@ private: void wireAttributes(const search::attribute::IAttributeContext & attrCtx) override; void onPrepare(bool preserveAccurateTypes) override; public: - AttributeKeyedNode(); - AttributeKeyedNode(vespalib::stringref name); - AttributeKeyedNode(const AttributeKeyedNode &); - AttributeKeyedNode(AttributeKeyedNode &&) = delete; - ~AttributeKeyedNode() override; - AttributeKeyedNode &operator=(const AttributeKeyedNode &rhs); - AttributeKeyedNode &operator=(AttributeKeyedNode &&rhs) = delete; + DECLARE_NBO_SERIALIZE; + DECLARE_EXPRESSIONNODE(AttributeMapLookupNode); + AttributeMapLookupNode(); + AttributeMapLookupNode(vespalib::stringref name, vespalib::stringref keyAttributeName, vespalib::stringref valueAttributeName, vespalib::stringref key, vespalib::stringref keySourceAttributeName); + AttributeMapLookupNode(const AttributeMapLookupNode &); + AttributeMapLookupNode(AttributeMapLookupNode &&) = delete; + ~AttributeMapLookupNode() override; + AttributeMapLookupNode &operator=(const AttributeMapLookupNode &rhs); + AttributeMapLookupNode &operator=(AttributeMapLookupNode &&rhs) = delete; void visitMembers(vespalib::ObjectVisitor &visitor) const override; - bool isKeyed() const override { return true; } }; } diff --git a/searchlib/src/vespa/searchlib/expression/attributenode.h b/searchlib/src/vespa/searchlib/expression/attributenode.h index e12b5490955..472267f4b5c 100644 --- a/searchlib/src/vespa/searchlib/expression/attributenode.h +++ b/searchlib/src/vespa/searchlib/expression/attributenode.h @@ -55,7 +55,6 @@ public: void useEnumOptimization(bool use=true) { _useEnumOptimization = use; } bool hasMultiValue() const { return _hasMultiValue; } - virtual bool isKeyed() const { return false; } public: class Handler { |