diff options
author | Harald Musum <musum@verizonmedia.com> | 2021-03-11 10:47:12 +0100 |
---|---|---|
committer | Harald Musum <musum@verizonmedia.com> | 2021-03-11 10:47:12 +0100 |
commit | 6fd7c4952a8fb3fa53de827a3731ca1dd9803cc1 (patch) | |
tree | 973304ccf5de51e11f5192c093de2fd07430432a | |
parent | 35390b007c102370bbad7490cc1cb5542dc1ad17 (diff) | |
parent | b310bcb0d382dcb2f5c481902772c591a77197d8 (diff) |
Merge branch 'master' into jonmv/cluster-controller-migration-cleanup-2
311 files changed, 5143 insertions, 2132 deletions
diff --git a/NOTICES b/NOTICES new file mode 100644 index 00000000000..e61f713521b --- /dev/null +++ b/NOTICES @@ -0,0 +1,368 @@ +Vespa, the open big data serving engine +Copyright Verizon Media. Code licensed under the Apache 2.0 license. + +Open-Source Vespa uses several third-party Open-Source libraries +with permissive licenses. These are included here. + +---------------------------------------------------------------------- +Boost C++ Libraries (https://www.boost.org) + +Boost Software License - Version 1.0 - August 17th, 2003 + +Permission is hereby granted, free of charge, to any person or organization +obtaining a copy of the software and accompanying documentation covered by +this license (the "Software") to use, reproduce, display, distribute, +execute, and transmit the Software, and to prepare derivative works of the +Software, and to permit third-parties to whom the Software is furnished to +do so, all subject to the following: + +The copyright notices in the Software and this entire statement, including +the above license grant, this restriction and the following disclaimer, +must be included in all copies of the Software, in whole or in part, and +all derivative works of the Software, unless such copies or derivative +works are solely in the form of machine-executable object code generated by +a source language processor. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE, TITLE AND NON-INFRINGEMENT. IN NO EVENT +SHALL THE COPYRIGHT HOLDERS OR ANYONE DISTRIBUTING THE SOFTWARE BE LIABLE +FOR ANY DAMAGES OR OTHER LIABILITY, WHETHER IN CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. + +---------------------------------------------------------------------- +Google RE2 (https://github.com/google/re2) + +Copyright (c) 2009 The RE2 Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---------------------------------------------------------------------- +OpenBLAS (https://github.com/xianyi/OpenBLAS) + +Copyright (c) 2011-2014, The OpenBLAS Project +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + 1. Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright + notice, this list of conditions and the following disclaimer in + the documentation and/or other materials provided with the + distribution. + 3. Neither the name of the OpenBLAS project nor the names of + its contributors may be used to endorse or promote products + derived from this software without specific prior written + permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE +USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---------------------------------------------------------------------- +OpenSSL (https://www.openssl.org) + +OpenSSL is licensed under the Apache License 2.0. See LICENSE under the +Vespa root for the full license text. + +---------------------------------------------------------------------- +LLVM (https://www.llvm.org) + +LLVM is licensed under the Apache License 2.0 with LLVM exceptions. +See LICENSE under the Vespa root for the full license text. Exceptions: + +---- LLVM Exceptions to the Apache 2.0 License ---- + +As an exception, if, as a result of your compiling your source code, portions +of this Software are embedded into an Object form of such source code, you +may redistribute such embedded portions in such Object form without complying +with the conditions of Sections 4(a), 4(b) and 4(d) of the License. + +In addition, if you combine or link compiled forms of this Software with +software that is licensed under the GPLv2 ("Combined Software") and if a +court of competent jurisdiction determines that the patent provision (Section +3), the indemnity provision (Section 9) or other Section of the License +conflicts with the conditions of the GPLv2, you may retroactively and +prospectively choose to deem waived or otherwise exclude such Section(s) of +the License, but only in their entirety and only with respect to the Combined +Software. + +---------------------------------------------------------------------- +Googletest (https://github.com/google/googletest) + +Copyright 2008, Google Inc. +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---------------------------------------------------------------------- +Protocol Buffers (https://github.com/protocolbuffers/protobuf/) + +Copyright 2008 Google Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +Code generated by the Protocol Buffer compiler is owned by the owner +of the input file used when generating it. This code is not +standalone and requires a support library to be linked with it. This +support library is itself covered by the above license. + +---------------------------------------------------------------------- +xxHash (https://github.com/Cyan4973/xxHash) + +xxHash Library +Copyright (c) 2012-present, Yann Collet +All rights reserved. + +BSD 2-Clause License (https://www.opensource.org/licenses/bsd-license.php) + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---------------------------------------------------------------------- +LZ4 (https://github.com/lz4/lz4) + +LZ4 Library +Copyright (c) 2011-2020, Yann Collet +All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this + list of conditions and the following disclaimer in the documentation and/or + other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---------------------------------------------------------------------- +Zstandard (https://github.com/facebook/zstd) + +BSD License + +For Zstandard software + +Copyright (c) 2016-present, Facebook, Inc. All rights reserved. + +Redistribution and use in source and binary forms, with or without modification, +are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + * Neither the name Facebook nor the names of its contributors may be used to + endorse or promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON +ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +---------------------------------------------------------------------- +zlib (https://www.zlib.net/) + + Copyright (C) 1995-2017 Jean-loup Gailly and Mark Adler + + This software is provided 'as-is', without any express or implied + warranty. In no event will the authors be held liable for any damages + arising from the use of this software. + + Permission is granted to anyone to use this software for any purpose, + including commercial applications, and to alter it and redistribute it + freely, subject to the following restrictions: + + 1. The origin of this software must not be misrepresented; you must not + claim that you wrote the original software. If you use this software + in a product, an acknowledgment in the product documentation would be + appreciated but is not required. + 2. Altered source versions must be plainly marked as such, and must not be + misrepresented as being the original software. + 3. This notice may not be removed or altered from any source distribution. + +---------------------------------------------------------------------- +ICU (https://github.com/unicode-org/icu) + +ICU License - ICU 1.8.1 and later + +COPYRIGHT AND PERMISSION NOTICE + +Copyright (c) 1995-2013 International Business Machines Corporation and +others + +All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the +"Software"), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, and/or sell copies of the Software, and to permit persons +to whom the Software is furnished to do so, provided that the above +copyright notice(s) and this permission notice appear in all copies of +the Software and that both the above copyright notice(s) and this +permission notice appear in supporting documentation. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT +OF THIRD PARTY RIGHTS. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +HOLDERS INCLUDED IN THIS NOTICE BE LIABLE FOR ANY CLAIM, OR ANY SPECIAL +INDIRECT OR CONSEQUENTIAL DAMAGES, OR ANY DAMAGES WHATSOEVER RESULTING +FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT, +NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION +WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + +Except as contained in this notice, the name of a copyright holder +shall not be used in advertising or otherwise to promote the sale, use +or other dealings in this Software without prior written authorization +of the copyright holder. + +---------------------------------------------------------------------- +ONNX Runtime (https://github.com/microsoft/onnxruntime) + +MIT License + +Copyright (c) Microsoft Corporation + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. + diff --git a/clustercontroller-core/pom.xml b/clustercontroller-core/pom.xml index ee503d91540..8d62acb0fb4 100644 --- a/clustercontroller-core/pom.xml +++ b/clustercontroller-core/pom.xml @@ -73,6 +73,12 @@ <scope>provided</scope> </dependency> <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>zookeeper-server-common</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> <!-- Not used by this module, but compilation fails without it because zookeeper uses these annotations. Provided scoped here to avoid dependents getting it transitively. --> <groupId>com.github.spotbugs</groupId> diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java index e8222d73a63..5d5ffb917d2 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/FleetController.java @@ -70,6 +70,7 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd private final AtomicBoolean running = new AtomicBoolean(true); private FleetControllerOptions options; private FleetControllerOptions nextOptions; + private final int configuredIndex; private final List<SystemStateListener> systemStateListeners = new CopyOnWriteArrayList<>(); private boolean processingCycle = false; private boolean wantedStateChanged = false; @@ -125,6 +126,7 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd MetricUpdater metricUpdater, FleetControllerOptions options) { log.info("Starting up cluster controller " + options.fleetControllerIndex + " for cluster " + cluster.getName()); + this.configuredIndex = options.fleetControllerIndex; this.timer = timer; this.monitor = timer; this.eventLog = eventLog; @@ -284,7 +286,7 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd } log.log(Level.INFO, "Fleetcontroller done shutting down event thread."); controllerThreadId = Thread.currentThread().getId(); - database.shutdown(this); + database.shutdown(databaseContext); if (statusPageServer != null) { statusPageServer.shutdown(); @@ -436,7 +438,13 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd */ public void lostDatabaseConnection() { verifyInControllerThread(); + boolean wasMaster = masterElectionHandler.isMaster(); masterElectionHandler.lostDatabaseConnection(); + if (wasMaster) { + // Enforce that we re-fetch all state information from ZooKeeper upon the next tick if we're still master. + dropLeadershipState(); + metricUpdater.updateMasterState(false); + } } private void failAllVersionDependentTasks() { @@ -481,6 +489,7 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd /** This is called when the options field has been set to a new set of options */ private void propagateOptions() { verifyInControllerThread(); + selfTerminateIfConfiguredNodeIndexHasChanged(); if (changesConfiguredNodeSet(options.nodes)) { // Force slobrok node re-fetch in case of changes to the set of configured nodes @@ -501,8 +510,8 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd cluster.setPollingFrequency(options.statePollingFrequency); cluster.setDistribution(options.storageDistribution); cluster.setNodes(options.nodes); - database.setZooKeeperAddress(options.zooKeeperServerAddress); - database.setZooKeeperSessionTimeout(options.zooKeeperSessionTimeout); + database.setZooKeeperAddress(options.zooKeeperServerAddress, databaseContext); + database.setZooKeeperSessionTimeout(options.zooKeeperSessionTimeout, databaseContext); stateGatherer.setMaxSlobrokDisconnectGracePeriod(options.maxSlobrokDisconnectGracePeriod); stateGatherer.setNodeStateRequestTimeout(options.nodeStateRequestTimeoutMS); @@ -538,6 +547,16 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd nextConfigGeneration = -1; } + private void selfTerminateIfConfiguredNodeIndexHasChanged() { + if (options.fleetControllerIndex != configuredIndex) { + log.warning(String.format("Got new configuration where CC index has changed from %d to %d. We do not support "+ + "doing this live; immediately exiting now to force new configuration", + configuredIndex, options.fleetControllerIndex)); + prepareShutdownEdge(); + System.exit(1); + } + } + public StatusPageResponse fetchStatusPage(StatusPageServer.HttpRequest httpRequest) { verifyInControllerThread(); StatusPageResponse.ResponseCode responseCode; @@ -1066,18 +1085,22 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd wantedStateChanged = false; } } else { - if (isMaster) { - eventLog.add(new ClusterEvent(ClusterEvent.Type.MASTER_ELECTION, "This node is no longer fleetcontroller master.", timer.getCurrentTimeInMillis())); - firstAllowedStateBroadcast = Long.MAX_VALUE; - failAllVersionDependentTasks(); - } - wantedStateChanged = false; - isMaster = false; + dropLeadershipState(); } metricUpdater.updateMasterState(isMaster); return didWork; } + private void dropLeadershipState() { + if (isMaster) { + eventLog.add(new ClusterEvent(ClusterEvent.Type.MASTER_ELECTION, "This node is no longer fleetcontroller master.", timer.getCurrentTimeInMillis())); + firstAllowedStateBroadcast = Long.MAX_VALUE; + failAllVersionDependentTasks(); + } + wantedStateChanged = false; + isMaster = false; + } + public void run() { controllerThreadId = Thread.currentThread().getId(); try { @@ -1093,12 +1116,16 @@ public class FleetController implements NodeStateOrHostInfoChangeHandler, NodeAd synchronized (monitor) { running.set(false); } System.exit(1); } finally { - running.set(false); - failAllVersionDependentTasks(); - synchronized (monitor) { monitor.notifyAll(); } + prepareShutdownEdge(); } } + private void prepareShutdownEdge() { + running.set(false); + failAllVersionDependentTasks(); + synchronized (monitor) { monitor.notifyAll(); } + } + public DatabaseHandler.Context databaseContext = new DatabaseHandler.Context() { @Override public ContentCluster getCluster() { return cluster; } diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/DatabaseHandler.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/DatabaseHandler.java index 5d5bb674d4f..4f02e31b426 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/DatabaseHandler.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/DatabaseHandler.java @@ -43,11 +43,15 @@ public class DatabaseHandler { ClusterStateBundle clusterStateBundle; void clear() { - masterVote = null; + clearNonClusterStateFields(); lastSystemStateVersion = null; + clusterStateBundle = null; + } + + void clearNonClusterStateFields() { + masterVote = null; wantedStates = null; startTimestamps = null; - clusterStateBundle = null; } } private class DatabaseListener implements Database.DatabaseListener { @@ -96,7 +100,7 @@ public class DatabaseHandler { this.nodeIndex = ourIndex; pendingStore.masterVote = ourIndex; // To begin with we'll vote for ourselves. this.monitor = monitor; - // TODO: Require non-null, not possible now since at least ClusterFeedBlockTest usese null address + // TODO: Require non-null, not possible now since at least ClusterFeedBlockTest uses null address this.zooKeeperAddress = zooKeeperAddress; } @@ -106,8 +110,8 @@ public class DatabaseHandler { } } - public void shutdown(FleetController fleetController) { - relinquishDatabaseConnectivity(fleetController); + public void shutdown(Context context) { + relinquishDatabaseConnectivity(context); } public boolean isClosed() { return database == null || database.isClosed(); } @@ -116,7 +120,7 @@ public class DatabaseHandler { return lastKnownStateBundleVersionWrittenBySelf; } - public void reset() { + public void reset(Context context) { final boolean wasRunning; synchronized (databaseMonitor) { wasRunning = database != null; @@ -126,37 +130,46 @@ public class DatabaseHandler { database = null; } } - clearSessionMetaData(); + clearSessionMetaData(true); + context.getFleetController().lostDatabaseConnection(); if (wasRunning) { log.log(Level.INFO, "Fleetcontroller " + nodeIndex + ": Done resetting database state"); } } - private void clearSessionMetaData() { + private void clearSessionMetaData(boolean clearPendingStateWrites) { // Preserve who we want to vote for Integer currentVote = (pendingStore.masterVote != null ? pendingStore.masterVote : currentlyStored.masterVote); currentlyStored.clear(); - pendingStore.clear(); + if (clearPendingStateWrites) { + pendingStore.clear(); + } else { + // If we have pending cluster state writes we cannot drop these on the floor, as otherwise the + // core CC logic may keep thinking it has persisted writes it really has not. Clearing pending + // state writes would also prevent the controller from detecting itself being out of sync by + // triggering CaS violations upon znode writes. + pendingStore.clearNonClusterStateFields(); + } pendingStore.masterVote = currentVote; log.log(Level.FINE, "Cleared session metadata. Pending master vote is now " + pendingStore.masterVote); } - public void setZooKeeperAddress(String address) { + public void setZooKeeperAddress(String address, Context context) { if (address == null && zooKeeperAddress == null) return; if (address != null && address.equals(zooKeeperAddress)) return; if (zooKeeperAddress != null) { log.log(Level.INFO, "Fleetcontroller " + nodeIndex + ": " + (address == null ? "Stopped using ZooKeeper." : "Got new ZooKeeper address to use: " + address)); } zooKeeperAddress = address; - reset(); + reset(context); } - public void setZooKeeperSessionTimeout(int timeout) { + public void setZooKeeperSessionTimeout(int timeout, Context context) { if (timeout == zooKeeperSessionTimeout) return; log.log(Level.FINE, "Fleetcontroller " + nodeIndex + ": Got new ZooKeeper session timeout of " + timeout + " milliseconds."); zooKeeperSessionTimeout = timeout; - reset(); + reset(context); } private boolean usingZooKeeper() { return (zooKeeperAddress != null); } @@ -169,7 +182,9 @@ public class DatabaseHandler { database.close(); } // We still hold the database lock while calling this, we want to block callers. - clearSessionMetaData(); + // Don't clear pending state writes in case they were attempted prior to connect() + // being called, but after receiving a database loss event. + clearSessionMetaData(false); log.log(Level.INFO, "Fleetcontroller " + nodeIndex + ": Setting up new ZooKeeper session at " + zooKeeperAddress); DatabaseFactory.Params params = new DatabaseFactory.Params() @@ -247,7 +262,7 @@ public class DatabaseHandler { "has likely taken over ownership: %s", e.getMessage())); // Clear DB and master election state. This shall trigger a full re-fetch of all // version and election-related metadata. - relinquishDatabaseConnectivity(context.getFleetController()); + relinquishDatabaseConnectivity(context); } return didWork; } @@ -257,9 +272,9 @@ public class DatabaseHandler { return zooKeeperAddress != null; } - private void relinquishDatabaseConnectivity(FleetController fleetController) { - reset(); - fleetController.lostDatabaseConnection(); + private void relinquishDatabaseConnectivity(Context context) { + // reset() will handle both session clearing and trigger a database loss callback into the CC. + reset(context); } private boolean performZooKeeperWrites() throws InterruptedException { diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/ZooKeeperDatabase.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/ZooKeeperDatabase.java index 60835d5bb4f..4cdbb49dedc 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/ZooKeeperDatabase.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/database/ZooKeeperDatabase.java @@ -238,6 +238,9 @@ public class ZooKeeperDatabase extends Database { } catch (InterruptedException e) { throw (InterruptedException) new InterruptedException("Interrupted").initCause(e); } catch (Exception e) { + // If we return a default, empty version, writes dependent on this bundle should only + // succeed if the previous znode version is 0, i.e. not yet created. + lastKnownStateVersionZNodeVersion = 0; maybeLogExceptionWarning(e, "Failed to retrieve latest system state version used. Returning null"); return null; } @@ -390,6 +393,9 @@ public class ZooKeeperDatabase extends Database { maybeLogExceptionWarning(e, "Failed to retrieve last published cluster state bundle from " + "ZooKeeper, will use an empty state as baseline"); } + // If we return a default, empty bundle, writes dependent on this bundle should only + // succeed if the previous znode version is 0, i.e. not yet created. + lastKnownStateBundleZNodeVersion = 0; return ClusterStateBundle.ofBaselineOnly(AnnotatedClusterState.emptyState()); } diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/ClusterStateGeneratorTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/ClusterStateGeneratorTest.java index ba3a1dd7d26..d5d6c4623f2 100644 --- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/ClusterStateGeneratorTest.java +++ b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/ClusterStateGeneratorTest.java @@ -7,18 +7,17 @@ import com.yahoo.vdslib.state.Node; import com.yahoo.vdslib.state.NodeState; import com.yahoo.vdslib.state.NodeType; import com.yahoo.vdslib.state.State; -import com.yahoo.vespa.jdk8compat.Set; import org.junit.Test; import java.util.Optional; +import java.util.Set; -import static com.yahoo.vespa.clustercontroller.core.matchers.HasStateReasonForNode.hasStateReasonForNode; -import static com.yahoo.vespa.clustercontroller.core.ClusterFixture.storageNode; import static com.yahoo.vespa.clustercontroller.core.ClusterFixture.distributorNode; - +import static com.yahoo.vespa.clustercontroller.core.ClusterFixture.storageNode; +import static com.yahoo.vespa.clustercontroller.core.matchers.HasStateReasonForNode.hasStateReasonForNode; import static org.hamcrest.CoreMatchers.not; -import static org.hamcrest.core.IsEqual.equalTo; import static org.hamcrest.core.Is.is; +import static org.hamcrest.core.IsEqual.equalTo; import static org.junit.Assert.assertThat; public class ClusterStateGeneratorTest { diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index c59746f8332..c4eeb41a663 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -83,7 +83,7 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"baldersheim"}) default boolean useBucketExecutorForLidSpaceCompact() { throw new UnsupportedOperationException("TODO specify default value"); } @ModelFeatureFlag(owners = {"baldersheim"}) default boolean useBucketExecutorForBucketMove() { throw new UnsupportedOperationException("TODO specify default value"); } @ModelFeatureFlag(owners = {"musum", "mpolden"}, comment = "Revisit in February 2021", removeAfter = "7.370") default boolean reconfigurableZookeeperServer() { return true; } - @ModelFeatureFlag(owners = {"geirst"}) default boolean enableFeedBlockInDistributor() { return false; } + @ModelFeatureFlag(owners = {"geirst"}) default boolean enableFeedBlockInDistributor() { return true; } @ModelFeatureFlag(owners = {"baldersheim", "geirst", "toregge"}) default double maxDeadBytesRatio() { return 0.2; } @ModelFeatureFlag(owners = {"hmusum"}) default int clusterControllerMaxHeapSizeInMb() { return 256; } @ModelFeatureFlag(owners = {"hmusum"}) default int metricsProxyMaxHeapSizeInMb(ClusterSpec.Type type) { return 256; } diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java index b7b6824eb9f..c1d45c213fb 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java @@ -53,7 +53,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea private double feedConcurrency = 0.5; private boolean useBucketExecutorForLidSpaceCompact; private boolean useBucketExecutorForBucketMove; - private boolean enableFeedBlockInDistributor = false; + private boolean enableFeedBlockInDistributor = true; private double maxDeadBytesRatio = 0.2; private int clusterControllerMaxHeapSizeInMb = 256; private int metricsProxyMaxHeapSizeInMb = 256; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java index 0f006c31959..3f346f20144 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/Admin.java @@ -105,6 +105,7 @@ public class Admin extends AbstractConfigProducer<Admin> implements Serializable return metricsProxyCluster; } + /** Used by model amenders */ public void setAdditionalDefaultMetrics(MetricSet additionalDefaultMetrics) { if (additionalDefaultMetrics == null) return; this.additionalDefaultMetrics = additionalDefaultMetrics; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/MetricSet.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/MetricSet.java index 30797f27789..9c969d4f11e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/MetricSet.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/MetricSet.java @@ -43,13 +43,13 @@ public class MetricSet { /** * Returns all metrics in this set, including all metrics in any contained metric sets. - * <br> + * * Joins this set's metrics with its child sets into a named flat map of metrics. * In the case of duplicate metrics, the metrics directly defined in this set * takes precedence with respect to output name, description and dimension value * (even if they are empty), while new dimensions from the children will be added. * - * @return All metrics contained in this set. + * @return all the metrics contained in this set */ public final Map<String, Metric> getMetrics() { return unmodifiableMap(flatten(metrics, children)); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validator.java index c926c1f13a0..f3bebbe7fb9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/Validator.java @@ -8,7 +8,6 @@ import com.yahoo.vespa.model.VespaModel; * Abstract superclass of all application package validators. * * @author hmusum - * @since 2010-01-29 */ public abstract class Validator { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java index 0de111e459e..6f7709efc24 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/content/ContentSearchCluster.java @@ -325,10 +325,6 @@ public class ContentSearchCluster extends AbstractConfigProducer<SearchCluster> indexedCluster.setSearchableCopies(redundancy.readyCopies()); } this.redundancy = redundancy; - for (SearchNode node : getSearchNodes()) { - node.setRedundancy(redundancy.finalRedundancy()); - node.setSearchableCopies(redundancy.readyCopies()); - } } private Optional<StreamingSearchCluster> findStreamingCluster(String docType) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java b/config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java index 7367850bb42..3fadd1f2efd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/NodeResourcesTuning.java @@ -16,23 +16,19 @@ public class NodeResourcesTuning implements ProtonConfig.Producer { final static long MB = 1024 * 1024; public final static long GB = MB * 1024; + // This is an approximate number base on observation of a node using 33G memory with 765M docs + private final static long MEMORY_COST_PER_DOCUMENT_STORE_ONLY = 46L; private final NodeResources resources; - private final int redundancy; - private final int searchableCopies; private final int threadsPerSearch; private final boolean combined; - // "Reserve" 1GB of memory for other processes running on the content node (config-proxy, cluster-controller, metrics-proxy). - public static final double reservedMemoryGb = 1; + // "Reserve" 0.5GB of memory for other processes running on the content node (config-proxy, metrics-proxy). + public static final double reservedMemoryGb = 0.5; public NodeResourcesTuning(NodeResources resources, - int redundancy, - int searchableCopies, int threadsPerSearch, boolean combined) { this.resources = resources; - this.redundancy = redundancy; - this.searchableCopies = searchableCopies; this.threadsPerSearch = threadsPerSearch; this.combined = combined; } @@ -56,8 +52,8 @@ public class NodeResourcesTuning implements ProtonConfig.Producer { private void getConfig(ProtonConfig.Documentdb.Builder builder) { ProtonConfig.Documentdb dbCfg = builder.build(); if (dbCfg.mode() != ProtonConfig.Documentdb.Mode.Enum.INDEX) { - long numDocs = (long)usableMemoryGb() * GB / 64L; - builder.allocation.initialnumdocs(numDocs/Math.max(searchableCopies, redundancy)); + long numDocs = (long)usableMemoryGb() * GB / MEMORY_COST_PER_DOCUMENT_STORE_ONLY; + builder.allocation.initialnumdocs(numDocs); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java b/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java index 16302ddff49..9f129f65281 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/search/SearchNode.java @@ -62,8 +62,6 @@ public class SearchNode extends AbstractService implements private final boolean flushOnShutdown; private NodeSpec nodeSpec; private int distributionKey; - private int redundancy = 1; - private int searchableCopies = 1; private final String clusterName; private TransactionLogServer tls; private AbstractService serviceLayerService; @@ -160,16 +158,6 @@ public class SearchNode extends AbstractService implements private String getBaseDir() { return getDefaults().underVespaHome("var/db/vespa/search/cluster." + getClusterName()) + "/n" + distributionKey; } - public void setSearchableCopies(int searchableCopies) { - this.searchableCopies = searchableCopies; - } - public void setRedundancy(int redundancy) { - this.redundancy = redundancy; - } - - void updatePartition(int partitionId) { - nodeSpec = new NodeSpec(nodeSpec.groupIndex(), partitionId); - } @Override public NodeSpec getNodeSpec() { @@ -286,8 +274,6 @@ public class SearchNode extends AbstractService implements } if (getHostResource() != null && ! getHostResource().realResources().isUnspecified()) { var nodeResourcesTuning = new NodeResourcesTuning(getHostResource().realResources(), - redundancy, - searchableCopies, tuning.map(Tuning::threadsPerSearch).orElse(1), combined); nodeResourcesTuning.getConfig(builder); diff --git a/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java b/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java index a76feb77c75..5a06379a1c2 100644 --- a/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/provision/ModelProvisioningTest.java @@ -1990,7 +1990,7 @@ public class ModelProvisioningTest { ProtonConfig cfg = getProtonConfig(model, cluster.getSearchNodes().get(0).getConfigId()); assertEquals(2000, cfg.flush().memory().maxtlssize()); // from config override assertEquals(1000, cfg.flush().memory().maxmemory()); // from explicit tuning - assertEquals((long) (128 - reservedMemoryGb) * GB / 8, cfg.flush().memory().each().maxmemory()); // from default node flavor tuning + assertEquals((long) ((128 - reservedMemoryGb) * GB / 8), cfg.flush().memory().each().maxmemory()); // from default node flavor tuning assertEquals(0.92, cfg.writefilter().memorylimit(), 0.0001); // from explicit resource-limits } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentSearchClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentSearchClusterTest.java index 50495a211e2..5fd4885a1f2 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentSearchClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentSearchClusterTest.java @@ -117,13 +117,13 @@ public class ContentSearchClusterTest { @Test public void requireThatOnlyDiskLimitCanBeSet() throws Exception { - assertProtonResourceLimits(0.88, 0.8, + assertProtonResourceLimits(0.88, 0.9, new ContentClusterBuilder().protonDiskLimit(0.88).getXml()); } @Test public void requireThatOnlyMemoryLimitCanBeSet() throws Exception { - assertProtonResourceLimits(0.8, 0.77, + assertProtonResourceLimits(0.9, 0.77, new ContentClusterBuilder().protonMemoryLimit(0.77).getXml()); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java index 32b84d95cc8..b1bd44d93b4 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/FleetControllerClusterTest.java @@ -28,7 +28,7 @@ public class FleetControllerClusterTest { } private ClusterControllerConfig parse(String xml) { - return parse(xml, false); + return parse(xml, true); } @Test @@ -158,6 +158,6 @@ public class FleetControllerClusterTest { } private FleetcontrollerConfig getConfigForBasicCluster() { - return getConfigForBasicCluster(false); + return getConfigForBasicCluster(true); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java b/config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java index 1f9c9b1e07a..00b4ce02411 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/search/NodeResourcesTuningTest.java @@ -10,6 +10,7 @@ import org.junit.Test; import java.util.Arrays; import java.util.List; +import static com.yahoo.vespa.model.search.NodeResourcesTuning.reservedMemoryGb; import static org.junit.Assert.assertEquals; import static com.yahoo.vespa.model.search.NodeResourcesTuning.MB; import static com.yahoo.vespa.model.search.NodeResourcesTuning.GB; @@ -19,9 +20,8 @@ import static com.yahoo.vespa.model.search.NodeResourcesTuning.GB; */ public class NodeResourcesTuningTest { - private static double delta = 0.00001; - private static double combinedFactor = 1 - 17.0/100; - private static int reservedMemoryGb = (int)NodeResourcesTuning.reservedMemoryGb; + private static final double delta = 0.00001; + private static final double combinedFactor = 1 - 17.0/100; @Test public void require_that_hwinfo_disk_size_is_set() { @@ -36,11 +36,11 @@ public class NodeResourcesTuningTest { } @Test - public void reserved_memory_on_content_node_is_1_gb() { - assertEquals(1.0, NodeResourcesTuning.reservedMemoryGb, delta); + public void reserved_memory_on_content_node_is_0_5_gb() { + assertEquals(0.5, reservedMemoryGb, delta); } - private ProtonConfig getProtonMemoryConfig(List<Pair<String, String>> sdAndMode, int gb, int redundancy, int searchableCopies) { + private ProtonConfig getProtonMemoryConfig(List<Pair<String, String>> sdAndMode, double gb, int redundancy, int searchableCopies) { ProtonConfig.Builder builder = new ProtonConfig.Builder(); for (Pair<String, String> sdMode : sdAndMode) { builder.documentdb.add(new ProtonConfig.Documentdb.Builder() @@ -48,18 +48,17 @@ public class NodeResourcesTuningTest { .configid("some/config/id/" + sdMode.getFirst()) .mode(ProtonConfig.Documentdb.Mode.Enum.valueOf(sdMode.getSecond()))); } - return configFromMemorySetting(gb, builder, redundancy, searchableCopies); + return configFromMemorySetting(gb, builder); } private void verify_that_initial_numdocs_is_dependent_of_mode(int redundancy, int searchablecopies) { - int divisor = Math.max(redundancy, searchablecopies); ProtonConfig cfg = getProtonMemoryConfig(Arrays.asList(new Pair<>("a", "INDEX"), new Pair<>("b", "STREAMING"), new Pair<>("c", "STORE_ONLY")), 24 + reservedMemoryGb, redundancy, searchablecopies); assertEquals(3, cfg.documentdb().size()); assertEquals(1024, cfg.documentdb(0).allocation().initialnumdocs()); assertEquals("a", cfg.documentdb(0).inputdoctypename()); - assertEquals(24 * GB / 64 / divisor, cfg.documentdb(1).allocation().initialnumdocs()); + assertEquals(24 * GB / 46, cfg.documentdb(1).allocation().initialnumdocs()); assertEquals("b", cfg.documentdb(1).inputdoctypename()); - assertEquals(24 * GB / 64 / divisor, cfg.documentdb(2).allocation().initialnumdocs()); + assertEquals(24 * GB / 46, cfg.documentdb(2).allocation().initialnumdocs()); assertEquals("c", cfg.documentdb(2).inputdoctypename()); } @@ -206,13 +205,13 @@ public class NodeResourcesTuningTest { return getConfig(new FlavorsConfig.Flavor.Builder().minDiskAvailableGb(diskGb), false); } - private static ProtonConfig configFromMemorySetting(int memoryGb, boolean combined) { + private static ProtonConfig configFromMemorySetting(double memoryGb, boolean combined) { return getConfig(new FlavorsConfig.Flavor.Builder().minMainMemoryAvailableGb(memoryGb), combined); } - private static ProtonConfig configFromMemorySetting(int memoryGb, ProtonConfig.Builder builder, int redundancy, int searchableCopies) { + private static ProtonConfig configFromMemorySetting(double memoryGb, ProtonConfig.Builder builder) { return getConfig(new FlavorsConfig.Flavor.Builder() - .minMainMemoryAvailableGb(memoryGb), builder, redundancy, searchableCopies, false); + .minMainMemoryAvailableGb(memoryGb), builder, false); } private static ProtonConfig configFromNumCoresSetting(double numCores) { @@ -221,7 +220,7 @@ public class NodeResourcesTuningTest { private static ProtonConfig configFromNumCoresSetting(double numCores, int numThreadsPerSearch) { return getConfig(new FlavorsConfig.Flavor.Builder().minCpuCores(numCores), - new ProtonConfig.Builder(), 1, 1, numThreadsPerSearch, false); + new ProtonConfig.Builder(), numThreadsPerSearch, false); } private static ProtonConfig configFromEnvironmentType(boolean docker) { @@ -233,25 +232,17 @@ public class NodeResourcesTuningTest { return getConfig(flavorBuilder, new ProtonConfig.Builder(), combined); } - private static ProtonConfig getConfig(FlavorsConfig.Flavor.Builder flavorBuilder, - ProtonConfig.Builder protonBuilder, boolean combined) { - return getConfig(flavorBuilder, protonBuilder, 1, 1, combined); - } - - private static ProtonConfig getConfig(FlavorsConfig.Flavor.Builder flavorBuilder, ProtonConfig.Builder protonBuilder, - int redundancy, int searchableCopies, boolean combined) { + private static ProtonConfig getConfig(FlavorsConfig.Flavor.Builder flavorBuilder, ProtonConfig.Builder protonBuilder, boolean combined) { flavorBuilder.name("my_flavor"); - NodeResourcesTuning tuning = new NodeResourcesTuning(new Flavor(new FlavorsConfig.Flavor(flavorBuilder)).resources(), - redundancy, searchableCopies, 1, combined); + NodeResourcesTuning tuning = new NodeResourcesTuning(new Flavor(new FlavorsConfig.Flavor(flavorBuilder)).resources(), 1, combined); tuning.getConfig(protonBuilder); return new ProtonConfig(protonBuilder); } private static ProtonConfig getConfig(FlavorsConfig.Flavor.Builder flavorBuilder, ProtonConfig.Builder protonBuilder, - int redundancy, int searchableCopies, int numThreadsPerSearch, boolean combined) { + int numThreadsPerSearch, boolean combined) { flavorBuilder.name("my_flavor"); - NodeResourcesTuning tuning = new NodeResourcesTuning(new Flavor(new FlavorsConfig.Flavor(flavorBuilder)).resources(), - redundancy, searchableCopies, numThreadsPerSearch, combined); + NodeResourcesTuning tuning = new NodeResourcesTuning(new Flavor(new FlavorsConfig.Flavor(flavorBuilder)).resources(), numThreadsPerSearch, combined); tuning.getConfig(protonBuilder); return new ProtonConfig(protonBuilder); } diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java index 12010042c79..7f563b876a7 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java @@ -104,10 +104,10 @@ public class NodeResources { } public NodeResources(double vcpu, double memoryGb, double diskGb, double bandwidthGbps, DiskSpeed diskSpeed, StorageType storageType) { - this.vcpu = vcpu; - this.memoryGb = memoryGb; - this.diskGb = diskGb; - this.bandwidthGbps = bandwidthGbps; + this.vcpu = validate(vcpu, "vcpu"); + this.memoryGb = validate(memoryGb, "memory"); + this.diskGb = validate(diskGb, "disk"); + this.bandwidthGbps = validate(bandwidthGbps, "bandwith"); this.diskSpeed = diskSpeed; this.storageType = storageType; } @@ -214,8 +214,8 @@ public class NodeResources { } private static StringBuilder appendDouble(StringBuilder sb, double d) { - long x10 = Math.round(d*10); - sb.append(x10/10).append('.').append(x10%10); + long x10 = Math.round(d * 10); + sb.append(x10 / 10).append('.').append(x10 % 10); return sb; } @@ -310,4 +310,10 @@ public class NodeResources { return new NodeResources(cpu, mem, dsk, 0.3, DiskSpeed.getDefault(), StorageType.getDefault()); } + private double validate(double value, String valueName) { + if (Double.isNaN(value)) throw new IllegalArgumentException(valueName + " cannot be NaN"); + if (Double.isInfinite(value)) throw new IllegalArgumentException(valueName + " cannot be infinite"); + return value; + } + } diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeType.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeType.java index 56882f8676f..b1e9fe3ea05 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeType.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeType.java @@ -49,6 +49,16 @@ public enum NodeType { return !childNodeTypes.isEmpty(); } + /** either config server or controller */ + public boolean isConfigServerLike() { + return this == config || this == controller; + } + + /** either config server host or controller host */ + public boolean isConfigServerHostLike() { + return this == confighost || this == controllerhost; + } + /** Returns whether this supports host sharing */ public boolean isSharable() { return this == NodeType.host; diff --git a/config-provisioning/src/test/java/com/yahoo/config/provision/NodeResourcesTest.java b/config-provisioning/src/test/java/com/yahoo/config/provision/NodeResourcesTest.java index 044afa72a5d..18eed0deecc 100644 --- a/config-provisioning/src/test/java/com/yahoo/config/provision/NodeResourcesTest.java +++ b/config-provisioning/src/test/java/com/yahoo/config/provision/NodeResourcesTest.java @@ -3,7 +3,10 @@ package com.yahoo.config.provision; import org.junit.Test; +import java.util.function.Supplier; + import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * @author bratseth @@ -17,18 +20,17 @@ public class NodeResourcesTest { assertEquals("[vcpu: 0.3, memory: 3.3 Gb, disk 33.3 Gb, bandwidth: 0.3 Gbps]", new NodeResources(1/3., 10/3., 100/3., 0.3).toString()); assertEquals("[vcpu: 0.7, memory: 9.0 Gb, disk 66.7 Gb, bandwidth: 0.7 Gbps]", - new NodeResources(2/3., 8.97, 200/3., 0.67).toString()); + new NodeResources(2/3., 8.97, 200/3., 0.67).toString()); } - private long runTest(NodeResources [] resouces, int num) { - long sum = 0; - for (int i = 0; i < num; i++) { - for (NodeResources ns :resouces) { - sum += ns.toString().length(); - } - } - return sum; + @Test + public void testInvalid() { + assertInvalid("vcpu", () -> new NodeResources(Double.NaN, 1.0, 1.0, 1.0)); + assertInvalid("memory", () -> new NodeResources(1.0, Double.NaN, 1.0, 1.0)); + assertInvalid("disk", () -> new NodeResources(1.0, 1.0, Double.NaN, 1.0)); + assertInvalid("bandwith", () -> new NodeResources(1.0, 1.0, 1.0, Double.NaN)); } + @Test public void benchmark() { NodeResources [] resouces = new NodeResources[100]; @@ -44,4 +46,24 @@ public class NodeResourcesTest { assertEquals(warmup, benchmark); } + private void assertInvalid(String valueName, Supplier<NodeResources> nodeResources) { + try { + nodeResources.get(); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals(valueName + " cannot be NaN", e.getMessage()); + } + } + + private long runTest(NodeResources [] resouces, int num) { + long sum = 0; + for (int i = 0; i < num; i++) { + for (NodeResources ns :resouces) { + sum += ns.toString().length(); + } + } + return sum; + } + } diff --git a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/RpcConfigSourceClient.java b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/RpcConfigSourceClient.java index 095bde76c39..3b9b101a8c6 100644 --- a/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/RpcConfigSourceClient.java +++ b/config-proxy/src/main/java/com/yahoo/vespa/config/proxy/RpcConfigSourceClient.java @@ -179,9 +179,13 @@ class RpcConfigSourceClient implements ConfigSourceClient, Runnable { @Override public void cancel() { + log.log(Level.FINE, "shutdownSourceConnections"); shutdownSourceConnections(); + log.log(Level.FINE, "delayedResponsesFuture.cancel"); delayedResponsesFuture.cancel(true); + log.log(Level.FINE, "delayedResponsesFuture.shutdownNow"); delayedResponsesScheduler.shutdownNow(); + log.log(Level.FINE, "supervisor.transport().shutdown().join()"); supervisor.transport().shutdown().join(); } @@ -190,10 +194,14 @@ class RpcConfigSourceClient implements ConfigSourceClient, Runnable { */ @Override public void shutdownSourceConnections() { + log.log(Level.FINE, "Subscriber::cancel"); activeSubscribers.values().forEach(Subscriber::cancel); activeSubscribers.clear(); + log.log(Level.FINE, "nextConfigFuture.cancel"); nextConfigFuture.cancel(true); + log.log(Level.FINE, "nextConfigScheduler.shutdownNow"); nextConfigScheduler.shutdownNow(); + log.log(Level.FINE, "requester.close"); requester.close(); } diff --git a/config/src/main/java/com/yahoo/config/subscription/impl/JRTManagedConnectionPools.java b/config/src/main/java/com/yahoo/config/subscription/impl/JRTManagedConnectionPools.java index 32d2d962e4d..f0e0a6b8481 100644 --- a/config/src/main/java/com/yahoo/config/subscription/impl/JRTManagedConnectionPools.java +++ b/config/src/main/java/com/yahoo/config/subscription/impl/JRTManagedConnectionPools.java @@ -21,12 +21,14 @@ public class JRTManagedConnectionPools { } } private static class CountedPool { - long count; final JRTConnectionPool pool; - final ScheduledThreadPoolExecutor scheduler = new ScheduledThreadPoolExecutor(1, new JRTSourceThreadFactory()); + final ScheduledThreadPoolExecutor scheduler; + long count; CountedPool(JRTConnectionPool requester) { - this.pool = requester; + pool = requester; + scheduler = new ScheduledThreadPoolExecutor(1, new JRTSourceThreadFactory()); count = 0; + scheduler.setExecuteExistingDelayedTasksAfterShutdownPolicy(false); } } @@ -56,9 +58,9 @@ public class JRTManagedConnectionPools { } countedPool.pool.close(); - countedPool.scheduler.shutdown(); + countedPool.scheduler.shutdownNow(); try { - countedPool.scheduler.awaitTermination(30, TimeUnit.SECONDS); + countedPool.scheduler.awaitTermination(1, TimeUnit.SECONDS); } catch (InterruptedException e) { throw new RuntimeException("Failed shutting down scheduler:", e); } diff --git a/configd/src/apps/sentinel/sentinel.cpp b/configd/src/apps/sentinel/sentinel.cpp index 329173d6f8c..1bca6f72d10 100644 --- a/configd/src/apps/sentinel/sentinel.cpp +++ b/configd/src/apps/sentinel/sentinel.cpp @@ -80,8 +80,7 @@ main(int argc, char **argv) return EXIT_FAILURE; } - struct timeval lastTv; - gettimeofday(&lastTv, nullptr); + vespalib::steady_time lastTime = vespalib::steady_clock::now(); while (!stop()) { try { vespalib::SignalHandler::CHLD.clear(); @@ -106,18 +105,16 @@ main(int argc, char **argv) handler.updateActiveFdset(&fds, &maxNum); struct timeval tv; - tv.tv_sec = 1; - tv.tv_usec = 0; + tv.tv_sec = 0; + tv.tv_usec = 100000; //0.1s select(maxNum, &fds, nullptr, nullptr, &tv); - gettimeofday(&tv, nullptr); - double delta = tv.tv_sec - lastTv.tv_sec - + 1e-6 * (tv.tv_usec - lastTv.tv_usec); - if (delta < 0.01) { - usleep(12500); // Avoid busy looping; + vespalib::steady_time now = vespalib::steady_clock::now(); + if ((now - lastTime) < 10ms) { + std::this_thread::sleep_for(12ms); // Avoid busy looping; } - lastTv = tv; + lastTime = now; } EV_STOPPING("config-sentinel", "normal exit"); diff --git a/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java b/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java index a833d4d1608..9252d6a4073 100644 --- a/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java +++ b/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java @@ -12,7 +12,6 @@ import com.yahoo.vespa.flags.FetchVector; import com.yahoo.vespa.flags.FlagId; import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.flags.UnboundBooleanFlag; -import com.yahoo.vespa.jdk8compat.List; import com.yahoo.yolean.Exceptions; import org.junit.Test; @@ -20,6 +19,7 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.InputStream; import java.nio.charset.StandardCharsets; +import java.util.List; import java.util.stream.Collectors; import java.util.stream.Stream; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java index 98d0c32c5d5..5d46f3dc240 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ApplicationRepository.java @@ -29,6 +29,7 @@ import com.yahoo.docproc.jdisc.metric.NullMetric; import com.yahoo.io.IOUtils; import com.yahoo.jdisc.Metric; import com.yahoo.path.Path; +import com.yahoo.slime.Slime; import com.yahoo.transaction.NestedTransaction; import com.yahoo.transaction.Transaction; import com.yahoo.vespa.config.server.application.Application; @@ -663,10 +664,10 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye } } - public Set<ApplicationId> listApplications() { + public List<ApplicationId> listApplications() { return tenantRepository.getAllTenants().stream() .flatMap(tenant -> tenant.getApplicationRepo().activeApplications().stream()) - .collect(Collectors.toSet()); + .collect(Collectors.toList()); } private boolean isFileLastModifiedBefore(File fileReference, Instant instant) { @@ -702,9 +703,9 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye : applicationSet.get().getAllVersions(applicationId); } - public HttpResponse validateSecretStore(ApplicationId applicationId, TenantSecretStore tenantSecretStore, String tenantSecretName) { + public HttpResponse validateSecretStore(ApplicationId applicationId, SystemName systemName, Slime slime) { Application application = getApplication(applicationId); - return secretStoreValidator.validateSecretStore(application, tenantSecretStore, tenantSecretName); + return secretStoreValidator.validateSecretStore(application, systemName, slime); } // ---------------- Convergence ---------------------------------------------------------------- @@ -867,8 +868,7 @@ public class ApplicationRepository implements com.yahoo.config.provision.Deploye public void deleteExpiredLocalSessions() { Map<Tenant, Collection<LocalSession>> sessionsPerTenant = new HashMap<>(); tenantRepository.getAllTenants() - .forEach(tenant -> sessionsPerTenant.put(tenant, - List.copyOf(tenant.getSessionRepository().getLocalSessions()))); + .forEach(tenant -> sessionsPerTenant.put(tenant, tenant.getSessionRepository().getLocalSessions())); Set<ApplicationId> applicationIds = new HashSet<>(); sessionsPerTenant.values() diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java index 3275dc42477..be26e880440 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ConfigServerBootstrap.java @@ -17,10 +17,11 @@ import com.yahoo.yolean.Exceptions; import java.time.Duration; import java.time.Instant; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; -import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -203,7 +204,8 @@ public class ConfigServerBootstrap extends AbstractComponent implements Runnable private boolean redeployAllApplications() throws InterruptedException { Instant end = Instant.now().plus(maxDurationOfRedeployment); - Set<ApplicationId> applicationsNotRedeployed = applicationRepository.listApplications(); + List<ApplicationId> applicationsNotRedeployed = applicationRepository.listApplications(); + Collections.shuffle(applicationsNotRedeployed); long failCount = 0; do { applicationsNotRedeployed = redeployApplications(applicationsNotRedeployed); @@ -225,7 +227,7 @@ public class ConfigServerBootstrap extends AbstractComponent implements Runnable } // Returns the set of applications that failed to redeploy - private Set<ApplicationId> redeployApplications(Set<ApplicationId> applicationIds) throws InterruptedException { + private List<ApplicationId> redeployApplications(List<ApplicationId> applicationIds) throws InterruptedException { ExecutorService executor = Executors.newFixedThreadPool(configserverConfig.numParallelTenantLoaders(), new DaemonThreadFactory("redeploy apps")); // Keep track of deployment per application @@ -235,12 +237,12 @@ public class ConfigServerBootstrap extends AbstractComponent implements Runnable executor.submit(() -> applicationRepository.deployFromLocalActive(appId, true /* bootstrap */) .ifPresent(Deployment::activate)))); - Set<ApplicationId> failedDeployments = + List<ApplicationId> failedDeployments = deployments.entrySet().stream() .map(entry -> checkDeployment(entry.getKey(), entry.getValue())) .filter(Optional::isPresent) .map(Optional::get) - .collect(Collectors.toSet()); + .collect(Collectors.toList()); executor.shutdown(); executor.awaitTermination(365, TimeUnit.DAYS); // Timeout should never happen diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/Application.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/Application.java index 96dc334ba45..8e98bf5884b 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/Application.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/Application.java @@ -146,7 +146,8 @@ public class Application implements ModelResult { } catch (ConfigurationRuntimeException e) { // This can happen in cases where services ask for config that no longer exist before they have been able // to reconfigure themselves - log.log(Level.INFO, "Error resolving instance for builder '" + builder.getClass().getName() + + log.log(Level.INFO, TenantRepository.logPre(getId()) + + ": Error resolving instance for builder '" + builder.getClass().getName() + "', returning empty config: " + Exceptions.toMessageString(e)); payload = ConfigPayload.fromBuilder(new ConfigPayloadBuilder()); } @@ -178,7 +179,7 @@ public class Application implements ModelResult { } private void debug(String message) { - log.log(Level.FINE, TenantRepository.logPre(getId())+message); + log.log(Level.FINE, TenantRepository.logPre(getId()) + message); } private ConfigDefinition getTargetDef(GetConfigRequest req) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationCuratorDatabase.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationCuratorDatabase.java index 2edd0d9fc2a..519e53272dc 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationCuratorDatabase.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/ApplicationCuratorDatabase.java @@ -77,6 +77,8 @@ public class ApplicationCuratorDatabase { if ( ! id.tenant().equals(tenant)) throw new IllegalArgumentException("Cannot write application id '" + id + "' for tenant '" + tenant + "'"); try (Lock lock = lock(id)) { + if (curator.exists(applicationPath(id))) return; + curator.create(applicationPath(id)); modifyReindexing(id, ApplicationReindexing.empty(), UnaryOperator.identity()); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java index 980134f8884..05e73e7f454 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java @@ -219,7 +219,8 @@ public class Deployment implements com.yahoo.config.provision.Deployment { .ignoreValidationErrors(ignoreValidationErrors) .isBootstrap(isBootstrap) .force(force) - .waitForResourcesInPrepare(waitForResourcesInPrepare); + .waitForResourcesInPrepare(waitForResourcesInPrepare) + .tenantSecretStores(session.getTenantSecretStores()); session.getDockerImageRepository().ifPresent(params::dockerImageRepository); session.getAthenzDomain().ifPresent(params::athenzDomain); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/SecretStoreValidator.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/SecretStoreValidator.java index 72f5747e15c..71eed450955 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/SecretStoreValidator.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/SecretStoreValidator.java @@ -3,12 +3,16 @@ package com.yahoo.vespa.config.server.http; import ai.vespa.util.http.VespaHttpClientBuilder; import com.yahoo.config.model.api.HostInfo; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.TenantName; +import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.secretstore.SecretStore; import com.yahoo.slime.Slime; import com.yahoo.slime.SlimeUtils; import com.yahoo.vespa.config.server.application.Application; import com.yahoo.config.model.api.TenantSecretStore; +import com.yahoo.vespa.config.server.tenant.SecretStoreExternalIdRetriever; import com.yahoo.yolean.Exceptions; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.ByteArrayEntity; @@ -22,6 +26,7 @@ import static com.yahoo.yolean.Exceptions.uncheck; /** * @author olaa + * Takes the payload received from the controller, adds external ID and acts as a proxy for the AwsParameterStoreValidationHandler result */ public class SecretStoreValidator { @@ -35,22 +40,12 @@ public class SecretStoreValidator { this.secretStore = secretStore; } - public HttpResponse validateSecretStore(Application application, TenantSecretStore tenantSecretStore, String tenantSecretName) { - var slime = toSlime(tenantSecretStore, tenantSecretName); + public HttpResponse validateSecretStore(Application application, SystemName system, Slime slime) { + addExternalId(application.getId().tenant(), system, slime); var uri = getUri(application); return postRequest(uri, slime); } - private Slime toSlime(TenantSecretStore tenantSecretStore, String tenantSecretName) { - var slime = new Slime(); - var cursor = slime.setObject(); - cursor.setString("externalId", secretStore.getSecret(tenantSecretName)); - cursor.setString("awsId", tenantSecretStore.getAwsId()); - cursor.setString("name", tenantSecretStore.getName()); - cursor.setString("role", tenantSecretStore.getRole()); - return slime; - } - private URI getUri(Application application) { var hostname = application.getModel().getHosts() .stream() @@ -78,4 +73,11 @@ public class SecretStoreValidator { } } + private void addExternalId(TenantName tenantName, SystemName system, Slime slime) { + var data = slime.get(); + var name = data.field("name").asString(); + var secretName = SecretStoreExternalIdRetriever.secretName(tenantName, system, name); + data.setString("externalId", secretStore.getSecret(secretName)); + } + } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java index ba2989164ee..3634a6825a3 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandler.java @@ -29,7 +29,6 @@ import com.yahoo.vespa.config.server.http.HttpHandler; import com.yahoo.vespa.config.server.http.JSONResponse; import com.yahoo.vespa.config.server.http.NotFoundException; import com.yahoo.vespa.config.server.tenant.Tenant; -import com.yahoo.config.model.api.TenantSecretStore; import java.io.IOException; import java.net.URLDecoder; @@ -47,7 +46,6 @@ import java.util.stream.Stream; import static com.yahoo.yolean.Exceptions.uncheck; import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.Map.Entry.comparingByKey; import static java.util.stream.Collectors.toList; /** @@ -70,7 +68,7 @@ public class ApplicationHandler extends HttpHandler { "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/metrics/*", "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/metrics/*", "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/logs", - "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/validate-secret-store/*", + "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/validate-secret-store", "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/tester/*/*", "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/tester/*", "http://*/application/v2/tenant/*/application/*/environment/*/region/*/instance/*/quota", @@ -231,9 +229,8 @@ public class ApplicationHandler extends HttpHandler { } if (isValidateSecretStoreRequest(request)) { - var tenantSecretStore = tenantSecretStoreFromRequest(request); - var tenantSecretName = tenantSecretNameFromRequest(request); - return applicationRepository.validateSecretStore(applicationId, tenantSecretStore, tenantSecretName); + var slime = uncheck(() -> SlimeUtils.jsonToSlime(request.getData().readAllBytes())); + return applicationRepository.validateSecretStore(applicationId, zone.system(), slime); } throw new NotFoundException("Illegal POST request '" + request.getUri() + "'"); @@ -360,8 +357,8 @@ public class ApplicationHandler extends HttpHandler { } private static boolean isValidateSecretStoreRequest(HttpRequest request) { - return getBindingMatch(request).groupCount() == 8 && - request.getUri().getPath().contains("/validate-secret-store/"); + return getBindingMatch(request).groupCount() == 7 && + request.getUri().getPath().endsWith("/validate-secret-store"); } private static boolean isServiceConvergeListRequest(HttpRequest request) { @@ -424,11 +421,6 @@ public class ApplicationHandler extends HttpHandler { return bm.group(8); } - private static String tenantSecretNameFromRequest(HttpRequest req) { - BindingMatch<?> bm = getBindingMatch(req); - return bm.group(7); - } - private static ApplicationId getApplicationIdFromRequest(HttpRequest req) { // Two bindings for this: with full app id or only application name BindingMatch<?> bm = getBindingMatch(req); @@ -533,14 +525,6 @@ public class ApplicationHandler extends HttpHandler { } - private TenantSecretStore tenantSecretStoreFromRequest(HttpRequest httpRequest) { - var data = uncheck(() -> SlimeUtils.jsonToSlime(httpRequest.getData().readAllBytes()).get()); - var awsId = data.field("awsId").asString(); - var name = data.field("name").asString(); - var role = data.field("role").asString(); - return new TenantSecretStore(name, awsId, role); - } - private static JSONResponse createMessageResponse(String message) { return new JSONResponse(Response.Status.OK) { { object.setString("message", message); } }; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java index fc8fbe4d7d5..5a3e0311db9 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java @@ -205,9 +205,14 @@ public final class PrepareParams { } public Builder tenantSecretStores(String serialized) { - this.tenantSecretStores = (serialized == null) + List<TenantSecretStore> secretStores = (serialized == null) ? List.of() : TenantSecretStoreSerializer.listFromSlime(SlimeUtils.jsonToSlime(serialized).get()); + return tenantSecretStores(secretStores); + } + + public Builder tenantSecretStores(List<TenantSecretStore> tenantSecretStores) { + this.tenantSecretStores = tenantSecretStores; return this; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/Session.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/Session.java index 9b74eac5631..f1044b28049 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/Session.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/Session.java @@ -6,6 +6,7 @@ import com.yahoo.config.FileReference; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationMetaData; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.api.TenantSecretStore; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.AthenzDomain; @@ -17,6 +18,7 @@ import com.yahoo.vespa.config.server.application.ApplicationSet; import com.yahoo.vespa.config.server.tenant.TenantRepository; import java.time.Instant; +import java.util.List; import java.util.Optional; /** @@ -131,6 +133,10 @@ public abstract class Session implements Comparable<Session> { sessionZooKeeperClient.writeAthenzDomain(athenzDomain); } + public void setTenantSecretStores(List<TenantSecretStore> tenantSecretStores) { + sessionZooKeeperClient.writeTenantSecretStores(tenantSecretStores); + } + /** Returns application id read from ZooKeeper. Will throw RuntimeException if not found */ public ApplicationId getApplicationId() { return sessionZooKeeperClient.readApplicationId() @@ -162,6 +168,10 @@ public abstract class Session implements Comparable<Session> { return createSetStatusTransaction(Status.DEACTIVATE); } + public List<TenantSecretStore> getTenantSecretStores() { + return sessionZooKeeperClient.readTenantSecretStores(); + } + private Transaction createSetStatusTransaction(Status status) { return sessionZooKeeperClient.createWriteStatusTransaction(status); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index d2260683495..f56e74c2869 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -196,8 +196,9 @@ public class SessionRepository { return localSessionCache.get(sessionId); } + /** Returns a copy of local sessions */ public Collection<LocalSession> getLocalSessions() { - return localSessionCache.values(); + return List.copyOf(localSessionCache.values()); } private void loadLocalSessions(ExecutorService executor) { @@ -255,6 +256,7 @@ public class SessionRepository { session.setVespaVersion(existingSession.getVespaVersion()); session.setDockerImageRepository(existingSession.getDockerImageRepository()); session.setAthenzDomain(existingSession.getAthenzDomain()); + session.setTenantSecretStores(existingSession.getTenantSecretStores()); return session; } @@ -297,8 +299,7 @@ public class SessionRepository { } private void deleteAllSessions() { - List<LocalSession> sessions = new ArrayList<>(localSessionCache.values()); - for (LocalSession session : sessions) { + for (LocalSession session : getLocalSessions()) { deleteLocalSession(session); } } @@ -309,6 +310,11 @@ public class SessionRepository { return remoteSessionCache.get(sessionId); } + /** Returns a copy of remote sessions */ + public Collection<RemoteSession> getRemoteSessions() { + return List.copyOf(remoteSessionCache.values()); + } + public List<Long> getRemoteSessionsFromZooKeeper() { return getSessionList(curator.getChildren(sessionsPath)); } @@ -524,9 +530,7 @@ public class SessionRepository { private void nodeChanged() { zkWatcherExecutor.execute(() -> { Multiset<Session.Status> sessionMetrics = HashMultiset.create(); - for (Session session : remoteSessionCache.values()) { - sessionMetrics.add(session.getStatus()); - } + getRemoteSessions().forEach(session -> sessionMetrics.add(session.getStatus())); metricUpdater.setNewSessions(sessionMetrics.count(Session.Status.NEW)); metricUpdater.setPreparedSessions(sessionMetrics.count(Session.Status.PREPARE)); metricUpdater.setActivatedSessions(sessionMetrics.count(Session.Status.ACTIVATE)); @@ -556,7 +560,7 @@ public class SessionRepository { log.log(Level.FINE, () -> "Purging old sessions for tenant '" + tenantName + "'"); Set<LocalSession> toDelete = new HashSet<>(); try { - for (LocalSession candidate : List.copyOf(localSessionCache.values())) { + for (LocalSession candidate : getLocalSessions()) { Instant createTime = candidate.getCreateTime(); log.log(Level.FINE, () -> "Candidate session for deletion: " + candidate.getSessionId() + ", created: " + createTime); @@ -629,7 +633,7 @@ public class SessionRepository { sessionZKClient.createNewSession(clock.instant()); Curator.CompletionWaiter waiter = sessionZKClient.getUploadWaiter(); LocalSession session = new LocalSession(tenantName, sessionId, app, sessionZKClient); - waiter.awaitCompletion(timeoutBudget.timeLeft()); + waiter.awaitCompletion(Duration.ofSeconds(Math.min(60, timeoutBudget.timeLeft().getSeconds()))); addLocalSession(session); return session; } catch (Exception e) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/SecretStoreExternalIdRetriever.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/SecretStoreExternalIdRetriever.java index cd2ae9d9d0c..0c254606169 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/SecretStoreExternalIdRetriever.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/SecretStoreExternalIdRetriever.java @@ -25,7 +25,7 @@ public class SecretStoreExternalIdRetriever { .collect(Collectors.toList()); } - private static String secretName(TenantName tenant, SystemName system, String storeName) { + public static String secretName(TenantName tenant, SystemName system, String storeName) { return String.format(SECRET_NAME_FORMAT, tenantSecretGroup(system), tenant.value(), storeName); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantDebugger.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantDebugger.java index d378f6e9235..334e53ee11f 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantDebugger.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantDebugger.java @@ -1,13 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.tenant; -import java.util.logging.Level; import com.yahoo.vespa.curator.Curator; import org.apache.curator.framework.CuratorFramework; import org.apache.curator.framework.recipes.cache.TreeCache; import org.apache.curator.framework.recipes.cache.TreeCacheEvent; import org.apache.curator.framework.recipes.cache.TreeCacheListener; +import java.util.logging.Level; import java.util.logging.Logger; /** @@ -19,6 +19,7 @@ public class TenantDebugger implements TreeCacheListener { private static final Logger log = Logger.getLogger(TenantDebugger.class.getName()); + @SuppressWarnings("deprecation") // TreeCache is deprecated, and recommended replacement is CuratorCache public TenantDebugger(Curator curator) throws Exception { TreeCache cache = new TreeCache(curator.framework(), "/config/v2/tenants"); cache.getListenable().addListener(this); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/MockSecretStoreValidator.java b/configserver/src/test/java/com/yahoo/vespa/config/server/MockSecretStoreValidator.java index c464af404d9..a02e8a3f3a6 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/MockSecretStoreValidator.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/MockSecretStoreValidator.java @@ -1,9 +1,12 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server; +import com.yahoo.config.provision.SystemName; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.container.jdisc.SecretStoreProvider; +import com.yahoo.restapi.SlimeJsonResponse; import com.yahoo.restapi.StringResponse; +import com.yahoo.slime.Slime; import com.yahoo.vespa.config.server.application.Application; import com.yahoo.config.model.api.TenantSecretStore; import com.yahoo.vespa.config.server.http.SecretStoreValidator; @@ -17,7 +20,7 @@ public class MockSecretStoreValidator extends SecretStoreValidator { super(new SecretStoreProvider().get()); } - public HttpResponse validateSecretStore(Application application, TenantSecretStore tenantSecretStore, String tenantSecretName) { - return new StringResponse(tenantSecretStore.toString() + " - " + tenantSecretName); + public HttpResponse validateSecretStore(Application application, SystemName system, Slime slime) { + return new SlimeJsonResponse(slime); } } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/HostedDeployTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/HostedDeployTest.java index 11e14192c74..f1404b30f30 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/HostedDeployTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/deploy/HostedDeployTest.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.api.ModelCreateResult; import com.yahoo.config.model.api.ModelFactory; import com.yahoo.config.model.api.ServiceInfo; +import com.yahoo.config.model.api.TenantSecretStore; import com.yahoo.config.model.api.ValidationParameters; import com.yahoo.config.model.provision.Host; import com.yahoo.config.model.provision.Hosts; @@ -111,6 +112,21 @@ public class HostedDeployTest { } @Test + public void testRedeployWithTenantSecretStores() throws IOException { + List<TenantSecretStore> tenantSecretStores = List.of(new TenantSecretStore("foo", "123", "role")); + DeployTester tester = new DeployTester.Builder() + .modelFactory(createHostedModelFactory(Version.fromString("4.5.6"), Clock.systemUTC())) + .configserverConfig(createConfigserverConfig()).build(); + tester.deployApp("src/test/apps/hosted/", Instant.now(), new PrepareParams.Builder() + .tenantSecretStores(tenantSecretStores)); + + Optional<com.yahoo.config.provision.Deployment> deployment = tester.redeployFromLocalActive(tester.applicationId()); + assertTrue(deployment.isPresent()); + deployment.get().activate(); + assertEquals(tenantSecretStores, ((Deployment) deployment.get()).session().getTenantSecretStores()); + } + + @Test public void testDeployMultipleVersions() throws IOException { List<ModelFactory> modelFactories = List.of(createHostedModelFactory(Version.fromString("6.1.0")), createHostedModelFactory(Version.fromString("6.2.0")), diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/SecretStoreValidatorTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/SecretStoreValidatorTest.java index d308cd72b55..8e5dd16bc05 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/SecretStoreValidatorTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/SecretStoreValidatorTest.java @@ -4,9 +4,13 @@ import com.github.tomakehurst.wiremock.junit.WireMockRule; import com.yahoo.config.model.api.HostInfo; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ServiceInfo; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.TenantName; import com.yahoo.container.jdisc.secretstore.SecretStore; +import com.yahoo.slime.SlimeUtils; import com.yahoo.vespa.config.server.application.Application; -import com.yahoo.config.model.api.TenantSecretStore; +import com.yahoo.vespa.config.server.tenant.SecretStoreExternalIdRetriever; import org.junit.Rule; import org.junit.Test; @@ -35,20 +39,26 @@ public class SecretStoreValidatorTest { @Test public void createsCorrectRequestData() throws IOException { var app = mockApplication(); - var tenantSecretStore = new TenantSecretStore("store", "123", "role"); - var tenantSecretName = "some-secret"; - when(secretStore.getSecret(tenantSecretName)).thenReturn("some-secret-value"); - + var requestBody = SlimeUtils.jsonToSlime("{\"awsId\":\"123\"," + + "\"name\":\"store\"," + + "\"role\":\"role\"," + + "\"region\":\"some-region\"," + + "\"parameterName\":\"some-parameter\"" + + "}"); + var expectedSecretName = SecretStoreExternalIdRetriever.secretName(TenantName.defaultName(), SystemName.PublicCd, "store"); + when(secretStore.getSecret(expectedSecretName)).thenReturn("some-secret-value"); stubFor(post(urlEqualTo("/validate-secret-store")) - .withRequestBody(equalToJson("{\"externalId\":\"some-secret-value\"," + - "\"awsId\":\"123\"," + + .withRequestBody(equalToJson("{\"awsId\":\"123\"," + "\"name\":\"store\"," + - "\"role\":\"role\"" + + "\"role\":\"role\"," + + "\"region\":\"some-region\"," + + "\"parameterName\":\"some-parameter\"," + + "\"externalId\":\"some-secret-value\"" + "}")) .willReturn(aResponse() .withStatus(200) .withBody("is ok"))); - var response = secretStoreValidator.validateSecretStore(app, tenantSecretStore, tenantSecretName); + var response = secretStoreValidator.validateSecretStore(app, SystemName.PublicCd, requestBody); var body = new ByteArrayOutputStream(); response.render(body); assertEquals("is ok", body.toString()); @@ -60,6 +70,7 @@ public class SecretStoreValidatorTest { var hostList = createHostList(); when(app.getModel()).thenReturn(model); when(model.getHosts()).thenReturn(hostList); + when(app.getId()).thenReturn(ApplicationId.defaultId()); return app; } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java index ab9457c5d2a..d364785f415 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java @@ -401,17 +401,17 @@ public class ApplicationHandlerTest { @Test public void testValidateSecretStore() throws IOException { applicationRepository.deploy(new File("src/test/apps/app-logserver-with-container"), prepareParams(applicationId)); - var url = toUrlPath(applicationId, Zone.defaultZone(), true) + "/validate-secret-store/some-secret-name"; + var url = toUrlPath(applicationId, Zone.defaultZone(), true) + "/validate-secret-store"; var mockHandler = createApplicationHandler(); - var requestData = new ByteArrayInputStream("{\"name\": \"store\", \"awsId\":\"aws-id\", \"role\":\"role\"}".getBytes(StandardCharsets.UTF_8)); + var requestString = "{\"name\":\"store\",\"awsId\":\"aws-id\",\"role\":\"role\",\"region\":\"us-west-1\",\"parameterName\":\"some-parameter\"}"; + var requestData = new ByteArrayInputStream(requestString.getBytes(StandardCharsets.UTF_8)); var response = mockHandler.handle(createTestRequest(url, POST, requestData)); assertEquals(200, response.getStatus()); - // MockSecretStoreValidator returns response on format tenantSecretStore.toString() - tenantSecretName - var expectedResponse = "TenantSecretStore{name='store', awsId='aws-id', role='role'} - some-secret-name"; - assertEquals(expectedResponse, getRenderedString(response)); + // MockSecretStoreValidator simply returns the request body + assertEquals(requestString, getRenderedString(response)); } @Test diff --git a/container-core/src/main/java/com/yahoo/container/handler/metrics/PrometheusV1Handler.java b/container-core/src/main/java/com/yahoo/container/handler/metrics/PrometheusV1Handler.java index 00fb488489e..33c6fbefa71 100644 --- a/container-core/src/main/java/com/yahoo/container/handler/metrics/PrometheusV1Handler.java +++ b/container-core/src/main/java/com/yahoo/container/handler/metrics/PrometheusV1Handler.java @@ -5,18 +5,19 @@ import com.google.inject.Inject; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.restapi.Path; import com.yahoo.restapi.StringResponse; -import com.yahoo.vespa.jdk8compat.List; import com.yahoo.yolean.Exceptions; -import java.io.IOException; -import java.net.URI; -import java.util.Optional; -import java.util.concurrent.Executor; import org.apache.http.client.HttpClient; import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.HttpGet; import org.apache.http.impl.client.BasicResponseHandler; import org.apache.http.impl.client.CloseableHttpClient; +import java.io.IOException; +import java.net.URI; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.Executor; + import static com.yahoo.container.handler.metrics.MetricsV2Handler.consumerQuery; import static com.yahoo.jdisc.Response.Status.INTERNAL_SERVER_ERROR; diff --git a/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java b/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java index 0dd284e21cf..a38772ab5dd 100644 --- a/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java +++ b/container-core/src/main/java/com/yahoo/container/jdisc/state/StateHandler.java @@ -74,7 +74,7 @@ public class StateHandler extends AbstractRequestHandler { } @Override - public ContentChannel handleRequest(final Request request, ResponseHandler handler) { + public ContentChannel handleRequest(Request request, ResponseHandler handler) { new ResponseDispatch() { @Override diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/aws/NoopRoleService.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/aws/NoopRoleService.java index d967ad3dca4..719f948eaa9 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/aws/NoopRoleService.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/aws/NoopRoleService.java @@ -1,7 +1,6 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.aws; -import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.TenantName; import java.util.Collections; @@ -27,7 +26,7 @@ public class NoopRoleService implements RoleService { } @Override - public void deleteTenantPolicy(TenantName tenant, String policyName) { } + public void deleteTenantPolicy(TenantName tenant, String policyName, String role) { } @Override public void maintainRoles(List<TenantName> tenants) { } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/aws/RoleService.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/aws/RoleService.java index 4219ad35612..ac499a0def3 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/aws/RoleService.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/aws/RoleService.java @@ -1,7 +1,6 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.api.integration.aws; -import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.TenantName; import java.util.List; @@ -18,7 +17,7 @@ public interface RoleService { String createTenantPolicy(TenantName tenant, String policyName, String awsId, String role); - void deleteTenantPolicy(TenantName tenant, String policyName); + void deleteTenantPolicy(TenantName tenant, String policyName, String role); /* * Maintain roles for the tenants in the system. Create missing roles, update trust. diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java index 315665dbffc..ed545dc35d1 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java @@ -151,6 +151,6 @@ public interface ConfigServer { void setSuspension(DeploymentId deploymentId, boolean suspend); /** Validates secret store configuration. */ - String validateSecretStore(DeploymentId deploymentId, TenantSecretStore tenantSecretStore); + String validateSecretStore(DeploymentId deploymentId, TenantSecretStore tenantSecretStore, String region, String parameterName); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java index 77aefc90481..e1587d250ec 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/InternalStepRunner.java @@ -299,7 +299,7 @@ public class InternalStepRunner implements StepRunner { } List<Node> nodes = controller.serviceRegistry().configServer().nodeRepository().list(id.type().zone(controller.system()), id.application(), - ImmutableSet.of(active, reserved)); + Set.of(active)); List<Node> parents = controller.serviceRegistry().configServer().nodeRepository().list(id.type().zone(controller.system()), nodes.stream().map(node -> node.parentHostname().get()).collect(toList())); NodeList nodeList = NodeList.of(nodes, parents, services.get()); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index 23f62ee3cf5..6aeb30c3f09 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -33,6 +33,7 @@ import com.yahoo.restapi.SlimeJsonResponse; import com.yahoo.security.KeyUtils; import com.yahoo.slime.Cursor; import com.yahoo.slime.Inspector; +import com.yahoo.slime.JsonParseException; import com.yahoo.slime.Slime; import com.yahoo.slime.SlimeUtils; import com.yahoo.vespa.hosted.controller.Controller; @@ -223,7 +224,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { if (path.matches("/application/v4/tenant")) return tenants(request); if (path.matches("/application/v4/tenant/{tenant}")) return tenant(path.get("tenant"), request); if (path.matches("/application/v4/tenant/{tenant}/info")) return tenantInfo(path.get("tenant"), request); - if (path.matches("/application/v4/tenant/{tenant}/secret-store/{name}/validate")) return validateSecretStore(path.get("tenant"), path.get("name")); + if (path.matches("/application/v4/tenant/{tenant}/secret-store/{name}/region/{region}/parameter-name/{parameter-name}/validate")) return validateSecretStore(path.get("tenant"), path.get("name"), path.get("region"), path.get("parameter-name")); if (path.matches("/application/v4/tenant/{tenant}/application")) return applications(path.get("tenant"), Optional.empty(), request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}")) return application(path.get("tenant"), path.get("application"), request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/compile-version")) return compileVersion(path.get("tenant"), path.get("application")); @@ -584,7 +585,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { } - private HttpResponse validateSecretStore(String tenantName, String name) { + private HttpResponse validateSecretStore(String tenantName, String name, String region, String parameterName) { var tenant = TenantName.from(tenantName); if (controller.tenants().require(tenant).type() != Tenant.Type.cloud) return ErrorResponse.badRequest("Tenant '" + tenant + "' is not a cloud tenant"); @@ -601,8 +602,18 @@ public class ApplicationApiHandler extends LoggingRequestHandler { if (tenantSecretStore.isEmpty()) return ErrorResponse.notFoundError("No secret store '" + name + "' configured for tenant '" + tenantName + "'"); - var response = controller.serviceRegistry().configServer().validateSecretStore(deployment.get(), tenantSecretStore.get()); - return new MessageResponse(response); + var response = controller.serviceRegistry().configServer().validateSecretStore(deployment.get(), tenantSecretStore.get(), region, parameterName); + try { + var responseRoot = new Slime(); + var responseCursor = responseRoot.setObject(); + responseCursor.setString("target", deployment.get().toString()); + var responseResultCursor = responseCursor.setObject("result"); + var responseSlime = SlimeUtils.jsonToSlime(response); + SlimeUtils.copyObject(responseSlime.get(), responseResultCursor); + return new SlimeJsonResponse(responseRoot); + } catch (JsonParseException e) { + return ErrorResponse.internalServerError(response); + } } private Optional<DeploymentId> getActiveDeployment(TenantName tenant) { @@ -700,7 +711,11 @@ public class ApplicationApiHandler extends LoggingRequestHandler { lockedTenant = lockedTenant.withSecretStore(tenantSecretStore); controller.tenants().store(lockedTenant); }); - return new MessageResponse("Configured secret store: " + tenantSecretStore); + + tenant = (CloudTenant) controller.tenants().require(TenantName.from(tenantName)); + var slime = new Slime(); + toSlime(slime.setObject(), tenant.tenantSecretStores()); + return new SlimeJsonResponse(slime); } private HttpResponse deleteSecretStore(String tenantName, String name, HttpRequest request) { @@ -715,15 +730,15 @@ public class ApplicationApiHandler extends LoggingRequestHandler { var tenantSecretStore = optionalSecretStore.get(); controller.serviceRegistry().tenantSecretService().deleteSecretStore(tenant.name(), tenantSecretStore); + controller.serviceRegistry().roleService().deleteTenantPolicy(tenant.name(), tenantSecretStore.getName(), tenantSecretStore.getRole()); controller.tenants().lockOrThrow(tenant.name(), LockedTenant.Cloud.class, lockedTenant -> { lockedTenant = lockedTenant.withoutSecretStore(tenantSecretStore); controller.tenants().store(lockedTenant); }); + + tenant = (CloudTenant) controller.tenants().require(TenantName.from(tenantName)); var slime = new Slime(); - var cursor = slime.setObject(); - cursor.setString("name", tenantSecretStore.getName()); - cursor.setString("awsId", tenantSecretStore.getAwsId()); - cursor.setString("role", tenantSecretStore.getRole()); + toSlime(slime.setObject(), tenant.tenantSecretStores()); return new SlimeJsonResponse(slime); } @@ -2004,13 +2019,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { keyObject.setString("user", user.getName()); }); - Cursor secretStore = object.setArray("secretStores"); - cloudTenant.tenantSecretStores().forEach(store -> { - Cursor storeObject = secretStore.addObject(); - storeObject.setString("name", store.getName()); - storeObject.setString("awsId", store.getAwsId()); - storeObject.setString("role", store.getRole()); - }); + toSlime(object, cloudTenant.tenantSecretStores()); var tenantQuota = controller.serviceRegistry().billingController().getQuota(tenant.name()); var usedQuota = applications.stream() @@ -2269,6 +2278,16 @@ public class ApplicationApiHandler extends LoggingRequestHandler { array.addString(string); } + private void toSlime(Cursor object, List<TenantSecretStore> tenantSecretStores) { + Cursor secretStore = object.setArray("secretStores"); + tenantSecretStores.forEach(store -> { + Cursor storeObject = secretStore.addObject(); + storeObject.setString("name", store.getName()); + storeObject.setString("awsId", store.getAwsId()); + storeObject.setString("role", store.getRole()); + }); + } + private String readToString(InputStream stream) { Scanner scanner = new Scanner(stream).useDelimiter("\\A"); if ( ! scanner.hasNext()) return null; diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java index b669e942494..b219ee7ee9f 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java @@ -576,8 +576,8 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer } @Override - public String validateSecretStore(DeploymentId deployment, TenantSecretStore tenantSecretStore) { - return deployment.toString() + " - " + tenantSecretStore.toString(); + public String validateSecretStore(DeploymentId deployment, TenantSecretStore tenantSecretStore, String region, String parameterName) { + return "{\"settings\":{\"name\":\"foo\",\"role\":\"vespa-secretstore-access\",\"awsId\":\"892075328880\",\"externalId\":\"*****\",\"region\":\"us-east-1\"},\"status\":\"ok\"}"; } public static class Application { diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiCloudTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiCloudTest.java index 88307018385..605abf63a66 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiCloudTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiCloudTest.java @@ -129,7 +129,7 @@ public class ApplicationApiCloudTest extends ControllerContainerCloudTest { "\"externalId\": \"321\"" + "}") .roles(Set.of(Role.administrator(tenantName))); - tester.assertResponse(secretStoreRequest, "{\"message\":\"Configured secret store: TenantSecretStore{name='some-name', awsId='123', role='role-id'}\"}", 200); + tester.assertResponse(secretStoreRequest, "{\"secretStores\":[{\"name\":\"some-name\",\"awsId\":\"123\",\"role\":\"role-id\"}]}", 200); tester.assertResponse(secretStoreRequest, "{" + "\"error-code\":\"BAD_REQUEST\"," + "\"message\":\"Secret store TenantSecretStore{name='some-name', awsId='123', role='role-id'} is already configured\"" + @@ -152,7 +152,7 @@ public class ApplicationApiCloudTest extends ControllerContainerCloudTest { @Test public void validate_secret_store() { var secretStoreRequest = - request("/application/v4/tenant/scoober/secret-store/secret-foo/validate", GET) + request("/application/v4/tenant/scoober/secret-store/secret-foo/region/us-west-1/parameter-name/foo/validate", GET) .roles(Set.of(Role.administrator(tenantName))); tester.assertResponse(secretStoreRequest, "{" + "\"error-code\":\"BAD_REQUEST\"," + @@ -161,7 +161,7 @@ public class ApplicationApiCloudTest extends ControllerContainerCloudTest { deployApplication(); secretStoreRequest = - request("/application/v4/tenant/scoober/secret-store/secret-foo/validate", GET) + request("/application/v4/tenant/scoober/secret-store/secret-foo/region/us-west-1/parameter-name/foo/validate", GET) .roles(Set.of(Role.administrator(tenantName))); tester.assertResponse(secretStoreRequest, "{" + "\"error-code\":\"NOT_FOUND\"," + @@ -175,11 +175,9 @@ public class ApplicationApiCloudTest extends ControllerContainerCloudTest { // ConfigServerMock returns message on format deployment.toString() + " - " + tenantSecretStore.toString() secretStoreRequest = - request("/application/v4/tenant/scoober/secret-store/secret-foo/validate", GET) + request("/application/v4/tenant/scoober/secret-store/secret-foo/region/us-west-1/parameter-name/foo/validate", GET) .roles(Set.of(Role.administrator(tenantName))); - tester.assertResponse(secretStoreRequest, "{" + - "\"message\":\"scoober.albums in prod.us-central-1 - TenantSecretStore{name='secret-foo', awsId='123', role='some-role'}\"" + - "}", 200); + tester.assertResponse(secretStoreRequest, "{\"target\":\"scoober.albums in prod.us-central-1\",\"result\":{\"settings\":{\"name\":\"foo\",\"role\":\"vespa-secretstore-access\",\"awsId\":\"892075328880\",\"externalId\":\"*****\",\"region\":\"us-east-1\"},\"status\":\"ok\"}}", 200); } @Test @@ -198,7 +196,7 @@ public class ApplicationApiCloudTest extends ControllerContainerCloudTest { }); var tenant = (CloudTenant) tester.controller().tenants().require(tenantName); assertEquals(1, tenant.tenantSecretStores().size()); - tester.assertResponse(deleteRequest, "{\"name\":\"secret-foo\",\"awsId\":\"123\",\"role\":\"some-role\"}", 200); + tester.assertResponse(deleteRequest, "{\"secretStores\":[]}", 200); tenant = (CloudTenant) tester.controller().tenants().require(tenantName); assertEquals(0, tenant.tenantSecretStores().size()); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiTest.java index 9e0d645583a..bbba115b0a8 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserApiTest.java @@ -178,7 +178,7 @@ public class UserApiTest extends ControllerContainerCloudTest { .principal("admin@tenant") .roles(Set.of(Role.administrator(id.tenant()))) .data("{\"awsId\":\"123\",\"role\":\"secret-role\",\"externalId\":\"abc\"}"), - "{\"message\":\"Configured secret store: TenantSecretStore{name='secret-foo', awsId='123', role='secret-role'}\"}", + "{\"secretStores\":[{\"name\":\"secret-foo\",\"awsId\":\"123\",\"role\":\"secret-role\"}]}", 200); // GET a tenant with secret stores configured diff --git a/default_build_settings.cmake b/default_build_settings.cmake index c88d82db59a..4bd2509d8f7 100644 --- a/default_build_settings.cmake +++ b/default_build_settings.cmake @@ -31,7 +31,11 @@ endfunction() function(setup_vespa_default_build_settings_centos_8) message("-- Setting up default build settings for centos 8") set(DEFAULT_EXTRA_INCLUDE_DIRECTORY "${VESPA_DEPS}/include" "/usr/include/openblas" PARENT_SCOPE) - set(DEFAULT_VESPA_LLVM_VERSION "10" PARENT_SCOPE) + if (VESPA_OS_DISTRO_NAME STREQUAL "CentOS Stream") + set(DEFAULT_VESPA_LLVM_VERSION "11" PARENT_SCOPE) + else() + set(DEFAULT_VESPA_LLVM_VERSION "10" PARENT_SCOPE) + endif() endfunction() function(setup_vespa_default_build_settings_darwin) @@ -72,13 +76,13 @@ endfunction() function(setup_vespa_default_build_settings_fedora_34) message("-- Setting up default build settings for fedora 34") set(DEFAULT_EXTRA_INCLUDE_DIRECTORY "${VESPA_DEPS}/include" "/usr/include/openblas" PARENT_SCOPE) - set(DEFAULT_VESPA_LLVM_VERSION "11" PARENT_SCOPE) + set(DEFAULT_VESPA_LLVM_VERSION "12" PARENT_SCOPE) endfunction() function(setup_vespa_default_build_settings_fedora_35) message("-- Setting up default build settings for fedora 35") set(DEFAULT_EXTRA_INCLUDE_DIRECTORY "${VESPA_DEPS}/include" "/usr/include/openblas" PARENT_SCOPE) - set(DEFAULT_VESPA_LLVM_VERSION "11" PARENT_SCOPE) + set(DEFAULT_VESPA_LLVM_VERSION "12" PARENT_SCOPE) endfunction() function(setup_vespa_default_build_settings_ubuntu) diff --git a/dist/vespa.spec b/dist/vespa.spec index 795d666bb76..6e154a64a80 100644 --- a/dist/vespa.spec +++ b/dist/vespa.spec @@ -41,11 +41,21 @@ BuildRequires: rh-maven35 %define _rhmaven35_enable /opt/rh/rh-maven35/enable %endif %if 0%{?el8} +%if 0%{?centos} +%global _centos_stream %(grep -qs '^NAME="CentOS Stream"' /etc/os-release && echo 1 || echo 0) +%endif +%if 0%{?_centos_stream} +BuildRequires: gcc-toolset-10-gcc-c++ +BuildRequires: gcc-toolset-10-binutils +%define _devtoolset_enable /opt/rh/gcc-toolset-10/enable +BuildRequires: vespa-boost-devel >= 1.75.0-1 +%else BuildRequires: gcc-toolset-9-gcc-c++ BuildRequires: gcc-toolset-9-binutils -BuildRequires: maven %define _devtoolset_enable /opt/rh/gcc-toolset-9/enable %endif +BuildRequires: maven +%endif %if 0%{?fedora} BuildRequires: gcc-c++ BuildRequires: libatomic @@ -64,7 +74,11 @@ BuildRequires: vespa-libzstd-devel >= 1.4.5-2 %endif %if 0%{?el8} BuildRequires: cmake >= 3.11.4-3 +%if 0%{?_centos_stream} +BuildRequires: llvm-devel >= 11.0.0 +%else BuildRequires: llvm-devel >= 10.0.1 +%endif BuildRequires: boost-devel >= 1.66 BuildRequires: openssl-devel BuildRequires: vespa-gtest >= 1.8.1-1 @@ -96,14 +110,14 @@ BuildRequires: gmock-devel %endif %if 0%{?fc34} BuildRequires: protobuf-devel -BuildRequires: llvm-devel >= 11.1.0 +BuildRequires: llvm-devel >= 12.0.0 BuildRequires: boost-devel >= 1.75 BuildRequires: gtest-devel BuildRequires: gmock-devel %endif %if 0%{?fc35} BuildRequires: protobuf-devel -BuildRequires: llvm-devel >= 11.1.0 +BuildRequires: llvm-devel >= 12.0.0 BuildRequires: boost-devel >= 1.75 BuildRequires: gtest-devel BuildRequires: gmock-devel @@ -181,8 +195,13 @@ Requires: vespa-zstd >= 1.4.5-2 %define _extra_include_directory /usr/include/llvm7.0;%{_vespa_deps_prefix}/include;/usr/include/openblas %endif %if 0%{?el8} +%if 0%{?_centos_stream} +Requires: llvm-libs >= 11.0.0 +%define _vespa_llvm_version 11 +%else Requires: llvm-libs >= 10.0.1 %define _vespa_llvm_version 10 +%endif Requires: openssl-libs Requires: vespa-lz4 >= 1.9.2-2 Requires: vespa-onnxruntime = 1.4.0 @@ -208,13 +227,13 @@ Requires: llvm-libs >= 11.0.0 %endif %if 0%{?fc34} Requires: protobuf -Requires: llvm-libs >= 11.1.0 -%define _vespa_llvm_version 11 +Requires: llvm-libs >= 12.0.0 +%define _vespa_llvm_version 12 %endif %if 0%{?fc35} Requires: protobuf -Requires: llvm-libs >= 11.1.0 -%define _vespa_llvm_version 11 +Requires: llvm-libs >= 12.0.0 +%define _vespa_llvm_version 12 %endif %define _extra_link_directory %{_vespa_deps_prefix}/lib64 %define _extra_include_directory %{_vespa_deps_prefix}/include;/usr/include/openblas diff --git a/document/src/vespa/document/update/tensor_modify_update.cpp b/document/src/vespa/document/update/tensor_modify_update.cpp index 791c3efe872..1358f2c8e46 100644 --- a/document/src/vespa/document/update/tensor_modify_update.cpp +++ b/document/src/vespa/document/update/tensor_modify_update.cpp @@ -22,6 +22,7 @@ using vespalib::IllegalArgumentException; using vespalib::IllegalStateException; using vespalib::make_string; using vespalib::eval::ValueType; +using vespalib::eval::CellType; using vespalib::eval::FastValueBuilderFactory; using join_fun_t = double (*)(double, double); @@ -77,7 +78,7 @@ convertToCompatibleType(const TensorDataType &tensorType) for (const auto &dim : tensorType.getTensorType().dimensions()) { list.emplace_back(dim.name); } - return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list), tensorType.getTensorType().cell_type())); + return std::make_unique<const TensorDataType>(ValueType::make_type(tensorType.getTensorType().cell_type(), std::move(list))); } } diff --git a/document/src/vespa/document/update/tensor_remove_update.cpp b/document/src/vespa/document/update/tensor_remove_update.cpp index 5c8c5c07116..0b1096fce0e 100644 --- a/document/src/vespa/document/update/tensor_remove_update.cpp +++ b/document/src/vespa/document/update/tensor_remove_update.cpp @@ -19,6 +19,7 @@ using vespalib::IllegalStateException; using vespalib::make_string; using vespalib::eval::Value; using vespalib::eval::ValueType; +using vespalib::eval::CellType; using vespalib::eval::FastValueBuilderFactory; namespace document { @@ -34,7 +35,7 @@ convertToCompatibleType(const TensorDataType &tensorType) list.emplace_back(dim.name); } } - return std::make_unique<const TensorDataType>(ValueType::tensor_type(std::move(list), tensorType.getTensorType().cell_type())); + return std::make_unique<const TensorDataType>(ValueType::make_type(tensorType.getTensorType().cell_type(), std::move(list))); } } diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index 05ac4cb78ab..7d3d94d7ed6 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -21,6 +21,7 @@ vespa_define_module( src/tests/eval/gbdt src/tests/eval/gen_spec src/tests/eval/inline_operation + src/tests/eval/int8float src/tests/eval/interpreted_function src/tests/eval/multiply_add src/tests/eval/nested_loop diff --git a/eval/src/apps/eval_expr/eval_expr.cpp b/eval/src/apps/eval_expr/eval_expr.cpp index f5e5c5d0dfd..12c94c6e68e 100644 --- a/eval/src/apps/eval_expr/eval_expr.cpp +++ b/eval/src/apps/eval_expr/eval_expr.cpp @@ -22,7 +22,7 @@ int main(int argc, char **argv) { auto type = ValueType::from_spec(result.type()); if (type.is_error()) { fprintf(stdout, "error\n"); - } else if (type.is_scalar()) { + } else if (type.is_double()) { fprintf(stdout, "%.32g\n", result.as_double()); } else { fprintf(stdout, "%s\n", result.to_string().c_str()); diff --git a/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp b/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp index b4b4b628ee2..6e882fc3d9d 100644 --- a/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp +++ b/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp @@ -21,9 +21,9 @@ template <typename T> std::vector<bool> with_cell_type_opts(); template <> std::vector<bool> with_cell_type_opts<double>() { return {false, true}; } template <> std::vector<bool> with_cell_type_opts<float>() { return {true}; } -template <typename T> uint8_t cell_type(); -template <> uint8_t cell_type<double>() { return 0; } -template <> uint8_t cell_type<float>() { return 1; } +template <typename T> uint8_t cell_type_id(); +template <> uint8_t cell_type_id<double>() { return 0; } +template <> uint8_t cell_type_id<float>() { return 1; } template <typename T> const char *cell_type_str(); template <> const char *cell_type_str<double>() { return ""; } @@ -33,7 +33,7 @@ template <typename T> nbostream make_sparse(bool with_cell_type) { nbostream data; if (with_cell_type) { data << uint8_t(0x5); - data << cell_type<T>(); + data << cell_type_id<T>(); } else { data << uint8_t(0x1); } @@ -44,7 +44,7 @@ template <typename T> nbostream make_dense(bool with_cell_type) { nbostream data; if (with_cell_type) { data << uint8_t(0x6); - data << cell_type<T>(); + data << cell_type_id<T>(); } else { data << uint8_t(0x2); } @@ -55,7 +55,7 @@ template <typename T> nbostream make_mixed(bool with_cell_type) { nbostream data; if (with_cell_type) { data << uint8_t(0x7); - data << cell_type<T>(); + data << cell_type_id<T>(); } else { data << uint8_t(0x3); } diff --git a/eval/src/tests/eval/fast_value/fast_value_test.cpp b/eval/src/tests/eval/fast_value/fast_value_test.cpp index 70c534b2010..2acb6c448c9 100644 --- a/eval/src/tests/eval/fast_value/fast_value_test.cpp +++ b/eval/src/tests/eval/fast_value/fast_value_test.cpp @@ -160,7 +160,8 @@ TEST(FastValueBuilderFactoryTest, fast_values_can_be_copied) { auto factory = FastValueBuilderFactory::get(); for (const auto &layout: layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec expect = layout.cpy().cells(ct); + auto expect = layout.cpy().cells(ct); + if (expect.bad_scalar()) continue; std::unique_ptr<Value> value = value_from_spec(expect, factory); std::unique_ptr<Value> copy = factory.copy(*value); TensorSpec actual = spec_from_value(*copy); diff --git a/eval/src/tests/eval/gen_spec/gen_spec_test.cpp b/eval/src/tests/eval/gen_spec/gen_spec_test.cpp index ba169b72489..9d8eb419a67 100644 --- a/eval/src/tests/eval/gen_spec/gen_spec_test.cpp +++ b/eval/src/tests/eval/gen_spec/gen_spec_test.cpp @@ -61,9 +61,8 @@ TEST(GenSpecTest, scalar_double) { EXPECT_EQ(GenSpec(5.0).gen(), scalar_5); } -TEST(GenSpecTest, not_scalar_float_just_yet) { - EXPECT_EQ(GenSpec().cells_float().gen(), scalar_1); - EXPECT_EQ(GenSpec(5.0).cells_float().gen(), scalar_5); +TEST(GenSpecTest, scalar_float_is_bad_scalar) { + EXPECT_TRUE(GenSpec().cells_float().bad_scalar()); } //----------------------------------------------------------------------------- @@ -126,7 +125,6 @@ GenSpec dbl() { return GenSpec().cells_double(); } TEST(GenSpecTest, value_type) { EXPECT_EQ(dbl().type().to_spec(), "double"); - EXPECT_EQ(flt().type().to_spec(), "double"); // NB EXPECT_EQ(dbl().idx("x", 10).type().to_spec(), "tensor(x[10])"); EXPECT_EQ(flt().idx("x", 10).type().to_spec(), "tensor<float>(x[10])"); EXPECT_EQ(dbl().map("y", {}).type().to_spec(), "tensor(y{})"); diff --git a/eval/src/tests/eval/int8float/CMakeLists.txt b/eval/src/tests/eval/int8float/CMakeLists.txt new file mode 100644 index 00000000000..fb258f46fc2 --- /dev/null +++ b/eval/src/tests/eval/int8float/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_int8float_test_app TEST + SOURCES + int8float_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_int8float_test_app COMMAND eval_int8float_test_app) diff --git a/eval/src/tests/eval/int8float/int8float_test.cpp b/eval/src/tests/eval/int8float/int8float_test.cpp new file mode 100644 index 00000000000..9debf712150 --- /dev/null +++ b/eval/src/tests/eval/int8float/int8float_test.cpp @@ -0,0 +1,59 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/int8float.h> +#include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <cmath> + +using namespace vespalib; +using namespace vespalib::eval; + +static std::vector<float> simple_values = { + 0.0, 1.0, -1.0, -17.0, 42.0, 127.0, -128.0 +}; + +TEST(Int8FloatTest, normal_usage) { + EXPECT_EQ(sizeof(float), 4); + EXPECT_EQ(sizeof(Int8Float), 1); + Int8Float answer = 42; + double fortytwo = answer; + EXPECT_EQ(fortytwo, 42); + for (float value : simple_values) { + Int8Float b = value; + float recover = b; + EXPECT_EQ(value, recover); + } + // undefined behavior here: + Int8Float b1 = 128.0; + EXPECT_NE(float(b1), 128.0); + Int8Float b2 = -129.0; + EXPECT_NE(float(b2), -129.0); +} + +TEST(Int8FloatTest, with_nbostream) { + nbostream buf; + for (Int8Float value : simple_values) { + buf << value; + } + for (float value : simple_values) { + Int8Float stored; + buf >> stored; + EXPECT_EQ(float(stored), value); + } +} + +TEST(Int8FloatTest, traits_check) { + EXPECT_TRUE(std::is_trivially_constructible<Int8Float>::value); + EXPECT_TRUE(std::is_trivially_move_constructible<Int8Float>::value); + EXPECT_TRUE(std::is_trivially_default_constructible<Int8Float>::value); + EXPECT_TRUE((std::is_trivially_assignable<Int8Float,Int8Float>::value)); + EXPECT_TRUE(std::is_trivially_move_assignable<Int8Float>::value); + EXPECT_TRUE(std::is_trivially_copy_assignable<Int8Float>::value); + EXPECT_TRUE(std::is_trivially_copyable<Int8Float>::value); + EXPECT_TRUE(std::is_trivially_destructible<Int8Float>::value); + EXPECT_TRUE(std::is_trivial<Int8Float>::value); + EXPECT_TRUE(std::is_swappable<Int8Float>::value); + EXPECT_TRUE(std::has_unique_object_representations<Int8Float>::value); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp index 871a564bfa4..ba73a578f6f 100644 --- a/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp +++ b/eval/src/tests/eval/interpreted_function/interpreted_function_test.cpp @@ -80,7 +80,7 @@ struct MyEvalTest : test::EvalSpec::EvalTest { InterpretedFunction ifun(factory, function, node_types); InterpretedFunction::Context ictx(ifun); const Value &result_value = ifun.eval(ictx, params); - report_result(result_value.is_double(), result_value.as_double(), expected_result, description); + report_result(result_value.type().is_double(), result_value.as_double(), expected_result, description); } }; diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index 89c37af0a83..a5a17ea15a0 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -293,8 +293,8 @@ TEST("require that tensor concat resolves correct type") { } TEST("require that tensor cell_cast resolves correct type") { - TEST_DO(verify("cell_cast(double,float)", "double")); // NB - TEST_DO(verify("cell_cast(float,double)", "double")); + TEST_DO(verify("cell_cast(double,double)", "double")); + TEST_DO(verify("cell_cast(double,float)", "error")); TEST_DO(verify("cell_cast(tensor<double>(x{},y[5]),float)", "tensor<float>(x{},y[5])")); TEST_DO(verify("cell_cast(tensor<float>(x{},y[5]),double)", "tensor<double>(x{},y[5])")); TEST_DO(verify("cell_cast(tensor<float>(x{},y[5]),float)", "tensor<float>(x{},y[5])")); @@ -304,7 +304,7 @@ TEST("require that double only expressions can be detected") { auto plain_fun = Function::parse("1+2"); auto complex_fun = Function::parse("reduce(a,sum)"); NodeTypes plain_types(*plain_fun, {}); - NodeTypes complex_types(*complex_fun, {ValueType::tensor_type({{"x"}})}); + NodeTypes complex_types(*complex_fun, {ValueType::make_type(CellType::DOUBLE, {{"x"}})}); EXPECT_TRUE(plain_types.get_type(plain_fun->root()).is_double()); EXPECT_TRUE(complex_types.get_type(complex_fun->root()).is_double()); EXPECT_TRUE(plain_types.all_types_are_double()); diff --git a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp index fdbf375fa3a..2edbefc7717 100644 --- a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp +++ b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp @@ -144,8 +144,9 @@ TEST(ReferenceCellCastTest, cell_cast_works) { for (CellType from_type: CellTypeUtils::list_types()) { for (CellType to_type: CellTypeUtils::list_types()) { for (const auto &gen: gen_list) { - TensorSpec input = gen.cpy().cells(from_type); - TensorSpec expect = gen.cpy().cells(to_type); + auto input = gen.cpy().cells(from_type); + auto expect = gen.cpy().cells(to_type); + if (input.bad_scalar() || expect.bad_scalar()) continue; auto actual = ReferenceOperations::cell_cast(input, to_type); EXPECT_EQ(actual, expect); } diff --git a/eval/src/tests/eval/simple_value/simple_value_test.cpp b/eval/src/tests/eval/simple_value/simple_value_test.cpp index 974f87a6055..57c71903bf1 100644 --- a/eval/src/tests/eval/simple_value/simple_value_test.cpp +++ b/eval/src/tests/eval/simple_value/simple_value_test.cpp @@ -69,7 +69,8 @@ TensorSpec simple_value_join(const TensorSpec &a, const TensorSpec &b, join_fun_ TEST(SimpleValueTest, simple_values_can_be_converted_from_and_to_tensor_spec) { for (const auto &layout: layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec expect = layout.cpy().cells(ct); + auto expect = layout.cpy().cells(ct); + if (expect.bad_scalar()) continue; std::unique_ptr<Value> value = value_from_spec(expect, SimpleValueBuilderFactory::get()); TensorSpec actual = spec_from_value(*value); EXPECT_EQ(actual, expect); @@ -80,7 +81,8 @@ TEST(SimpleValueTest, simple_values_can_be_converted_from_and_to_tensor_spec) { TEST(SimpleValueTest, simple_values_can_be_copied) { for (const auto &layout: layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec expect = layout.cpy().cells(ct); + auto expect = layout.cpy().cells(ct); + if (expect.bad_scalar()) continue; std::unique_ptr<Value> value = value_from_spec(expect, SimpleValueBuilderFactory::get()); std::unique_ptr<Value> copy = SimpleValueBuilderFactory::get().copy(*value); TensorSpec actual = spec_from_value(*copy); @@ -131,11 +133,13 @@ TEST(SimpleValueTest, new_generic_join_works_for_simple_values) { const auto l = join_layouts[i].cpy().seq(N_16ths); const auto r = join_layouts[i + 1].cpy().seq(N_16ths); for (CellType lct : CellTypeUtils::list_types()) { - TensorSpec lhs = l.cpy().cells(lct); + auto lhs = l.cpy().cells(lct); + if (lhs.bad_scalar()) continue; for (CellType rct : CellTypeUtils::list_types()) { - TensorSpec rhs = r.cpy().cells(rct); + auto rhs = r.cpy().cells(rct); + if (rhs.bad_scalar()) continue; for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Div::f}) { - SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); + SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str())); auto expect = ReferenceOperations::join(lhs, rhs, fun); auto actual = simple_value_join(lhs, rhs, fun); EXPECT_EQ(actual, expect); diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp index f1bd900b350..c457f68a614 100644 --- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp +++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp @@ -229,9 +229,9 @@ TEST("require that full tensor reduction works") { size_t a_id = ctx.add_tensor(ctx.make_tensor_reduce_input()); const auto &fun = reduce(inject(ValueType::from_spec("tensor(x[3],y[2])"), a_id, ctx.stash), Aggr::SUM, {}, ctx.stash); EXPECT_TRUE(fun.result_is_mutable()); - EXPECT_EQUAL(ValueType::from_spec("double"), fun.result_type()); + EXPECT_EQUAL(ValueType::double_type(), fun.result_type()); const Value &result = ctx.eval(fun); - EXPECT_TRUE(result.is_double()); + EXPECT_TRUE(result.type().is_double()); EXPECT_EQUAL(21.0, result.as_double()); } @@ -300,8 +300,8 @@ TEST("require that tensor create works") { size_t b_id = ctx.add_tensor(ctx.make_double(2.0)); Value::UP my_const = ctx.make_double(3.0); Value::UP expect = ctx.make_vector({1.0, 2.0, 3.0}); - const auto &a = inject(ValueType::from_spec("double"), a_id, ctx.stash); - const auto &b = inject(ValueType::from_spec("double"), b_id, ctx.stash); + const auto &a = inject(ValueType::double_type(), a_id, ctx.stash); + const auto &b = inject(ValueType::double_type(), b_id, ctx.stash); const auto &c = const_value(*my_const, ctx.stash); const auto &fun = create(ValueType::from_spec("tensor(x[3])"), { @@ -321,8 +321,8 @@ TEST("require that single value tensor peek works") { size_t b_id = ctx.add_tensor(ctx.make_double(1000.0)); Value::UP my_const = ctx.make_mixed_tensor(1.0, 2.0, 3.0, 4.0); Value::UP expect = ctx.make_vector({2.0, 3.0, 0.0}); - const auto &a = inject(ValueType::from_spec("double"), a_id, ctx.stash); - const auto &b = inject(ValueType::from_spec("double"), b_id, ctx.stash); + const auto &a = inject(ValueType::double_type(), a_id, ctx.stash); + const auto &b = inject(ValueType::double_type(), b_id, ctx.stash); const auto &t = const_value(*my_const, ctx.stash); const auto &peek1 = peek(t, {{"x", "foo"}, {"y", a}}, ctx.stash); const auto &peek2 = peek(t, {{"x", "bar"}, {"y", size_t(0)}}, ctx.stash); @@ -354,13 +354,13 @@ TEST("require that automatic string conversion tensor peek works") { EvalCtx ctx(simple_factory); size_t a_id = ctx.add_tensor(ctx.make_double(1.0)); Value::UP my_const = ctx.make_vector({1.0, 2.0, 3.0}, "x", true); - const auto &a = inject(ValueType::from_spec("double"), a_id, ctx.stash); + const auto &a = inject(ValueType::double_type(), a_id, ctx.stash); const auto &t = const_value(*my_const, ctx.stash); const auto &fun = peek(t, {{"x", a}}, ctx.stash); EXPECT_TRUE(fun.result_is_mutable()); EXPECT_TRUE(fun.result_type().is_double()); const Value &result = ctx.eval(fun); - EXPECT_TRUE(result.is_double()); + EXPECT_TRUE(result.type().is_double()); EXPECT_EQUAL(2.0, result.as_double()); } diff --git a/eval/src/tests/eval/value_codec/value_codec_test.cpp b/eval/src/tests/eval/value_codec/value_codec_test.cpp index 434ad0b2a53..0bb1bcfb337 100644 --- a/eval/src/tests/eval/value_codec/value_codec_test.cpp +++ b/eval/src/tests/eval/value_codec/value_codec_test.cpp @@ -33,7 +33,8 @@ const std::vector<GenSpec> layouts = { TEST(ValueCodecTest, simple_values_can_be_converted_from_and_to_tensor_spec) { for (const auto &layout: layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec expect = layout.cpy().cells(ct); + auto expect = layout.cpy().cells(ct); + if (expect.bad_scalar()) continue; std::unique_ptr<Value> value = value_from_spec(expect, factory); TensorSpec actual = spec_from_value(*value); EXPECT_EQ(actual, expect); diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index c1b25d48bf7..e0c90166fa2 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -36,15 +36,8 @@ TEST("require that DOUBLE value type can be created") { EXPECT_EQUAL(t.dimensions().size(), 0u); } -TEST("require that FLOAT value type can be created") { - ValueType t = ValueType::make_type(CellType::FLOAT, {}); - EXPECT_FALSE(t.is_error()); - EXPECT_TRUE(t.cell_type() == CellType::FLOAT); - EXPECT_EQUAL(t.dimensions().size(), 0u); -} - TEST("require that TENSOR value type can be created") { - ValueType t = ValueType::tensor_type({{"x", 10},{"y"}}); + ValueType t = ValueType::make_type(CellType::DOUBLE, {{"x", 10},{"y"}}); EXPECT_FALSE(t.is_error()); EXPECT_TRUE(t.cell_type() == CellType::DOUBLE); ASSERT_EQUAL(t.dimensions().size(), 2u); @@ -55,7 +48,7 @@ TEST("require that TENSOR value type can be created") { } TEST("require that float TENSOR value type can be created") { - ValueType t = ValueType::tensor_type({{"x", 10},{"y"}}, CellType::FLOAT); + ValueType t = ValueType::make_type(CellType::FLOAT, {{"x", 10},{"y"}}); EXPECT_FALSE(t.is_error()); EXPECT_TRUE(t.cell_type() == CellType::FLOAT); ASSERT_EQUAL(t.dimensions().size(), 2u); @@ -66,7 +59,7 @@ TEST("require that float TENSOR value type can be created") { } TEST("require that TENSOR value type sorts dimensions") { - ValueType t = ValueType::tensor_type({{"x", 10}, {"z", 30}, {"y"}}); + ValueType t = ValueType::make_type(CellType::DOUBLE, {{"x", 10}, {"z", 30}, {"y"}}); EXPECT_FALSE(t.is_error()); EXPECT_TRUE(t.cell_type() == CellType::DOUBLE); ASSERT_EQUAL(t.dimensions().size(), 3u); @@ -78,19 +71,16 @@ TEST("require that TENSOR value type sorts dimensions") { EXPECT_EQUAL(t.dimensions()[2].size, 30u); } -TEST("require that 'tensor<float>()' is normalized to 'double'") { - ValueType t = ValueType::tensor_type({}, CellType::FLOAT); - EXPECT_FALSE(t.is_error()); - EXPECT_TRUE(t.cell_type() == CellType::DOUBLE); - EXPECT_EQUAL(t.dimensions().size(), 0u); +TEST("require that non-double scalar values are not allowed") { + EXPECT_TRUE(ValueType::make_type(CellType::FLOAT, {}).is_error()); } TEST("require that use of zero-size dimensions result in error types") { - EXPECT_TRUE(ValueType::tensor_type({{"x", 0}}).is_error()); + EXPECT_TRUE(ValueType::make_type(CellType::DOUBLE, {{"x", 0}}).is_error()); } TEST("require that duplicate dimension names result in error types") { - EXPECT_TRUE(ValueType::tensor_type({{"x"}, {"x"}}).is_error()); + EXPECT_TRUE(ValueType::make_type(CellType::DOUBLE, {{"x"}, {"x"}}).is_error()); } //----------------------------------------------------------------------------- @@ -116,18 +106,17 @@ void verify_not_equal(const ValueType &a, const ValueType &b) { TEST("require that value types can be compared") { TEST_DO(verify_equal(ValueType::error_type(), ValueType::error_type())); TEST_DO(verify_not_equal(ValueType::error_type(), ValueType::double_type())); - TEST_DO(verify_not_equal(ValueType::error_type(), ValueType::tensor_type({{"x"}}))); + TEST_DO(verify_not_equal(ValueType::error_type(), ValueType::make_type(CellType::DOUBLE, {{"x"}}))); TEST_DO(verify_equal(ValueType::double_type(), ValueType::double_type())); - TEST_DO(verify_not_equal(ValueType::double_type(), ValueType::make_type(CellType::FLOAT, {}))); - TEST_DO(verify_equal(ValueType::double_type(), ValueType::tensor_type({}))); - TEST_DO(verify_not_equal(ValueType::double_type(), ValueType::tensor_type({{"x"}}))); - TEST_DO(verify_equal(ValueType::tensor_type({{"x"}, {"y"}}), ValueType::tensor_type({{"y"}, {"x"}}))); - TEST_DO(verify_not_equal(ValueType::tensor_type({{"x"}, {"y"}}), ValueType::tensor_type({{"x"}, {"y"}, {"z"}}))); - TEST_DO(verify_equal(ValueType::tensor_type({{"x", 10}, {"y", 20}}), ValueType::tensor_type({{"y", 20}, {"x", 10}}))); - TEST_DO(verify_not_equal(ValueType::tensor_type({{"x", 10}, {"y", 20}}), ValueType::tensor_type({{"x", 10}, {"y", 10}}))); - TEST_DO(verify_not_equal(ValueType::tensor_type({{"x", 10}}), ValueType::tensor_type({{"x"}}))); - TEST_DO(verify_equal(ValueType::tensor_type({{"x", 10}}, CellType::FLOAT), ValueType::tensor_type({{"x", 10}}, CellType::FLOAT))); - TEST_DO(verify_not_equal(ValueType::tensor_type({{"x", 10}}, CellType::DOUBLE), ValueType::tensor_type({{"x", 10}}, CellType::FLOAT))); + TEST_DO(verify_equal(ValueType::double_type(), ValueType::make_type(CellType::DOUBLE, {}))); + TEST_DO(verify_not_equal(ValueType::double_type(), ValueType::make_type(CellType::DOUBLE, {{"x"}}))); + TEST_DO(verify_equal(ValueType::make_type(CellType::DOUBLE, {{"x"}, {"y"}}), ValueType::make_type(CellType::DOUBLE, {{"y"}, {"x"}}))); + TEST_DO(verify_not_equal(ValueType::make_type(CellType::DOUBLE, {{"x"}, {"y"}}), ValueType::make_type(CellType::DOUBLE, {{"x"}, {"y"}, {"z"}}))); + TEST_DO(verify_equal(ValueType::make_type(CellType::DOUBLE, {{"x", 10}, {"y", 20}}), ValueType::make_type(CellType::DOUBLE, {{"y", 20}, {"x", 10}}))); + TEST_DO(verify_not_equal(ValueType::make_type(CellType::DOUBLE, {{"x", 10}, {"y", 20}}), ValueType::make_type(CellType::DOUBLE, {{"x", 10}, {"y", 10}}))); + TEST_DO(verify_not_equal(ValueType::make_type(CellType::DOUBLE, {{"x", 10}}), ValueType::make_type(CellType::DOUBLE, {{"x"}}))); + TEST_DO(verify_equal(ValueType::make_type(CellType::FLOAT, {{"x", 10}}), ValueType::make_type(CellType::FLOAT, {{"x", 10}}))); + TEST_DO(verify_not_equal(ValueType::make_type(CellType::DOUBLE, {{"x", 10}}), ValueType::make_type(CellType::FLOAT, {{"x", 10}}))); } //----------------------------------------------------------------------------- @@ -135,46 +124,45 @@ TEST("require that value types can be compared") { TEST("require that value type can make spec") { EXPECT_EQUAL("error", ValueType::error_type().to_spec()); EXPECT_EQUAL("double", ValueType::double_type().to_spec()); - EXPECT_EQUAL("float", ValueType::make_type(CellType::FLOAT, {}).to_spec()); - EXPECT_EQUAL("double", ValueType::tensor_type({}).to_spec()); - EXPECT_EQUAL("double", ValueType::tensor_type({}, CellType::FLOAT).to_spec()); - EXPECT_EQUAL("tensor(x{})", ValueType::tensor_type({{"x"}}).to_spec()); - EXPECT_EQUAL("tensor(y[10])", ValueType::tensor_type({{"y", 10}}).to_spec()); - EXPECT_EQUAL("tensor(x{},y[10],z[5])", ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}).to_spec()); - EXPECT_EQUAL("tensor<float>(x{})", ValueType::tensor_type({{"x"}}, CellType::FLOAT).to_spec()); - EXPECT_EQUAL("tensor<float>(y[10])", ValueType::tensor_type({{"y", 10}}, CellType::FLOAT).to_spec()); - EXPECT_EQUAL("tensor<float>(x{},y[10],z[5])", ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}, CellType::FLOAT).to_spec()); + EXPECT_EQUAL("error", ValueType::make_type(CellType::FLOAT, {}).to_spec()); + EXPECT_EQUAL("double", ValueType::make_type(CellType::DOUBLE, {}).to_spec()); + EXPECT_EQUAL("tensor(x{})", ValueType::make_type(CellType::DOUBLE, {{"x"}}).to_spec()); + EXPECT_EQUAL("tensor(y[10])", ValueType::make_type(CellType::DOUBLE, {{"y", 10}}).to_spec()); + EXPECT_EQUAL("tensor(x{},y[10],z[5])", ValueType::make_type(CellType::DOUBLE, {{"x"}, {"y", 10}, {"z", 5}}).to_spec()); + EXPECT_EQUAL("tensor<float>(x{})", ValueType::make_type(CellType::FLOAT, {{"x"}}).to_spec()); + EXPECT_EQUAL("tensor<float>(y[10])", ValueType::make_type(CellType::FLOAT, {{"y", 10}}).to_spec()); + EXPECT_EQUAL("tensor<float>(x{},y[10],z[5])", ValueType::make_type(CellType::FLOAT, {{"x"}, {"y", 10}, {"z", 5}}).to_spec()); } //----------------------------------------------------------------------------- TEST("require that value type spec can be parsed") { - EXPECT_EQUAL(ValueType::double_type(), ValueType::from_spec("double")); - EXPECT_EQUAL(ValueType::make_type(CellType::FLOAT, {}), ValueType::from_spec("float")); - EXPECT_EQUAL(ValueType::tensor_type({}), ValueType::from_spec("tensor()")); - EXPECT_EQUAL(ValueType::tensor_type({{"x"}}), ValueType::from_spec("tensor(x{})")); - EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec("tensor(y[10])")); - EXPECT_EQUAL(ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}), ValueType::from_spec("tensor(x{},y[10],z[5])")); - EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec("tensor<double>(y[10])")); - EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}, CellType::FLOAT), ValueType::from_spec("tensor<float>(y[10])")); + EXPECT_EQUAL(ValueType::double_type(), type("double")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {}), type("tensor()")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {}), type("tensor<double>()")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"x"}}), type("tensor(x{})")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"y", 10}}), type("tensor(y[10])")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"x"}, {"y", 10}, {"z", 5}}), type("tensor(x{},y[10],z[5])")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"y", 10}}), type("tensor<double>(y[10])")); + EXPECT_EQUAL(ValueType::make_type(CellType::FLOAT, {{"y", 10}}), type("tensor<float>(y[10])")); } TEST("require that value type spec can be parsed with extra whitespace") { - EXPECT_EQUAL(ValueType::double_type(), ValueType::from_spec(" double ")); - EXPECT_EQUAL(ValueType::make_type(CellType::FLOAT, {}), ValueType::from_spec(" float ")); - EXPECT_EQUAL(ValueType::tensor_type({}), ValueType::from_spec(" tensor ( ) ")); - EXPECT_EQUAL(ValueType::tensor_type({{"x"}}), ValueType::from_spec(" tensor ( x { } ) ")); - EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec(" tensor ( y [ 10 ] ) ")); - EXPECT_EQUAL(ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}), - ValueType::from_spec(" tensor ( x { } , y [ 10 ] , z [ 5 ] ) ")); - EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec(" tensor < double > ( y [ 10 ] ) ")); - EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}, CellType::FLOAT), ValueType::from_spec(" tensor < float > ( y [ 10 ] ) ")); + EXPECT_EQUAL(ValueType::double_type(), type(" double ")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {}), type(" tensor ( ) ")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {}), type(" tensor < double > ( ) ")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"x"}}), type(" tensor ( x { } ) ")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"y", 10}}), type(" tensor ( y [ 10 ] ) ")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"x"}, {"y", 10}, {"z", 5}}), + type(" tensor ( x { } , y [ 10 ] , z [ 5 ] ) ")); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"y", 10}}), type(" tensor < double > ( y [ 10 ] ) ")); + EXPECT_EQUAL(ValueType::make_type(CellType::FLOAT, {{"y", 10}}), type(" tensor < float > ( y [ 10 ] ) ")); } TEST("require that the unsorted dimension list can be obtained when parsing type spec") { std::vector<ValueType::Dimension> unsorted; auto type = ValueType::from_spec("tensor(y[10],z[5],x{})", unsorted); - EXPECT_EQUAL(ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}), type); + EXPECT_EQUAL(ValueType::make_type(CellType::DOUBLE, {{"x"}, {"y", 10}, {"z", 5}}), type); ASSERT_EQUAL(unsorted.size(), 3u); EXPECT_EQUAL(unsorted[0].name, "y"); EXPECT_EQUAL(unsorted[0].size, 10u); @@ -207,6 +195,7 @@ TEST("require that malformed value type spec is parsed as error") { EXPECT_TRUE(ValueType::from_spec(" ").is_error()); EXPECT_TRUE(ValueType::from_spec("error").is_error()); EXPECT_TRUE(ValueType::from_spec("any").is_error()); + EXPECT_TRUE(ValueType::from_spec("float").is_error()); EXPECT_TRUE(ValueType::from_spec("tensor").is_error()); EXPECT_TRUE(ValueType::from_spec("tensor<double>").is_error()); EXPECT_TRUE(ValueType::from_spec("tensor() tensor()").is_error()); @@ -224,7 +213,8 @@ TEST("require that malformed value type spec is parsed as error") { EXPECT_TRUE(ValueType::from_spec("tensor(x{},x[10])").is_error()); EXPECT_TRUE(ValueType::from_spec("tensor(x{},x[])").is_error()); EXPECT_TRUE(ValueType::from_spec("tensor(z[])").is_error()); - EXPECT_TRUE(ValueType::from_spec("tensor<float16>(x[10])").is_error()); + EXPECT_TRUE(ValueType::from_spec("tensor<float>()").is_error()); + EXPECT_TRUE(ValueType::from_spec("tensor<int7>(x[10])").is_error()); } struct ParseResult { @@ -247,7 +237,7 @@ ParseResult::~ParseResult() = default; TEST("require that we can parse a partial string into a type with the low-level API") { ParseResult result("tensor(a[5]) , "); - EXPECT_EQUAL(result.type, ValueType::tensor_type({{"a", 5}})); + EXPECT_EQUAL(result.type, ValueType::make_type(CellType::DOUBLE, {{"a", 5}})); ASSERT_TRUE(result.after_inside()); EXPECT_EQUAL(*result.after, ','); } @@ -315,7 +305,7 @@ void verify_predicates(const ValueType &type, { EXPECT_EQUAL(type.is_error(), expect_error); EXPECT_EQUAL(type.is_double(), expect_double); - EXPECT_EQUAL(type.is_tensor(), expect_tensor); + EXPECT_EQUAL(type.has_dimensions(), expect_tensor); EXPECT_EQUAL(type.is_sparse(), expect_sparse); EXPECT_EQUAL(type.is_dense(), expect_dense); } @@ -507,8 +497,12 @@ void verify_cell_cast(const ValueType &type) { if (type.is_error()) { EXPECT_TRUE(res_type.is_error()); EXPECT_EQUAL(res_type, type); - } else if (type.is_scalar()) { - EXPECT_TRUE(res_type.is_double()); // NB + } else if (type.is_double()) { + if (cell_type == CellType::DOUBLE) { + EXPECT_TRUE(res_type.is_double()); + } else { + EXPECT_TRUE(res_type.is_error()); + } } else { EXPECT_FALSE(res_type.is_error()); EXPECT_EQUAL(int(res_type.cell_type()), int(cell_type)); @@ -519,7 +513,6 @@ void verify_cell_cast(const ValueType &type) { TEST("require that value type cell cast works correctly") { TEST_DO(verify_cell_cast(type("error"))); - TEST_DO(verify_cell_cast(type("float"))); TEST_DO(verify_cell_cast(type("double"))); TEST_DO(verify_cell_cast(type("tensor<float>(x[10])"))); TEST_DO(verify_cell_cast(type("tensor<double>(x[10])"))); @@ -548,4 +541,24 @@ TEST("require that cell type name recognition is strict") { EXPECT_FALSE(value_type::cell_type_from_name("").has_value()); } +TEST("require that map type inference works as expected") { + EXPECT_EQUAL(type("error").map(), type("error")); + EXPECT_EQUAL(type("double").map(), type("double")); + EXPECT_EQUAL(type("tensor(x[10])").map(), type("tensor(x[10])")); + EXPECT_EQUAL(type("tensor<float>(x{})").map(), type("tensor<float>(x{})")); +} + +TEST("require that peek type inference works as expected") { + auto input1 = type("tensor(a[2],b{},c[3],d{},e[5])"); + auto input2 = type("tensor<float>(a[2],b{},c[3],d{},e[5])"); + EXPECT_EQUAL(type("error").peek({}), type("error")); + EXPECT_EQUAL(type("double").peek({}), type("error")); + EXPECT_EQUAL(input1.peek({}), type("error")); + EXPECT_EQUAL(input1.peek({"x"}), type("error")); + EXPECT_EQUAL(input1.peek({"a", "c", "e"}), type("tensor(b{},d{})")); + EXPECT_EQUAL(input2.peek({"b", "d"}), type("tensor<float>(a[2],c[3],e[5])")); + EXPECT_EQUAL(input1.peek({"a", "b", "c", "d", "e"}), type("double")); + EXPECT_EQUAL(input2.peek({"a", "b", "c", "d", "e"}), type("double")); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/instruction/generic_cell_cast/generic_cell_cast_test.cpp b/eval/src/tests/instruction/generic_cell_cast/generic_cell_cast_test.cpp index eb156bbe531..6e4ac6fca2b 100644 --- a/eval/src/tests/instruction/generic_cell_cast/generic_cell_cast_test.cpp +++ b/eval/src/tests/instruction/generic_cell_cast/generic_cell_cast_test.cpp @@ -47,8 +47,10 @@ void test_generic_cell_cast_with(const ValueBuilderFactory &factory) { for (const auto &layout : layouts) { for (CellType in_type: CellTypeUtils::list_types()) { for (CellType out_type: CellTypeUtils::list_types()) { - TensorSpec lhs = layout.cpy().cells(in_type); - SCOPED_TRACE(fmt("\n===\nLHS: %s\n===\n", lhs.to_string().c_str())); + auto lhs = layout.cpy().cells(in_type); + auto res_check = layout.cpy().cells(out_type); + if (lhs.bad_scalar() || res_check.bad_scalar()) continue; + SCOPED_TRACE(fmt("\n===\nLHS: %s\n===\n", lhs.gen().to_string().c_str())); auto expect = ReferenceOperations::cell_cast(lhs, out_type); auto actual = perform_generic_cell_cast(lhs, out_type, factory); EXPECT_EQ(actual, expect); diff --git a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp index 6b6a803a4b1..a74b0f99841 100644 --- a/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp +++ b/eval/src/tests/instruction/generic_concat/generic_concat_test.cpp @@ -80,10 +80,12 @@ void test_generic_concat_with(const ValueBuilderFactory &factory) { const auto l = concat_layouts[i]; const auto r = concat_layouts[i+1].cpy().seq(N_16ths); for (CellType lct : CellTypeUtils::list_types()) { - TensorSpec lhs = l.cpy().cells(lct); + auto lhs = l.cpy().cells(lct); + if (lhs.bad_scalar()) continue; for (CellType rct : CellTypeUtils::list_types()) { - TensorSpec rhs = r.cpy().cells(rct); - SCOPED_TRACE(fmt("\n===\nin LHS: %s\nin RHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); + auto rhs = r.cpy().cells(rct); + if (rhs.bad_scalar()) continue; + SCOPED_TRACE(fmt("\n===\nin LHS: %s\nin RHS: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str())); auto actual = perform_generic_concat(lhs, rhs, "y", factory); auto expect = ReferenceOperations::concat(lhs, rhs, "y"); EXPECT_EQ(actual, expect); diff --git a/eval/src/tests/instruction/generic_create/generic_create_test.cpp b/eval/src/tests/instruction/generic_create/generic_create_test.cpp index 843a292612d..9389c8401e9 100644 --- a/eval/src/tests/instruction/generic_create/generic_create_test.cpp +++ b/eval/src/tests/instruction/generic_create/generic_create_test.cpp @@ -92,7 +92,7 @@ TensorSpec perform_generic_create(const TensorSpec &a, const ValueBuilderFactory void test_generic_create_with(const ValueBuilderFactory &factory) { for (const auto &layout : create_layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec full = layout.cpy().cells(ct); + auto full = layout.cpy().cells(ct); auto actual = perform_generic_create(full, factory); auto expect = reference_create(full); EXPECT_EQ(actual, expect); diff --git a/eval/src/tests/instruction/generic_join/generic_join_test.cpp b/eval/src/tests/instruction/generic_join/generic_join_test.cpp index 114881e6bee..a4f645c5dee 100644 --- a/eval/src/tests/instruction/generic_join/generic_join_test.cpp +++ b/eval/src/tests/instruction/generic_join/generic_join_test.cpp @@ -109,11 +109,13 @@ TEST(GenericJoinTest, generic_join_works_for_simple_and_fast_values) { const auto &l = join_layouts[i]; const auto &r = join_layouts[i+1]; for (CellType lct : CellTypeUtils::list_types()) { - TensorSpec lhs = l.cpy().cells(lct); + auto lhs = l.cpy().cells(lct); + if (lhs.bad_scalar()) continue; for (CellType rct : CellTypeUtils::list_types()) { - TensorSpec rhs = r.cpy().cells(rct); + auto rhs = r.cpy().cells(rct); + if (rhs.bad_scalar()) continue; for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Div::f}) { - SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); + SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str())); auto expect = ReferenceOperations::join(lhs, rhs, fun); auto simple = perform_generic_join(lhs, rhs, fun, SimpleValueBuilderFactory::get()); auto fast = perform_generic_join(lhs, rhs, fun, FastValueBuilderFactory::get()); diff --git a/eval/src/tests/instruction/generic_map/generic_map_test.cpp b/eval/src/tests/instruction/generic_map/generic_map_test.cpp index 56405eefdde..bfa7154968d 100644 --- a/eval/src/tests/instruction/generic_map/generic_map_test.cpp +++ b/eval/src/tests/instruction/generic_map/generic_map_test.cpp @@ -36,8 +36,7 @@ const std::vector<GenSpec> map_layouts = { TensorSpec perform_generic_map(const TensorSpec &a, map_fun_t func, const ValueBuilderFactory &factory) { auto lhs = value_from_spec(a, factory); - // XXX for now: - auto res_type = lhs->type(); + auto res_type = lhs->type().map(); auto my_op = GenericMap::make_instruction(res_type, lhs->type(), func); InterpretedFunction::EvalSingle single(factory, my_op); return spec_from_value(single.eval(std::vector<Value::CREF>({*lhs}))); @@ -46,9 +45,10 @@ TensorSpec perform_generic_map(const TensorSpec &a, map_fun_t func, const ValueB void test_generic_map_with(const ValueBuilderFactory &factory) { for (const auto &layout : map_layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec lhs = layout.cpy().cells(ct); + auto lhs = layout.cpy().cells(ct); + if (lhs.bad_scalar()) continue; for (auto func : {operation::Floor::f, operation::Fabs::f, operation::Square::f, operation::Inv::f}) { - SCOPED_TRACE(fmt("\n===\nLHS: %s\n===\n", lhs.to_string().c_str())); + SCOPED_TRACE(fmt("\n===\nLHS: %s\n===\n", lhs.gen().to_string().c_str())); auto expect = ReferenceOperations::map(lhs, func); auto actual = perform_generic_map(lhs, func, factory); EXPECT_EQ(actual, expect); diff --git a/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp index 9fde59a7c86..701fb26d3ff 100644 --- a/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp +++ b/eval/src/tests/instruction/generic_merge/generic_merge_test.cpp @@ -52,10 +52,12 @@ void test_generic_merge_with(const ValueBuilderFactory &factory) { const auto l = merge_layouts[i]; const auto r = merge_layouts[i+1].cpy().seq(N_16ths); for (CellType lct : CellTypeUtils::list_types()) { - TensorSpec lhs = l.cpy().cells(lct); + auto lhs = l.cpy().cells(lct); + if (lhs.bad_scalar()) continue; for (CellType rct : CellTypeUtils::list_types()) { - TensorSpec rhs = r.cpy().cells(rct); - SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); + auto rhs = r.cpy().cells(rct); + if (rhs.bad_scalar()) continue; + SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str())); for (auto fun: {operation::Add::f, operation::Mul::f, operation::Sub::f, operation::Max::f}) { auto expect = ReferenceOperations::merge(lhs, rhs, fun); auto actual = perform_generic_merge(lhs, rhs, fun, factory); diff --git a/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp b/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp index 073df8be7e9..4b773b07734 100644 --- a/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp +++ b/eval/src/tests/instruction/generic_peek/generic_peek_test.cpp @@ -150,7 +150,7 @@ void verify_peek_equal(const TensorSpec &input, reduce_dims.push_back(dim_name); } if (reduce_dims.empty()) return; - ValueType result_type = param_type.reduce(reduce_dims); + ValueType result_type = param_type.peek(reduce_dims); auto expect = reference_peek(input, spec); SCOPED_TRACE(fmt("peek input: %s\n peek spec: %s\n peek result %s\n", input.to_string().c_str(), @@ -195,8 +195,8 @@ void fill_dims_and_check(const TensorSpec &input, void test_generic_peek_with(const ValueBuilderFactory &factory) { for (const auto &layout : peek_layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec input = layout.cpy().cells(ct); - ValueType input_type = ValueType::from_spec(input.type()); + auto input = layout.cpy().cells(ct); + ValueType input_type = input.type(); const auto &dims = input_type.dimensions(); PeekSpec spec; fill_dims_and_check(input, spec, dims, factory); diff --git a/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp b/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp index 3babe80766a..e3eea84fdea 100644 --- a/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp +++ b/eval/src/tests/instruction/generic_reduce/generic_reduce_test.cpp @@ -72,8 +72,9 @@ TEST(GenericReduceTest, sparse_reduce_plan_can_be_created) { void test_generic_reduce_with(const ValueBuilderFactory &factory) { for (const auto &layout: layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec input = layout.cpy().cells(ct); - SCOPED_TRACE(fmt("tensor type: %s, num_cells: %zu", input.type().c_str(), input.cells().size())); + auto input = layout.cpy().cells(ct); + if (input.bad_scalar()) continue; + SCOPED_TRACE(fmt("tensor type: %s, num_cells: %zu", input.gen().type().c_str(), input.gen().cells().size())); for (Aggr aggr: {Aggr::SUM, Aggr::AVG, Aggr::MIN, Aggr::MAX}) { SCOPED_TRACE(fmt("aggregator: %s", AggrNames::name_of(aggr)->c_str())); auto t = layout.type(); diff --git a/eval/src/tests/instruction/generic_rename/generic_rename_test.cpp b/eval/src/tests/instruction/generic_rename/generic_rename_test.cpp index 4edf2a0ca87..ca14149f1ff 100644 --- a/eval/src/tests/instruction/generic_rename/generic_rename_test.cpp +++ b/eval/src/tests/instruction/generic_rename/generic_rename_test.cpp @@ -112,13 +112,13 @@ TensorSpec perform_generic_rename(const TensorSpec &a, void test_generic_rename_with(const ValueBuilderFactory &factory) { for (const auto &layout : rename_layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec lhs = layout.cpy().cells(ct); - ValueType lhs_type = ValueType::from_spec(lhs.type()); + auto lhs = layout.cpy().cells(ct); + ValueType lhs_type = lhs.type(); for (const auto & from_to : rename_from_to) { ValueType renamed_type = lhs_type.rename(from_to.from, from_to.to); if (renamed_type.is_error()) continue; // printf("type %s -> %s\n", lhs_type.to_spec().c_str(), renamed_type.to_spec().c_str()); - SCOPED_TRACE(fmt("\n===\nLHS: %s\n===\n", lhs.to_string().c_str())); + SCOPED_TRACE(fmt("\n===\nLHS: %s\n===\n", lhs.gen().to_string().c_str())); auto expect = ReferenceOperations::rename(lhs, from_to.from, from_to.to); auto actual = perform_generic_rename(lhs, from_to, factory); EXPECT_EQ(actual, expect); diff --git a/eval/src/tests/streamed/value/streamed_value_test.cpp b/eval/src/tests/streamed/value/streamed_value_test.cpp index a750ee88667..bb286dbfdc8 100644 --- a/eval/src/tests/streamed/value/streamed_value_test.cpp +++ b/eval/src/tests/streamed/value/streamed_value_test.cpp @@ -69,7 +69,8 @@ TensorSpec streamed_value_join(const TensorSpec &a, const TensorSpec &b, join_fu TEST(StreamedValueTest, streamed_values_can_be_converted_from_and_to_tensor_spec) { for (const auto &layout: layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec expect = layout.cpy().cells(ct); + auto expect = layout.cpy().cells(ct); + if (expect.bad_scalar()) continue; std::unique_ptr<Value> value = value_from_spec(expect, StreamedValueBuilderFactory::get()); TensorSpec actual = spec_from_value(*value); EXPECT_EQ(actual, expect); @@ -80,7 +81,8 @@ TEST(StreamedValueTest, streamed_values_can_be_converted_from_and_to_tensor_spec TEST(StreamedValueTest, streamed_values_can_be_copied) { for (const auto &layout: layouts) { for (CellType ct : CellTypeUtils::list_types()) { - TensorSpec expect = layout.cpy().cells(ct); + auto expect = layout.cpy().cells(ct); + if (expect.bad_scalar()) continue; std::unique_ptr<Value> value = value_from_spec(expect, StreamedValueBuilderFactory::get()); std::unique_ptr<Value> copy = StreamedValueBuilderFactory::get().copy(*value); TensorSpec actual = spec_from_value(*copy); @@ -131,11 +133,13 @@ TEST(StreamedValueTest, new_generic_join_works_for_streamed_values) { const auto l = join_layouts[i].cpy().seq(N_16ths); const auto r = join_layouts[i + 1].cpy().seq(N_16ths); for (CellType lct : CellTypeUtils::list_types()) { - TensorSpec lhs = l.cpy().cells(lct); + auto lhs = l.cpy().cells(lct); + if (lhs.bad_scalar()) continue; for (CellType rct : CellTypeUtils::list_types()) { - TensorSpec rhs = r.cpy().cells(rct); + auto rhs = r.cpy().cells(rct); + if (rhs.bad_scalar()) continue; for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Max::f}) { - SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); + SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.gen().to_string().c_str(), rhs.gen().to_string().c_str())); auto expect = ReferenceOperations::join(lhs, rhs, fun); auto actual = streamed_value_join(lhs, rhs, fun); EXPECT_EQ(actual, expect); diff --git a/eval/src/tests/tensor/instruction_benchmark/CMakeLists.txt b/eval/src/tests/tensor/instruction_benchmark/CMakeLists.txt index b5949398f50..e553692dfcb 100644 --- a/eval/src/tests/tensor/instruction_benchmark/CMakeLists.txt +++ b/eval/src/tests/tensor/instruction_benchmark/CMakeLists.txt @@ -1,5 +1,5 @@ # Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_executable(vespa-tensor-instructions-benchmark +vespa_add_executable(vespa-tensor-instructions-benchmark TEST SOURCES instruction_benchmark.cpp OUTPUT_NAME vespa-tensor-instructions-benchmark @@ -8,3 +8,4 @@ vespa_add_executable(vespa-tensor-instructions-benchmark vespaeval GTest::GTest ) +vespa_add_test(NAME eval_tensor_instruction_test COMMAND vespa-tensor-instructions-benchmark --smoke-test) diff --git a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp index ec49f9772a8..ec9a5419ffb 100644 --- a/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp +++ b/eval/src/tests/tensor/instruction_benchmark/instruction_benchmark.cpp @@ -62,6 +62,7 @@ template <typename T> using CREF = std::reference_wrapper<const T>; //----------------------------------------------------------------------------- +TensorSpec NUM(double value) { return test::GenSpec(value).gen(); } test::GenSpec GS(double bias) { return test::GenSpec(bias).cells_float(); } //----------------------------------------------------------------------------- @@ -222,7 +223,7 @@ vespalib::string short_header(" vespalib::string ghost_name(" loaded from ghost.json"); vespalib::string ghost_short_name(" ghost"); -constexpr double budget = 5.0; +double budget = 5.0; constexpr double best_limit = 0.95; // everything within 95% of best performance gets a star constexpr double bad_limit = 0.90; // BAD: optimized has performance lower than 90% of un-optimized constexpr double good_limit = 1.10; // GOOD: optimized has performance higher than 110% of un-optimized @@ -365,6 +366,9 @@ struct EvalOp { } TensorSpec result() { return impl.create_spec(single.eval(stack)); } size_t suggest_loop_cnt() { + if (budget < 0.1) { + return 1; + } size_t loop_cnt = 1; auto my_loop = [&](){ for (size_t i = 0; i < loop_cnt; ++i) { @@ -389,19 +393,27 @@ struct EvalOp { } double estimate_cost_us(size_t self_loop_cnt, size_t ref_loop_cnt) { size_t loop_cnt = ((self_loop_cnt * 128) < ref_loop_cnt) ? self_loop_cnt : ref_loop_cnt; - assert((loop_cnt % 8) == 0); - auto my_loop = [&](){ - for (size_t i = 0; (i + 7) < loop_cnt; i += 8) { - for (size_t j = 0; j < 8; ++j) { - single.eval(stack); + BenchmarkTimer timer(budget); + if (loop_cnt == 1) { + while (timer.has_budget()) { + timer.before(); + single.eval(stack); + timer.after(); + } + } else { + assert((loop_cnt % 8) == 0); + auto my_loop = [&](){ + for (size_t i = 0; (i + 7) < loop_cnt; i += 8) { + for (size_t j = 0; j < 8; ++j) { + single.eval(stack); + } } + }; + while (timer.has_budget()) { + timer.before(); + my_loop(); + timer.after(); } - }; - BenchmarkTimer timer(budget); - while (timer.has_budget()) { - timer.before(); - my_loop(); - timer.after(); } return timer.min_time() * 1000.0 * 1000.0 / double(loop_cnt); } @@ -425,8 +437,9 @@ void benchmark(const vespalib::string &desc, const std::vector<EvalOp::UP> &list for (const auto &eval: list) { loop_cnt[eval->impl.order] = eval->suggest_loop_cnt(); } + size_t ref_idx = (list.size() > 1 ? 1u : 0u); for (const auto &eval: list) { - double time = eval->estimate_cost_us(loop_cnt[eval->impl.order], loop_cnt[1]); + double time = eval->estimate_cost_us(loop_cnt[eval->impl.order], loop_cnt[ref_idx]); fprintf(stderr, " %s(%s): %10.3f us\n", eval->impl.name.c_str(), eval->impl.short_name.c_str(), time); result.sample(eval->impl.order, time); } @@ -567,7 +580,7 @@ void benchmark_tensor_create(const vespalib::string &desc, const TensorSpec &pro ASSERT_FALSE(proto_type.is_error()); std::vector<CREF<TensorSpec>> stack_spec; for (const auto &cell: proto.cells()) { - stack_spec.emplace_back(stash.create<TensorSpec>(GS(cell.second))); + stack_spec.emplace_back(stash.create<TensorSpec>(NUM(cell.second))); } std::vector<EvalOp::UP> list; for (const Impl &impl: impl_list) { @@ -603,7 +616,7 @@ void benchmark_tensor_peek(const vespalib::string &desc, const TensorSpec &lhs, stack_spec.emplace_back(lhs); if (peek_spec.is_dynamic) { for (const auto &entry: peek_spec.spec) { - stack_spec.emplace_back(stash.create<TensorSpec>(GS(double(entry.second)))); + stack_spec.emplace_back(stash.create<TensorSpec>(NUM(double(entry.second)))); } } std::vector<EvalOp::UP> list; @@ -618,12 +631,12 @@ void benchmark_tensor_peek(const vespalib::string &desc, const TensorSpec &lhs, //----------------------------------------------------------------------------- TEST(MakeInputTest, print_some_test_input) { - auto number = GS(5.0); + auto number = NUM(5.0); auto sparse = GS(1.0).map("x", 5, 3); auto dense = GS(10.0).idx("x", 5); auto mixed = GS(100.0).map("x", 3, 7).idx("y", 2).idx("z", 2); fprintf(stderr, "--------------------------------------------------------\n"); - fprintf(stderr, "simple number: %s\n", number.gen().to_string().c_str()); + fprintf(stderr, "simple number: %s\n", number.to_string().c_str()); fprintf(stderr, "sparse vector: %s\n", sparse.gen().to_string().c_str()); fprintf(stderr, "dense vector: %s\n", dense.gen().to_string().c_str()); fprintf(stderr, "mixed cube: %s\n", mixed.gen().to_string().c_str()); @@ -747,8 +760,8 @@ TEST(MixedConcat, large_mixed_b) { //----------------------------------------------------------------------------- TEST(NumberJoin, plain_op2) { - auto lhs = GS(2.0); - auto rhs = GS(3.0); + auto lhs = NUM(2.0); + auto rhs = NUM(3.0); benchmark_join("simple numbers multiply", lhs, rhs, operation::Mul::f); } @@ -793,7 +806,7 @@ TEST(DenseJoin, simple_expand) { } TEST(DenseJoin, multiply_by_number) { - auto lhs = GS(3.0); + auto lhs = NUM(3.0); auto rhs = GS(2.0).idx("a", 16).idx("b", 16).idx("c", 16); benchmark_join("dense cube multiply by number", lhs, rhs, operation::Mul::f); } @@ -837,7 +850,7 @@ TEST(SparseJoin, no_overlap) { } TEST(SparseJoin, multiply_by_number) { - auto lhs = GS(3.0); + auto lhs = NUM(3.0); auto rhs = GS(2.0).map("a", 16, 2).map("b", 16, 2).map("c", 16, 2); benchmark_join("sparse multiply by number", lhs, rhs, operation::Mul::f); } @@ -863,7 +876,7 @@ TEST(MixedJoin, no_overlap) { } TEST(MixedJoin, multiply_by_number) { - auto lhs = GS(3.0); + auto lhs = NUM(3.0); auto rhs = GS(2.0).map("a", 16, 2).map("b", 16, 2).idx("c", 16); benchmark_join("mixed multiply by number", lhs, rhs, operation::Mul::f); } @@ -871,7 +884,7 @@ TEST(MixedJoin, multiply_by_number) { //----------------------------------------------------------------------------- TEST(ReduceBench, number_reduce) { - auto lhs = GS(1.0); + auto lhs = NUM(1.0); benchmark_reduce("number reduce", lhs, Aggr::SUM, {}); } @@ -954,7 +967,7 @@ TEST(MergeBench, mixed_merge) { //----------------------------------------------------------------------------- TEST(MapBench, number_map) { - auto lhs = GS(1.75); + auto lhs = NUM(1.75); benchmark_map("number map", lhs, operation::Floor::f); } @@ -999,7 +1012,7 @@ TEST(TensorCreateBench, create_mixed) { TEST(TensorLambdaBench, simple_lambda) { auto type = ValueType::from_spec("tensor<float>(a[64],b[64])"); - auto p0 = GS(3.5); + auto p0 = NUM(3.5); auto function = Function::parse({"a", "b", "p0"}, "(a*64+b)*p0"); ASSERT_FALSE(function->has_error()); benchmark_tensor_lambda("simple tensor lambda", type, p0, *function); @@ -1089,13 +1102,21 @@ int main(int argc, char **argv) { load_ghost("ghost.json"); const std::string run_only_prod_option = "--limit-implementations"; const std::string ghost_mode_option = "--ghost-mode"; - if ((argc > 1) && (argv[1] == run_only_prod_option )) { + const std::string smoke_test_option = "--smoke-test"; + if ((argc > 1) && (argv[1] == run_only_prod_option)) { impl_list.clear(); impl_list.push_back(optimized_fast_value_impl); impl_list.push_back(fast_value_impl); ++argv; --argc; - } else if ((argc > 1) && (argv[1] == ghost_mode_option )) { + } else if ((argc > 1) && (argv[1] == ghost_mode_option)) { + impl_list.clear(); + impl_list.push_back(optimized_fast_value_impl); + has_ghost = true; + ++argv; + --argc; + } else if ((argc > 1) && (argv[1] == smoke_test_option)) { + budget = 0.001; impl_list.clear(); impl_list.push_back(optimized_fast_value_impl); has_ghost = true; diff --git a/eval/src/vespa/eval/eval/CMakeLists.txt b/eval/src/vespa/eval/eval/CMakeLists.txt index 2c0922e617c..639ac3b5864 100644 --- a/eval/src/vespa/eval/eval/CMakeLists.txt +++ b/eval/src/vespa/eval/eval/CMakeLists.txt @@ -15,6 +15,7 @@ vespa_add_library(eval_eval OBJECT fast_value.cpp function.cpp gbdt.cpp + int8float.cpp interpreted_function.cpp key_gen.cpp lazy_params.cpp diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h index beb0b32386f..6bb9756421f 100644 --- a/eval/src/vespa/eval/eval/cell_type.h +++ b/eval/src/vespa/eval/eval/cell_type.h @@ -5,34 +5,147 @@ #include <vespa/vespalib/util/typify.h> #include <vector> #include <cstdint> +#include <cassert> namespace vespalib::eval { enum class CellType : char { FLOAT, DOUBLE }; -// utility templates +// converts actual cell type to CellType enum value +template <typename CT> constexpr CellType get_cell_type(); +template <> constexpr CellType get_cell_type<double>() { return CellType::DOUBLE; } +template <> constexpr CellType get_cell_type<float>() { return CellType::FLOAT; } -template <typename CT> inline bool check_cell_type(CellType type); -template <> inline bool check_cell_type<double>(CellType type) { return (type == CellType::DOUBLE); } -template <> inline bool check_cell_type<float>(CellType type) { return (type == CellType::FLOAT); } +// check if the given CellType enum value and actual cell type match +template <typename CT> constexpr bool check_cell_type(CellType type) { + return (type == get_cell_type<CT>()); +} -template <typename LCT, typename RCT> struct UnifyCellTypes{}; -template <> struct UnifyCellTypes<double, double> { using type = double; }; -template <> struct UnifyCellTypes<double, float> { using type = double; }; -template <> struct UnifyCellTypes<float, double> { using type = double; }; -template <> struct UnifyCellTypes<float, float> { using type = float; }; +// converts CellType enum value to actual cell type by using the +// return value as a type tag. usage: +// decltype(get_cell_value<my_cell_type>()) +template <CellType cell_type> constexpr auto get_cell_value() { + if constexpr (cell_type == CellType::DOUBLE) { + return double(); + } else if constexpr (cell_type == CellType::FLOAT) { + return float(); + } else { + static_assert((cell_type == CellType::DOUBLE), "unknown cell type"); + } +} +template <CellType cell_type> using CellValueType = decltype(get_cell_value<cell_type>()); + +// simple CellMeta value wrapper to reduce template expansion +// -> for values that are results of operations that are not scalars +struct LimitedCellMetaNotScalar { + const CellType cell_type; +}; + +// simple CellMeta value wrapper to reduce template expansion +// -> for values that are results of operations +struct LimitedCellMeta { + const CellType cell_type; + const bool is_scalar; + constexpr LimitedCellMetaNotScalar not_scalar() const { + assert(!is_scalar); + return {cell_type}; + } +}; + +// simple CellMeta value wrapper to reduce template expansion +// -> for values that we known are not scalar +struct CellMetaNotScalar { + const CellType cell_type; +}; + +// meta-information about the cell type and 'scalar-ness' of a value +struct CellMeta { + const CellType cell_type; + const bool is_scalar; + constexpr CellMeta(CellType cell_type_in, bool is_scalar_in) + : cell_type(cell_type_in), is_scalar(is_scalar_in) + { + // is_scalar -> double cell type + assert(!is_scalar || (cell_type == CellType::DOUBLE)); + } + constexpr bool is_limited() const { + return ((cell_type == CellType::DOUBLE) || (cell_type == CellType::FLOAT)); + } + constexpr LimitedCellMeta limit() const { + assert(is_limited()); + return {cell_type, is_scalar}; + } + constexpr CellMetaNotScalar not_scalar() const { + assert(!is_scalar); + return {cell_type}; + } + + constexpr CellMeta self() const { return *this; } + + constexpr bool eq(const CellMeta &rhs) const { + return ((cell_type == rhs.cell_type) && (is_scalar == rhs.is_scalar)); + } + + // promote cell type to at least float + constexpr CellMeta decay() const { + if (cell_type == CellType::DOUBLE) { + return self(); + } + return {CellType::FLOAT, is_scalar}; + } + + // normalize to make sure scalar values have cell type double + static constexpr CellMeta normalize(CellType cell_type, bool is_scalar) { + if (is_scalar) { + return CellMeta(CellType::DOUBLE, true); + } else { + return CellMeta(cell_type, false); + } + } + + // unify the cell meta across two values + static constexpr CellMeta unify(CellMeta a, CellMeta b) { + if (a.is_scalar) { + return b; + } else if (b.is_scalar) { + return a; + } + if (a.cell_type == b.cell_type) { + return {a.cell_type, false}; + } else if ((a.cell_type == CellType::DOUBLE) || (b.cell_type == CellType::DOUBLE)) { + return {CellType::DOUBLE, false}; + } else { + return {CellType::FLOAT, false}; + } + } + + // convenience functions to be used for specific operations + constexpr CellMeta map() const { return decay(); } + static constexpr CellMeta reduce(CellType input_cell_type, bool output_is_scalar) { + return normalize(input_cell_type, output_is_scalar).decay(); + } + static constexpr CellMeta join(CellMeta a, CellMeta b) { return unify(a, b).decay(); } + static constexpr CellMeta merge(CellMeta a, CellMeta b) { return unify(a, b).decay(); } + static constexpr CellMeta concat(CellMeta a, CellMeta b) { return unify(a, b); } + static constexpr CellMeta peek(CellType input_cell_type, bool output_is_scalar) { + return normalize(input_cell_type, output_is_scalar); + } + constexpr CellMeta rename() const { return self(); } +}; -template <typename CT> inline CellType get_cell_type(); -template <> inline CellType get_cell_type<double>() { return CellType::DOUBLE; } -template <> inline CellType get_cell_type<float>() { return CellType::FLOAT; } +template <typename A, typename B> constexpr auto unify_cell_types() { + constexpr CellMeta a(get_cell_type<A>(), false); + constexpr CellMeta b(get_cell_type<B>(), false); + return get_cell_value<CellMeta::unify(a, b).cell_type>(); +} struct CellTypeUtils { static void bad_argument [[ noreturn ]] (uint32_t id); static constexpr uint32_t alignment(CellType cell_type) { switch (cell_type) { - case CellType::DOUBLE: return sizeof(double); - case CellType::FLOAT: return sizeof(float); + case CellType::DOUBLE: return alignof(double); + case CellType::FLOAT: return alignof(float); default: bad_argument((uint32_t)cell_type); } } @@ -53,7 +166,7 @@ struct CellTypeUtils { struct TypifyCellType { template <typename T> using Result = TypifyResultType<T>; template <typename F> static decltype(auto) resolve(CellType value, F &&f) { - switch(value) { + switch (value) { case CellType::DOUBLE: return f(Result<double>()); case CellType::FLOAT: return f(Result<float>()); } @@ -61,4 +174,52 @@ struct TypifyCellType { } }; +struct TypifyCellMeta { + template <CellMeta VALUE> using Result = TypifyResultValue<CellMeta, VALUE>; + template <typename F> static decltype(auto) resolve(CellMeta value, F &&f) { + if (value.is_scalar) { + if (value.cell_type == CellType::DOUBLE) { + return f(Result<CellMeta(CellType::DOUBLE, true)>()); + } + abort(); + } else { + switch (value.cell_type) { + case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); + case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + } + abort(); + } + } + template <typename F> static decltype(auto) resolve(CellMetaNotScalar value, F &&f) { + switch (value.cell_type) { + case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); + case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + } + abort(); + } + template <typename F> static decltype(auto) resolve(LimitedCellMeta value, F &&f) { + if (value.is_scalar) { + if (value.cell_type == CellType::DOUBLE) { + return f(Result<CellMeta(CellType::DOUBLE, true)>()); + } + abort(); + } else { + switch (value.cell_type) { + case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); + case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + default: break; + } + abort(); + } + } + template <typename F> static decltype(auto) resolve(LimitedCellMetaNotScalar value, F &&f) { + switch (value.cell_type) { + case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); + case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + default: break; + } + abort(); + } +}; + } // namespace diff --git a/eval/src/vespa/eval/eval/fast_value.cpp b/eval/src/vespa/eval/eval/fast_value.cpp index 4b39fc48d9b..07d76bb6c97 100644 --- a/eval/src/vespa/eval/eval/fast_value.cpp +++ b/eval/src/vespa/eval/eval/fast_value.cpp @@ -15,8 +15,8 @@ struct CreateFastValueBuilderBase { size_t num_mapped_dims, size_t subspace_size, size_t expected_subspaces) { assert(check_cell_type<T>(type.cell_type())); - if (type.is_scalar()) { - return std::make_unique<FastScalarBuilder<T>>(); + if (type.is_double()) { + return std::make_unique<FastDoubleValueBuilder>(); } else if (num_mapped_dims == 0) { return std::make_unique<FastDenseValue<T>>(type, subspace_size); } else { diff --git a/eval/src/vespa/eval/eval/fast_value.hpp b/eval/src/vespa/eval/eval/fast_value.hpp index 69a496e9bff..a8e1e38eba2 100644 --- a/eval/src/vespa/eval/eval/fast_value.hpp +++ b/eval/src/vespa/eval/eval/fast_value.hpp @@ -45,9 +45,9 @@ struct FastLookupView : public Value::Index::View { struct FastFilterView : public Value::Index::View { const FastAddrMap ↦ - std::vector<size_t> match_dims; - std::vector<size_t> extract_dims; - std::vector<string_id> query; + SmallVector<size_t> match_dims; + SmallVector<size_t> extract_dims; + SmallVector<string_id> query; size_t pos; bool is_match(ConstArrayRef<string_id> addr) const { @@ -332,12 +332,11 @@ template <typename T> FastDenseValue<T>::~FastDenseValue() = default; //----------------------------------------------------------------------------- -template <typename T> -struct FastScalarBuilder final : ValueBuilder<T> { - T _value; - ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref>) final override { return ArrayRef<T>(&_value, 1); } - ArrayRef<T> add_subspace(ConstArrayRef<string_id>) final override { return ArrayRef<T>(&_value, 1); }; - std::unique_ptr<Value> build(std::unique_ptr<ValueBuilder<T>>) final override { return std::make_unique<ScalarValue<T>>(_value); } +struct FastDoubleValueBuilder final : ValueBuilder<double> { + double _value; + ArrayRef<double> add_subspace(ConstArrayRef<vespalib::stringref>) final override { return ArrayRef<double>(&_value, 1); } + ArrayRef<double> add_subspace(ConstArrayRef<string_id>) final override { return ArrayRef<double>(&_value, 1); }; + std::unique_ptr<Value> build(std::unique_ptr<ValueBuilder<double>>) final override { return std::make_unique<DoubleValue>(_value); } }; //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index b03c2c1ed24..580a9b120d5 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -777,7 +777,7 @@ bool maybe_parse_tensor_generator(ParseContext &ctx) { ctx.restore_input_mark(my_mark); return false; } - bool is_create = (type.is_tensor() && (ctx.get() == ':')); + bool is_create = (type.has_dimensions() && (ctx.get() == ':')); bool is_lambda = (type.is_dense() && (ctx.get() == '(')); if (is_create) { parse_tensor_create(ctx, type, dim_list); diff --git a/eval/src/vespa/eval/eval/inline_operation.h b/eval/src/vespa/eval/eval/inline_operation.h index 21516c4d94e..d52705e652d 100644 --- a/eval/src/vespa/eval/eval/inline_operation.h +++ b/eval/src/vespa/eval/eval/inline_operation.h @@ -94,7 +94,10 @@ template <> struct InlineOp2<Mul> { }; template <> struct InlineOp2<Pow> { InlineOp2(op2_t) {} - template <typename A, typename B> constexpr auto operator()(A a, B b) const { return std::pow(a,b); } + constexpr float operator()(float a, float b) const { return std::pow(a,b); } + constexpr double operator()(float a, double b) const { return std::pow(a,b); } + constexpr double operator()(double a, float b) const { return std::pow(a,b); } + constexpr double operator()(double a, double b) const { return std::pow(a,b); } }; template <> struct InlineOp2<Sub> { InlineOp2(op2_t) {} diff --git a/eval/src/vespa/eval/eval/int8float.cpp b/eval/src/vespa/eval/eval/int8float.cpp new file mode 100644 index 00000000000..38af441fd08 --- /dev/null +++ b/eval/src/vespa/eval/eval/int8float.cpp @@ -0,0 +1,3 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "int8float.h" diff --git a/eval/src/vespa/eval/eval/int8float.h b/eval/src/vespa/eval/eval/int8float.h new file mode 100644 index 00000000000..7279d519bed --- /dev/null +++ b/eval/src/vespa/eval/eval/int8float.h @@ -0,0 +1,46 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <cstdint> +#include <vespa/vespalib/objects/nbostream.h> + +namespace vespalib::eval { + +/** + * Class holding an 8-bit integer which decays into float. + **/ +class Int8Float { +private: + int8_t _bits; +public: + constexpr Int8Float(float value) noexcept : _bits(value) {} + Int8Float() noexcept = default; + ~Int8Float() noexcept = default; + constexpr Int8Float(const Int8Float &other) noexcept = default; + constexpr Int8Float(Int8Float &&other) noexcept = default; + constexpr Int8Float& operator=(const Int8Float &other) noexcept = default; + constexpr Int8Float& operator=(Int8Float &&other) noexcept = default; + constexpr Int8Float& operator=(float value) noexcept { + _bits = value; + return *this; + } + + constexpr operator float () const noexcept { return _bits; } + + constexpr float to_float() const noexcept { return _bits; } + constexpr void assign(float value) noexcept { _bits = value; } + + constexpr int8_t get_bits() const { return _bits; } + constexpr void assign_bits(int8_t value) noexcept { _bits = value; } +}; + +inline nbostream & operator << (nbostream &stream, Int8Float v) { return stream << v.get_bits(); } +inline nbostream & operator >> (nbostream &stream, Int8Float & v) { + int8_t byte; + stream >> byte; + v.assign_bits(byte); + return stream; +} + +} diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 2df22d5433c..63da6d79c6f 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -109,7 +109,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { } void resolve_op1(const Node &node) { - bind(type(node.get_child(0)), node); + bind(type(node.get_child(0)).map(), node); } void resolve_op2(const Node &node) { @@ -234,7 +234,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { } } } - bind(param_type.reduce(dimensions), node); + bind(param_type.peek(dimensions), node); } void visit(const Add &node) override { resolve_op2(node); } void visit(const Sub &node) override { resolve_op2(node); } diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index a93c1dea83e..fd8b20bfed0 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -316,7 +316,7 @@ Lambda::create_spec_impl(const ValueType &type, const LazyParams ¶ms, const InterpretedFunction::Instruction Lambda::compile_self(const ValueBuilderFactory &factory, Stash &stash) const { - return instruction::GenericLambda::make_instruction(result_type(), *this, factory, stash); + return instruction::GenericLambda::make_instruction(*this, factory, stash); } void @@ -447,7 +447,7 @@ const TensorFunction &reduce(const TensorFunction &child, Aggr aggr, const std:: } const TensorFunction &map(const TensorFunction &child, map_fun_t function, Stash &stash) { - ValueType result_type = child.result_type(); + ValueType result_type = child.result_type().map(); return stash.create<Map>(result_type, child, function); } @@ -485,7 +485,7 @@ const TensorFunction &peek(const TensorFunction ¶m, const std::map<vespalib: dimensions.push_back(dim_spec.first); } assert(!dimensions.empty()); - ValueType result_type = param.result_type().reduce(dimensions); + ValueType result_type = param.result_type().peek(dimensions); return stash.create<Peek>(result_type, param, spec); } diff --git a/eval/src/vespa/eval/eval/test/gen_spec.cpp b/eval/src/vespa/eval/eval/test/gen_spec.cpp index 2f28f559051..a308cbeff46 100644 --- a/eval/src/vespa/eval/eval/test/gen_spec.cpp +++ b/eval/src/vespa/eval/eval/test/gen_spec.cpp @@ -60,6 +60,12 @@ GenSpec &GenSpec::operator=(const GenSpec &other) = default; GenSpec::~GenSpec() = default; +bool +GenSpec::bad_scalar() const +{ + return (_dims.empty() && (_cells != CellType::DOUBLE)); +} + ValueType GenSpec::type() const { @@ -67,7 +73,7 @@ GenSpec::type() const for (const auto &dim: _dims) { dim_types.push_back(dim.type()); } - auto type = ValueType::tensor_type(dim_types, _cells); + auto type = ValueType::make_type(_cells, dim_types); assert(!type.is_error()); return type; } @@ -77,6 +83,7 @@ GenSpec::gen() const { size_t idx = 0; TensorSpec::Address addr; + assert(!bad_scalar()); TensorSpec result(type().to_spec()); std::function<void(size_t)> add_cells = [&](size_t dim_idx) { if (dim_idx == _dims.size()) { diff --git a/eval/src/vespa/eval/eval/test/gen_spec.h b/eval/src/vespa/eval/eval/test/gen_spec.h index b3a3916967b..bbc1663a11f 100644 --- a/eval/src/vespa/eval/eval/test/gen_spec.h +++ b/eval/src/vespa/eval/eval/test/gen_spec.h @@ -128,6 +128,7 @@ public: _seq = seq_in; return *this; } + bool bad_scalar() const; ValueType type() const; TensorSpec gen() const; operator TensorSpec() const { return gen(); } diff --git a/eval/src/vespa/eval/eval/test/reference_operations.cpp b/eval/src/vespa/eval/eval/test/reference_operations.cpp index f09c74ca009..58c90de65a2 100644 --- a/eval/src/vespa/eval/eval/test/reference_operations.cpp +++ b/eval/src/vespa/eval/eval/test/reference_operations.cpp @@ -213,7 +213,7 @@ TensorSpec ReferenceOperations::peek(const PeekSpec &peek_spec, const std::vecto } TensorSpec param = children[0].normalize(); ValueType param_type = ValueType::from_spec(param.type()); - ValueType result_type = param_type.reduce(peek_dims); + ValueType result_type = param_type.peek(peek_dims); TensorSpec result(result_type.to_spec()); if (result_type.is_error()) { return result; diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 117e8c9b149..17ad75ae455 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -560,8 +560,10 @@ struct TestContext { void test_cell_cast(const GenSpec &a) { for (CellType cell_type: CellTypeUtils::list_types()) { + auto expect = a.cpy().cells(cell_type); + if (expect.bad_scalar()) continue; vespalib::string expr = fmt("cell_cast(a,%s)", value_type::cell_type_to_name(cell_type).c_str()); - TEST_DO(verify_result(factory, expr, {a}, a.cpy().cells(cell_type))); + TEST_DO(verify_result(factory, expr, {a}, expect)); } } @@ -570,8 +572,8 @@ struct TestContext { for (CellType cell_type: CellTypeUtils::list_types()) { gen_list.push_back(GenSpec(-3).cells(cell_type)); } + TEST_DO(test_cell_cast(GenSpec(42))); for (const auto &gen: gen_list) { - TEST_DO(test_cell_cast(gen)); TEST_DO(test_cell_cast(gen.cpy().idx("x", 10))); TEST_DO(test_cell_cast(gen.cpy().map("x", 10, 1))); TEST_DO(test_cell_cast(gen.cpy().map("x", 4, 1).idx("y", 4))); diff --git a/eval/src/vespa/eval/eval/value.cpp b/eval/src/vespa/eval/eval/value.cpp index 9cd5ef45765..b799658cfae 100644 --- a/eval/src/vespa/eval/eval/value.cpp +++ b/eval/src/vespa/eval/eval/value.cpp @@ -58,11 +58,7 @@ Value::as_double() const return typify_invoke<1,TypifyCellType,MySum>(type().cell_type(), cells()); } -template <typename T> -ValueType ScalarValue<T>::_type = ValueType::make_type(get_cell_type<T>(), {}); - -template class ScalarValue<double>; -template class ScalarValue<float>; +ValueType DoubleValue::_type = ValueType::double_type(); namespace { diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h index ee850cdd47e..fcdaa7131c7 100644 --- a/eval/src/vespa/eval/eval/value.h +++ b/eval/src/vespa/eval/eval/value.h @@ -22,7 +22,6 @@ struct Value { virtual const ValueType &type() const = 0; virtual ~Value() {} -// ---- new interface enabling separation of values and operations // Root lookup structure for mapping labels to dense subspace indexes struct Index { @@ -59,16 +58,9 @@ struct Value { }; virtual TypedCells cells() const = 0; virtual const Index &index() const = 0; -// --- end of new interface - virtual MemoryUsage get_memory_usage() const = 0; - -// --- old interface that may be (partially) removed in the future - virtual bool is_double() const { return type().is_double(); } - virtual bool is_tensor() const { return type().is_tensor(); } virtual double as_double() const; bool as_bool() const { return (as_double() != 0.0); } -// --- end of old interface }; /** @@ -84,28 +76,21 @@ public: std::unique_ptr<View> create_view(ConstArrayRef<size_t> dims) const override; }; -template <typename T> -class ScalarValue final : public Value +class DoubleValue final : public Value { private: - T _value; + double _value; static ValueType _type; public: - ScalarValue(T value) : _value(value) {} - TypedCells cells() const final override { return TypedCells(ConstArrayRef<T>(&_value, 1)); } + DoubleValue(double value) : _value(value) {} + TypedCells cells() const final override { return TypedCells(ConstArrayRef<double>(&_value, 1)); } const Index &index() const final override { return TrivialIndex::get(); } - MemoryUsage get_memory_usage() const final override { return self_memory_usage<ScalarValue<T>>(); } - bool is_double() const final override { return std::is_same_v<T,double>; } + MemoryUsage get_memory_usage() const final override { return self_memory_usage<DoubleValue>(); } double as_double() const final override { return _value; } const ValueType &type() const final override { return _type; } static const ValueType &shared_type() { return _type; } }; -extern template class ScalarValue<double>; -extern template class ScalarValue<float>; - -using DoubleValue = ScalarValue<double>; - /** * A generic value without any mapped dimensions referencing its * components without owning anything. @@ -227,7 +212,6 @@ protected: } -VESPA_CAN_SKIP_DESTRUCTION(::vespalib::eval::ScalarValue<double>); -VESPA_CAN_SKIP_DESTRUCTION(::vespalib::eval::ScalarValue<float>); +VESPA_CAN_SKIP_DESTRUCTION(::vespalib::eval::DoubleValue); VESPA_CAN_SKIP_DESTRUCTION(::vespalib::eval::DenseValueView); VESPA_CAN_SKIP_DESTRUCTION(::vespalib::eval::ValueView); diff --git a/eval/src/vespa/eval/eval/value_codec.cpp b/eval/src/vespa/eval/eval/value_codec.cpp index 0016dfc694f..bf45f34fd64 100644 --- a/eval/src/vespa/eval/eval/value_codec.cpp +++ b/eval/src/vespa/eval/eval/value_codec.cpp @@ -6,6 +6,7 @@ #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/util/typify.h> +#include <vespa/vespalib/util/small_vector.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/util/shared_string_repo.h> @@ -116,10 +117,12 @@ ValueType decode_type(nbostream &input, const Format &format) { dim_list.emplace_back(name, input.getInt1_4Bytes()); } } - if (dim_list.empty()) { - assert(cell_type == CellType::DOUBLE); + auto result = ValueType::make_type(cell_type, std::move(dim_list)); + if (result.is_error()) { + throw IllegalArgumentException(fmt("Invalid type (with %zu dimensions and cell type %u)", + dim_list.size(), (uint32_t)cell_type)); } - return ValueType::tensor_type(std::move(dim_list), cell_type); + return result; } size_t maybe_decode_num_blocks(nbostream &input, bool has_mapped_dims, const Format &format) { @@ -129,14 +132,14 @@ size_t maybe_decode_num_blocks(nbostream &input, bool has_mapped_dims, const For return 1; } -void encode_mapped_labels(nbostream &output, size_t num_mapped_dims, const std::vector<string_id> &addr) { +void encode_mapped_labels(nbostream &output, size_t num_mapped_dims, const SmallVector<string_id> &addr) { for (size_t i = 0; i < num_mapped_dims; ++i) { vespalib::string str = SharedStringRepo::Handle::string_from_id(addr[i]); output.writeSmallString(str); } } -void decode_mapped_labels(nbostream &input, size_t num_mapped_dims, std::vector<vespalib::stringref> &addr) { +void decode_mapped_labels(nbostream &input, size_t num_mapped_dims, SmallVector<vespalib::stringref> &addr) { for (size_t i = 0; i < num_mapped_dims; ++i) { size_t strSize = input.getInt1_4Bytes(); addr[i] = vespalib::stringref(input.peek(), strSize); @@ -163,7 +166,7 @@ struct DecodeState { struct ContentDecoder { template<typename T> static std::unique_ptr<Value> invoke(nbostream &input, const DecodeState &state, const ValueBuilderFactory &factory) { - std::vector<vespalib::stringref> address(state.num_mapped_dims); + SmallVector<vespalib::stringref> address(state.num_mapped_dims); if (state.num_blocks * state.subspace_size * sizeof(T) > input.size()) { auto err = fmt("serialized input claims %zu blocks of size %zu*%zu, but only %zu bytes available", state.num_blocks, state.subspace_size, sizeof(T), input.size()); @@ -190,7 +193,7 @@ struct CreateValueFromTensorSpec { size_t dense_size = type.dense_subspace_size(); ArrayArrayMap<vespalib::stringref,T> map(type.count_mapped_dimensions(), dense_size, std::max(spec.cells().size() / dense_size, size_t(1))); - std::vector<vespalib::stringref> sparse_key; + SmallVector<vespalib::stringref> sparse_key; for (const auto &entry: spec.cells()) { sparse_key.clear(); size_t dense_key = 0; @@ -231,8 +234,8 @@ struct CreateTensorSpecFromValue { TensorSpec spec(value.type().to_spec()); size_t subspace_id = 0; size_t subspace_size = value.type().dense_subspace_size(); - std::vector<string_id> labels(value.type().count_mapped_dimensions()); - std::vector<string_id*> label_refs; + SmallVector<string_id> labels(value.type().count_mapped_dimensions()); + SmallVector<string_id*> label_refs; for (auto &label: labels) { label_refs.push_back(&label); } @@ -272,8 +275,8 @@ struct EncodeState { struct ContentEncoder { template<typename T> static void invoke(const Value &value, const EncodeState &state, nbostream &output) { - std::vector<string_id> address(state.num_mapped_dims); - std::vector<string_id*> a_refs(state.num_mapped_dims);; + SmallVector<string_id> address(state.num_mapped_dims); + SmallVector<string_id*> a_refs(state.num_mapped_dims);; for (size_t i = 0; i < state.num_mapped_dims; ++i) { a_refs[i] = &address[i]; } diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp index b70edef7153..fbd04babc70 100644 --- a/eval/src/vespa/eval/eval/value_type.cpp +++ b/eval/src/vespa/eval/eval/value_type.cpp @@ -3,6 +3,7 @@ #include "value_type.h" #include "value_type_spec.h" #include <algorithm> +#include <cassert> namespace vespalib::eval { @@ -11,29 +12,6 @@ namespace { using Dimension = ValueType::Dimension; using DimensionList = std::vector<Dimension>; -template <typename A, typename B> -CellType unify() { - using type = typename UnifyCellTypes<A,B>::type; - return get_cell_type<type>(); -} - -template <typename A> -CellType unify(CellType b) { - switch (b) { - case CellType::DOUBLE: return unify<A,double>(); - case CellType::FLOAT: return unify<A,float>(); - } - abort(); -} - -CellType unify(CellType a, CellType b) { - switch (a) { - case CellType::DOUBLE: return unify<double>(b); - case CellType::FLOAT: return unify<float>(b); - } - abort(); -} - size_t my_dimension_index(const std::vector<Dimension> &list, const vespalib::string &name) { for (size_t idx = 0; idx < list.size(); ++idx) { if (list[idx].name == name) { @@ -65,6 +43,28 @@ bool verify_dimensions(const DimensionList &dimensions) { return true; } +struct MyReduce { + bool has_error; + std::vector<Dimension> dimensions; + MyReduce(const std::vector<Dimension> &dim_list, const std::vector<vespalib::string> &rem_list) + : has_error(false), dimensions() + { + if (!rem_list.empty()) { + size_t removed = 0; + for (const Dimension &dim: dim_list) { + if (std::find(rem_list.begin(), rem_list.end(), dim.name) == rem_list.end()) { + dimensions.push_back(dim); + } else { + ++removed; + } + } + if (removed != rem_list.size()) { + has_error = true; + } + } + } +}; + struct MyJoin { bool mismatch; DimensionList dimensions; @@ -142,9 +142,28 @@ struct Renamer { constexpr ValueType::Dimension::size_type ValueType::Dimension::npos; +ValueType +ValueType::error_if(bool has_error, ValueType else_type) +{ + if (has_error) { + return error_type(); + } else { + return else_type; + } +} + ValueType::~ValueType() = default; bool +ValueType::is_double() const { + if (!_error && _dimensions.empty()) { + assert(_cell_type == CellType::DOUBLE); + return true; + } + return false; +} + +bool ValueType::is_sparse() const { if (dimensions().empty()) { @@ -246,26 +265,28 @@ ValueType::dimension_names() const } ValueType +ValueType::map() const +{ + auto meta = cell_meta().map(); + return error_if(_error, make_type(meta.cell_type, _dimensions)); +} + +ValueType ValueType::reduce(const std::vector<vespalib::string> &dimensions_in) const { - if (is_error()) { - return error_type(); - } else if (dimensions_in.empty()) { - return double_type(); - } - size_t removed = 0; - std::vector<Dimension> result; - for (const Dimension &d: _dimensions) { - if (std::find(dimensions_in.begin(), dimensions_in.end(), d.name) == dimensions_in.end()) { - result.push_back(d); - } else { - ++removed; - } - } - if (removed != dimensions_in.size()) { - return error_type(); - } - return tensor_type(std::move(result), _cell_type); + MyReduce result(_dimensions, dimensions_in); + auto meta = CellMeta::reduce(_cell_type, result.dimensions.empty()); + return error_if(_error || result.has_error, + make_type(meta.cell_type, std::move(result.dimensions))); +} + +ValueType +ValueType::peek(const std::vector<vespalib::string> &dimensions_in) const +{ + MyReduce result(_dimensions, dimensions_in); + auto meta = CellMeta::peek(_cell_type, result.dimensions.empty()); + return error_if(_error || result.has_error || dimensions_in.empty(), + make_type(meta.cell_type, std::move(result.dimensions))); } ValueType @@ -280,25 +301,24 @@ ValueType::rename(const std::vector<vespalib::string> &from, for (const auto &dim: _dimensions) { dim_list.emplace_back(renamer.rename(dim.name), dim.size); } - if (!renamer.matched_all()) { - return error_type(); - } - return tensor_type(dim_list, _cell_type); + auto meta = cell_meta().rename(); + return error_if(!renamer.matched_all(), + make_type(meta.cell_type, std::move(dim_list))); } ValueType ValueType::cell_cast(CellType to_cell_type) const { - if (is_error()) { - return error_type(); - } - // TODO: return make_type(to_cell_type, _dimensions); - return tensor_type(_dimensions, to_cell_type); + return error_if(_error, make_type(to_cell_type, _dimensions)); } ValueType ValueType::make_type(CellType cell_type, std::vector<Dimension> dimensions_in) { + if (dimensions_in.empty() && (cell_type != CellType::DOUBLE)) { + // Note: all scalar values must have cell_type double + return error_type(); + } sort_dimensions(dimensions_in); if (!verify_dimensions(dimensions_in)) { return error_type(); @@ -307,15 +327,6 @@ ValueType::make_type(CellType cell_type, std::vector<Dimension> dimensions_in) } ValueType -ValueType::tensor_type(std::vector<Dimension> dimensions_in, CellType cell_type) -{ - if (dimensions_in.empty()) { - return double_type(); - } - return make_type(cell_type, std::move(dimensions_in)); -} - -ValueType ValueType::from_spec(const vespalib::string &spec) { return value_type::from_spec(spec); @@ -336,66 +347,35 @@ ValueType::to_spec() const ValueType ValueType::join(const ValueType &lhs, const ValueType &rhs) { - if (lhs.is_error() || rhs.is_error()) { - return error_type(); - } else if (lhs.is_double()) { - return rhs; - } else if (rhs.is_double()) { - return lhs; - } MyJoin result(lhs._dimensions, rhs._dimensions); - if (result.mismatch) { - return error_type(); - } - return tensor_type(std::move(result.dimensions), unify(lhs._cell_type, rhs._cell_type)); + auto meta = CellMeta::join(lhs.cell_meta(), rhs.cell_meta()); + return error_if(lhs._error || rhs._error || result.mismatch, + make_type(meta.cell_type, std::move(result.dimensions))); } ValueType ValueType::merge(const ValueType &lhs, const ValueType &rhs) { - if ((lhs.is_error() != rhs.is_error()) || - (lhs.dimensions() != rhs.dimensions())) - { - return error_type(); - } - if (lhs.dimensions().empty()) { - return lhs; - } - return tensor_type(lhs.dimensions(), unify(lhs._cell_type, rhs._cell_type)); -} - -CellType -ValueType::unify_cell_types(const ValueType &a, const ValueType &b) { - if (a.is_double()) { - return b.cell_type(); - } else if (b.is_double()) { - return a.cell_type(); - } - return unify(a.cell_type(), b.cell_type()); + auto meta = CellMeta::merge(lhs.cell_meta(), rhs.cell_meta()); + return error_if(lhs._error || rhs._error || (lhs._dimensions != rhs._dimensions), + make_type(meta.cell_type, lhs._dimensions)); } ValueType ValueType::concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension) { - if (lhs.is_error() || rhs.is_error()) { - return error_type(); - } MyJoin result(lhs._dimensions, rhs._dimensions, dimension); - if (result.mismatch) { - return error_type(); - } if (!find_dimension(result.dimensions, dimension)) { result.dimensions.emplace_back(dimension, 2); } - return tensor_type(std::move(result.dimensions), unify_cell_types(lhs, rhs)); + auto meta = CellMeta::concat(lhs.cell_meta(), rhs.cell_meta()); + return error_if(lhs._error || rhs._error || result.mismatch, + make_type(meta.cell_type, std::move(result.dimensions))); } ValueType ValueType::either(const ValueType &one, const ValueType &other) { - if (one != other) { - return error_type(); - } - return one; + return error_if(one != other, one); } std::ostream & diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index 247912b274a..e1a0d073337 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -45,6 +45,8 @@ private: ValueType(CellType cell_type_in, std::vector<Dimension> &&dimensions_in) : _error(false), _cell_type(cell_type_in), _dimensions(std::move(dimensions_in)) {} + static ValueType error_if(bool has_error, ValueType else_type); + public: ValueType(ValueType &&) noexcept = default; ValueType(const ValueType &) = default; @@ -52,22 +54,12 @@ public: ValueType &operator=(const ValueType &) = default; ~ValueType(); CellType cell_type() const { return _cell_type; } + CellMeta cell_meta() const { return {_cell_type, is_double()}; } bool is_error() const { return _error; } - bool is_scalar() const { return _dimensions.empty(); } + bool is_double() const; + bool has_dimensions() const { return !_dimensions.empty(); } bool is_sparse() const; bool is_dense() const; - - // TODO: remove is_double and is_tensor - // is_tensor should no longer be useful - // is_double should be replaced with is_scalar where you also - // handle cell type correctly (free float values will - // not be introduced by type-resolving just yet, so - // is_double and is_scalar will be interchangeable in - // most cases for a while) - - bool is_double() const { return (!_error && is_scalar() && (_cell_type == CellType::DOUBLE)); } - bool is_tensor() const { return (!_dimensions.empty()); } - size_t count_indexed_dimensions() const; size_t count_mapped_dimensions() const; size_t dense_subspace_size() const; @@ -83,27 +75,21 @@ public: } bool operator!=(const ValueType &rhs) const { return !(*this == rhs); } + ValueType map() const; ValueType reduce(const std::vector<vespalib::string> &dimensions_in) const; + ValueType peek(const std::vector<vespalib::string> &dimensions_in) const; ValueType rename(const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to) const; ValueType cell_cast(CellType to_cell_type) const; static ValueType error_type() { return ValueType(); } static ValueType make_type(CellType cell_type, std::vector<Dimension> dimensions_in); - - // TODO: remove double_type and tensor_type and use make_type - // directly. Currently the tensor_type function contains - // protection against ending up with scalar float values. - static ValueType double_type() { return make_type(CellType::DOUBLE, {}); } - static ValueType tensor_type(std::vector<Dimension> dimensions_in, CellType cell_type = CellType::DOUBLE); - static ValueType from_spec(const vespalib::string &spec); static ValueType from_spec(const vespalib::string &spec, std::vector<ValueType::Dimension> &unsorted); vespalib::string to_spec() const; static ValueType join(const ValueType &lhs, const ValueType &rhs); static ValueType merge(const ValueType &lhs, const ValueType &rhs); - static CellType unify_cell_types(const ValueType &a, const ValueType &b); static ValueType concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension); static ValueType either(const ValueType &one, const ValueType &other); }; diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp index 470da4f63a3..b518ccd1b30 100644 --- a/eval/src/vespa/eval/eval/value_type_spec.cpp +++ b/eval/src/vespa/eval/eval/value_type_spec.cpp @@ -192,9 +192,7 @@ parse_spec(const char *pos_in, const char *end_in, const char *&pos_out, if (type_name == "error") { return ValueType::error_type(); } else if (type_name == "double") { - return ValueType::make_type(CellType::DOUBLE, {}); - } else if (type_name == "float") { - return ValueType::make_type(CellType::FLOAT, {}); + return ValueType::double_type(); } else if (type_name == "tensor") { CellType cell_type = parse_cell_type(ctx); std::vector<ValueType::Dimension> list = parse_dimension_list(ctx); @@ -202,7 +200,7 @@ parse_spec(const char *pos_in, const char *end_in, const char *&pos_out, if (unsorted != nullptr) { *unsorted = list; } - return ValueType::tensor_type(std::move(list), cell_type); + return ValueType::make_type(cell_type, std::move(list)); } } else { ctx.fail(); @@ -241,8 +239,8 @@ to_spec(const ValueType &type) size_t cnt = 0; if (type.is_error()) { os << "error"; - } else if (type.is_scalar()) { - os << cell_type_to_name(type.cell_type()); + } else if (type.is_double()) { + os << "double"; } else { os << "tensor"; if (type.cell_type() != CellType::DOUBLE) { diff --git a/eval/src/vespa/eval/instruction/dense_cell_range_function.cpp b/eval/src/vespa/eval/instruction/dense_cell_range_function.cpp index 4c655c67747..78fecbd10d8 100644 --- a/eval/src/vespa/eval/instruction/dense_cell_range_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_cell_range_function.cpp @@ -31,6 +31,7 @@ DenseCellRangeFunction::DenseCellRangeFunction(const ValueType &result_type, _offset(offset), _length(length) { + assert(result_type.cell_type() == child.result_type().cell_type()); } DenseCellRangeFunction::~DenseCellRangeFunction() = default; diff --git a/eval/src/vespa/eval/instruction/dense_lambda_peek_optimizer.cpp b/eval/src/vespa/eval/instruction/dense_lambda_peek_optimizer.cpp index 39b2bed4017..e584d94edbe 100644 --- a/eval/src/vespa/eval/instruction/dense_lambda_peek_optimizer.cpp +++ b/eval/src/vespa/eval/instruction/dense_lambda_peek_optimizer.cpp @@ -65,9 +65,9 @@ Node_UP make_floor(Node_UP a) { } struct PeekAnalyzer { - std::vector<size_t> dst_dim_sizes; - std::vector<size_t> src_dim_sizes; - std::vector<CompiledFunction::UP> src_dim_funs; + SmallVector<size_t> dst_dim_sizes; + SmallVector<size_t> src_dim_sizes; + SmallVector<CompiledFunction::UP> src_dim_funs; std::shared_ptr<Function const> src_idx_fun; struct CellRange { @@ -111,7 +111,7 @@ struct PeekAnalyzer { src_idx_fun = Function::create(std::move(idx_expr), dst_type.dimension_names()); } - bool step_params(std::vector<double> ¶ms) { + bool step_params(SmallVector<double> ¶ms) { for (size_t idx = params.size(); idx-- > 0; ) { if (size_t(params[idx] += 1.0) < dst_dim_sizes[idx]) { return true; @@ -122,7 +122,7 @@ struct PeekAnalyzer { return false; } - size_t calculate_index(const std::vector<size_t> &src_address) { + size_t calculate_index(const SmallVector<size_t> &src_address) { size_t result = 0; for (size_t i = 0; i < src_address.size(); ++i) { result *= src_dim_sizes[i]; @@ -134,8 +134,8 @@ struct PeekAnalyzer { Result analyze_indexes() { CellRange range{0,0}; bool is_complex = false; - std::vector<double> params(dst_dim_sizes.size(), 0.0); - std::vector<size_t> src_address(src_dim_sizes.size(), 0); + SmallVector<double> params(dst_dim_sizes.size(), 0.0); + SmallVector<size_t> src_address(src_dim_sizes.size(), 0); do { for (size_t i = 0; i < src_dim_funs.size(); ++i) { auto dim_fun = src_dim_funs[i]->get_function(); diff --git a/eval/src/vespa/eval/instruction/dense_matmul_function.cpp b/eval/src/vespa/eval/instruction/dense_matmul_function.cpp index 33d9054820b..11ad646d0f5 100644 --- a/eval/src/vespa/eval/instruction/dense_matmul_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_matmul_function.cpp @@ -28,7 +28,7 @@ double my_dot_product(const LCT *lhs, const RCT *rhs, size_t lhs_size, size_t co template <typename LCT, typename RCT, bool lhs_common_inner, bool rhs_common_inner> void my_matmul_op(InterpretedFunction::State &state, uint64_t param) { const DenseMatMulFunction::Self &self = unwrap_param<DenseMatMulFunction::Self>(param); - using OCT = typename UnifyCellTypes<LCT,RCT>::type; + using OCT = decltype(unify_cell_types<LCT,RCT>()); auto lhs_cells = state.peek(1).cells().typify<LCT>(); auto rhs_cells = state.peek(0).cells().typify<RCT>(); auto dst_cells = state.stash.create_uninitialized_array<OCT>(self.lhs_size * self.rhs_size); diff --git a/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp b/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp index e67aa042881..55c6760a391 100644 --- a/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_simple_expand_function.cpp @@ -36,7 +36,7 @@ template <typename LCT, typename RCT, typename Fun, bool rhs_inner> void my_simple_expand_op(State &state, uint64_t param) { using ICT = typename std::conditional<rhs_inner,RCT,LCT>::type; using OCT = typename std::conditional<rhs_inner,LCT,RCT>::type; - using DCT = typename UnifyCellTypes<ICT,OCT>::type; + using DCT = decltype(unify_cell_types<LCT,RCT>()); using OP = typename std::conditional<rhs_inner,SwapArgs2<Fun>,Fun>::type; const ExpandParams ¶ms = unwrap_param<ExpandParams>(param); OP my_op(params.function); diff --git a/eval/src/vespa/eval/instruction/dense_tensor_peek_function.cpp b/eval/src/vespa/eval/instruction/dense_tensor_peek_function.cpp index 07fd0f8938c..b5f99ad6c8b 100644 --- a/eval/src/vespa/eval/instruction/dense_tensor_peek_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_tensor_peek_function.cpp @@ -6,7 +6,7 @@ namespace vespalib::eval { using Child = TensorFunction::Child; -using SpecVector = std::vector<std::pair<int64_t,size_t>>; +using SpecVector = SmallVector<std::pair<int64_t,size_t>>; using namespace tensor_function; namespace { @@ -30,7 +30,7 @@ void my_tensor_peek_op(InterpretedFunction::State &state, uint64_t param) { } auto cells = state.peek(0).cells().typify<CT>(); state.stack.pop_back(); - const Value &result = state.stash.create<DoubleValue>(valid ? cells[idx] : 0.0); + const Value &result = state.stash.create<DoubleValue>(valid ? double(cells[idx]) : 0.0); state.stack.emplace_back(result); } @@ -42,7 +42,7 @@ struct MyTensorPeekOp { } // namespace <unnamed> DenseTensorPeekFunction::DenseTensorPeekFunction(std::vector<Child> children, - std::vector<std::pair<int64_t,size_t>> spec) + SmallVector<std::pair<int64_t,size_t>> spec) : TensorFunction(), _children(std::move(children)), _spec(std::move(spec)) @@ -73,7 +73,7 @@ DenseTensorPeekFunction::optimize(const TensorFunction &expr, Stash &stash) if (auto peek = as<Peek>(expr)) { const ValueType &peek_type = peek->param_type(); if (expr.result_type().is_double() && peek_type.is_dense()) { - std::vector<std::pair<int64_t,size_t>> spec; + SmallVector<std::pair<int64_t,size_t>> spec; assert(peek_type.dimensions().size() == peek->map().size()); for (auto dim = peek_type.dimensions().rbegin(); dim != peek_type.dimensions().rend(); ++dim) { auto dim_spec = peek->map().find(dim->name); diff --git a/eval/src/vespa/eval/instruction/dense_tensor_peek_function.h b/eval/src/vespa/eval/instruction/dense_tensor_peek_function.h index bfd217991aa..30ba510c168 100644 --- a/eval/src/vespa/eval/instruction/dense_tensor_peek_function.h +++ b/eval/src/vespa/eval/instruction/dense_tensor_peek_function.h @@ -20,9 +20,9 @@ private: // index and size of all dimensions in reverse order // when index is -1, use result of next child expression // (note that child expression order is inverted by the stack) - std::vector<std::pair<int64_t,size_t>> _spec; + SmallVector<std::pair<int64_t,size_t>> _spec; public: - DenseTensorPeekFunction(std::vector<Child> children, std::vector<std::pair<int64_t,size_t>> spec); + DenseTensorPeekFunction(std::vector<Child> children, SmallVector<std::pair<int64_t,size_t>> spec); ~DenseTensorPeekFunction(); const ValueType &result_type() const override { return DoubleValue::shared_type(); } void push_children(std::vector<Child::CREF> &children) const override; diff --git a/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp b/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp index 7d124555f55..b68a3a87ef1 100644 --- a/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_xw_product_function.cpp @@ -29,7 +29,7 @@ double my_dot_product(const LCT *lhs, const RCT *rhs, size_t vector_size, size_t template <typename LCT, typename RCT, bool common_inner> void my_xw_product_op(InterpretedFunction::State &state, uint64_t param) { const DenseXWProductFunction::Self &self = unwrap_param<DenseXWProductFunction::Self>(param); - using OCT = typename UnifyCellTypes<LCT,RCT>::type; + using OCT = decltype(unify_cell_types<LCT,RCT>()); auto vector_cells = state.peek(1).cells().typify<LCT>(); auto matrix_cells = state.peek(0).cells().typify<RCT>(); auto dst_cells = state.stash.create_uninitialized_array<OCT>(self.result_size); diff --git a/eval/src/vespa/eval/instruction/generic_concat.cpp b/eval/src/vespa/eval/instruction/generic_concat.cpp index c878d099c5e..61f736d43d2 100644 --- a/eval/src/vespa/eval/instruction/generic_concat.cpp +++ b/eval/src/vespa/eval/instruction/generic_concat.cpp @@ -26,9 +26,10 @@ struct ConcatParam DenseConcatPlan dense_plan; const ValueBuilderFactory &factory; - ConcatParam(const ValueType &lhs_type, const ValueType &rhs_type, + ConcatParam(const ValueType &res_type_in, + const ValueType &lhs_type, const ValueType &rhs_type, const vespalib::string &dimension, const ValueBuilderFactory &factory_in) - : res_type(ValueType::concat(lhs_type, rhs_type, dimension)), + : res_type(res_type_in), sparse_plan(lhs_type, rhs_type), dense_plan(lhs_type, rhs_type, dimension, res_type), factory(factory_in) @@ -243,8 +244,8 @@ GenericConcat::make_instruction(const ValueType &result_type, const vespalib::string &dimension, const ValueBuilderFactory &factory, Stash &stash) { - auto ¶m = stash.create<ConcatParam>(lhs_type, rhs_type, dimension, factory); - assert(result_type == param.res_type); + auto ¶m = stash.create<ConcatParam>(result_type, lhs_type, rhs_type, dimension, factory); + assert(result_type == ValueType::concat(lhs_type, rhs_type, dimension)); auto fun = typify_invoke<3,TypifyCellType,SelectGenericConcatOp>( lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), param); diff --git a/eval/src/vespa/eval/instruction/generic_join.cpp b/eval/src/vespa/eval/instruction/generic_join.cpp index 313aa38f753..8881794c6bb 100644 --- a/eval/src/vespa/eval/instruction/generic_join.cpp +++ b/eval/src/vespa/eval/instruction/generic_join.cpp @@ -113,30 +113,35 @@ void my_dense_join_op(State &state, uint64_t param_in) { //----------------------------------------------------------------------------- -template <typename LCT, typename RCT, typename OCT, typename Fun> -void my_scalar_join_op(State &state, uint64_t param_in) { +template <typename Fun> +void my_double_join_op(State &state, uint64_t param_in) { Fun fun(unwrap_param<JoinParam>(param_in).function); - state.pop_pop_push(state.stash.create<ScalarValue<OCT>>(fun(state.peek(1).cells().typify<LCT>()[0], - state.peek(0).cells().typify<RCT>()[0]))); + state.pop_pop_push(state.stash.create<DoubleValue>(fun(state.peek(1).as_double(), + state.peek(0).as_double()))); }; //----------------------------------------------------------------------------- struct SelectGenericJoinOp { - template <typename LCT, typename RCT, typename OCT, typename Fun> static auto invoke(const JoinParam ¶m) { - if (param.res_type.is_scalar()) { - return my_scalar_join_op<LCT,RCT,OCT,Fun>; - } - if (param.sparse_plan.sources.empty()) { - return my_dense_join_op<LCT,RCT,OCT,Fun>; - } - if (param.sparse_plan.should_forward_lhs_index()) { - return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,true>; - } - if (param.sparse_plan.should_forward_rhs_index()) { - return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,false>; + template <typename LCM, typename RCM, typename Fun> static auto invoke(const JoinParam ¶m) { + constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value); + using LCT = CellValueType<LCM::value.cell_type>; + using RCT = CellValueType<RCM::value.cell_type>; + using OCT = CellValueType<ocm.cell_type>; + if constexpr (ocm.is_scalar) { + return my_double_join_op<Fun>; + } else { + if (param.sparse_plan.sources.empty()) { + return my_dense_join_op<LCT,RCT,OCT,Fun>; + } + if (param.sparse_plan.should_forward_lhs_index()) { + return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,true>; + } + if (param.sparse_plan.should_forward_rhs_index()) { + return my_mixed_dense_join_op<LCT,RCT,OCT,Fun,false>; + } + return my_mixed_join_op<LCT,RCT,OCT,Fun>; } - return my_mixed_join_op<LCT,RCT,OCT,Fun>; } }; @@ -284,16 +289,17 @@ JoinParam::~JoinParam() = default; //----------------------------------------------------------------------------- -using JoinTypify = TypifyValue<TypifyCellType,operation::TypifyOp2>; +using JoinTypify = TypifyValue<TypifyCellMeta,operation::TypifyOp2>; Instruction GenericJoin::make_instruction(const ValueType &result_type, const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function, const ValueBuilderFactory &factory, Stash &stash) { - auto ¶m = stash.create<JoinParam>(lhs_type, rhs_type, function, factory); - assert(result_type == param.res_type); - auto fun = typify_invoke<4,JoinTypify,SelectGenericJoinOp>(lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), function, param); + auto ¶m = stash.create<JoinParam>(result_type, lhs_type, rhs_type, function, factory); + assert(result_type == ValueType::join(lhs_type, rhs_type)); + assert(param.res_type.cell_meta().eq(CellMeta::join(lhs_type.cell_meta(), rhs_type.cell_meta()))); + auto fun = typify_invoke<3,JoinTypify,SelectGenericJoinOp>(lhs_type.cell_meta(), rhs_type.cell_meta(), function, param); return Instruction(fun, wrap_param<JoinParam>(param)); } diff --git a/eval/src/vespa/eval/instruction/generic_join.h b/eval/src/vespa/eval/instruction/generic_join.h index 6ac2472ea2a..80a1179e0d5 100644 --- a/eval/src/vespa/eval/instruction/generic_join.h +++ b/eval/src/vespa/eval/instruction/generic_join.h @@ -98,9 +98,10 @@ struct JoinParam { DenseJoinPlan dense_plan; join_fun_t function; const ValueBuilderFactory &factory; - JoinParam(const ValueType &lhs_type, const ValueType &rhs_type, + JoinParam(const ValueType &res_type_in, + const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function_in, const ValueBuilderFactory &factory_in) - : res_type(ValueType::join(lhs_type, rhs_type)), + : res_type(res_type_in), sparse_plan(lhs_type, rhs_type), dense_plan(lhs_type, rhs_type), function(function_in), diff --git a/eval/src/vespa/eval/instruction/generic_lambda.cpp b/eval/src/vespa/eval/instruction/generic_lambda.cpp index 19d98773aa6..2b0d6a18035 100644 --- a/eval/src/vespa/eval/instruction/generic_lambda.cpp +++ b/eval/src/vespa/eval/instruction/generic_lambda.cpp @@ -121,10 +121,10 @@ struct MyInterpretedLambdaOp { } // namespace <unnamed> Instruction -GenericLambda::make_instruction(const ValueType &result_type, - const tensor_function::Lambda &lambda_in, +GenericLambda::make_instruction(const tensor_function::Lambda &lambda_in, const ValueBuilderFactory &factory, Stash &stash) { + const ValueType & result_type = lambda_in.result_type(); assert(result_type.count_mapped_dimensions() == 0); if (!CompiledFunction::detect_issues(lambda_in.lambda()) && lambda_in.types().all_types_are_double()) diff --git a/eval/src/vespa/eval/instruction/generic_lambda.h b/eval/src/vespa/eval/instruction/generic_lambda.h index eef09c9d79f..a5f4c10e214 100644 --- a/eval/src/vespa/eval/instruction/generic_lambda.h +++ b/eval/src/vespa/eval/instruction/generic_lambda.h @@ -10,8 +10,7 @@ namespace vespalib::eval::instruction { struct GenericLambda { static InterpretedFunction::Instruction - make_instruction(const ValueType &result_type, - const tensor_function::Lambda &lambda_in, + make_instruction(const tensor_function::Lambda &lambda_in, const ValueBuilderFactory &factory, Stash &stash); }; diff --git a/eval/src/vespa/eval/instruction/generic_map.cpp b/eval/src/vespa/eval/instruction/generic_map.cpp index 0144a39a58a..4f6780c2276 100644 --- a/eval/src/vespa/eval/instruction/generic_map.cpp +++ b/eval/src/vespa/eval/instruction/generic_map.cpp @@ -32,17 +32,17 @@ void my_generic_map_op(State &state, uint64_t param_in) { state.pop_push(result_ref); } -template <typename CT, typename Func> -void my_scalar_map_op(State &state, uint64_t param_in) { - Func function(to_map_fun(param_in)); - const Value &a = state.peek(0); - state.pop_push(state.stash.create<ScalarValue<CT>>(function(a.cells().typify<CT>()[0]))); +template <typename Func> +void my_double_map_op(State &state, uint64_t param_in) { + Func fun(to_map_fun(param_in)); + state.pop_push(state.stash.create<DoubleValue>(fun(state.peek(0).as_double()))); } struct SelectGenericMapOp { template <typename CT, typename Func> static auto invoke(const ValueType &type) { - if (type.is_scalar()) { - return my_scalar_map_op<CT, Func>; + if (type.is_double()) { + assert((std::is_same_v<CT,double>)); + return my_double_map_op<Func>; } return my_generic_map_op<CT, Func>; } @@ -56,8 +56,7 @@ InterpretedFunction::Instruction GenericMap::make_instruction(const ValueType &result_type, const ValueType &input_type, map_fun_t function) { - // for now: - assert(result_type == input_type); + assert(result_type == input_type.map()); auto op = typify_invoke<2,MapTypify,SelectGenericMapOp>(input_type.cell_type(), function, input_type); return Instruction(op, to_param(function)); } diff --git a/eval/src/vespa/eval/instruction/generic_merge.cpp b/eval/src/vespa/eval/instruction/generic_merge.cpp index 0ab6bdab67b..218746d492a 100644 --- a/eval/src/vespa/eval/instruction/generic_merge.cpp +++ b/eval/src/vespa/eval/instruction/generic_merge.cpp @@ -106,8 +106,8 @@ GenericMerge::make_instruction(const ValueType &result_type, const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function, const ValueBuilderFactory &factory, Stash &stash) { - const auto ¶m = stash.create<MergeParam>(lhs_type, rhs_type, function, factory); - assert(result_type == param.res_type); + const auto ¶m = stash.create<MergeParam>(result_type, lhs_type, rhs_type, function, factory); + assert(result_type == ValueType::merge(lhs_type, rhs_type)); auto fun = typify_invoke<4,MergeTypify,SelectGenericMergeOp>(lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), function); return Instruction(fun, wrap_param<MergeParam>(param)); } diff --git a/eval/src/vespa/eval/instruction/generic_merge.h b/eval/src/vespa/eval/instruction/generic_merge.h index 0e0013ac8b9..4f06e4259fc 100644 --- a/eval/src/vespa/eval/instruction/generic_merge.h +++ b/eval/src/vespa/eval/instruction/generic_merge.h @@ -11,11 +11,12 @@ struct MergeParam { const join_fun_t function; const size_t num_mapped_dimensions; const size_t dense_subspace_size; - std::vector<size_t> all_view_dims; + SmallVector<size_t> all_view_dims; const ValueBuilderFactory &factory; - MergeParam(const ValueType &lhs_type, const ValueType &rhs_type, + MergeParam(const ValueType &res_type_in, + const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function_in, const ValueBuilderFactory &factory_in) - : res_type(ValueType::merge(lhs_type, rhs_type)), + : res_type(res_type_in), function(function_in), num_mapped_dimensions(lhs_type.count_mapped_dimensions()), dense_subspace_size(lhs_type.dense_subspace_size()), diff --git a/eval/src/vespa/eval/instruction/generic_peek.cpp b/eval/src/vespa/eval/instruction/generic_peek.cpp index 4658b20e79d..c8198526b3d 100644 --- a/eval/src/vespa/eval/instruction/generic_peek.cpp +++ b/eval/src/vespa/eval/instruction/generic_peek.cpp @@ -274,8 +274,8 @@ struct PeekParam { size_t num_children; const ValueBuilderFactory &factory; - PeekParam(const ValueType &input_type, - const ValueType &res_type_in, + PeekParam(const ValueType &res_type_in, + const ValueType &input_type, const GenericPeek::SpecMap &spec_in, const ValueBuilderFactory &factory_in) : res_type(res_type_in), @@ -362,7 +362,7 @@ GenericPeek::make_instruction(const ValueType &result_type, const ValueBuilderFactory &factory, Stash &stash) { - const auto ¶m = stash.create<PeekParam>(input_type, result_type, spec, factory); + const auto ¶m = stash.create<PeekParam>(result_type, input_type, spec, factory); auto fun = typify_invoke<2,TypifyCellType,SelectGenericPeekOp>(input_type.cell_type(), result_type.cell_type()); return Instruction(fun, wrap_param<PeekParam>(param)); } diff --git a/eval/src/vespa/eval/instruction/generic_reduce.cpp b/eval/src/vespa/eval/instruction/generic_reduce.cpp index 7f3cf7ef0c6..2c630ca0419 100644 --- a/eval/src/vespa/eval/instruction/generic_reduce.cpp +++ b/eval/src/vespa/eval/instruction/generic_reduce.cpp @@ -154,7 +154,7 @@ void my_generic_dense_reduce_op(State &state, uint64_t param_in) { } }; -template <typename ICT, typename OCT, typename AGGR> +template <typename ICT, typename AGGR> void my_full_reduce_op(State &state, uint64_t) { auto cells = state.peek(0).cells().typify<ICT>(); if (cells.size() >= 8) { @@ -176,31 +176,34 @@ void my_full_reduce_op(State &state, uint64_t) { aggrs[0].merge(aggrs[2]); aggrs[1].merge(aggrs[3]); aggrs[0].merge(aggrs[1]); - state.pop_push(state.stash.create<ScalarValue<OCT>>(aggrs[0].result())); + state.pop_push(state.stash.create<DoubleValue>(aggrs[0].result())); } else if (cells.size() > 0) { AGGR aggr; for (ICT value: cells) { aggr.sample(value); } - state.pop_push(state.stash.create<ScalarValue<OCT>>(aggr.result())); + state.pop_push(state.stash.create<DoubleValue>(aggr.result())); } else { - state.pop_push(state.stash.create<ScalarValue<OCT>>(OCT{0})); + state.pop_push(state.stash.create<DoubleValue>(0.0)); } }; struct SelectGenericReduceOp { - template <typename ICT, typename OCT, typename AGGR> static auto invoke(const ReduceParam ¶m) { + template <typename ICM, typename OCM, typename AGGR> static auto invoke(const ReduceParam ¶m) { + using ICT = CellValueType<ICM::value.cell_type>; + using OCT = CellValueType<OCM::value.cell_type>; using AggrType = typename AGGR::template templ<OCT>; - if (param.res_type.is_scalar()) { - return my_full_reduce_op<ICT, OCT, AggrType>; - } - if (param.sparse_plan.should_forward_index()) { - return my_generic_dense_reduce_op<ICT, OCT, AggrType, true>; - } - if (param.res_type.is_dense()) { - return my_generic_dense_reduce_op<ICT, OCT, AggrType, false>; + if constexpr (OCM::value.is_scalar) { + return my_full_reduce_op<ICT, AggrType>; + } else { + if (param.sparse_plan.should_forward_index()) { + return my_generic_dense_reduce_op<ICT, OCT, AggrType, true>; + } + if (param.res_type.is_dense()) { + return my_generic_dense_reduce_op<ICT, OCT, AggrType, false>; + } + return my_generic_reduce_op<ICT, OCT, AggrType>; } - return my_generic_reduce_op<ICT, OCT, AggrType>; } }; @@ -287,7 +290,7 @@ SparseReducePlan::~SparseReducePlan() = default; //----------------------------------------------------------------------------- -using ReduceTypify = TypifyValue<TypifyCellType,TypifyAggr>; +using ReduceTypify = TypifyValue<TypifyCellMeta,TypifyAggr>; Instruction GenericReduce::make_instruction(const ValueType &result_type, @@ -296,7 +299,8 @@ GenericReduce::make_instruction(const ValueType &result_type, { auto ¶m = stash.create<ReduceParam>(input_type, dimensions, factory); assert(result_type == param.res_type); - auto fun = typify_invoke<3,ReduceTypify,SelectGenericReduceOp>(input_type.cell_type(), result_type.cell_type(), aggr, param); + assert(result_type.cell_meta().eq(CellMeta::reduce(input_type.cell_type(), result_type.is_double()))); + auto fun = typify_invoke<3,ReduceTypify,SelectGenericReduceOp>(input_type.cell_meta(), result_type.cell_meta().limit(), aggr, param); return Instruction(fun, wrap_param<ReduceParam>(param)); } diff --git a/eval/src/vespa/eval/instruction/join_with_number_function.cpp b/eval/src/vespa/eval/instruction/join_with_number_function.cpp index cd95a109e60..c574e3f8ad9 100644 --- a/eval/src/vespa/eval/instruction/join_with_number_function.cpp +++ b/eval/src/vespa/eval/instruction/join_with_number_function.cpp @@ -93,7 +93,7 @@ JoinWithNumberFunction::visit_self(vespalib::ObjectVisitor &visitor) const const TensorFunction & JoinWithNumberFunction::optimize(const TensorFunction &expr, Stash &stash) { - if (! expr.result_type().is_scalar()) { + if (! expr.result_type().is_double()) { if (const auto *join = as<Join>(expr)) { const ValueType &result_type = join->result_type(); const TensorFunction &lhs = join->lhs(); diff --git a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp index c8a4df2b82d..a223463240a 100644 --- a/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp +++ b/eval/src/vespa/eval/instruction/mixed_inner_product_function.cpp @@ -105,7 +105,7 @@ MixedInnerProductFunction::compile_self(const ValueBuilderFactory &, Stash &stas bool MixedInnerProductFunction::compatible_types(const ValueType &res, const ValueType &mixed, const ValueType &vector) { - if (vector.is_dense() && ! res.is_scalar()) { + if (vector.is_dense() && ! res.is_double()) { auto dense_dims = vector.nontrivial_indexed_dimensions(); auto mixed_dims = mixed.nontrivial_indexed_dimensions(); while (! dense_dims.empty()) { @@ -139,7 +139,7 @@ MixedInnerProductFunction::optimize(const TensorFunction &expr, Stash &stash) { const auto & res_type = expr.result_type(); auto reduce = as<Reduce>(expr); - if ((! res_type.is_scalar()) && reduce && (reduce->aggr() == Aggr::SUM)) { + if ((! res_type.is_double()) && reduce && (reduce->aggr() == Aggr::SUM)) { auto join = as<Join>(reduce->child()); if (join && (join->function() == Mul::f)) { const TensorFunction &lhs = join->lhs(); diff --git a/eval/src/vespa/eval/instruction/mixed_map_function.cpp b/eval/src/vespa/eval/instruction/mixed_map_function.cpp index 06b53006952..69917ae94e0 100644 --- a/eval/src/vespa/eval/instruction/mixed_map_function.cpp +++ b/eval/src/vespa/eval/instruction/mixed_map_function.cpp @@ -75,7 +75,7 @@ const TensorFunction & MixedMapFunction::optimize(const TensorFunction &expr, Stash &stash) { if (auto map = as<Map>(expr)) { - if (! map->child().result_type().is_scalar()) { + if (! map->child().result_type().is_double()) { return stash.create<MixedMapFunction>(map->result_type(), map->child(), map->function()); } } diff --git a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp index ed88da77f3e..d487ab42d26 100644 --- a/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp +++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp @@ -59,7 +59,7 @@ template <typename LCT, typename RCT, typename Fun, bool swap, Overlap overlap, void my_simple_join_op(State &state, uint64_t param) { using PCT = typename std::conditional<swap,RCT,LCT>::type; using SCT = typename std::conditional<swap,LCT,RCT>::type; - using OCT = typename UnifyCellTypes<PCT,SCT>::type; + using OCT = decltype(unify_cell_types<LCT,RCT>()); using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type; const JoinParams ¶ms = unwrap_param<JoinParams>(param); OP my_op(params.function); diff --git a/eval/src/vespa/eval/instruction/pow_as_map_optimizer.cpp b/eval/src/vespa/eval/instruction/pow_as_map_optimizer.cpp index aa33e98c939..5c09ba2c8cc 100644 --- a/eval/src/vespa/eval/instruction/pow_as_map_optimizer.cpp +++ b/eval/src/vespa/eval/instruction/pow_as_map_optimizer.cpp @@ -15,7 +15,7 @@ PowAsMapOptimizer::optimize(const TensorFunction &expr, Stash &stash) const TensorFunction &lhs = join->lhs(); const TensorFunction &rhs = join->rhs(); if ((join->function() == Pow::f) && - rhs.result_type().is_scalar()) + rhs.result_type().is_double()) { if (auto const_value = as<ConstValue>(rhs)) { if (const_value->value().as_double() == 2.0) { diff --git a/eval/src/vespa/eval/instruction/remove_trivial_dimension_optimizer.cpp b/eval/src/vespa/eval/instruction/remove_trivial_dimension_optimizer.cpp index 77f5247aaaa..06c3fb886b5 100644 --- a/eval/src/vespa/eval/instruction/remove_trivial_dimension_optimizer.cpp +++ b/eval/src/vespa/eval/instruction/remove_trivial_dimension_optimizer.cpp @@ -28,11 +28,11 @@ RemoveTrivialDimensionOptimizer::optimize(const TensorFunction &expr, Stash &sta { if (auto reduce = as<Reduce>(expr)) { const TensorFunction &child = reduce->child(); - if ((! expr.result_type().dimensions().empty()) && + if (expr.result_type().has_dimensions() && aggr::is_ident(reduce->aggr()) && - is_trivial_dim_list(child.result_type(), reduce->dimensions())) + is_trivial_dim_list(child.result_type(), reduce->dimensions()) && + (expr.result_type().cell_type() == child.result_type().cell_type())) { - assert(expr.result_type().cell_type() == child.result_type().cell_type()); return ReplaceTypeFunction::create_compact(expr.result_type(), child, stash); } } diff --git a/eval/src/vespa/eval/instruction/sparse_dot_product_function.cpp b/eval/src/vespa/eval/instruction/sparse_dot_product_function.cpp index 7cc4417bdbb..4da3dbe4f5b 100644 --- a/eval/src/vespa/eval/instruction/sparse_dot_product_function.cpp +++ b/eval/src/vespa/eval/instruction/sparse_dot_product_function.cpp @@ -73,7 +73,7 @@ void my_sparse_dot_product_op(InterpretedFunction::State &state, uint64_t num_ma double result = __builtin_expect(are_fast(lhs_idx, rhs_idx), true) ? my_fast_sparse_dot_product<CT,single_dim>(&as_fast(lhs_idx).map, &as_fast(rhs_idx).map, lhs_cells, rhs_cells) : my_sparse_dot_product_fallback<CT>(lhs_idx, rhs_idx, lhs_cells, rhs_cells, num_mapped_dims); - state.pop_pop_push(state.stash.create<ScalarValue<double>>(result)); + state.pop_pop_push(state.stash.create<DoubleValue>(result)); } struct MyGetFun { @@ -87,7 +87,7 @@ using MyTypify = TypifyValue<TypifyCellType,TypifyBool>; SparseDotProductFunction::SparseDotProductFunction(const TensorFunction &lhs_in, const TensorFunction &rhs_in) - : tensor_function::Op2(ValueType::make_type(CellType::DOUBLE, {}), lhs_in, rhs_in) + : tensor_function::Op2(ValueType::double_type(), lhs_in, rhs_in) { } @@ -103,7 +103,7 @@ SparseDotProductFunction::compile_self(const ValueBuilderFactory &, Stash &) con bool SparseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) { - return (res.is_scalar() && (res.cell_type() == CellType::DOUBLE) && + return (res.is_double() && lhs.is_sparse() && (rhs.dimensions() == lhs.dimensions()) && lhs.cell_type() == rhs.cell_type()); } diff --git a/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp index 480af3315b1..0c6ac51cde0 100644 --- a/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp +++ b/eval/src/vespa/eval/instruction/sparse_full_overlap_join_function.cpp @@ -72,11 +72,14 @@ void my_sparse_full_overlap_join_op(InterpretedFunction::State &state, uint64_t } struct SelectSparseFullOverlapJoinOp { - template <typename CT, typename Fun, typename SINGLE_DIM> - static auto invoke() { return my_sparse_full_overlap_join_op<CT,Fun,SINGLE_DIM::value>; } + template <typename R1, typename Fun, typename SINGLE_DIM> + static auto invoke() { + using CT = CellValueType<R1::value.cell_type>; + return my_sparse_full_overlap_join_op<CT,Fun,SINGLE_DIM::value>; + } }; -using MyTypify = TypifyValue<TypifyCellType,operation::TypifyOp2,TypifyBool>; +using MyTypify = TypifyValue<TypifyCellMeta,operation::TypifyOp2,TypifyBool>; bool is_sparse_like(const ValueType &type) { return ((type.count_mapped_dimensions() > 0) && (type.dense_subspace_size() == 1)); @@ -96,10 +99,10 @@ SparseFullOverlapJoinFunction::SparseFullOverlapJoinFunction(const tensor_functi InterpretedFunction::Instruction SparseFullOverlapJoinFunction::compile_self(const ValueBuilderFactory &factory, Stash &stash) const { - const auto ¶m = stash.create<JoinParam>(lhs().result_type(), rhs().result_type(), function(), factory); - assert(param.res_type == result_type()); + const auto ¶m = stash.create<JoinParam>(result_type(), lhs().result_type(), rhs().result_type(), function(), factory); + assert(result_type() == ValueType::join(lhs().result_type(), rhs().result_type())); bool single_dim = (result_type().count_mapped_dimensions() == 1); - auto op = typify_invoke<3,MyTypify,SelectSparseFullOverlapJoinOp>(result_type().cell_type(), function(), single_dim); + auto op = typify_invoke<3,MyTypify,SelectSparseFullOverlapJoinOp>(result_type().cell_meta().limit(), function(), single_dim); return InterpretedFunction::Instruction(op, wrap_param<JoinParam>(param)); } @@ -107,12 +110,12 @@ bool SparseFullOverlapJoinFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) { if ((lhs.cell_type() == rhs.cell_type()) && + (res.cell_type() == lhs.cell_type()) && is_sparse_like(lhs) && is_sparse_like(rhs) && (res.count_mapped_dimensions() == lhs.count_mapped_dimensions()) && (res.count_mapped_dimensions() == rhs.count_mapped_dimensions())) { assert(is_sparse_like(res)); - assert(res.cell_type() == lhs.cell_type()); return true; } return false; diff --git a/eval/src/vespa/eval/instruction/sparse_merge_function.cpp b/eval/src/vespa/eval/instruction/sparse_merge_function.cpp index 924c4d69fe9..728a5be43b6 100644 --- a/eval/src/vespa/eval/instruction/sparse_merge_function.cpp +++ b/eval/src/vespa/eval/instruction/sparse_merge_function.cpp @@ -87,11 +87,14 @@ void my_sparse_merge_op(InterpretedFunction::State &state, uint64_t param_in) { } struct SelectSparseMergeOp { - template <typename CT, typename SINGLE_DIM, typename Fun> - static auto invoke() { return my_sparse_merge_op<CT,SINGLE_DIM::value,Fun>; } + template <typename R1, typename SINGLE_DIM, typename Fun> + static auto invoke() { + using CT = CellValueType<R1::value.cell_type>; + return my_sparse_merge_op<CT,SINGLE_DIM::value,Fun>; + } }; -using MyTypify = TypifyValue<TypifyCellType,TypifyBool,operation::TypifyOp2>; +using MyTypify = TypifyValue<TypifyCellMeta,TypifyBool,operation::TypifyOp2>; } // namespace <unnamed> @@ -107,10 +110,11 @@ SparseMergeFunction::SparseMergeFunction(const tensor_function::Merge &original) InterpretedFunction::Instruction SparseMergeFunction::compile_self(const ValueBuilderFactory &factory, Stash &stash) const { - const auto ¶m = stash.create<MergeParam>(lhs().result_type(), rhs().result_type(), + const auto ¶m = stash.create<MergeParam>(result_type(), + lhs().result_type(), rhs().result_type(), function(), factory); size_t num_dims = result_type().count_mapped_dimensions(); - auto op = typify_invoke<3,MyTypify,SelectSparseMergeOp>(result_type().cell_type(), + auto op = typify_invoke<3,MyTypify,SelectSparseMergeOp>(result_type().cell_meta().limit(), num_dims == 1, function()); return InterpretedFunction::Instruction(op, wrap_param<MergeParam>(param)); @@ -120,6 +124,7 @@ bool SparseMergeFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) { if ((lhs.cell_type() == rhs.cell_type()) + && (lhs.cell_type() == res.cell_type()) && (lhs.count_mapped_dimensions() > 0) && (lhs.dense_subspace_size() == 1)) { diff --git a/eval/src/vespa/eval/instruction/sparse_no_overlap_join_function.cpp b/eval/src/vespa/eval/instruction/sparse_no_overlap_join_function.cpp index a9f68c7314d..2dfe1b07373 100644 --- a/eval/src/vespa/eval/instruction/sparse_no_overlap_join_function.cpp +++ b/eval/src/vespa/eval/instruction/sparse_no_overlap_join_function.cpp @@ -22,9 +22,9 @@ const Value &my_fast_no_overlap_sparse_join(const FastAddrMap &lhs_map, const Fa const auto &addr_sources = param.sparse_plan.sources; size_t num_mapped_dims = addr_sources.size(); auto &result = stash.create<FastValue<CT,true>>(param.res_type, num_mapped_dims, 1, lhs_map.size() * rhs_map.size()); - std::vector<string_id> output_addr(num_mapped_dims); - std::vector<size_t> store_lhs_idx; - std::vector<size_t> store_rhs_idx; + SmallVector<string_id> output_addr(num_mapped_dims); + SmallVector<size_t> store_lhs_idx; + SmallVector<size_t> store_rhs_idx; size_t out_idx = 0; for (auto source: addr_sources) { switch (source) { @@ -78,11 +78,14 @@ void my_sparse_no_overlap_join_op(InterpretedFunction::State &state, uint64_t pa } struct SelectSparseNoOverlapJoinOp { - template <typename CT, typename Fun> - static auto invoke() { return my_sparse_no_overlap_join_op<CT,Fun>; } + template <typename R1, typename Fun> + static auto invoke() { + using CT = CellValueType<R1::value.cell_type>; + return my_sparse_no_overlap_join_op<CT,Fun>; + } }; -using MyTypify = TypifyValue<TypifyCellType,operation::TypifyOp2>; +using MyTypify = TypifyValue<TypifyCellMeta,operation::TypifyOp2>; bool is_sparse_like(const ValueType &type) { return ((type.count_mapped_dimensions() > 0) && (type.dense_subspace_size() == 1)); @@ -102,8 +105,10 @@ SparseNoOverlapJoinFunction::SparseNoOverlapJoinFunction(const tensor_function:: InterpretedFunction::Instruction SparseNoOverlapJoinFunction::compile_self(const ValueBuilderFactory &factory, Stash &stash) const { - const auto ¶m = stash.create<JoinParam>(lhs().result_type(), rhs().result_type(), function(), factory); - auto op = typify_invoke<2,MyTypify,SelectSparseNoOverlapJoinOp>(result_type().cell_type(), function()); + const auto ¶m = stash.create<JoinParam>(result_type(), + lhs().result_type(), rhs().result_type(), + function(), factory); + auto op = typify_invoke<2,MyTypify,SelectSparseNoOverlapJoinOp>(result_type().cell_meta().limit(), function()); return InterpretedFunction::Instruction(op, wrap_param<JoinParam>(param)); } @@ -111,11 +116,11 @@ bool SparseNoOverlapJoinFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) { if ((lhs.cell_type() == rhs.cell_type()) && + (res.cell_type() == lhs.cell_type()) && is_sparse_like(lhs) && is_sparse_like(rhs) && (res.count_mapped_dimensions() == (lhs.count_mapped_dimensions() + rhs.count_mapped_dimensions()))) { assert(is_sparse_like(res)); - assert(res.cell_type() == lhs.cell_type()); return true; } return false; diff --git a/eval/src/vespa/eval/instruction/sum_max_dot_product_function.cpp b/eval/src/vespa/eval/instruction/sum_max_dot_product_function.cpp index 4b541062007..bdf1682cccd 100644 --- a/eval/src/vespa/eval/instruction/sum_max_dot_product_function.cpp +++ b/eval/src/vespa/eval/instruction/sum_max_dot_product_function.cpp @@ -25,7 +25,7 @@ void my_sum_max_dot_product_op(InterpretedFunction::State &state, uint64_t dp_si result += max_dp; } } - state.pop_pop_push(state.stash.create<ScalarValue<double>>(result)); + state.pop_pop_push(state.stash.create<DoubleValue>(result)); } const Reduce *check_reduce(const TensorFunction &expr, Aggr aggr) { @@ -49,7 +49,7 @@ const Join *check_mul(const TensorFunction &expr) { bool check_params(const ValueType &res_type, const ValueType &query, const ValueType &document, const vespalib::string &sum_dim, const vespalib::string &max_dim, const vespalib::string &dp_dim) { - if (res_type.is_scalar() && (res_type.cell_type() == CellType::DOUBLE) && + if (res_type.is_double() && (query.dimensions().size() == 2) && (query.cell_type() == CellType::FLOAT) && (document.dimensions().size() == 2) && (document.cell_type() == CellType::FLOAT)) { diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp index d9c0d659b1e..2891b37ebe8 100644 --- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp @@ -297,7 +297,7 @@ Onnx::WirePlanner::make_output_type(const TensorInfo &onnx_out) const } dim_list.emplace_back(fmt("d%zu", dim_list.size()), dim_size); } - return ValueType::tensor_type(std::move(dim_list), to_cell_type(elements)); + return ValueType::make_type(to_cell_type(elements), std::move(dim_list)); } Onnx::WireInfo diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 3cff6b1ba72..0f40cf1df95 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -207,7 +207,7 @@ public class Flags { APPLICATION_ID); public static final UnboundBooleanFlag ENABLE_FEED_BLOCK_IN_DISTRIBUTOR = defineFeatureFlag( - "enable-feed-block-in-distributor", false, + "enable-feed-block-in-distributor", true, List.of("geirst"), "2021-01-27", "2021-04-01", "Enables blocking of feed in the distributor if resource usage is above limit on at least one content node", "Takes effect at redeployment", diff --git a/functions.cmake b/functions.cmake index 0152554669a..9fa1f326e0a 100644 --- a/functions.cmake +++ b/functions.cmake @@ -711,6 +711,10 @@ function(vespa_detect_build_platform) message(FATAL_ERROR "-- Could not determine ${OS_DISTRO} version") endif() endif() + file(STRINGS /etc/os-release OS_DISTRO_NAME REGEX "^NAME=") + if (OS_DISTRO_NAME) + string(REGEX REPLACE "NAME=\"?([^\"]+)\"?" "\\1" OS_DISTRO_NAME ${OS_DISTRO_NAME}) + endif() elseif(EXISTS /etc/redhat-release) set(OS_DISTRO "rhel") file(STRINGS "/etc/redhat-release" OS_DISTRO_VERSION) @@ -724,6 +728,9 @@ function(vespa_detect_build_platform) set(VESPA_OS_DISTRO_VERSION ${OS_DISTRO_VERSION} PARENT_SCOPE) string(CONCAT OS_DISTRO_COMBINED ${OS_DISTRO} " " ${OS_DISTRO_VERSION}) set(VESPA_OS_DISTRO_COMBINED ${OS_DISTRO_COMBINED} PARENT_SCOPE) + if (OS_DISTRO_NAME) + set(VESPA_OS_DISTRO_NAME ${OS_DISTRO_NAME} PARENT_SCOPE) + endif() else() message(FATAL_ERROR "-- Could not determine vespa build platform") endif() diff --git a/hosted-api/src/main/java/ai/vespa/hosted/api/Properties.java b/hosted-api/src/main/java/ai/vespa/hosted/api/Properties.java index 22c32cfa9ec..30b9396678b 100644 --- a/hosted-api/src/main/java/ai/vespa/hosted/api/Properties.java +++ b/hosted-api/src/main/java/ai/vespa/hosted/api/Properties.java @@ -24,7 +24,7 @@ public class Properties { public static ApplicationId application() { return ApplicationId.from(requireNonBlankProperty("tenant"), requireNonBlankProperty("application"), - getNonBlankProperty("instance").orElse(user())); + requireNonBlankProperty("instance")); } /** Returns the relevant environment, if this is set with the 'environment' property */ diff --git a/http-utils/src/test/java/ai/vespa/util/http/retry/DelayedConnectionLevelRetryHandlerTest.java b/http-utils/src/test/java/ai/vespa/util/http/retry/DelayedConnectionLevelRetryHandlerTest.java index 85adeae6d78..82bf4ff8080 100644 --- a/http-utils/src/test/java/ai/vespa/util/http/retry/DelayedConnectionLevelRetryHandlerTest.java +++ b/http-utils/src/test/java/ai/vespa/util/http/retry/DelayedConnectionLevelRetryHandlerTest.java @@ -1,7 +1,6 @@ // Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.util.http.retry; -import com.yahoo.vespa.jdk8compat.List; import org.apache.http.client.protocol.HttpClientContext; import org.junit.Test; @@ -9,6 +8,7 @@ import javax.net.ssl.SSLException; import java.io.IOException; import java.net.ConnectException; import java.time.Duration; +import java.util.Arrays; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; @@ -105,7 +105,7 @@ public class DelayedConnectionLevelRetryHandlerTest { DelayedConnectionLevelRetryHandler handler = DelayedConnectionLevelRetryHandler.Builder .withFixedDelay(Duration.ofSeconds(2), maxRetries) - .retryForExceptions(List.of(SSLException.class, ConnectException.class)) + .retryForExceptions(Arrays.asList(SSLException.class, ConnectException.class)) .withSleeper(mock(Sleeper.class)) .build(); @@ -122,7 +122,7 @@ public class DelayedConnectionLevelRetryHandlerTest { public void does_not_retry_for_non_listed_exception() { DelayedConnectionLevelRetryHandler handler = DelayedConnectionLevelRetryHandler.Builder .withFixedDelay(Duration.ofSeconds(2), 2) - .retryForExceptions(List.of(SSLException.class, ConnectException.class)) + .retryForExceptions(Arrays.asList(SSLException.class, ConnectException.class)) .withSleeper(mock(Sleeper.class)) .build(); diff --git a/http-utils/src/test/java/ai/vespa/util/http/retry/DelayedResponseLevelRetryHandlerTest.java b/http-utils/src/test/java/ai/vespa/util/http/retry/DelayedResponseLevelRetryHandlerTest.java index dbc93f28d6b..b1d78fc09eb 100644 --- a/http-utils/src/test/java/ai/vespa/util/http/retry/DelayedResponseLevelRetryHandlerTest.java +++ b/http-utils/src/test/java/ai/vespa/util/http/retry/DelayedResponseLevelRetryHandlerTest.java @@ -9,6 +9,7 @@ import org.apache.http.message.BasicStatusLine; import org.junit.Test; import java.time.Duration; +import java.util.Arrays; import java.util.List; import static org.junit.Assert.assertEquals; @@ -82,7 +83,7 @@ public class DelayedResponseLevelRetryHandlerTest { HttpClientContext ctx = new HttpClientContext(); int lastExecutionCount = maxRetries + 1; List<Duration> expectedIntervals = - com.yahoo.vespa.jdk8compat.List.of( + Arrays.asList( startDelay, Duration.ofSeconds(1), Duration.ofSeconds(2), Duration.ofSeconds(4), Duration.ofSeconds(5), Duration.ofSeconds(5), Duration.ofSeconds(5), Duration.ofSeconds(5), Duration.ofSeconds(5), Duration.ofSeconds(5), Duration.ofSeconds(5)); @@ -98,7 +99,7 @@ public class DelayedResponseLevelRetryHandlerTest { DelayedResponseLevelRetryHandler handler = DelayedResponseLevelRetryHandler.Builder .withFixedDelay(Duration.ofSeconds(2), maxRetries) - .retryForStatusCodes(com.yahoo.vespa.jdk8compat.List.of(HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_BAD_GATEWAY)) + .retryForStatusCodes(Arrays.asList(HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_BAD_GATEWAY)) .build(); HttpResponse response = createResponse(HttpStatus.SC_SERVICE_UNAVAILABLE); @@ -114,7 +115,7 @@ public class DelayedResponseLevelRetryHandlerTest { public void does_not_retry_for_non_listed_exception() { DelayedResponseLevelRetryHandler handler = DelayedResponseLevelRetryHandler.Builder .withFixedDelay(Duration.ofSeconds(2), 2) - .retryForStatusCodes(com.yahoo.vespa.jdk8compat.List.of(HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_BAD_GATEWAY)) + .retryForStatusCodes(Arrays.asList(HttpStatus.SC_SERVICE_UNAVAILABLE, HttpStatus.SC_BAD_GATEWAY)) .build(); HttpResponse response = createResponse(HttpStatus.SC_OK); diff --git a/jdisc-cloud-aws/src/main/java/com/yahoo/jdisc/cloud/aws/AwsParameterStore.java b/jdisc-cloud-aws/src/main/java/com/yahoo/jdisc/cloud/aws/AwsParameterStore.java index 3e90e4ca204..f2cac68c030 100644 --- a/jdisc-cloud-aws/src/main/java/com/yahoo/jdisc/cloud/aws/AwsParameterStore.java +++ b/jdisc-cloud-aws/src/main/java/com/yahoo/jdisc/cloud/aws/AwsParameterStore.java @@ -3,6 +3,7 @@ package com.yahoo.jdisc.cloud.aws; import com.amazonaws.auth.STSAssumeRoleSessionCredentialsProvider; +import com.amazonaws.regions.Regions; import com.amazonaws.services.securitytoken.AWSSecurityTokenService; import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder; import com.amazonaws.services.simplesystemsmanagement.AWSSimpleSystemsManagement; @@ -14,6 +15,11 @@ import com.yahoo.component.AbstractComponent; import com.yahoo.container.jdisc.secretstore.SecretNotFoundException; import com.yahoo.container.jdisc.secretstore.SecretStore; import com.yahoo.container.jdisc.secretstore.SecretStoreConfig; +import com.yahoo.slime.Cursor; +import com.yahoo.slime.Slime; + +import java.util.List; +import java.util.stream.Collectors; /** * @author mortent @@ -21,32 +27,36 @@ import com.yahoo.container.jdisc.secretstore.SecretStoreConfig; public class AwsParameterStore extends AbstractComponent implements SecretStore { private final VespaAwsCredentialsProvider credentialsProvider; - private final SecretStoreConfig secretStoreConfig; + private final List<AwsSettings> configuredStores; @Inject public AwsParameterStore(SecretStoreConfig secretStoreConfig) { - this.secretStoreConfig = secretStoreConfig; + this(translateConfig(secretStoreConfig)); + } + + public AwsParameterStore(List<AwsSettings> configuredStores) { + this.configuredStores = configuredStores; this.credentialsProvider = new VespaAwsCredentialsProvider(); } @Override public String getSecret(String key) { - for (var group : secretStoreConfig.groups()) { + for (var store : configuredStores) { AWSSecurityTokenService tokenService = AWSSecurityTokenServiceClientBuilder .standard() - .withRegion(group.region()) + .withRegion(Regions.DEFAULT_REGION) .withCredentials(credentialsProvider) .build(); STSAssumeRoleSessionCredentialsProvider assumeExtAccountRole = new STSAssumeRoleSessionCredentialsProvider - .Builder(toRoleArn(group.awsId(), group.role()), "vespa") - .withExternalId(group.externalId()) + .Builder(toRoleArn(store.getAwsId(), store.getRole()), "vespa") + .withExternalId(store.getExternalId()) .withStsClient(tokenService) .build(); AWSSimpleSystemsManagement client = AWSSimpleSystemsManagementClient.builder() .withCredentials(assumeExtAccountRole) - .withRegion(group.region()) + .withRegion(store.getRegion()) .build(); GetParametersRequest parametersRequest = new GetParametersRequest().withNames(key).withWithDecryption(true); @@ -70,4 +80,73 @@ public class AwsParameterStore extends AbstractComponent implements SecretStore private String toRoleArn(String awsId, String role) { return "arn:aws:iam::" + awsId + ":role/" + role; } + + private static List<AwsSettings> translateConfig(SecretStoreConfig secretStoreConfig) { + return secretStoreConfig.groups() + .stream() + .map(config -> new AwsSettings(config.name(), config.role(), config.awsId(), config.externalId(), config.region())) + .collect(Collectors.toList()); + } + + public static class AwsSettings { + String name; + String role; + String awsId; + String externalId; + String region; + + AwsSettings(String name, String role, String awsId, String externalId, String region) { + this.name = validate(name, "name"); + this.role = validate(role, "role"); + this.awsId = validate(awsId, "awsId"); + this.externalId = validate(externalId, "externalId"); + this.region = validate(region, "region"); + } + + + public String getName() { + return name; + } + + public String getRole() { + return role; + } + + public String getAwsId() { + return awsId; + } + + public String getExternalId() { + return externalId; + } + + public String getRegion() { + return region; + } + + static AwsSettings fromSlime(Slime slime) { + var json = slime.get(); + return new AwsSettings( + json.field("name").asString(), + json.field("role").asString(), + json.field("awsId").asString(), + json.field("externalId").asString(), + json.field("region").asString() + ); + } + + void toSlime(Cursor slime) { + slime.setString("name", name); + slime.setString("role", role); + slime.setString("awsId", awsId); + slime.setString("externalId", "*****"); + slime.setString("region", region); + } + + static String validate(String value, String name) { + if (value == null || value.isBlank()) + throw new IllegalArgumentException("Config parameter '" + name + "' was blank or empty"); + return value; + } + } } diff --git a/jdisc-cloud-aws/src/main/java/com/yahoo/jdisc/cloud/aws/AwsParameterStoreValidationHandler.java b/jdisc-cloud-aws/src/main/java/com/yahoo/jdisc/cloud/aws/AwsParameterStoreValidationHandler.java index d45ead37480..665e55c8f24 100644 --- a/jdisc-cloud-aws/src/main/java/com/yahoo/jdisc/cloud/aws/AwsParameterStoreValidationHandler.java +++ b/jdisc-cloud-aws/src/main/java/com/yahoo/jdisc/cloud/aws/AwsParameterStoreValidationHandler.java @@ -8,13 +8,14 @@ import com.yahoo.container.jdisc.LoggingRequestHandler; import com.yahoo.io.IOUtils; import com.yahoo.restapi.ErrorResponse; import com.yahoo.restapi.SlimeJsonResponse; -import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; import com.yahoo.slime.SlimeUtils; import com.yahoo.yolean.Exceptions; +import com.yahoo.jdisc.cloud.aws.AwsParameterStore.AwsSettings; import java.io.IOException; import java.io.InputStream; +import java.util.List; import java.util.logging.Level; import java.util.logging.Logger; @@ -27,12 +28,10 @@ import java.util.logging.Logger; public class AwsParameterStoreValidationHandler extends LoggingRequestHandler { private static final Logger log = Logger.getLogger(AwsParameterStoreValidationHandler.class.getName()); - private final AwsParameterStore awsParameterStore; @Inject - public AwsParameterStoreValidationHandler(Context ctx, AwsParameterStore awsParameterStore) { + public AwsParameterStoreValidationHandler(Context ctx) { super(ctx); - this.awsParameterStore = awsParameterStore; } @Override @@ -50,14 +49,22 @@ public class AwsParameterStoreValidationHandler extends LoggingRequestHandler { private HttpResponse handlePOST(HttpRequest request) { var json = toSlime(request.getData()); - var settings = AwsSettings.fromSlime(json); + AwsSettings settings; + + try { + settings = AwsSettings.fromSlime(json); + } catch (IllegalArgumentException e) { + return ErrorResponse.badRequest(Exceptions.toMessageString(e)); + } var response = new Slime(); var root = response.setObject(); settings.toSlime(root.setObject("settings")); try { - awsParameterStore.getSecret("vespa-secret"); + var parameterName = json.get().field("parameterName").asString(); + var store = new AwsParameterStore(List.of(settings)); + store.getSecret(parameterName); root.setString("status", "ok"); } catch (RuntimeException e) { root.setString("status", "error"); @@ -78,34 +85,4 @@ public class AwsParameterStoreValidationHandler extends LoggingRequestHandler { } } - private static class AwsSettings { - String name; - String role; - String awsId; - String externalId; - - AwsSettings(String name, String role, String awsId, String externalId) { - this.name = name; - this.role = role; - this.awsId = awsId; - this.externalId = externalId; - } - - static AwsSettings fromSlime(Slime slime) { - var json = slime.get(); - return new AwsSettings( - json.field("name").asString(), - json.field("role").asString(), - json.field("awsId").asString(), - json.field("externalId").asString() - ); - } - - void toSlime(Cursor slime) { - slime.setString("name", name); - slime.setString("role", role); - slime.setString("awsId", awsId); - slime.setString("externalId", "*****"); - } - } -} +}
\ No newline at end of file diff --git a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyConnectionLogger.java b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyConnectionLogger.java index 34a91a3bbb4..cd1ca490f61 100644 --- a/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyConnectionLogger.java +++ b/jdisc_http_service/src/main/java/com/yahoo/jdisc/http/server/jetty/JettyConnectionLogger.java @@ -49,8 +49,8 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List private static final Logger log = Logger.getLogger(JettyConnectionLogger.class.getName()); - private final ConcurrentMap<SocketChannelEndPoint, ConnectionInfo> connectionInfo = new ConcurrentHashMap<>(); - private final ConcurrentMap<SSLEngine, ConnectionInfo> sslToConnectionInfo = new ConcurrentHashMap<>(); + private final ConcurrentMap<IdentityKey<SocketChannelEndPoint>, ConnectionInfo> connectionInfo = new ConcurrentHashMap<>(); + private final ConcurrentMap<IdentityKey<SSLEngine>, ConnectionInfo> sslToConnectionInfo = new ConcurrentHashMap<>(); private final boolean enabled; private final ConnectionLog connectionLog; @@ -88,14 +88,15 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List public void onOpened(Connection connection) { handleListenerInvocation("Connection.Listener", "onOpened", "%h", List.of(connection), () -> { SocketChannelEndPoint endpoint = findUnderlyingSocketEndpoint(connection.getEndPoint()); - ConnectionInfo info = connectionInfo.get(endpoint); + var endpointKey = IdentityKey.of(endpoint); + ConnectionInfo info = connectionInfo.get(endpointKey); if (info == null) { info = ConnectionInfo.from(endpoint); - connectionInfo.put(endpoint, info); + connectionInfo.put(IdentityKey.of(endpoint), info); } if (connection instanceof SslConnection) { SSLEngine sslEngine = ((SslConnection) connection).getSSLEngine(); - sslToConnectionInfo.put(sslEngine, info); + sslToConnectionInfo.put(IdentityKey.of(sslEngine), info); } if (connection.getEndPoint() instanceof ProxyConnectionFactory.ProxyEndPoint) { InetSocketAddress remoteAddress = connection.getEndPoint().getRemoteAddress(); @@ -108,7 +109,8 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List public void onClosed(Connection connection) { handleListenerInvocation("Connection.Listener", "onClosed", "%h", List.of(connection), () -> { SocketChannelEndPoint endpoint = findUnderlyingSocketEndpoint(connection.getEndPoint()); - ConnectionInfo info = connectionInfo.get(endpoint); + var endpointKey = IdentityKey.of(endpoint); + ConnectionInfo info = connectionInfo.get(endpointKey); if (info == null) return; // Closed connection already handled if (connection instanceof HttpConnection) { info.setHttpBytes(connection.getBytesIn(), connection.getBytesOut()); @@ -116,7 +118,7 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List if (!endpoint.isOpen()) { info.setClosedAt(System.currentTimeMillis()); connectionLog.log(info.toLogEntry()); - connectionInfo.remove(endpoint); + connectionInfo.remove(endpointKey); } }); } @@ -131,7 +133,7 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List public void onRequestBegin(Request request) { handleListenerInvocation("HttpChannel.Listener", "onRequestBegin", "%h", List.of(request), () -> { SocketChannelEndPoint endpoint = findUnderlyingSocketEndpoint(request.getHttpChannel().getEndPoint()); - ConnectionInfo info = Objects.requireNonNull(connectionInfo.get(endpoint)); + ConnectionInfo info = Objects.requireNonNull(connectionInfo.get(IdentityKey.of(endpoint))); info.incrementRequests(); request.setAttribute(CONNECTION_ID_REQUEST_ATTRIBUTE, info.uuid()); }); @@ -141,7 +143,7 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List public void onResponseBegin(Request request) { handleListenerInvocation("HttpChannel.Listener", "onResponseBegin", "%h", List.of(request), () -> { SocketChannelEndPoint endpoint = findUnderlyingSocketEndpoint(request.getHttpChannel().getEndPoint()); - ConnectionInfo info = Objects.requireNonNull(connectionInfo.get(endpoint)); + ConnectionInfo info = Objects.requireNonNull(connectionInfo.get(IdentityKey.of(endpoint))); info.incrementResponses(); }); } @@ -156,7 +158,7 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List public void handshakeSucceeded(Event event) { SSLEngine sslEngine = event.getSSLEngine(); handleListenerInvocation("SslHandshakeListener", "handshakeSucceeded", "sslEngine=%h", List.of(sslEngine), () -> { - ConnectionInfo info = sslToConnectionInfo.remove(sslEngine); + ConnectionInfo info = sslToConnectionInfo.remove(IdentityKey.of(sslEngine)); info.setSslSessionDetails(sslEngine.getSession()); }); } @@ -166,7 +168,7 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List SSLEngine sslEngine = event.getSSLEngine(); handleListenerInvocation("SslHandshakeListener", "handshakeFailed", "sslEngine=%h,failure=%s", List.of(sslEngine, failure), () -> { log.log(Level.FINE, failure, failure::toString); - ConnectionInfo info = sslToConnectionInfo.remove(sslEngine); + ConnectionInfo info = sslToConnectionInfo.remove(IdentityKey.of(sslEngine)); info.setSslHandshakeFailure((SSLHandshakeException)failure); }); } @@ -350,4 +352,22 @@ class JettyConnectionLogger extends AbstractLifeCycle implements Connection.List } } + + private static class IdentityKey<T> { + final T instance; + + IdentityKey(T instance) { this.instance = instance; } + + static <T> IdentityKey<T> of(T instance) { return new IdentityKey<>(instance); } + + @Override public int hashCode() { return System.identityHashCode(instance); } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (!(obj instanceof IdentityKey<?>)) return false; + IdentityKey<?> other = (IdentityKey<?>) obj; + return this.instance == other.instance; + } + } } diff --git a/jdisc_http_service/src/test/java/com/yahoo/container/logging/JsonConnectionLogWriterTest.java b/jdisc_http_service/src/test/java/com/yahoo/container/logging/JsonConnectionLogWriterTest.java index 8944ae9d288..15118b23f85 100644 --- a/jdisc_http_service/src/test/java/com/yahoo/container/logging/JsonConnectionLogWriterTest.java +++ b/jdisc_http_service/src/test/java/com/yahoo/container/logging/JsonConnectionLogWriterTest.java @@ -1,13 +1,13 @@ package com.yahoo.container.logging;// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. import com.yahoo.test.json.JsonTestHelper; -import com.yahoo.vespa.jdk8compat.List; import org.junit.jupiter.api.Test; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.time.Instant; +import java.util.List; import java.util.UUID; /** diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/core/MetricsManager.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/core/MetricsManager.java index 33827634ebf..53398073314 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/core/MetricsManager.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/core/MetricsManager.java @@ -32,7 +32,8 @@ import static java.util.stream.Collectors.toList; * @author gjoranv */ public class MetricsManager { - private static Logger log = Logger.getLogger(MetricsManager.class.getName()); + + private static final Logger log = Logger.getLogger(MetricsManager.class.getName()); static final DimensionId VESPA_VERSION = toDimensionId("vespaVersion"); @@ -75,8 +76,8 @@ public class MetricsManager { /** * Returns the metrics for the given services. The empty list is returned if no services are given. * - * @param services The services to retrieve metrics for. - * @return Metrics for all matching services. + * @param services the services to retrieve metrics for + * @return metrics for all matching services */ public List<MetricsPacket> getMetrics(List<VespaService> services, Instant startTime) { return getMetricsAsBuilders(services, startTime).stream() diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/core/VespaMetrics.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/core/VespaMetrics.java index b895d6221c3..6dda8350c8e 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/core/VespaMetrics.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/core/VespaMetrics.java @@ -256,8 +256,7 @@ public class VespaMetrics { } } if (isForwarded) { - b.append(formatter.format(s, alias, metric.getValue())) - .append(" "); + b.append(formatter.format(s, alias, metric.getValue())).append(" "); } } } diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/ExternalMetrics.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/ExternalMetrics.java index 337967f3075..158f626cf84 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/ExternalMetrics.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/ExternalMetrics.java @@ -25,11 +25,12 @@ import static java.util.logging.Level.FINE; import static java.util.stream.Collectors.toCollection; /** - * This class is responsible for handling metrics received from external processes. + * Handling of metrics received from external processes. * * @author gjoranv */ public class ExternalMetrics { + private static final Logger log = Logger.getLogger(ExternalMetrics.class.getName()); // NOTE: node service id must be kept in sync with the same constant _value_ used in docker-api:Metrics.java @@ -52,7 +53,8 @@ public class ExternalMetrics { public void setExtraMetrics(List<MetricsPacket.Builder> externalPackets) { // TODO: Metrics filtering per consumer is not yet implemented. - // Split each packet per metric, and re-aggregate based on the metrics each consumer wants. Then filter out all packages with no consumers. + // Split each packet per metric, and re-aggregate based on the metrics each consumer wants. + // Then filter out all packages with no consumers. log.log(FINE, () -> "Setting new external metrics with " + externalPackets.size() + " metrics packets."); externalPackets.forEach(packet -> { packet.addConsumers(consumers.getAllConsumers()) @@ -95,4 +97,5 @@ public class ExternalMetrics { dimensions.keySet().retainAll(Set.of(ROLE_DIMENSION, STATE_DIMENSION, ORCHESTRATOR_STATE_DIMENSION)); return dimensions; } + } diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/Metric.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/Metric.java index a7ea70495f7..63e147bf6b4 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/Metric.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/Metric.java @@ -13,6 +13,7 @@ import java.util.Set; * @author Jo Kristian Bergum */ public class Metric { + private final long time; private final Number value; private final String description; diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/Metrics.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/Metrics.java index 45d76375c07..b45e7743640 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/Metrics.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/metric/Metrics.java @@ -13,6 +13,7 @@ import java.util.List; // TODO: remove timestamp, only used as temporary storage. // TODO: instances of this class can probably be replaced by a simple freezable map. public class Metrics { + private final List<Metric> metrics = new ArrayList<>(); private long timestamp; private boolean isFrozen = false; diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/rpc/RpcServer.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/rpc/RpcServer.java index 30e7d7b81f8..79af500cde7 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/rpc/RpcServer.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/rpc/RpcServer.java @@ -37,7 +37,7 @@ public class RpcServer { private static final Logger log = Logger.getLogger(RpcServer.class.getName()); - private static int LOG_SPENT_TIME_LIMIT = 10 * 1000; // ms. same as default client RPC timeout used in rpc_invoke + private static final int LOG_SPENT_TIME_LIMIT = 10 * 1000; // ms. same as default client RPC timeout used in rpc_invoke private final VespaServices vespaServices; private final MetricsManager metricsManager; diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java new file mode 100644 index 00000000000..f9445e5b26a --- /dev/null +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/MetricsParser.java @@ -0,0 +1,154 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.metricsproxy.service; + +import ai.vespa.metricsproxy.metric.Metric; +import ai.vespa.metricsproxy.metric.Metrics; +import ai.vespa.metricsproxy.metric.model.DimensionId; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.core.JsonToken; +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; + +import java.io.IOException; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; + +import static ai.vespa.metricsproxy.metric.model.DimensionId.toDimensionId; + +/** + * Fetch metrics for a given vespa service + * + * @author Jo Kristian Bergum + */ +public class MetricsParser { + + private static final ObjectMapper jsonMapper = new ObjectMapper(); + + static Metrics parse(String data) throws IOException { + JsonParser parser = jsonMapper.createParser(data); + + if (parser.nextToken() != JsonToken.START_OBJECT) { + throw new IOException("Expected start of object, got " + parser.currentToken()); + } + + Metrics metrics = new Metrics(); + for (parser.nextToken(); parser.getCurrentToken() != JsonToken.END_OBJECT; parser.nextToken()) { + String fieldName = parser.getCurrentName(); + JsonToken token = parser.nextToken(); + if (fieldName.equals("metrics")) { + metrics = parseMetrics(parser); + } else { + if (token == JsonToken.START_OBJECT || token == JsonToken.START_ARRAY) { + parser.skipChildren(); + } + } + } + return metrics; + } + + static private Metrics parseSnapshot(JsonParser parser) throws IOException { + if (parser.getCurrentToken() != JsonToken.START_OBJECT) { + throw new IOException("Expected start of 'snapshot' object, got " + parser.currentToken()); + } + Metrics metrics = new Metrics(); + for (parser.nextToken(); parser.getCurrentToken() != JsonToken.END_OBJECT; parser.nextToken()) { + String fieldName = parser.getCurrentName(); + JsonToken token = parser.nextToken(); + if (fieldName.equals("to")) { + long timestamp = parser.getLongValue(); + long now = System.currentTimeMillis() / 1000; + timestamp = Metric.adjustTime(timestamp, now); + metrics = new Metrics(timestamp); + } else { + if (token == JsonToken.START_OBJECT || token == JsonToken.START_ARRAY) { + parser.skipChildren(); + } + } + } + return metrics; + } + + static private void parseValues(JsonParser parser, Metrics metrics) throws IOException { + if (parser.getCurrentToken() != JsonToken.START_ARRAY) { + throw new IOException("Expected start of 'metrics:values' array, got " + parser.currentToken()); + } + + Map<String, Map<DimensionId, String>> uniqueDimensions = new HashMap<>(); + while (parser.nextToken() == JsonToken.START_OBJECT) { + // read everything from this START_OBJECT to the matching END_OBJECT + // and return it as a tree model ObjectNode + JsonNode value = jsonMapper.readTree(parser); + handleValue(value, metrics.getTimeStamp(), metrics, uniqueDimensions); + + // do whatever you need to do with this object + } + } + + static private Metrics parseMetrics(JsonParser parser) throws IOException { + if (parser.getCurrentToken() != JsonToken.START_OBJECT) { + throw new IOException("Expected start of 'metrics' object, got " + parser.currentToken()); + } + Metrics metrics = new Metrics(); + for (parser.nextToken(); parser.getCurrentToken() != JsonToken.END_OBJECT; parser.nextToken()) { + String fieldName = parser.getCurrentName(); + JsonToken token = parser.nextToken(); + if (fieldName.equals("snapshot")) { + metrics = parseSnapshot(parser); + } else if (fieldName.equals("values")) { + parseValues(parser, metrics); + } else { + if (token == JsonToken.START_OBJECT || token == JsonToken.START_ARRAY) { + parser.skipChildren(); + } + } + } + return metrics; + } + + static private void handleValue(JsonNode metric, long timestamp, Metrics metrics, Map<String, Map<DimensionId, String>> uniqueDimensions) { + String name = metric.get("name").textValue(); + String description = ""; + + if (metric.has("description")) { + description = metric.get("description").textValue(); + } + + Map<DimensionId, String> dim = Collections.emptyMap(); + if (metric.has("dimensions")) { + JsonNode dimensions = metric.get("dimensions"); + StringBuilder sb = new StringBuilder(); + for (Iterator<?> it = dimensions.fieldNames(); it.hasNext(); ) { + String k = (String) it.next(); + String v = dimensions.get(k).asText(); + sb.append(toDimensionId(k)).append(v); + } + if ( ! uniqueDimensions.containsKey(sb.toString())) { + dim = new HashMap<>(); + for (Iterator<?> it = dimensions.fieldNames(); it.hasNext(); ) { + String k = (String) it.next(); + String v = dimensions.get(k).textValue(); + dim.put(toDimensionId(k), v); + } + uniqueDimensions.put(sb.toString(), Collections.unmodifiableMap(dim)); + } + dim = uniqueDimensions.get(sb.toString()); + } + + JsonNode aggregates = metric.get("values"); + for (Iterator<?> it = aggregates.fieldNames(); it.hasNext(); ) { + String aggregator = (String) it.next(); + JsonNode aggregatorValue = aggregates.get(aggregator); + if (aggregatorValue == null) { + throw new IllegalArgumentException("Value for aggregator '" + aggregator + "' is missing"); + } + Number value = aggregatorValue.numberValue(); + if (value == null) { + throw new IllegalArgumentException("Value for aggregator '" + aggregator + "' is not a number"); + } + StringBuilder metricName = (new StringBuilder()).append(name).append(".").append(aggregator); + metrics.add(new Metric(metricName.toString(), value, timestamp, dim, description)); + } + } +} diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/RemoteMetricsFetcher.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/RemoteMetricsFetcher.java index 464f215edc4..314d556b9b4 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/RemoteMetricsFetcher.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/RemoteMetricsFetcher.java @@ -1,20 +1,9 @@ // Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.metricsproxy.service; -import ai.vespa.metricsproxy.metric.Metric; import ai.vespa.metricsproxy.metric.Metrics; -import ai.vespa.metricsproxy.metric.model.DimensionId; -import com.fasterxml.jackson.databind.JsonNode; -import com.fasterxml.jackson.databind.ObjectMapper; -import com.fasterxml.jackson.databind.node.ArrayNode; import java.io.IOException; -import java.util.Collections; -import java.util.HashMap; -import java.util.Iterator; -import java.util.Map; - -import static ai.vespa.metricsproxy.metric.model.DimensionId.toDimensionId; /** * Fetch metrics for a given vespa service @@ -23,8 +12,6 @@ import static ai.vespa.metricsproxy.metric.model.DimensionId.toDimensionId; */ public class RemoteMetricsFetcher extends HttpMetricFetcher { - private static final ObjectMapper jsonMapper = new ObjectMapper(); - final static String METRICS_PATH = STATE_PATH + "metrics"; RemoteMetricsFetcher(VespaService service, int port) { @@ -45,90 +32,14 @@ public class RemoteMetricsFetcher extends HttpMetricFetcher { return createMetrics(data, fetchCount); } - /** - * Connect to remote service over http and fetch metrics - */ Metrics createMetrics(String data, int fetchCount) { Metrics remoteMetrics = new Metrics(); try { - remoteMetrics = parse(data); + remoteMetrics = MetricsParser.parse(data); } catch (Exception e) { handleException(e, data, fetchCount); } return remoteMetrics; } - - private Metrics parse(String data) throws IOException { - JsonNode o = jsonMapper.readTree(data); - if (!(o.has("metrics"))) { - return new Metrics(); //empty - } - - JsonNode metrics = o.get("metrics"); - ArrayNode values; - long timestamp; - - try { - JsonNode snapshot = metrics.get("snapshot"); - timestamp = snapshot.get("to").asLong(); - values = (ArrayNode) metrics.get("values"); - } catch (Exception e) { - // snapshot might not have been produced. Do not throw exception into log - return new Metrics(); - } - long now = System.currentTimeMillis() / 1000; - timestamp = Metric.adjustTime(timestamp, now); - Metrics m = new Metrics(timestamp); - - Map<DimensionId, String> noDims = Collections.emptyMap(); - Map<String, Map<DimensionId, String>> uniqueDimensions = new HashMap<>(); - for (int i = 0; i < values.size(); i++) { - JsonNode metric = values.get(i); - String name = metric.get("name").textValue(); - String description = ""; - - if (metric.has("description")) { - description = metric.get("description").textValue(); - } - - Map<DimensionId, String> dim = noDims; - if (metric.has("dimensions")) { - JsonNode dimensions = metric.get("dimensions"); - StringBuilder sb = new StringBuilder(); - for (Iterator<?> it = dimensions.fieldNames(); it.hasNext(); ) { - String k = (String) it.next(); - String v = dimensions.get(k).asText(); - sb.append(toDimensionId(k)).append(v); - } - if ( ! uniqueDimensions.containsKey(sb.toString())) { - dim = new HashMap<>(); - for (Iterator<?> it = dimensions.fieldNames(); it.hasNext(); ) { - String k = (String) it.next(); - String v = dimensions.get(k).textValue(); - dim.put(toDimensionId(k), v); - } - uniqueDimensions.put(sb.toString(), Collections.unmodifiableMap(dim)); - } - dim = uniqueDimensions.get(sb.toString()); - } - - JsonNode aggregates = metric.get("values"); - for (Iterator<?> it = aggregates.fieldNames(); it.hasNext(); ) { - String aggregator = (String) it.next(); - JsonNode aggregatorValue = aggregates.get(aggregator); - if (aggregatorValue == null) { - throw new IllegalArgumentException("Value for aggregator '" + aggregator + "' is missing"); - } - Number value = aggregatorValue.numberValue(); - if (value == null) { - throw new IllegalArgumentException("Value for aggregator '" + aggregator + "' is not a number"); - } - StringBuilder metricName = (new StringBuilder()).append(name).append(".").append(aggregator); - m.add(new Metric(metricName.toString(), value, timestamp, dim, description)); - } - } - - return m; - } } diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/SystemPoller.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/SystemPoller.java index a6c2220b5a2..05e65449163 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/SystemPoller.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/SystemPoller.java @@ -25,7 +25,7 @@ import java.util.logging.Logger; */ public class SystemPoller { - final private static Logger log = Logger.getLogger(SystemPoller.class.getName()); + private static final Logger log = Logger.getLogger(SystemPoller.class.getName()); private final int pollingIntervalSecs; private final List<VespaService> services; @@ -55,12 +55,10 @@ public class SystemPoller { * @return array[0] = memoryResident, array[1] = memoryVirtual (kB units) */ long[] getMemoryUsage(VespaService service) { - long size[] = new long[2]; + long[] size = new long[2]; BufferedReader br; int pid = service.getPid(); - size[0] = 0; - size[1] = 0; try { br = new BufferedReader(new FileReader("/proc/" + pid + "/smaps")); } catch (FileNotFoundException ex) { diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/VespaService.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/VespaService.java index 69558b0c474..b069256f527 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/VespaService.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/VespaService.java @@ -27,7 +27,6 @@ public class VespaService implements Comparable<VespaService> { private final String monitoringPrefix; private final Map<DimensionId, String> dimensions; - private volatile int pid = -1; private volatile String state = "UNKNOWN"; private volatile boolean isAlive; @@ -42,8 +41,8 @@ public class VespaService implements Comparable<VespaService> { // Used to keep track of log level when health or metrics requests fail - private AtomicInteger metricsFetchCount = new AtomicInteger(0); - private AtomicInteger healthFetchCount = new AtomicInteger(0); + private final AtomicInteger metricsFetchCount = new AtomicInteger(0); + private final AtomicInteger healthFetchCount = new AtomicInteger(0); public static VespaService create(String name, String id, int statePort) { diff --git a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/VespaServices.java b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/VespaServices.java index e01a68a1f7f..71f8a5f2b21 100644 --- a/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/VespaServices.java +++ b/metrics-proxy/src/main/java/ai/vespa/metricsproxy/service/VespaServices.java @@ -24,6 +24,7 @@ import static java.util.logging.Level.FINE; * @author gjoranv */ public class VespaServices { + private static final Logger log = Logger.getLogger(VespaServices.class.getName()); public static final String ALL_SERVICES = "all"; diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/AddNode.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/AddNode.java index d1b70450226..1cf9c2bfb76 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/AddNode.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/AddNode.java @@ -15,6 +15,7 @@ import java.util.Set; public class AddNode { public final String hostname; + public final Optional<String> id; public final Optional<String> parentHostname; public final Optional<String> nodeFlavor; public final Optional<FlavorOverrides> flavorOverrides; @@ -23,19 +24,20 @@ public class AddNode { public final Set<String> ipAddresses; public final Set<String> additionalIpAddresses; - public static AddNode forHost(String hostname, String nodeFlavor, Optional<FlavorOverrides> flavorOverrides, NodeType nodeType, Set<String> ipAddresses, Set<String> additionalIpAddresses) { - return new AddNode(hostname, Optional.empty(), Optional.of(nodeFlavor), flavorOverrides, Optional.empty(), nodeType, ipAddresses, additionalIpAddresses); + public static AddNode forHost(String hostname, Optional<String> id, String nodeFlavor, Optional<FlavorOverrides> flavorOverrides, NodeType nodeType, Set<String> ipAddresses, Set<String> additionalIpAddresses) { + return new AddNode(hostname, id, Optional.empty(), Optional.of(nodeFlavor), flavorOverrides, Optional.empty(), nodeType, ipAddresses, additionalIpAddresses); } public static AddNode forNode(String hostname, String parentHostname, NodeResources nodeResources, NodeType nodeType, Set<String> ipAddresses) { - return new AddNode(hostname, Optional.of(parentHostname), Optional.empty(), Optional.empty(), Optional.of(nodeResources), nodeType, ipAddresses, Set.of()); + return new AddNode(hostname, Optional.empty(), Optional.of(parentHostname), Optional.empty(), Optional.empty(), Optional.of(nodeResources), nodeType, ipAddresses, Set.of()); } - private AddNode(String hostname, Optional<String> parentHostname, + private AddNode(String hostname, Optional<String> id, Optional<String> parentHostname, Optional<String> nodeFlavor, Optional<FlavorOverrides> flavorOverrides, Optional<NodeResources> nodeResources, NodeType nodeType, Set<String> ipAddresses, Set<String> additionalIpAddresses) { this.hostname = hostname; + this.id = id; this.parentHostname = parentHostname; this.nodeFlavor = nodeFlavor; this.flavorOverrides = flavorOverrides; @@ -51,8 +53,11 @@ public class AddNode { if (o == null || getClass() != o.getClass()) return false; AddNode addNode = (AddNode) o; return Objects.equals(hostname, addNode.hostname) && + Objects.equals(id, addNode.id) && Objects.equals(parentHostname, addNode.parentHostname) && Objects.equals(nodeFlavor, addNode.nodeFlavor) && + Objects.equals(flavorOverrides, addNode.flavorOverrides) && + Objects.equals(nodeResources, addNode.nodeResources) && nodeType == addNode.nodeType && Objects.equals(ipAddresses, addNode.ipAddresses) && Objects.equals(additionalIpAddresses, addNode.additionalIpAddresses); @@ -60,15 +65,18 @@ public class AddNode { @Override public int hashCode() { - return Objects.hash(hostname, parentHostname, nodeFlavor, nodeType, ipAddresses, additionalIpAddresses); + return Objects.hash(hostname, id, parentHostname, nodeFlavor, flavorOverrides, nodeResources, nodeType, ipAddresses, additionalIpAddresses); } @Override public String toString() { return "AddNode{" + "hostname='" + hostname + '\'' + + ", id=" + id + ", parentHostname=" + parentHostname + ", nodeFlavor='" + nodeFlavor + '\'' + + ", flavorOverrides='" + flavorOverrides + '\'' + + ", nodeResources='" + nodeResources + '\'' + ", nodeType=" + nodeType + ", ipAddresses=" + ipAddresses + ", additionalIpAddresses=" + additionalIpAddresses + diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java index f7d68fe87ea..41f0932419b 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeSpec.java @@ -25,6 +25,7 @@ import static com.yahoo.config.provision.NodeResources.DiskSpeed.slow; public class NodeSpec { private final String hostname; + private final Optional<String> id; private final NodeState state; private final NodeType type; private final String flavor; @@ -66,6 +67,7 @@ public class NodeSpec { public NodeSpec( String hostname, + Optional<String> id, Optional<DockerImage> wantedDockerImage, Optional<DockerImage> currentDockerImage, NodeState state, @@ -102,6 +104,7 @@ public class NodeSpec { } this.hostname = Objects.requireNonNull(hostname); + this.id = Objects.requireNonNull(id); this.wantedDockerImage = Objects.requireNonNull(wantedDockerImage); this.currentDockerImage = Objects.requireNonNull(currentDockerImage); this.state = Objects.requireNonNull(state); @@ -134,6 +137,11 @@ public class NodeSpec { return hostname; } + /** Returns the cloud-specific ID of the host. */ + public Optional<String> id() { + return id; + } + public NodeState state() { return state; } @@ -268,11 +276,13 @@ public class NodeSpec { NodeSpec that = (NodeSpec) o; return Objects.equals(hostname, that.hostname) && + Objects.equals(id, that.id) && Objects.equals(wantedDockerImage, that.wantedDockerImage) && Objects.equals(currentDockerImage, that.currentDockerImage) && Objects.equals(state, that.state) && Objects.equals(type, that.type) && Objects.equals(flavor, that.flavor) && + Objects.equals(modelName, that.modelName) && Objects.equals(wantedVespaVersion, that.wantedVespaVersion) && Objects.equals(currentVespaVersion, that.currentVespaVersion) && Objects.equals(wantedOsVersion, that.wantedOsVersion) && @@ -299,11 +309,13 @@ public class NodeSpec { public int hashCode() { return Objects.hash( hostname, + id, wantedDockerImage, currentDockerImage, state, type, flavor, + modelName, wantedVespaVersion, currentVespaVersion, wantedOsVersion, @@ -330,11 +342,13 @@ public class NodeSpec { public String toString() { return getClass().getSimpleName() + " {" + " hostname=" + hostname + + " id=" + id + " wantedDockerImage=" + wantedDockerImage + " currentDockerImage=" + currentDockerImage + " state=" + state + " type=" + type + " flavor=" + flavor + + " modelName=" + modelName + " wantedVespaVersion=" + wantedVespaVersion + " currentVespaVersion=" + currentVespaVersion + " wantedOsVersion=" + wantedOsVersion @@ -360,6 +374,7 @@ public class NodeSpec { public static class Builder { private String hostname; + private Optional<String> id = Optional.empty(); private NodeState state; private NodeType type; private String flavor; @@ -423,6 +438,11 @@ public class NodeSpec { return this; } + public Builder id(String id) { + this.id = Optional.of(id); + return this; + } + public Builder wantedDockerImage(DockerImage wantedDockerImage) { this.wantedDockerImage = Optional.of(wantedDockerImage); return this; @@ -681,7 +701,7 @@ public class NodeSpec { } public NodeSpec build() { - return new NodeSpec(hostname, wantedDockerImage, currentDockerImage, state, type, flavor, + return new NodeSpec(hostname, id, wantedDockerImage, currentDockerImage, state, type, flavor, wantedVespaVersion, currentVespaVersion, wantedOsVersion, currentOsVersion, orchestratorStatus, owner, membership, wantedRestartGeneration, currentRestartGeneration, diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java index 0747912bba2..5069f02c6b7 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java @@ -153,6 +153,7 @@ public class RealNodeRepository implements NodeRepository { NodeReports reports = NodeReports.fromMap(Optional.ofNullable(node.reports).orElseGet(Map::of)); return new NodeSpec( node.hostname, + Optional.ofNullable(node.openStackId), Optional.ofNullable(node.wantedDockerImage).map(DockerImage::fromString), Optional.ofNullable(node.currentDockerImage).map(DockerImage::fromString), nodeState, @@ -227,7 +228,7 @@ public class RealNodeRepository implements NodeRepository { private static NodeRepositoryNode nodeRepositoryNodeFromAddNode(AddNode addNode) { NodeRepositoryNode node = new NodeRepositoryNode(); - node.openStackId = "fake-" + addNode.hostname; + node.openStackId = addNode.id.orElse("fake-" + addNode.hostname); node.hostname = addNode.hostname; node.parentHostname = addNode.parentHostname.orElse(null); addNode.nodeFlavor.ifPresent(f -> node.flavor = f); diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Template.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Template.java index cef35803e98..1ab4a30dc58 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Template.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/task/util/file/Template.java @@ -3,7 +3,6 @@ package com.yahoo.vespa.hosted.node.admin.task.util.file; import org.apache.velocity.VelocityContext; import org.apache.velocity.app.Velocity; -import org.apache.velocity.app.VelocityEngine; import java.io.StringWriter; import java.nio.file.Files; @@ -19,16 +18,16 @@ import static com.yahoo.yolean.Exceptions.uncheck; */ public class Template { - private final VelocityEngine velocityEngine = new VelocityEngine(); + static { + Velocity.addProperty(Velocity.RUNTIME_LOG_LOGSYSTEM_CLASS, "org.apache.velocity.runtime.log.NullLogSystem"); + Velocity.init(); + } + private final VelocityContext velocityContext = new VelocityContext(); private final String template; private Template(String template) { this.template = template; - - velocityEngine.addProperty(Velocity.RUNTIME_LOG_LOGSYSTEM_CLASS, - "org.apache.velocity.runtime.log.NullLogSystem"); - velocityEngine.init(); } public static Template at(Path templatePath) { @@ -50,7 +49,7 @@ public class Template { public String render() { StringWriter writer = new StringWriter(); - velocityEngine.evaluate(velocityContext, writer, "Template", template); + Velocity.evaluate(velocityContext, writer, "Template", template); return writer.toString(); } diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java index 6d26e16f314..fe06812c608 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java @@ -163,6 +163,7 @@ public class RealNodeRepositoryTest { @Test public void testAddNodes() { AddNode host = AddNode.forHost("host123.domain.tld", + Optional.of("id1"), "default", Optional.of(FlavorOverrides.ofDisk(123)), NodeType.confighost, @@ -175,6 +176,7 @@ public class RealNodeRepositoryTest { nodeRepositoryApi.addNodes(List.of(host, node)); NodeSpec hostSpec = nodeRepositoryApi.getOptionalNode("host123.domain.tld").orElseThrow(); + assertEquals("id1", hostSpec.id().orElseThrow()); assertEquals("default", hostSpec.flavor()); assertEquals(123, hostSpec.diskGb(), 0); assertEquals(NodeType.confighost, hostSpec.type()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java index b16859fa6fb..bddbcf43bd0 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/applications/Cluster.java @@ -5,6 +5,7 @@ import com.yahoo.config.provision.ClusterResources; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.vespa.hosted.provision.autoscale.Autoscaler; +import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.List; @@ -127,6 +128,32 @@ public class Cluster { return new Cluster(id, exclusive, min, max, suggested, target, scalingEvents, autoscalingStatus); } + /** The predicted duration of a rescaling of this cluster */ + public Duration scalingDuration(ClusterSpec clusterSpec) { + int completedEventCount = 0; + Duration totalDuration = Duration.ZERO; + for (ScalingEvent event : scalingEvents()) { + if (event.duration().isEmpty()) continue; + completedEventCount++; + totalDuration = totalDuration.plus(event.duration().get()); + } + + if (completedEventCount == 0) { // Use defaults + if (clusterSpec.isStateful()) return Duration.ofHours(12); + return Duration.ofMinutes(10); + } + else { + Duration predictedDuration = totalDuration.dividedBy(completedEventCount); + + // TODO: Remove when we have reliable completion for content clusters + if (clusterSpec.isStateful() && predictedDuration.minus(Duration.ofHours(12)).isNegative()) + return Duration.ofHours(12); + + if (predictedDuration.minus(Duration.ofMinutes(5)).isNegative()) return Duration.ofMinutes(5); // minimum + return predictedDuration; + } + } + @Override public int hashCode() { return id.hashCode(); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/AllocationOptimizer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/AllocationOptimizer.java index 84634b26c4a..14e68bc2f0f 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/AllocationOptimizer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/AllocationOptimizer.java @@ -16,8 +16,7 @@ import java.util.Optional; public class AllocationOptimizer { // The min and max nodes to consider when not using application supplied limits - private static final int minimumStatelessNodes = 2; // Since this number includes redundancy it cannot be lower than 2 - private static final int minimumStatefulNodes = 3; // Leader election requires 3 nodes to have redundancy + private static final int minimumNodes = 2; // Since this number includes redundancy it cannot be lower than 2 private static final int maximumNodes = 150; // When a query is issued on a node the cost is the sum of a fixed cost component and a cost component @@ -41,7 +40,7 @@ public class AllocationOptimizer { public Optional<AllocatableClusterResources> findBestAllocation(ResourceTarget target, AllocatableClusterResources current, Limits limits) { - int minimumNodes = current.clusterSpec().isStateful() ? minimumStatefulNodes : minimumStatelessNodes; + int minimumNodes = AllocationOptimizer.minimumNodes; if (limits.isEmpty()) limits = Limits.of(new ClusterResources(minimumNodes, 1, NodeResources.unspecified()), new ClusterResources(maximumNodes, maximumNodes, NodeResources.unspecified())); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java index 2d192fae11f..5d5c6fdac5a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Autoscaler.java @@ -2,14 +2,12 @@ package com.yahoo.vespa.hosted.provision.autoscale; import com.yahoo.config.provision.ClusterResources; -import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.NodeResources; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.applications.Application; import com.yahoo.vespa.hosted.provision.applications.Cluster; -import com.yahoo.vespa.hosted.provision.applications.ScalingEvent; import java.time.Duration; import java.time.Instant; @@ -23,9 +21,9 @@ import java.util.Optional; */ public class Autoscaler { - /** What cost difference factor is worth a reallocation? */ + /** What cost difference is worth a reallocation? */ private static final double costDifferenceWorthReallocation = 0.1; - /** What difference factor for a resource is worth a reallocation? */ + /** What resource difference is worth a reallocation? */ private static final double resourceDifferenceWorthReallocation = 0.1; private final MetricsDb metricsDb; @@ -64,42 +62,38 @@ public class Autoscaler { if ( ! stable(clusterNodes, nodeRepository)) return Advice.none("Cluster change in progress"); - Duration scalingWindow = scalingWindow(clusterNodes.clusterSpec(), cluster); + Duration scalingWindow = cluster.scalingDuration(clusterNodes.clusterSpec()); if (scaledIn(scalingWindow, cluster)) - return Advice.dontScale("Won't autoscale now: Less than " + scalingWindow + " since last rescaling"); + return Advice.dontScale("Won't autoscale now: Less than " + scalingWindow + " since last resource change"); - ClusterTimeseries clusterTimeseries = - new ClusterTimeseries(scalingWindow, cluster, clusterNodes, metricsDb); - AllocatableClusterResources currentAllocation = - new AllocatableClusterResources(clusterNodes.asList(), nodeRepository, cluster.exclusive()); + var clusterNodesTimeseries = new ClusterNodesTimeseries(scalingWindow, cluster, clusterNodes, metricsDb); + var currentAllocation = new AllocatableClusterResources(clusterNodes.asList(), nodeRepository, cluster.exclusive()); - int measurementsPerNode = clusterTimeseries.measurementsPerNode(); + int measurementsPerNode = clusterNodesTimeseries.measurementsPerNode(); if (measurementsPerNode < minimumMeasurementsPerNode(scalingWindow)) - return Advice.none("Collecting more data before making new scaling decisions: " + - "Have " + measurementsPerNode + " measurements per node but require " + - minimumMeasurementsPerNode(scalingWindow)); + return Advice.none("Collecting more data before making new scaling decisions: Need to measure for " + + scalingWindow + " since the last resource change completed"); - int nodesMeasured = clusterTimeseries.nodesMeasured(); + int nodesMeasured = clusterNodesTimeseries.nodesMeasured(); if (nodesMeasured != clusterNodes.size()) return Advice.none("Collecting more data before making new scaling decisions: " + - "Have measurements from " + nodesMeasured + " but require from " + clusterNodes.size()); + "Have measurements from " + nodesMeasured + " nodes, but require from " + clusterNodes.size()); - double cpuLoad = clusterTimeseries.averageLoad(Resource.cpu); - double memoryLoad = clusterTimeseries.averageLoad(Resource.memory); - double diskLoad = clusterTimeseries.averageLoad(Resource.disk); - var target = ResourceTarget.idealLoad(cpuLoad, memoryLoad, diskLoad, currentAllocation, application); + var clusterTimeseries = metricsDb.getClusterTimeseries(application.id(), cluster.id()); + var target = ResourceTarget.idealLoad(clusterTimeseries, clusterNodesTimeseries, currentAllocation, application); Optional<AllocatableClusterResources> bestAllocation = allocationOptimizer.findBestAllocation(target, currentAllocation, limits); if (bestAllocation.isEmpty()) - return Advice.dontScale("No allocation changes are possible within configured limits"); + return Advice.dontScale("No allocation improvements are possible within configured limits"); if (similar(bestAllocation.get().realResources(), currentAllocation.realResources())) return Advice.dontScale("Cluster is ideally scaled within configured limits"); if (isDownscaling(bestAllocation.get(), currentAllocation) && scaledIn(scalingWindow.multipliedBy(3), cluster)) - return Advice.dontScale("Waiting " + scalingWindow.multipliedBy(3) + " since last rescaling before reducing resources"); + return Advice.dontScale("Waiting " + scalingWindow.multipliedBy(3) + + " since the last change before reducing resources"); return Advice.scaleTo(bestAllocation.get().advertisedResources()); } @@ -128,32 +122,6 @@ public class Autoscaler { .isAfter(nodeRepository.clock().instant().minus(delay)); } - /** The duration of the window we need to consider to make a scaling decision. See also minimumMeasurementsPerNode */ - private Duration scalingWindow(ClusterSpec clusterSpec, Cluster cluster) { - int completedEventCount = 0; - Duration totalDuration = Duration.ZERO; - for (ScalingEvent event : cluster.scalingEvents()) { - if (event.duration().isEmpty()) continue; - completedEventCount++; - totalDuration = totalDuration.plus(event.duration().get()); - } - - if (completedEventCount == 0) { // Use defaults - if (clusterSpec.isStateful()) return Duration.ofHours(12); - return Duration.ofMinutes(10); - } - else { - Duration predictedDuration = totalDuration.dividedBy(completedEventCount); - - // TODO: Remove when we have reliable completion for content clusters - if (clusterSpec.isStateful() && predictedDuration.minus(Duration.ofHours(12)).isNegative()) - return Duration.ofHours(12); - - if (predictedDuration.minus(Duration.ofMinutes(5)).isNegative()) return Duration.ofMinutes(5); // minimum - return predictedDuration; - } - } - static Duration maxScalingWindow() { return Duration.ofHours(48); } @@ -213,7 +181,7 @@ public class Autoscaler { private static Advice none(String reason) { return new Advice(Optional.empty(), false, reason); } private static Advice dontScale(String reason) { return new Advice(Optional.empty(), true, reason); } private static Advice scaleTo(ClusterResources target) { - return new Advice(Optional.of(target), true, "Scaling due to load changes"); + return new Advice(Optional.of(target), true, "Scheduled scaling to " + target + " due to load changes"); } @Override diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterMetricSnapshot.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterMetricSnapshot.java new file mode 100644 index 00000000000..fd8e91584c4 --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterMetricSnapshot.java @@ -0,0 +1,42 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.autoscale; + +import java.time.Instant; + +/** + * Cluster level metrics. + * These are aggregated at fetch time over the nodes in the cluster at that point in time. + * + * @author bratseth + */ +public class ClusterMetricSnapshot implements Comparable<ClusterMetricSnapshot> { + + private final Instant at; + + private final double queryRate; + + public ClusterMetricSnapshot(Instant at, double queryRate) { + this.at = at; + this.queryRate = queryRate; + } + + public Instant at() { return at; } + + /** Queries per second */ + public double queryRate() { return queryRate; } + + public ClusterMetricSnapshot withQueryRate(double queryRate) { + return new ClusterMetricSnapshot(at, queryRate); + } + + @Override + public int compareTo(ClusterMetricSnapshot other) { + return at.compareTo(other.at); + } + + @Override + public String toString() { return "metrics at " + at + ":" + + " queryRate: " + queryRate; + } + +} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterNodesTimeseries.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterNodesTimeseries.java new file mode 100644 index 00000000000..173d76e4c26 --- /dev/null +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterNodesTimeseries.java @@ -0,0 +1,76 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.autoscale; + +import com.yahoo.vespa.hosted.provision.NodeList; +import com.yahoo.vespa.hosted.provision.applications.Cluster; + +import java.time.Duration; +import java.util.List; +import java.util.function.Predicate; +import java.util.stream.Collectors; + +/** + * A series of metric snapshots for the nodes of a cluster used to compute load + * + * @author bratseth + */ +public class ClusterNodesTimeseries { + + private final Cluster cluster; + private final NodeList clusterNodes; + + /** The measurements for all nodes in this snapshot */ + private final List<NodeTimeseries> timeseries; + + public ClusterNodesTimeseries(Duration period, Cluster cluster, NodeList clusterNodes, MetricsDb db) { + this.cluster = cluster; + this.clusterNodes = clusterNodes; + var timeseries = db.getNodeTimeseries(period, clusterNodes); + + if (cluster.lastScalingEvent().isPresent()) + timeseries = filter(timeseries, snapshot -> snapshot.generation() < 0 || // Content nodes do not yet send generation + snapshot.generation() >= cluster.lastScalingEvent().get().generation()); + timeseries = filter(timeseries, snapshot -> snapshot.inService() && snapshot.stable()); + + this.timeseries = timeseries; + } + + /** The cluster this is a timeseries for */ + public Cluster cluster() { return cluster; } + + /** The nodes of the cluster this is a timeseries for */ + public NodeList clusterNodes() { return clusterNodes; } + + /** Returns the average number of measurements per node */ + public int measurementsPerNode() { + int measurementCount = timeseries.stream().mapToInt(m -> m.size()).sum(); + return measurementCount / clusterNodes.size(); + } + + /** Returns the number of nodes measured in this */ + public int nodesMeasured() { + return timeseries.size(); + } + + /** Returns the average load of this resource in this */ + public double averageLoad(Resource resource) { + int measurementCount = timeseries.stream().mapToInt(m -> m.size()).sum(); + if (measurementCount == 0) return 0; + double measurementSum = timeseries.stream().flatMap(m -> m.asList().stream()).mapToDouble(m -> value(resource, m)).sum(); + return measurementSum / measurementCount; + } + + private double value(Resource resource, NodeMetricSnapshot snapshot) { + switch (resource) { + case cpu: return snapshot.cpu(); + case memory: return snapshot.memory(); + case disk: return snapshot.disk(); + default: throw new IllegalArgumentException("Got an unknown resource " + resource); + } + } + + private List<NodeTimeseries> filter(List<NodeTimeseries> timeseries, Predicate<NodeMetricSnapshot> filter) { + return timeseries.stream().map(nodeTimeseries -> nodeTimeseries.filter(filter)).collect(Collectors.toList()); + } + +} diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterTimeseries.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterTimeseries.java index e359579117f..5b6ed43b713 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterTimeseries.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterTimeseries.java @@ -1,70 +1,104 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.autoscale; +import com.yahoo.config.provision.ClusterSpec; import com.yahoo.vespa.hosted.provision.NodeList; -import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.applications.Cluster; import java.time.Duration; import java.time.Instant; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.function.Predicate; import java.util.stream.Collectors; /** - * A series of metric snapshots for all nodes in a cluster + * A list of metric snapshots from a cluster, sorted by increasing time (newest last). * * @author bratseth */ public class ClusterTimeseries { - private final NodeList clusterNodes; + private final ClusterSpec.Id cluster; + private final List<ClusterMetricSnapshot> snapshots; - /** The measurements for all nodes in this snapshot */ - private final List<NodeTimeseries> allTimeseries; + ClusterTimeseries(ClusterSpec.Id cluster, List<ClusterMetricSnapshot> snapshots) { + this.cluster = cluster; + List<ClusterMetricSnapshot> sortedSnapshots = new ArrayList<>(snapshots); + Collections.sort(sortedSnapshots); + this.snapshots = Collections.unmodifiableList(sortedSnapshots); + } + + public boolean isEmpty() { return snapshots.isEmpty(); } + + public int size() { return snapshots.size(); } + + public ClusterMetricSnapshot get(int index) { return snapshots.get(index); } - public ClusterTimeseries(Duration period, Cluster cluster, NodeList clusterNodes, MetricsDb db) { - this.clusterNodes = clusterNodes; - var timeseries = db.getNodeTimeseries(period, clusterNodes); + public List<ClusterMetricSnapshot> asList() { return snapshots; } - if (cluster.lastScalingEvent().isPresent()) - timeseries = filter(timeseries, snapshot -> snapshot.generation() < 0 || // Content nodes do not yet send generation - snapshot.generation() >= cluster.lastScalingEvent().get().generation()); - timeseries = filter(timeseries, snapshot -> snapshot.inService() && snapshot.stable()); + public ClusterSpec.Id cluster() { return cluster; } - this.allTimeseries = timeseries; + public ClusterTimeseries add(ClusterMetricSnapshot snapshot) { + List<ClusterMetricSnapshot> list = new ArrayList<>(snapshots); + list.add(snapshot); + return new ClusterTimeseries(cluster, list); } - /** Returns the average number of measurements per node */ - public int measurementsPerNode() { - int measurementCount = allTimeseries.stream().mapToInt(m -> m.size()).sum(); - return measurementCount / clusterNodes.size(); + /** The max query growth rate we can predict from this time-series as a fraction of the current traffic per minute */ + public double maxQueryGrowthRate() { + if (snapshots.isEmpty()) return 0.1; + + // Find the period having the highest growth rate, where total growth exceeds 30% increase + double maxGrowthRate = 0; // In query rate per minute + for (int start = 0; start < snapshots.size(); start++) { + if (start > 0) { // Optimization: Skip this point when starting from the previous is better relative to the best rate so far + Duration duration = durationBetween(start - 1, start); + if (duration.toMinutes() != 0) { + double growthRate = (queryRateAt(start - 1) - queryRateAt(start)) / duration.toMinutes(); + if (growthRate >= maxGrowthRate) + continue; + } + } + for (int end = start + 1; end < snapshots.size(); end++) { + if (queryRateAt(end) >= queryRateAt(start) * 1.3) { + Duration duration = durationBetween(start, end); + if (duration.toMinutes() == 0) continue; + double growthRate = (queryRateAt(end) - queryRateAt(start)) / duration.toMinutes(); + if (growthRate > maxGrowthRate) + maxGrowthRate = growthRate; + } + } + } + if (maxGrowthRate == 0) { // No periods of significant growth + if (durationBetween(0, snapshots.size() - 1).toHours() < 24) + return 0.1; // ... because not much data + else + return 0.0; // ... because load is stable + } + if (queryRateNow() == 0) return 0.1; // Growth not expressible as a fraction of the current rate + return maxGrowthRate / queryRateNow(); } - /** Returns the number of nodes measured in this */ - public int nodesMeasured() { - return allTimeseries.size(); + /** The current query rate as a fraction of the peak rate in this timeseries */ + public double currentQueryFractionOfMax() { + if (snapshots.isEmpty()) return 0.5; + var max = snapshots.stream().mapToDouble(ClusterMetricSnapshot::queryRate).max().getAsDouble(); + if (max == 0) return 1.0; + return snapshots.get(snapshots.size() - 1).queryRate() / max; } - /** Returns the average load of this resource in this */ - public double averageLoad(Resource resource) { - int measurementCount = allTimeseries.stream().mapToInt(m -> m.size()).sum(); - if (measurementCount == 0) return 0; - double measurementSum = allTimeseries.stream().flatMap(m -> m.asList().stream()).mapToDouble(m -> value(resource, m)).sum(); - return measurementSum / measurementCount; + private double queryRateAt(int index) { + return snapshots.get(index).queryRate(); } - private double value(Resource resource, MetricSnapshot snapshot) { - switch (resource) { - case cpu: return snapshot.cpu(); - case memory: return snapshot.memory(); - case disk: return snapshot.disk(); - default: throw new IllegalArgumentException("Got an unknown resource " + resource); - } + private double queryRateNow() { + return queryRateAt(snapshots.size() - 1); } - private List<NodeTimeseries> filter(List<NodeTimeseries> timeseries, Predicate<MetricSnapshot> filter) { - return timeseries.stream().map(nodeTimeseries -> nodeTimeseries.filter(filter)).collect(Collectors.toList()); + private Duration durationBetween(int startIndex, int endIndex) { + return Duration.between(snapshots.get(startIndex).at(), snapshots.get(endIndex).at()); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MemoryMetricsDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MemoryMetricsDb.java index 1b1e5933604..bf8d354665a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MemoryMetricsDb.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MemoryMetricsDb.java @@ -2,9 +2,12 @@ package com.yahoo.vespa.hosted.provision.autoscale; import com.yahoo.collections.Pair; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepository; +import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.ArrayList; @@ -26,8 +29,10 @@ public class MemoryMetricsDb implements MetricsDb { private final NodeRepository nodeRepository; - /** Metric time seriest by node (hostname). Each list of metric snapshots is sorted by increasing timestamp */ - private final Map<String, NodeTimeseries> db = new HashMap<>(); + /** Metric time series by node (hostname). Each list of metric snapshots is sorted by increasing timestamp */ + private final Map<String, NodeTimeseries> nodeTimeseries = new HashMap<>(); + + private final Map<Pair<ApplicationId, ClusterSpec.Id>, ClusterTimeseries> clusterTimeseries = new HashMap<>(); /** Lock all access for now since we modify lists inside a map */ private final Object lock = new Object(); @@ -37,7 +42,10 @@ public class MemoryMetricsDb implements MetricsDb { } @Override - public void add(Collection<Pair<String, MetricSnapshot>> nodeMetrics) { + public Clock clock() { return nodeRepository.clock(); } + + @Override + public void addNodeMetrics(Collection<Pair<String, NodeMetricSnapshot>> nodeMetrics) { synchronized (lock) { for (var value : nodeMetrics) { add(value.getFirst(), value.getSecond()); @@ -46,27 +54,48 @@ public class MemoryMetricsDb implements MetricsDb { } @Override + public void addClusterMetrics(ApplicationId application, Map<ClusterSpec.Id, ClusterMetricSnapshot> clusterMetrics) { + synchronized (lock) { + for (var value : clusterMetrics.entrySet()) { + add(application, value.getKey(), value.getValue()); + } + } + } + + public void clearClusterMetrics(ApplicationId application, ClusterSpec.Id cluster) { + synchronized (lock) { + clusterTimeseries.remove(new Pair<>(application, cluster)); + } + } + + @Override public List<NodeTimeseries> getNodeTimeseries(Duration period, Set<String> hostnames) { Instant startTime = nodeRepository.clock().instant().minus(period); synchronized (lock) { return hostnames.stream() - .map(hostname -> db.getOrDefault(hostname, new NodeTimeseries(hostname, List.of())).justAfter(startTime)) + .map(hostname -> nodeTimeseries.getOrDefault(hostname, new NodeTimeseries(hostname, List.of())).justAfter(startTime)) .collect(Collectors.toList()); } } @Override + public ClusterTimeseries getClusterTimeseries(ApplicationId application, ClusterSpec.Id cluster) { + return clusterTimeseries.computeIfAbsent(new Pair<>(application, cluster), + __ -> new ClusterTimeseries(cluster, new ArrayList<>())); + } + + @Override public void gc() { synchronized (lock) { // Each measurement is Object + long + float = 16 + 8 + 4 = 28 bytes // 12 hours with 1k nodes and 3 resources and 1 measurement/sec is about 5Gb - for (String hostname : db.keySet()) { - var timeseries = db.get(hostname); + for (String hostname : nodeTimeseries.keySet()) { + var timeseries = nodeTimeseries.get(hostname); timeseries = timeseries.justAfter(nodeRepository.clock().instant().minus(Autoscaler.maxScalingWindow())); if (timeseries.isEmpty()) - db.remove(hostname); + nodeTimeseries.remove(hostname); else - db.put(hostname, timeseries); + nodeTimeseries.put(hostname, timeseries); } } } @@ -74,16 +103,22 @@ public class MemoryMetricsDb implements MetricsDb { @Override public void close() {} - private void add(String hostname, MetricSnapshot snapshot) { - NodeTimeseries timeseries = db.get(hostname); + private void add(String hostname, NodeMetricSnapshot snapshot) { + NodeTimeseries timeseries = nodeTimeseries.get(hostname); if (timeseries == null) { // new node Optional<Node> node = nodeRepository.nodes().node(hostname); if (node.isEmpty()) return; if (node.get().allocation().isEmpty()) return; timeseries = new NodeTimeseries(hostname, new ArrayList<>()); - db.put(hostname, timeseries); + nodeTimeseries.put(hostname, timeseries); } - db.put(hostname, timeseries.add(snapshot)); + nodeTimeseries.put(hostname, timeseries.add(snapshot)); + } + + private void add(ApplicationId application, ClusterSpec.Id cluster, ClusterMetricSnapshot snapshot) { + var key = new Pair<>(application, cluster); + var existing = clusterTimeseries.computeIfAbsent(key, __ -> new ClusterTimeseries(cluster, new ArrayList<>())); + clusterTimeseries.put(key, existing.add(snapshot)); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsDb.java index 6fdc87f2448..568c5f88661 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsDb.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsDb.java @@ -2,15 +2,17 @@ package com.yahoo.vespa.hosted.provision.autoscale; import com.yahoo.collections.Pair; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import java.time.Clock; import java.time.Duration; -import java.time.Instant; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -21,8 +23,12 @@ import java.util.stream.Collectors; */ public interface MetricsDb { - /** Adds snapshots to this. */ - void add(Collection<Pair<String, MetricSnapshot>> nodeMetrics); + Clock clock(); + + /** Adds node snapshots to this. */ + void addNodeMetrics(Collection<Pair<String, NodeMetricSnapshot>> nodeMetrics); + + void addClusterMetrics(ApplicationId application, Map<ClusterSpec.Id, ClusterMetricSnapshot> clusterMetrics); /** * Returns a list with one entry for each hostname containing @@ -36,12 +42,15 @@ public interface MetricsDb { return getNodeTimeseries(period, nodes.stream().map(Node::hostname).collect(Collectors.toSet())); } + /** Returns all cluster level metric snapshots for a given cluster */ + ClusterTimeseries getClusterTimeseries(ApplicationId applicationId, ClusterSpec.Id clusterId); + /** Must be called intermittently (as long as add is called) to gc old data */ void gc(); void close(); - static MetricsDb createTestInstance(NodeRepository nodeRepository) { + static MemoryMetricsDb createTestInstance(NodeRepository nodeRepository) { return new MemoryMetricsDb(nodeRepository); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsResponse.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsResponse.java index d6661b89536..0fa7a0e0bb1 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsResponse.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsResponse.java @@ -11,7 +11,6 @@ import com.yahoo.slime.SlimeUtils; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; -import com.yahoo.vespa.hosted.provision.applications.Application; import java.time.Instant; import java.util.ArrayList; @@ -28,14 +27,21 @@ import java.util.Optional; */ public class MetricsResponse { - private final Collection<Pair<String, MetricSnapshot>> nodeMetrics; + /** Node level metrics */ + private final Collection<Pair<String, NodeMetricSnapshot>> nodeMetrics; + + /** + * Cluster level metrics. + * Must be aggregated at fetch time to avoid issues with nodes and nodes joining/leaving the cluster over time. + */ + private final Map<ClusterSpec.Id, ClusterMetricSnapshot> clusterMetrics = new HashMap<>(); /** Creates this from a metrics/V2 response */ public MetricsResponse(String response, NodeList applicationNodes, NodeRepository nodeRepository) { this(SlimeUtils.jsonToSlime(response), applicationNodes, nodeRepository); } - public MetricsResponse(Collection<Pair<String, MetricSnapshot>> metrics) { + public MetricsResponse(Collection<Pair<String, NodeMetricSnapshot>> metrics) { this.nodeMetrics = metrics; } @@ -46,7 +52,9 @@ public class MetricsResponse { nodes.traverse((ArrayTraverser)(__, node) -> consumeNode(node, applicationNodes, nodeRepository)); } - public Collection<Pair<String, MetricSnapshot>> metrics() { return nodeMetrics; } + public Collection<Pair<String, NodeMetricSnapshot>> nodeMetrics() { return nodeMetrics; } + + public Map<ClusterSpec.Id, ClusterMetricSnapshot> clusterMetrics() { return clusterMetrics; } private void consumeNode(Inspector node, NodeList applicationNodes, NodeRepository nodeRepository) { String hostname = node.field("hostname").asString(); @@ -59,14 +67,21 @@ public class MetricsResponse { if (node.isEmpty()) return; // Node is not part of this cluster any more long timestampSecond = nodeData.field("timestamp").asLong(); Map<String, Double> values = consumeMetrics(nodeData.field("metrics")); - nodeMetrics.add(new Pair<>(hostname, new MetricSnapshot(Instant.ofEpochMilli(timestampSecond * 1000), - Metric.cpu.from(values), - Metric.memory.from(values), - Metric.disk.from(values), - (long)Metric.generation.from(values), - Metric.inService.from(values) > 0, - clusterIsStable(node.get(), applicationNodes, nodeRepository), - Metric.queryRate.from(values)))); + Instant at = Instant.ofEpochMilli(timestampSecond * 1000); + + nodeMetrics.add(new Pair<>(hostname, new NodeMetricSnapshot(at, + Metric.cpu.from(values), + Metric.memory.from(values), + Metric.disk.from(values), + (long)Metric.generation.from(values), + Metric.inService.from(values) > 0, + clusterIsStable(node.get(), applicationNodes, nodeRepository), + Metric.queryRate.from(values)))); + + var cluster = node.get().allocation().get().membership().cluster().id(); + var metrics = clusterMetrics.getOrDefault(cluster, new ClusterMetricSnapshot(at, 0.0)); + metrics = metrics.withQueryRate(metrics.queryRate() + Metric.queryRate.from(values)); + clusterMetrics.put(cluster, metrics); } private boolean clusterIsStable(Node node, NodeList applicationNodes, NodeRepository nodeRepository) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricSnapshot.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/NodeMetricSnapshot.java index 82812592809..be9f7bd4819 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/MetricSnapshot.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/NodeMetricSnapshot.java @@ -8,7 +8,7 @@ import java.time.Instant; * * @author bratseth */ -public class MetricSnapshot implements Comparable<MetricSnapshot> { +public class NodeMetricSnapshot implements Comparable<NodeMetricSnapshot> { private final Instant at; @@ -20,9 +20,9 @@ public class MetricSnapshot implements Comparable<MetricSnapshot> { private final boolean stable; private final double queryRate; - public MetricSnapshot(Instant at, double cpu, double memory, double disk, - long generation, boolean inService, boolean stable, - double queryRate) { + public NodeMetricSnapshot(Instant at, double cpu, double memory, double disk, + long generation, boolean inService, boolean stable, + double queryRate) { this.at = at; this.cpu = cpu; this.memory = memory; @@ -48,7 +48,7 @@ public class MetricSnapshot implements Comparable<MetricSnapshot> { public boolean stable() { return stable; } @Override - public int compareTo(MetricSnapshot other) { + public int compareTo(NodeMetricSnapshot other) { return at.compareTo(other.at); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/NodeTimeseries.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/NodeTimeseries.java index 24876609f58..cedc2edfe63 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/NodeTimeseries.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/NodeTimeseries.java @@ -16,11 +16,11 @@ import java.util.stream.Collectors; public class NodeTimeseries { private final String hostname; - private final List<MetricSnapshot> snapshots; + private final List<NodeMetricSnapshot> snapshots; - NodeTimeseries(String hostname, List<MetricSnapshot> snapshots) { + NodeTimeseries(String hostname, List<NodeMetricSnapshot> snapshots) { this.hostname = hostname; - List<MetricSnapshot> sortedSnapshots = new ArrayList<>(snapshots); + List<NodeMetricSnapshot> sortedSnapshots = new ArrayList<>(snapshots); Collections.sort(sortedSnapshots); this.snapshots = Collections.unmodifiableList(sortedSnapshots); } @@ -29,19 +29,19 @@ public class NodeTimeseries { public int size() { return snapshots.size(); } - public MetricSnapshot get(int index) { return snapshots.get(index); } + public NodeMetricSnapshot get(int index) { return snapshots.get(index); } - public List<MetricSnapshot> asList() { return snapshots; } + public List<NodeMetricSnapshot> asList() { return snapshots; } public String hostname() { return hostname; } - public NodeTimeseries add(MetricSnapshot snapshot) { - List<MetricSnapshot> list = new ArrayList<>(snapshots); + public NodeTimeseries add(NodeMetricSnapshot snapshot) { + List<NodeMetricSnapshot> list = new ArrayList<>(snapshots); list.add(snapshot); return new NodeTimeseries(hostname(), list); } - public NodeTimeseries filter(Predicate<MetricSnapshot> filter) { + public NodeTimeseries filter(Predicate<NodeMetricSnapshot> filter) { return new NodeTimeseries(hostname, snapshots.stream().filter(filter).collect(Collectors.toList())); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java index 37e70e3539a..efa1de6bb97 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDb.java @@ -5,6 +5,8 @@ import com.google.inject.Inject; import com.yahoo.collections.ListMap; import com.yahoo.collections.Pair; import com.yahoo.component.AbstractComponent; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; import com.yahoo.io.IOUtils; import com.yahoo.vespa.defaults.Defaults; import io.questdb.cairo.CairoConfiguration; @@ -30,6 +32,7 @@ import java.time.format.DateTimeFormatter; import java.util.ArrayList; import java.util.Collection; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; @@ -45,7 +48,8 @@ import java.util.stream.Collectors; public class QuestMetricsDb extends AbstractComponent implements MetricsDb { private static final Logger log = Logger.getLogger(QuestMetricsDb.class.getName()); - private static final String table = "metrics"; + private static final String nodeTable = "metrics"; + private static final String clusterTable = "clusterMetrics"; private final Clock clock; private final String dataDir; @@ -69,7 +73,8 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { } private void initializeDb() { - IOUtils.createDirectory(dataDir + "/" + table); + IOUtils.createDirectory(dataDir + "/" + nodeTable); + IOUtils.createDirectory(dataDir + "/" + clusterTable); // silence Questdb's custom logging system IOUtils.writeFile(new File(dataDir, "quest-log.conf"), new byte[0]); @@ -78,32 +83,36 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { CairoConfiguration configuration = new DefaultCairoConfiguration(dataDir); engine = new CairoEngine(configuration); - ensureExists(table); + ensureTablesExist(); } @Override - public void add(Collection<Pair<String, MetricSnapshot>> snapshots) { - try (TableWriter writer = engine.getWriter(newContext().getCairoSecurityContext(), table)) { - add(snapshots, writer); + public Clock clock() { return clock; } + + @Override + public void addNodeMetrics(Collection<Pair<String, NodeMetricSnapshot>> snapshots) { + try (TableWriter writer = engine.getWriter(newContext().getCairoSecurityContext(), nodeTable)) { + addNodeMetrics(snapshots, writer); } catch (CairoException e) { if (e.getMessage().contains("Cannot read offset")) { // This error seems non-recoverable repair(e); - try (TableWriter writer = engine.getWriter(newContext().getCairoSecurityContext(), table)) { - add(snapshots, writer); + try (TableWriter writer = engine.getWriter(newContext().getCairoSecurityContext(), nodeTable)) { + addNodeMetrics(snapshots, writer); } } } } - private void add(Collection<Pair<String, MetricSnapshot>> snapshots, TableWriter writer) { + private void addNodeMetrics(Collection<Pair<String, NodeMetricSnapshot>> snapshots, TableWriter writer) { for (var snapshot : snapshots) { long atMillis = adjustIfRecent(snapshot.getSecond().at().toEpochMilli(), highestTimestampAdded); if (atMillis < highestTimestampAdded) continue; // Ignore old data highestTimestampAdded = atMillis; TableWriter.Row row = writer.newRow(atMillis * 1000); // in microseconds row.putStr(0, snapshot.getFirst()); + // (1 is timestamp) row.putFloat(2, (float)snapshot.getSecond().cpu()); row.putFloat(3, (float)snapshot.getSecond().memory()); row.putFloat(4, (float)snapshot.getSecond().disk()); @@ -117,23 +126,70 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { } @Override + public void addClusterMetrics(ApplicationId application, Map<ClusterSpec.Id, ClusterMetricSnapshot> snapshots) { + try (TableWriter writer = engine.getWriter(newContext().getCairoSecurityContext(), clusterTable)) { + addClusterMetrics(application, snapshots, writer); + } + catch (CairoException e) { + if (e.getMessage().contains("Cannot read offset")) { + // This error seems non-recoverable + repair(e); + try (TableWriter writer = engine.getWriter(newContext().getCairoSecurityContext(), clusterTable)) { + addClusterMetrics(application, snapshots, writer); + } + } + } + } + + private void addClusterMetrics(ApplicationId applicationId, Map<ClusterSpec.Id, ClusterMetricSnapshot> snapshots, TableWriter writer) { + for (var snapshot : snapshots.entrySet()) { + long atMillis = adjustIfRecent(snapshot.getValue().at().toEpochMilli(), highestTimestampAdded); + if (atMillis < highestTimestampAdded) continue; // Ignore old data + highestTimestampAdded = atMillis; + TableWriter.Row row = writer.newRow(atMillis * 1000); // in microseconds + row.putStr(0, applicationId.serializedForm()); + row.putStr(1, snapshot.getKey().value()); + // (2 is timestamp) + row.putFloat(3, (float)snapshot.getValue().queryRate()); + row.append(); + } + writer.commit(); + } + + @Override public List<NodeTimeseries> getNodeTimeseries(Duration period, Set<String> hostnames) { try (SqlCompiler compiler = new SqlCompiler(engine)) { SqlExecutionContext context = newContext(); - var snapshots = getSnapshots(clock.instant().minus(period), hostnames, compiler, context); + var snapshots = getNodeSnapshots(clock.instant().minus(period), hostnames, compiler, context); return snapshots.entrySet().stream() .map(entry -> new NodeTimeseries(entry.getKey(), entry.getValue())) .collect(Collectors.toList()); } catch (SqlException e) { - throw new IllegalStateException("Could not read timeseries data in Quest stored in " + dataDir, e); + throw new IllegalStateException("Could not read node timeseries data in Quest stored in " + dataDir, e); + } + } + + @Override + public ClusterTimeseries getClusterTimeseries(ApplicationId applicationId, ClusterSpec.Id clusterId) { + try (SqlCompiler compiler = new SqlCompiler(engine)) { + SqlExecutionContext context = newContext(); + return getClusterSnapshots(applicationId, clusterId, compiler, context); + } + catch (SqlException e) { + throw new IllegalStateException("Could not read cluster timeseries data in Quest stored in " + dataDir, e); } } @Override public void gc() { - // Since we remove full days at once we need to keep at least the scaling window + 1 day - Instant oldestToKeep = clock.instant().minus(Autoscaler.maxScalingWindow().plus(Duration.ofDays(1))); + gc(nodeTable); + gc(clusterTable); + } + + private void gc(String table) { + // We remove full days at once and we want to see at least three days to not every only see weekend data + Instant oldestToKeep = clock.instant().minus(Duration.ofDays(4)); SqlExecutionContext context = newContext(); int partitions = 0; try (SqlCompiler compiler = new SqlCompiler(engine)) { @@ -157,7 +213,7 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { context); } catch (SqlException e) { - log.log(Level.WARNING, "Failed to gc old metrics data in " + dataDir, e); + log.log(Level.WARNING, "Failed to gc old metrics data in " + dataDir + " table " + table, e); } } @@ -181,18 +237,26 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { initializeDb(); } - private void ensureExists(String table) { + private boolean exists(String table, SqlExecutionContext context) { + return 0 == engine.getStatus(context.getCairoSecurityContext(), new Path(), table); + } + + private void ensureTablesExist() { SqlExecutionContext context = newContext(); - if (0 == engine.getStatus(context.getCairoSecurityContext(), new Path(), table)) { // table exists - ensureTableIsUpdated(table, context); - } else { - createTable(table, context); - } + if (exists(nodeTable, context)) + ensureNodeTableIsUpdated(context); + else + createNodeTable(context); + + if (exists(clusterTable, context)) + ensureClusterTableIsUpdated(context); + else + createClusterTable(context); } - private void createTable(String table, SqlExecutionContext context) { + private void createNodeTable(SqlExecutionContext context) { try (SqlCompiler compiler = new SqlCompiler(engine)) { - compiler.compile("create table " + table + + compiler.compile("create table " + nodeTable + " (hostname string, at timestamp, cpu_util float, mem_total_util float, disk_util float," + " application_generation long, inService boolean, stable boolean, queries_rate float)" + " timestamp(at)" + @@ -202,20 +266,39 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { // compiler.compile("alter table " + tableName + " alter column hostname add index", context); } catch (SqlException e) { - throw new IllegalStateException("Could not create Quest db table '" + table + "'", e); + throw new IllegalStateException("Could not create Quest db table '" + nodeTable + "'", e); + } + } + + private void createClusterTable(SqlExecutionContext context) { + try (SqlCompiler compiler = new SqlCompiler(engine)) { + compiler.compile("create table " + clusterTable + + " (application string, cluster string, at timestamp, queries_rate float)" + + " timestamp(at)" + + "PARTITION BY DAY;", + context); + // We should do this if we get a version where selecting on strings work embedded, see below + // compiler.compile("alter table " + tableName + " alter column cluster add index", context); + } + catch (SqlException e) { + throw new IllegalStateException("Could not create Quest db table '" + clusterTable + "'", e); } } - private void ensureTableIsUpdated(String table, SqlExecutionContext context) { + private void ensureNodeTableIsUpdated(SqlExecutionContext context) { try (SqlCompiler compiler = new SqlCompiler(engine)) { - if (0 == engine.getStatus(context.getCairoSecurityContext(), new Path(), table)) { - ensureColumnExists("queries_rate", "float", table, compiler, context); // TODO: Remove after March 2021 + if (0 == engine.getStatus(context.getCairoSecurityContext(), new Path(), nodeTable)) { + ensureColumnExists("queries_rate", "float", nodeTable, compiler, context); // TODO: Remove after March 2021 } } catch (SqlException e) { repair(e); } } + private void ensureClusterTableIsUpdated(SqlExecutionContext context) { + // Nothing to do for now + } + private void ensureColumnExists(String column, String columnType, String table, SqlCompiler compiler, SqlExecutionContext context) throws SqlException { if (columnNamesOf(table, compiler, context).contains(column)) return; @@ -246,34 +329,34 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { return timestamp; } - private ListMap<String, MetricSnapshot> getSnapshots(Instant startTime, - Set<String> hostnames, - SqlCompiler compiler, - SqlExecutionContext context) throws SqlException { + private ListMap<String, NodeMetricSnapshot> getNodeSnapshots(Instant startTime, + Set<String> hostnames, + SqlCompiler compiler, + SqlExecutionContext context) throws SqlException { DateTimeFormatter formatter = DateTimeFormatter.ISO_DATE_TIME.withZone(ZoneId.of("UTC")); String from = formatter.format(startTime).substring(0, 19) + ".000000Z"; String to = formatter.format(clock.instant()).substring(0, 19) + ".000000Z"; - String sql = "select * from " + table + " where at in('" + from + "', '" + to + "');"; + String sql = "select * from " + nodeTable + " where at in('" + from + "', '" + to + "');"; // WHERE clauses does not work: // String sql = "select * from " + tableName + " where hostname in('host1', 'host2', 'host3');"; try (RecordCursorFactory factory = compiler.compile(sql, context).getRecordCursorFactory()) { - ListMap<String, MetricSnapshot> snapshots = new ListMap<>(); + ListMap<String, NodeMetricSnapshot> snapshots = new ListMap<>(); try (RecordCursor cursor = factory.getCursor(context)) { Record record = cursor.getRecord(); while (cursor.hasNext()) { String hostname = record.getStr(0).toString(); if (hostnames.contains(hostname)) { snapshots.put(hostname, - new MetricSnapshot(Instant.ofEpochMilli(record.getTimestamp(1) / 1000), - record.getFloat(2), - record.getFloat(3), - record.getFloat(4), - record.getLong(5), - record.getBool(6), - record.getBool(7), - record.getFloat(8))); + new NodeMetricSnapshot(Instant.ofEpochMilli(record.getTimestamp(1) / 1000), + record.getFloat(2), + record.getFloat(3), + record.getFloat(4), + record.getLong(5), + record.getBool(6), + record.getBool(7), + record.getFloat(8))); } } } @@ -281,6 +364,29 @@ public class QuestMetricsDb extends AbstractComponent implements MetricsDb { } } + private ClusterTimeseries getClusterSnapshots(ApplicationId application, + ClusterSpec.Id cluster, + SqlCompiler compiler, + SqlExecutionContext context) throws SqlException { + String sql = "select * from " + clusterTable; + try (RecordCursorFactory factory = compiler.compile(sql, context).getRecordCursorFactory()) { + List<ClusterMetricSnapshot> snapshots = new ArrayList<>(); + try (RecordCursor cursor = factory.getCursor(context)) { + Record record = cursor.getRecord(); + while (cursor.hasNext()) { + String applicationIdString = record.getStr(0).toString(); + if ( ! application.serializedForm().equals(applicationIdString)) continue; + String clusterId = record.getStr(1).toString(); + if (cluster.value().equals(clusterId)) { + snapshots.add(new ClusterMetricSnapshot(Instant.ofEpochMilli(record.getTimestamp(2) / 1000), + record.getFloat(3))); + } + } + } + return new ClusterTimeseries(cluster, snapshots); + } + } + private SqlExecutionContext newContext() { return new SqlExecutionContextImpl(engine, 1); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Resource.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Resource.java index 8353f56df91..b841b31833f 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Resource.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/Resource.java @@ -12,7 +12,7 @@ public enum Resource { /** Cpu utilization ratio */ cpu { - public double idealAverageLoad() { return 0.4; } + public double idealAverageLoad() { return 0.8; } double valueFrom(NodeResources resources) { return resources.vcpu(); } }, diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ResourceTarget.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ResourceTarget.java index a2fbeb3b710..d2bfecfcdae 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ResourceTarget.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/autoscale/ResourceTarget.java @@ -3,6 +3,8 @@ package com.yahoo.vespa.hosted.provision.autoscale; import com.yahoo.vespa.hosted.provision.applications.Application; +import java.time.Duration; + /** * A resource target to hit for the allocation optimizer. * The target is measured in cpu, memory and disk per node in the allocation given by current. @@ -47,11 +49,16 @@ public class ResourceTarget { } /** Create a target of achieving ideal load given a current load */ - public static ResourceTarget idealLoad(double currentCpuLoad, double currentMemoryLoad, double currentDiskLoad, - AllocatableClusterResources current, Application application) { - return new ResourceTarget(nodeUsage(Resource.cpu, currentCpuLoad, current) / idealCpuLoad(application), - nodeUsage(Resource.memory, currentMemoryLoad, current) / Resource.memory.idealAverageLoad(), - nodeUsage(Resource.disk, currentDiskLoad, current) / Resource.disk.idealAverageLoad(), + public static ResourceTarget idealLoad(ClusterTimeseries clusterTimeseries, + ClusterNodesTimeseries clusterNodesTimeseries, + AllocatableClusterResources current, + Application application) { + return new ResourceTarget(nodeUsage(Resource.cpu, clusterNodesTimeseries.averageLoad(Resource.cpu), current) + / idealCpuLoad(clusterTimeseries, clusterNodesTimeseries, application), + nodeUsage(Resource.memory, clusterNodesTimeseries.averageLoad(Resource.memory), current) + / Resource.memory.idealAverageLoad(), + nodeUsage(Resource.disk, clusterNodesTimeseries.averageLoad(Resource.disk), current) + / Resource.disk.idealAverageLoad(), true); } @@ -64,16 +71,29 @@ public class ResourceTarget { } /** Ideal cpu load must take the application traffic fraction into account */ - private static double idealCpuLoad(Application application) { - double trafficFactor; + private static double idealCpuLoad(ClusterTimeseries clusterTimeseries, + ClusterNodesTimeseries clusterNodesTimeseries, + Application application) { + // What's needed to have headroom for growth during scale-up as a fraction of current resources? + double maxGrowthRate = clusterTimeseries.maxQueryGrowthRate(); // in fraction per minute of the current traffic + Duration scalingDuration = clusterNodesTimeseries.cluster().scalingDuration(clusterNodesTimeseries.clusterNodes().clusterSpec()); + double growthRateHeadroom = 1 + maxGrowthRate * scalingDuration.toMinutes(); + // Cap headroom at 10% above the historical observed peak + double fractionOfMax = clusterTimeseries.currentQueryFractionOfMax(); + if (fractionOfMax != 0) + growthRateHeadroom = Math.min(growthRateHeadroom, 1 / fractionOfMax + 0.1); + + // How much headroom is needed to handle sudden arrival of additional traffic due to another zone going down? + double trafficShiftHeadroom; if (application.status().maxReadShare() == 0) // No traffic fraction data - trafficFactor = 0.5; // assume we currently get half of the global share of traffic + trafficShiftHeadroom = 2.0; // assume we currently get half of the global share of traffic else - trafficFactor = application.status().currentReadShare() / application.status().maxReadShare(); + trafficShiftHeadroom = application.status().maxReadShare() / application.status().currentReadShare(); + + if (trafficShiftHeadroom > 2.0) // The expectation that we have almost no load with almost no queries is incorrect due + trafficShiftHeadroom = 2.0; // to write traffic; once that is separated we can increase this threshold - if (trafficFactor < 0.5) // The expectation that we have almost no load with almost no queries is incorrect due - trafficFactor = 0.5; // to write traffic; once that is separated we can lower this threshold (but not to 0) - return trafficFactor * Resource.cpu.idealAverageLoad(); + return 1 / growthRateHeadroom * 1 / trafficShiftHeadroom * Resource.cpu.idealAverageLoad(); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java index bcfdaefb305..9d910df01d9 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java @@ -14,7 +14,7 @@ import com.yahoo.vespa.hosted.provision.applications.Applications; import com.yahoo.vespa.hosted.provision.applications.Cluster; import com.yahoo.vespa.hosted.provision.autoscale.AllocatableClusterResources; import com.yahoo.vespa.hosted.provision.autoscale.Autoscaler; -import com.yahoo.vespa.hosted.provision.autoscale.MetricSnapshot; +import com.yahoo.vespa.hosted.provision.autoscale.NodeMetricSnapshot; import com.yahoo.vespa.hosted.provision.autoscale.MetricsDb; import com.yahoo.vespa.hosted.provision.autoscale.NodeTimeseries; import com.yahoo.vespa.hosted.provision.node.History; @@ -110,7 +110,7 @@ public class AutoscalingMaintainer extends NodeRepositoryMaintainer { // - 2. all nodes have switched to the right config generation for (NodeTimeseries nodeTimeseries : metricsDb.getNodeTimeseries(Duration.between(event.at(), clock().instant()), clusterNodes)) { - Optional<MetricSnapshot> firstOnNewGeneration = + Optional<NodeMetricSnapshot> firstOnNewGeneration = nodeTimeseries.asList().stream() .filter(snapshot -> snapshot.generation() >= event.generation()).findFirst(); if (firstOnNewGeneration.isEmpty()) return cluster; // Not completed diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java index 22c8e49825d..f072891f210 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirer.java @@ -23,14 +23,18 @@ import java.util.List; */ public class DirtyExpirer extends Expirer { + private final boolean keepAllocationOnExpiry; + DirtyExpirer(NodeRepository nodeRepository, Duration dirtyTimeout, Metric metric) { super(Node.State.dirty, History.Event.Type.deallocated, nodeRepository, dirtyTimeout, metric); + // Do not keep allocation in dynamically provisioned zones so that the hosts can be deprovisioned + this.keepAllocationOnExpiry = ! nodeRepository.zone().getCloud().dynamicProvisioning(); } @Override protected void expire(List<Node> expired) { for (Node expiredNode : expired) - nodeRepository().nodes().fail(expiredNode.hostname(), Agent.DirtyExpirer, "Node is stuck in dirty"); + nodeRepository().nodes().fail(expiredNode.hostname(), keepAllocationOnExpiry, Agent.DirtyExpirer, "Node is stuck in dirty"); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainer.java index b8548c4c3f4..f4509c0713e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainer.java @@ -70,11 +70,13 @@ public class NodeMetricsDbMaintainer extends NodeRepositoryMaintainer { ApplicationId application) { if (exception != null) { if (warnings.get() < maxWarningsPerInvocation) - log.log(Level.WARNING, "Could not update metrics for " + application, exception); + log.log(Level.WARNING, "Could not update metrics for " + application + ": " + + Exceptions.toMessageString(exception)); warnings.add(1); } else if (response != null) { - metricsDb.add(response.metrics()); + metricsDb.addNodeMetrics(response.nodeMetrics()); + metricsDb.addClusterMetrics(application, response.clusterMetrics()); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java index 33dc67801b9..7f41f89f664 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintenance.java @@ -137,7 +137,7 @@ public class NodeRepositoryMaintenance extends AbstractComponent { // Vespa upgrade frequency is higher in CD so (de)activate OS upgrades more frequently as well osUpgradeActivatorInterval = zone.system().isCd() ? Duration.ofSeconds(30) : Duration.ofMinutes(5); periodicRedeployInterval = Duration.ofMinutes(60); - provisionedExpiry = Duration.ofHours(4); + provisionedExpiry = zone.getCloud().dynamicProvisioning() ? Duration.ofMinutes(40) : Duration.ofHours(4); rebalancerInterval = Duration.ofMinutes(120); redeployMaintainerInterval = Duration.ofMinutes(1); // Need to be long enough for deployment to be finished for all config model versions diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/RetiredExpirer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/RetiredExpirer.java index e0a11aa5dac..10db9a08eeb 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/RetiredExpirer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/RetiredExpirer.java @@ -59,7 +59,7 @@ public class RetiredExpirer extends NodeRepositoryMaintainer { List<Node> retiredNodes = entry.getValue(); try (MaintenanceDeployment deployment = new MaintenanceDeployment(application, deployer, metric, nodeRepository())) { - if ( ! deployment.isValid()) continue; // this will be done at another config server + if ( ! deployment.isValid()) continue; List<Node> nodesToRemove = retiredNodes.stream().filter(this::canRemove).collect(Collectors.toList()); if (nodesToRemove.isEmpty()) continue; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java index 534115342f3..bb50d6fcc6f 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/node/Nodes.java @@ -287,7 +287,11 @@ public class Nodes { * @throws NoSuchNodeException if the node is not found */ public Node fail(String hostname, Agent agent, String reason) { - return move(hostname, true, Node.State.failed, agent, Optional.of(reason)); + return fail(hostname, true, agent, reason); + } + + public Node fail(String hostname, boolean keepAllocation, Agent agent, String reason) { + return move(hostname, keepAllocation, Node.State.failed, agent, Optional.of(reason)); } /** diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java index c3cb805499c..3ff4765dd00 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeSpec.java @@ -64,6 +64,9 @@ public interface NodeSpec { /** Returns true if there exist some circumstance where we may accept to have this node allocated */ boolean acceptable(NodeCandidate node); + /** Returns true if nodes with non-active parent hosts should be rejected */ + boolean rejectNonActiveParent(); + /** * Returns true if a node with given current resources and current spare host resources can be resized * in-place to resources in this spec. @@ -164,6 +167,11 @@ public interface NodeSpec { public boolean acceptable(NodeCandidate node) { return true; } @Override + public boolean rejectNonActiveParent() { + return false; + } + + @Override public String toString() { return "request for " + count + " nodes with " + requestedNodeResources; } } @@ -229,6 +237,11 @@ public interface NodeSpec { } @Override + public boolean rejectNonActiveParent() { + return true; + } + + @Override public String toString() { return "request for all nodes of type '" + type + "'"; } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Preparer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Preparer.java index 2eee3c3f01c..d2b701e5312 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Preparer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/Preparer.java @@ -5,17 +5,17 @@ import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.OutOfCapacityException; -import com.yahoo.lang.MutableInteger; import com.yahoo.vespa.flags.FlagSource; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; -import com.yahoo.vespa.hosted.provision.node.Agent; +import com.yahoo.vespa.hosted.provision.node.Nodes; import java.util.ArrayList; import java.util.List; import java.util.ListIterator; import java.util.Optional; +import java.util.stream.Collectors; /** * Performs preparation of node activation changes for an application. @@ -72,6 +72,15 @@ class Preparer { List<Node> accepted = groupPreparer.prepare(application, clusterGroup, requestedNodes.fraction(wantedGroups), surplusNodes, indices, wantedGroups); + + if (requestedNodes.rejectNonActiveParent()) { + Nodes nodes = nodeRepository.nodes(); + NodeList activeHosts = nodes.list(Node.State.active).parents().nodeType(requestedNodes.type().hostType()); + accepted = accepted.stream() + .filter(node -> node.parentHostname().isEmpty() || activeHosts.parentOf(node).isPresent()) + .collect(Collectors.toList()); + } + replace(acceptedNodes, accepted); } moveToActiveGroup(surplusNodes, wantedGroups, cluster.group()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java index ceaf88dd7d9..4235bae6850 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/ApplicationSerializer.java @@ -9,8 +9,7 @@ import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.applications.Application; import com.yahoo.vespa.hosted.provision.applications.Cluster; import com.yahoo.vespa.hosted.provision.applications.ScalingEvent; -import com.yahoo.vespa.hosted.provision.autoscale.AllocatableClusterResources; -import com.yahoo.vespa.hosted.provision.autoscale.ClusterTimeseries; +import com.yahoo.vespa.hosted.provision.autoscale.ClusterNodesTimeseries; import com.yahoo.vespa.hosted.provision.autoscale.MetricsDb; import com.yahoo.vespa.hosted.provision.autoscale.Resource; @@ -74,7 +73,7 @@ public class ApplicationSerializer { } private static void clusterUtilizationToSlime(Cluster cluster, NodeList nodes, MetricsDb metricsDb, Cursor utilizationObject) { - var timeseries = new ClusterTimeseries(Duration.ofHours(1), cluster, nodes, metricsDb); + var timeseries = new ClusterNodesTimeseries(Duration.ofHours(1), cluster, nodes, metricsDb); utilizationObject.setDouble("cpu", timeseries.averageLoad(Resource.cpu)); utilizationObject.setDouble("memory", timeseries.averageLoad(Resource.memory)); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingIntegrationTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingIntegrationTest.java index 87b8ccdc348..8c6c116a225 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingIntegrationTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingIntegrationTest.java @@ -46,7 +46,7 @@ public class AutoscalingIntegrationTest { for (int i = 0; i < 1000; i++) { tester.clock().advance(Duration.ofSeconds(10)); - fetcher.fetchMetrics(application1).whenComplete((r, e) -> tester.nodeMetricsDb().add(r.metrics())); + fetcher.fetchMetrics(application1).whenComplete((r, e) -> tester.nodeMetricsDb().addNodeMetrics(r.nodeMetrics())); tester.clock().advance(Duration.ofSeconds(10)); tester.nodeMetricsDb().gc(); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTest.java index 3fef1d9746b..baf7d2dbe15 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTest.java @@ -56,7 +56,7 @@ public class AutoscalingTest { tester.clock().advance(Duration.ofDays(1)); tester.addCpuMeasurements(0.25f, 1f, 120, application1); ClusterResources scaledResources = tester.assertResources("Scaling up since resource usage is too high", - 15, 1, 1.3, 28.6, 28.6, + 14, 1, 1.4, 30.8, 30.8, tester.autoscale(application1, cluster1.id(), min, max).target()); tester.deploy(application1, cluster1, scaledResources); @@ -74,7 +74,7 @@ public class AutoscalingTest { tester.addCpuMeasurements(0.1f, 1f, 120, application1); tester.assertResources("Scaling down to minimum since usage has gone down significantly", - 14, 1, 1.0, 30.8, 30.8, + 15, 1, 1.0, 28.6, 28.6, tester.autoscale(application1, cluster1.id(), min, max).target()); var events = tester.nodeRepository().applications().get(application1).get().cluster(cluster1.id()).get().scalingEvents(); @@ -129,7 +129,7 @@ public class AutoscalingTest { ClusterResources max = new ClusterResources(20, 1, new NodeResources(100, 1000, 1000, 1, NodeResources.DiskSpeed.any)); ClusterResources scaledResources = tester.assertResources("Scaling up since resource usage is too high", - 15, 1, 1.3, 28.6, 28.6, + 14, 1, 1.4, 30.8, 30.8, tester.autoscale(application1, cluster1.id(), min, max).target()); assertEquals("Disk speed from min/max is used", NodeResources.DiskSpeed.any, scaledResources.nodeResources().diskSpeed()); @@ -343,7 +343,7 @@ public class AutoscalingTest { tester.clock().advance(Duration.ofDays(1)); tester.addMemMeasurements(1.0f, 1f, 1000, application1); tester.assertResources("Increase group size to reduce memory load", - 8, 2, 12.9, 89.3, 62.5, + 8, 2, 13.6, 89.3, 62.5, tester.autoscale(application1, cluster1.id(), min, max).target()); } @@ -362,7 +362,7 @@ public class AutoscalingTest { tester.clock().advance(Duration.ofDays(2)); tester.addMemMeasurements(0.02f, 0.95f, 120, application1); tester.assertResources("Scaling down", - 6, 1, 2.8, 4.0, 95.0, + 6, 1, 2.9, 4.0, 95.0, tester.autoscale(application1, cluster1.id(), min, max).target()); } @@ -386,7 +386,7 @@ public class AutoscalingTest { tester.clock().advance(Duration.ofDays(2)); tester.addMemMeasurements(0.02f, 0.95f, 120, application1); tester.assertResources("Scaling down", - 6, 1, 2.8, 4.0, 95.0, + 6, 1, 2.9, 4.0, 95.0, tester.autoscale(application1, cluster1.id(), min, max).target()); } @@ -405,7 +405,7 @@ public class AutoscalingTest { tester.deploy(application1, cluster1, min); tester.addMeasurements(1.0f, 1.0f, 0.7f, 0, 1000, application1); tester.assertResources("Scaling up", - 4, 1, 7.0, 20, 200, + 4, 1, 7.4, 20, 200, tester.autoscale(application1, cluster1.id(), min, max).target()); } @@ -418,7 +418,7 @@ public class AutoscalingTest { tester.deploy(application1, cluster1, min); tester.addMeasurements(1.0f, 1.0f, 0.7f, 0, 1000, application1); tester.assertResources("Scaling up", - 4, 1, 7.0, 34, 200, + 4, 1, 7.4, 34, 200, tester.autoscale(application1, cluster1.id(), min, max).target()); } } @@ -457,7 +457,7 @@ public class AutoscalingTest { tester.clock().advance(Duration.ofDays(2)); tester.addMemMeasurements(0.3f, 0.6f, 1000, application1); tester.assertResources("Scaling down since resource usage has gone down", - 5, 1, 3, 83, 36, + 6, 1, 3, 83, 28.8, tester.autoscale(application1, cluster1.id(), min, max).target()); } @@ -491,6 +491,62 @@ public class AutoscalingTest { } + @Test + public void test_autoscaling_considers_growth_rate() { + NodeResources resources = new NodeResources(3, 100, 100, 1); + ClusterResources min = new ClusterResources( 1, 1, resources); + ClusterResources max = new ClusterResources(10, 1, resources); + AutoscalingTester tester = new AutoscalingTester(resources.withVcpu(resources.vcpu() * 2)); + + ApplicationId application1 = tester.applicationId("application1"); + ClusterSpec cluster1 = tester.clusterSpec(ClusterSpec.Type.container, "cluster1"); + + tester.deploy(application1, cluster1, 5, 1, resources); + tester.addCpuMeasurements(0.25f, 1f, 120, application1); + + // (no query rate data) + tester.assertResources("Advice to scale up since we assume we need 2x cpu for growth when no data", + 7, 1, 3, 100, 100, + tester.autoscale(application1, cluster1.id(), min, max).target()); + + tester.setScalingDuration(application1, cluster1.id(), Duration.ofMinutes(5)); + tester.addQueryRateMeasurements(application1, cluster1.id(), + 100, + t -> 10.0 + (t < 50 ? t : 100 - t)); + tester.assertResources("Advice to scale down since observed growth is much slower than scaling time", + 4, 1, 3, 100, 100, + tester.autoscale(application1, cluster1.id(), min, max).target()); + + tester.clearQueryRateMeasurements(application1, cluster1.id()); + + tester.setScalingDuration(application1, cluster1.id(), Duration.ofMinutes(60)); + tester.addQueryRateMeasurements(application1, cluster1.id(), + 100, + t -> 10.0 + (t < 50 ? t * t * t : 125000 - (t - 49) * (t - 49) * (t - 49))); + tester.assertResources("Advice to scale up since observed growth is much faster than scaling time", + 10, 1, 3, 100, 100, + tester.autoscale(application1, cluster1.id(), min, max).target()); + } + + @Test + public void test_cd_autoscaling_test() { + NodeResources resources = new NodeResources(1, 4, 50, 1); + ClusterResources min = new ClusterResources( 2, 1, resources); + ClusterResources max = new ClusterResources(3, 1, resources); + AutoscalingTester tester = new AutoscalingTester(resources.withVcpu(resources.vcpu() * 2)); + ApplicationId application1 = tester.applicationId("application1"); + ClusterSpec cluster1 = tester.clusterSpec(ClusterSpec.Type.container, "cluster1"); + tester.deploy(application1, cluster1, 2, 1, resources); + + tester.addCpuMeasurements(0.5f, 1f, 10, application1); + tester.addQueryRateMeasurements(application1, cluster1.id(), + 500, t -> 0.0); + + tester.assertResources("Advice to scale up since observed growth is much faster than scaling time", + 3, 1, 1, 4, 50, + tester.autoscale(application1, cluster1.id(), min, max).target()); + } + /** * This calculator subtracts the memory tax when forecasting overhead, but not when actually * returning information about nodes. This is allowed because the forecast is a *worst case*. diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTester.java index 156542ef1d4..ce3293aa518 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/AutoscalingTester.java @@ -20,15 +20,21 @@ import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.Nodelike; import com.yahoo.vespa.hosted.provision.applications.Application; +import com.yahoo.vespa.hosted.provision.applications.Cluster; +import com.yahoo.vespa.hosted.provision.applications.ScalingEvent; import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.IP; import com.yahoo.vespa.hosted.provision.provisioning.HostResourcesCalculator; import com.yahoo.vespa.hosted.provision.provisioning.ProvisioningTester; import java.time.Duration; +import java.time.Instant; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.IntFunction; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -40,7 +46,7 @@ class AutoscalingTester { private final ProvisioningTester provisioningTester; private final Autoscaler autoscaler; - private final MetricsDb db; + private final MemoryMetricsDb db; private final MockHostResourcesCalculator hostResourcesCalculator; /** Creates an autoscaling tester with a single host type ready */ @@ -137,14 +143,14 @@ class AutoscalingTester { float cpu = value * oneExtraNodeFactor; float memory = (float) Resource.memory.idealAverageLoad() * otherResourcesLoad * oneExtraNodeFactor; float disk = (float) Resource.disk.idealAverageLoad() * otherResourcesLoad * oneExtraNodeFactor; - db.add(List.of(new Pair<>(node.hostname(), new MetricSnapshot(clock().instant(), - cpu, - memory, - disk, - 0, - true, - true, - 0.0)))); + db.addNodeMetrics(List.of(new Pair<>(node.hostname(), new NodeMetricSnapshot(clock().instant(), + cpu, + memory, + disk, + 0, + true, + true, + 0.0)))); } } } @@ -169,14 +175,14 @@ class AutoscalingTester { float cpu = (float) 0.2 * otherResourcesLoad * oneExtraNodeFactor; float memory = value * oneExtraNodeFactor; float disk = (float) Resource.disk.idealAverageLoad() * otherResourcesLoad * oneExtraNodeFactor; - db.add(List.of(new Pair<>(node.hostname(), new MetricSnapshot(clock().instant(), - cpu, - memory, - disk, - 0, - true, - true, - 0.0)))); + db.addNodeMetrics(List.of(new Pair<>(node.hostname(), new NodeMetricSnapshot(clock().instant(), + cpu, + memory, + disk, + 0, + true, + true, + 0.0)))); } } } @@ -191,14 +197,14 @@ class AutoscalingTester { for (int i = 0; i < count; i++) { clock().advance(Duration.ofMinutes(1)); for (Node node : nodes) { - db.add(List.of(new Pair<>(node.hostname(), new MetricSnapshot(clock().instant(), - cpu, - memory, - disk, - generation, - inService, - stable, - 0.0)))); + db.addNodeMetrics(List.of(new Pair<>(node.hostname(), new NodeMetricSnapshot(clock().instant(), + cpu, + memory, + disk, + generation, + inService, + stable, + 0.0)))); } } } @@ -210,6 +216,41 @@ class AutoscalingTester { nodeRepository().applications().put(application, nodeRepository().nodes().lock(applicationId)); } + /** Creates a single redeployment event with bogus data except for the given duration */ + public void setScalingDuration(ApplicationId applicationId, ClusterSpec.Id clusterId, Duration duration) { + Application application = nodeRepository().applications().require(applicationId); + Cluster cluster = application.cluster(clusterId).get(); + cluster = new Cluster(clusterId, + cluster.exclusive(), + cluster.minResources(), + cluster.maxResources(), + cluster.suggestedResources(), + cluster.targetResources(), + List.of(), // Remove scaling events + cluster.autoscalingStatus()); + cluster = cluster.with(ScalingEvent.create(cluster.minResources(), cluster.minResources(), + 0, + clock().instant().minus(Duration.ofDays(1).minus(duration))).withCompletion(clock().instant().minus(Duration.ofDays(1)))); + application = application.with(cluster); + nodeRepository().applications().put(application, nodeRepository().nodes().lock(applicationId)); + } + + /** Creates the given number of measurements, spaced 5 minutes between, using the given function */ + public void addQueryRateMeasurements(ApplicationId application, + ClusterSpec.Id cluster, + int measurements, + IntFunction<Double> queryRate) { + Instant time = clock().instant(); + for (int i = 0; i < measurements; i++) { + db.addClusterMetrics(application, Map.of(cluster, new ClusterMetricSnapshot(time, queryRate.apply(i)))); + time = time.plus(Duration.ofMinutes(5)); + } + } + + public void clearQueryRateMeasurements(ApplicationId application, ClusterSpec.Id cluster) { + db.clearClusterMetrics(application, cluster); + } + public Autoscaler.Advice autoscale(ApplicationId applicationId, ClusterSpec.Id clusterId, ClusterResources min, ClusterResources max) { Application application = nodeRepository().applications().get(applicationId).orElse(Application.empty(applicationId)) diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterTimeseriesTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterTimeseriesTest.java new file mode 100644 index 00000000000..89fe2d76159 --- /dev/null +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/ClusterTimeseriesTest.java @@ -0,0 +1,109 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.autoscale; + +import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.test.ManualClock; +import org.junit.Test; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; +import java.util.function.IntFunction; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class ClusterTimeseriesTest { + + private static final double delta = 0.001; + private static final ClusterSpec.Id cluster = new ClusterSpec.Id("test"); + + @Test + public void test_empty() { + var timeseries = new ClusterTimeseries(cluster, List.of()); + assertEquals(0.1, timeseries.maxQueryGrowthRate(), delta); + } + + @Test + public void test_constant_rate_short() { + var clock = new ManualClock(); + var timeseries = new ClusterTimeseries(cluster, rate(10, clock, t -> 50.0)); + assertEquals(0.1, timeseries.maxQueryGrowthRate(), delta); + } + + @Test + public void test_constant_rate_long() { + var clock = new ManualClock(); + var timeseries = new ClusterTimeseries(cluster, rate(10000, clock, t -> 50.0)); + assertEquals(0.0, timeseries.maxQueryGrowthRate(), delta); + } + + @Test + public void test_single_spike() { + var clock = new ManualClock(); + var snapshots = new ArrayList<ClusterMetricSnapshot>(); + snapshots.addAll(rate(1000, clock, t -> 50.0)); + snapshots.addAll(rate( 10, clock, t -> 400.0)); + snapshots.addAll(rate(1000, clock, t -> 50.0)); + assertEquals((400-50)/5.0/50.0, new ClusterTimeseries(cluster, snapshots).maxQueryGrowthRate(), delta); + } + + @Test + public void test_three_spikes() { + var clock = new ManualClock(); + var snapshots = new ArrayList<ClusterMetricSnapshot>(); + snapshots.addAll(rate(1000, clock, t -> 50.0)); + snapshots.addAll(rate( 10, clock, t -> 400.0)); + snapshots.addAll(rate(1000, clock, t -> 50.0)); + snapshots.addAll(rate( 10, clock, t -> 600.0)); + snapshots.addAll(rate(1000, clock, t -> 50.0)); + snapshots.addAll(rate( 10, clock, t -> 800.0)); + snapshots.addAll(rate(1000, clock, t -> 50.0)); + assertEquals((800-50)/5.0/50.0, new ClusterTimeseries(cluster, snapshots).maxQueryGrowthRate(), delta); + } + + @Test + public void test_single_hill() { + var clock = new ManualClock(); + var snapshots = new ArrayList<ClusterMetricSnapshot>(); + snapshots.addAll(rate(100, clock, t -> (double)t)); + snapshots.addAll(rate(100, clock, t -> 100.0 - t)); + assertEquals(1/5.0, new ClusterTimeseries(cluster, snapshots).maxQueryGrowthRate(), delta); + } + + @Test + public void test_smooth_curve() { + var clock = new ManualClock(); + var timeseries = new ClusterTimeseries(cluster, rate(10000, clock, + t -> 10.0 + 100.0 * Math.sin(t))); + assertEquals(0.26, timeseries.maxQueryGrowthRate(), delta); + } + + @Test + public void test_smooth_curve_small_variation() { + var clock = new ManualClock(); + var timeseries = new ClusterTimeseries(cluster, rate(10000, clock, + t -> 1000.0 + 10.0 * Math.sin(t))); + assertEquals(0.0, timeseries.maxQueryGrowthRate(), delta); + } + + @Test + public void test_two_periods() { + var clock = new ManualClock(); + var timeseries = new ClusterTimeseries(cluster, rate(10000, clock, + t -> 10.0 + 100.0 * Math.sin(t) + 80.0 * Math.sin(10 * t)) ); + assertEquals(1.765, timeseries.maxQueryGrowthRate(), delta); + } + + private List<ClusterMetricSnapshot> rate(int count, ManualClock clock, IntFunction<Double> rate) { + List<ClusterMetricSnapshot> snapshots = new ArrayList<>(); + for (int i = 0; i < count; i++) { + snapshots.add(new ClusterMetricSnapshot(clock.instant(), rate.apply(i))); + clock.advance(Duration.ofMinutes(5)); + } + return snapshots; + } + +} diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcherTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcherTest.java index 384e8dd8439..14a9a596e78 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcherTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/MetricsV2MetricsFetcherTest.java @@ -45,7 +45,7 @@ public class MetricsV2MetricsFetcherTest { { httpClient.cannedResponse = cannedResponseForApplication1; - List<Pair<String, MetricSnapshot>> values = new ArrayList<>(fetcher.fetchMetrics(application1).get().metrics()); + List<Pair<String, NodeMetricSnapshot>> values = new ArrayList<>(fetcher.fetchMetrics(application1).get().nodeMetrics()); assertEquals("http://host-1.yahoo.com:4080/metrics/v2/values?consumer=autoscaling", httpClient.requestsReceived.get(0)); assertEquals(2, values.size()); @@ -63,7 +63,7 @@ public class MetricsV2MetricsFetcherTest { { httpClient.cannedResponse = cannedResponseForApplication2; - List<Pair<String, MetricSnapshot>> values = new ArrayList<>(fetcher.fetchMetrics(application2).get().metrics()); + List<Pair<String, NodeMetricSnapshot>> values = new ArrayList<>(fetcher.fetchMetrics(application2).get().nodeMetrics()); assertEquals("http://host-3.yahoo.com:4080/metrics/v2/values?consumer=autoscaling", httpClient.requestsReceived.get(1)); assertEquals(1, values.size()); @@ -81,7 +81,7 @@ public class MetricsV2MetricsFetcherTest { tester.nodeRepository().nodes().write(tester.nodeRepository().nodes().list(Node.State.active).owner(application2) .first().get().retire(tester.clock().instant()), lock); } - List<Pair<String, MetricSnapshot>> values = new ArrayList<>(fetcher.fetchMetrics(application2).get().metrics()); + List<Pair<String, NodeMetricSnapshot>> values = new ArrayList<>(fetcher.fetchMetrics(application2).get().nodeMetrics()); assertFalse(values.get(0).getSecond().stable()); } } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/NodeMetricsDbTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/NodeMetricsDbTest.java index c1c94c7dd24..76e56004871 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/NodeMetricsDbTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/NodeMetricsDbTest.java @@ -38,19 +38,19 @@ public class NodeMetricsDbTest { ManualClock clock = tester.clock(); MetricsDb db = MetricsDb.createTestInstance(tester.nodeRepository()); - Collection<Pair<String, MetricSnapshot>> values = new ArrayList<>(); + Collection<Pair<String, NodeMetricSnapshot>> values = new ArrayList<>(); for (int i = 0; i < 40; i++) { - values.add(new Pair<>(node0, new MetricSnapshot(clock.instant(), - 0.9f, - 0.6f, - 0.6f, - 0, - true, - false, - 0.0))); + values.add(new Pair<>(node0, new NodeMetricSnapshot(clock.instant(), + 0.9f, + 0.6f, + 0.6f, + 0, + true, + false, + 0.0))); clock.advance(Duration.ofMinutes(120)); } - db.add(values); + db.addNodeMetrics(values); // Avoid off-by-one bug when the below windows starts exactly on one of the above getEpochSecond() timestamps. clock.advance(Duration.ofMinutes(1)); diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDbTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDbTest.java index 70f9d581816..18b92fa6b0f 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDbTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/autoscale/QuestMetricsDbTest.java @@ -2,6 +2,8 @@ package com.yahoo.vespa.hosted.provision.autoscale; import com.yahoo.collections.Pair; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterSpec; import com.yahoo.io.IOUtils; import com.yahoo.test.ManualClock; import org.junit.Ignore; @@ -12,7 +14,9 @@ import java.time.Duration; import java.time.Instant; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -29,7 +33,7 @@ public class QuestMetricsDbTest { private static final double delta = 0.0000001; @Test - public void testReadWrite() { + public void testNodeMetricsReadWrite() { String dataDir = "data/QuestMetricsDbReadWrite"; IOUtils.recursiveDeleteDir(new File(dataDir)); IOUtils.createDirectory(dataDir + "/metrics"); @@ -38,7 +42,7 @@ public class QuestMetricsDbTest { Instant startTime = clock.instant(); clock.advance(Duration.ofSeconds(1)); - db.add(timeseries(1000, Duration.ofSeconds(1), clock, "host1", "host2", "host3")); + db.addNodeMetrics(nodeTimeseries(1000, Duration.ofSeconds(1), clock, "host1", "host2", "host3")); clock.advance(Duration.ofSeconds(1)); @@ -48,7 +52,7 @@ public class QuestMetricsDbTest { assertEquals(1, nodeTimeSeries1.size()); assertEquals("host1", nodeTimeSeries1.get(0).hostname()); assertEquals(1000, nodeTimeSeries1.get(0).size()); - MetricSnapshot snapshot = nodeTimeSeries1.get(0).asList().get(0); + NodeMetricSnapshot snapshot = nodeTimeSeries1.get(0).asList().get(0); assertEquals(startTime.plus(Duration.ofSeconds(1)), snapshot.at()); assertEquals(0.1, snapshot.cpu(), delta); assertEquals(0.2, snapshot.memory(), delta); @@ -75,6 +79,56 @@ public class QuestMetricsDbTest { } @Test + public void testClusterMetricsReadWrite() { + String dataDir = "data/QuestMetricsDbReadWrite"; + IOUtils.recursiveDeleteDir(new File(dataDir)); + IOUtils.createDirectory(dataDir + "/clusterMetrics"); + ManualClock clock = new ManualClock("2020-10-01T00:00:00"); + QuestMetricsDb db = new QuestMetricsDb(dataDir, clock); + Instant startTime = clock.instant(); + + var application1 = ApplicationId.from("t1", "a1", "i1"); + var application2 = ApplicationId.from("t1", "a2", "i1"); + var cluster1 = new ClusterSpec.Id("cluster1"); + var cluster2 = new ClusterSpec.Id("cluster2"); + db.addClusterMetrics(application1, Map.of(cluster1, new ClusterMetricSnapshot(clock.instant(), 30.0))); + db.addClusterMetrics(application1, Map.of(cluster2, new ClusterMetricSnapshot(clock.instant(), 60.0))); + clock.advance(Duration.ofMinutes(1)); + db.addClusterMetrics(application1, Map.of(cluster1, new ClusterMetricSnapshot(clock.instant(), 45.0))); + clock.advance(Duration.ofMinutes(1)); + db.addClusterMetrics(application2, Map.of(cluster1, new ClusterMetricSnapshot(clock.instant(), 90.0))); + + ClusterTimeseries clusterTimeseries11 = db.getClusterTimeseries(application1, cluster1); + assertEquals(cluster1, clusterTimeseries11.cluster()); + assertEquals(2, clusterTimeseries11.asList().size()); + + ClusterMetricSnapshot snapshot111 = clusterTimeseries11.get(0); + assertEquals(startTime, snapshot111.at()); + assertEquals(30, snapshot111.queryRate(), delta); + ClusterMetricSnapshot snapshot112 = clusterTimeseries11.get(1); + assertEquals(startTime.plus(Duration.ofMinutes(1)), snapshot112.at()); + assertEquals(45, snapshot112.queryRate(), delta); + + + ClusterTimeseries clusterTimeseries12 = db.getClusterTimeseries(application1, cluster2); + assertEquals(cluster2, clusterTimeseries12.cluster()); + assertEquals(1, clusterTimeseries12.asList().size()); + + ClusterMetricSnapshot snapshot121 = clusterTimeseries12.get(0); + assertEquals(startTime, snapshot121.at()); + assertEquals(60, snapshot121.queryRate(), delta); + + + ClusterTimeseries clusterTimeseries21 = db.getClusterTimeseries(application2, cluster1); + assertEquals(cluster1, clusterTimeseries21.cluster()); + assertEquals(1, clusterTimeseries21.asList().size()); + + ClusterMetricSnapshot snapshot211 = clusterTimeseries21.get(0); + assertEquals(startTime.plus(Duration.ofMinutes(2)), snapshot211.at()); + assertEquals(90, snapshot211.queryRate(), delta); + } + + @Test public void testWriteOldData() { String dataDir = "data/QuestMetricsDbWriteOldData"; IOUtils.recursiveDeleteDir(new File(dataDir)); @@ -83,19 +137,19 @@ public class QuestMetricsDbTest { QuestMetricsDb db = new QuestMetricsDb(dataDir, clock); Instant startTime = clock.instant(); clock.advance(Duration.ofSeconds(300)); - db.add(timeseriesAt(10, clock.instant(), "host1", "host2", "host3")); + db.addNodeMetrics(timeseriesAt(10, clock.instant(), "host1", "host2", "host3")); clock.advance(Duration.ofSeconds(1)); List<NodeTimeseries> nodeTimeSeries1 = db.getNodeTimeseries(Duration.between(startTime, clock.instant()), Set.of("host1")); assertEquals(10, nodeTimeSeries1.get(0).size()); - db.add(timeseriesAt(10, clock.instant().minus(Duration.ofSeconds(20)), "host1", "host2", "host3")); + db.addNodeMetrics(timeseriesAt(10, clock.instant().minus(Duration.ofSeconds(20)), "host1", "host2", "host3")); List<NodeTimeseries> nodeTimeSeries2 = db.getNodeTimeseries(Duration.between(startTime, clock.instant()), Set.of("host1")); assertEquals("Recent data is accepted", 20, nodeTimeSeries2.get(0).size()); - db.add(timeseriesAt(10, clock.instant().minus(Duration.ofSeconds(200)), "host1", "host2", "host3")); + db.addNodeMetrics(timeseriesAt(10, clock.instant().minus(Duration.ofSeconds(200)), "host1", "host2", "host3")); List<NodeTimeseries> nodeTimeSeries3 = db.getNodeTimeseries(Duration.between(startTime, clock.instant()), Set.of("host1")); assertEquals("Too old data is rejected", 20, nodeTimeSeries3.get(0).size()); @@ -111,15 +165,15 @@ public class QuestMetricsDbTest { Instant startTime = clock.instant(); int dayOffset = 3; clock.advance(Duration.ofHours(dayOffset)); - db.add(timeseries(24 * 10, Duration.ofHours(1), clock, "host1", "host2", "host3")); + db.addNodeMetrics(nodeTimeseries(24 * 10, Duration.ofHours(1), clock, "host1", "host2", "host3")); assertEquals(24 * 10, db.getNodeTimeseries(Duration.between(startTime, clock.instant()), Set.of("host1")).get(0).size()); db.gc(); - assertEquals(48 * 1 + dayOffset, db.getNodeTimeseries(Duration.between(startTime, clock.instant()), + assertEquals(75, db.getNodeTimeseries(Duration.between(startTime, clock.instant()), Set.of("host1")).get(0).size()); db.gc(); // no-op - assertEquals(48 * 1 + dayOffset, db.getNodeTimeseries(Duration.between(startTime, clock.instant()), + assertEquals(75, db.getNodeTimeseries(Duration.between(startTime, clock.instant()), Set.of("host1")).get(0).size()); } @@ -146,7 +200,7 @@ public class QuestMetricsDbTest { System.out.println(" " + snapshot); clock.advance(Duration.ofSeconds(1)); - db.add(timeseries(2, Duration.ofSeconds(1), clock, "host1")); + db.addNodeMetrics(nodeTimeseries(2, Duration.ofSeconds(1), clock, "host1")); System.out.println("New data written and read:"); timeseries = db.getNodeTimeseries(Duration.ofSeconds(2), Set.of("host1")); for (var snapshot : timeseries.get(0).asList()) @@ -163,7 +217,7 @@ public class QuestMetricsDbTest { ManualClock clock = new ManualClock("2020-10-01T00:00:00"); QuestMetricsDb db = new QuestMetricsDb(dataDir, clock); Instant startTime = clock.instant(); - db.add(timeseries(10, Duration.ofSeconds(1), clock, "host1")); + db.addNodeMetrics(nodeTimeseries(10, Duration.ofSeconds(1), clock, "host1")); int added = db.getNodeTimeseries(Duration.between(startTime, clock.instant()), Set.of("host1")).get(0).asList().size(); @@ -171,36 +225,46 @@ public class QuestMetricsDbTest { db.close(); } - private Collection<Pair<String, MetricSnapshot>> timeseries(int countPerHost, Duration sampleRate, ManualClock clock, - String ... hosts) { - Collection<Pair<String, MetricSnapshot>> timeseries = new ArrayList<>(); + private Collection<Pair<String, NodeMetricSnapshot>> nodeTimeseries(int countPerHost, Duration sampleRate, ManualClock clock, + String ... hosts) { + Collection<Pair<String, NodeMetricSnapshot>> timeseries = new ArrayList<>(); for (int i = 1; i <= countPerHost; i++) { for (String host : hosts) - timeseries.add(new Pair<>(host, new MetricSnapshot(clock.instant(), + timeseries.add(new Pair<>(host, new NodeMetricSnapshot(clock.instant(), i * 0.1, i * 0.2, i * 0.4, i % 100, - true, - true, - 30.0))); + true, + true, + 30.0))); + clock.advance(sampleRate); + } + return timeseries; + } + + private List<ClusterMetricSnapshot> clusterTimeseries(int count, Duration sampleRate, ManualClock clock, + ClusterSpec.Id cluster) { + List<ClusterMetricSnapshot> timeseries = new ArrayList<>(); + for (int i = 1; i <= count; i++) { + timeseries.add(new ClusterMetricSnapshot(clock.instant(), 30.0)); clock.advance(sampleRate); } return timeseries; } - private Collection<Pair<String, MetricSnapshot>> timeseriesAt(int countPerHost, Instant at, String ... hosts) { - Collection<Pair<String, MetricSnapshot>> timeseries = new ArrayList<>(); + private Collection<Pair<String, NodeMetricSnapshot>> timeseriesAt(int countPerHost, Instant at, String ... hosts) { + Collection<Pair<String, NodeMetricSnapshot>> timeseries = new ArrayList<>(); for (int i = 1; i <= countPerHost; i++) { for (String host : hosts) - timeseries.add(new Pair<>(host, new MetricSnapshot(at, + timeseries.add(new Pair<>(host, new NodeMetricSnapshot(at, i * 0.1, i * 0.2, i * 0.4, i % 100, - true, - false, - 0.0))); + true, + false, + 0.0))); } return timeseries; } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainerTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainerTester.java index 1b531fd3237..e8cfe6a2310 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainerTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainerTester.java @@ -16,8 +16,7 @@ import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.applications.Cluster; -import com.yahoo.vespa.hosted.provision.applications.ScalingEvent; -import com.yahoo.vespa.hosted.provision.autoscale.MetricSnapshot; +import com.yahoo.vespa.hosted.provision.autoscale.NodeMetricSnapshot; import com.yahoo.vespa.hosted.provision.autoscale.MetricsDb; import com.yahoo.vespa.hosted.provision.provisioning.FlavorConfigBuilder; import com.yahoo.vespa.hosted.provision.provisioning.ProvisioningTester; @@ -75,14 +74,14 @@ public class AutoscalingMaintainerTester { NodeList nodes = nodeRepository().nodes().list(Node.State.active).owner(applicationId); for (int i = 0; i < count; i++) { for (Node node : nodes) - metricsDb.add(List.of(new Pair<>(node.hostname(), new MetricSnapshot(clock().instant(), - cpu, - mem, - disk, - generation, - true, - true, - 0.0)))); + metricsDb.addNodeMetrics(List.of(new Pair<>(node.hostname(), new NodeMetricSnapshot(clock().instant(), + cpu, + mem, + disk, + generation, + true, + true, + 0.0)))); } } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java new file mode 100644 index 00000000000..cf5cfb93782 --- /dev/null +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DirtyExpirerTest.java @@ -0,0 +1,68 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.provision.maintenance; + +import com.yahoo.component.Version; +import com.yahoo.config.provision.Cloud; +import com.yahoo.config.provision.ClusterMembership; +import com.yahoo.config.provision.Environment; +import com.yahoo.config.provision.Flavor; +import com.yahoo.config.provision.NodeResources; +import com.yahoo.config.provision.NodeType; +import com.yahoo.config.provision.RegionName; +import com.yahoo.config.provision.SystemName; +import com.yahoo.config.provision.Zone; +import com.yahoo.vespa.hosted.provision.Node; +import com.yahoo.vespa.hosted.provision.node.Agent; +import com.yahoo.vespa.hosted.provision.node.Allocation; +import com.yahoo.vespa.hosted.provision.node.Generation; +import com.yahoo.vespa.hosted.provision.provisioning.ProvisioningTester; +import com.yahoo.vespa.hosted.provision.testutils.MockHostProvisioner; +import org.junit.Test; + +import java.time.Duration; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; + +/** + * @author freva + */ +public class DirtyExpirerTest { + + @Test + public void assert_allocation_after_expiry() { + assertAllocationAfterExpiry(true); + assertAllocationAfterExpiry(false); + } + + private void assertAllocationAfterExpiry(boolean dynamicProvisioning) { + Zone zone = new Zone(Cloud.builder().dynamicProvisioning(dynamicProvisioning).build(), SystemName.main, Environment.prod, RegionName.from("us-east")); + ProvisioningTester tester = new ProvisioningTester.Builder().zone(zone) + .hostProvisioner(dynamicProvisioning ? new MockHostProvisioner(List.of()) : null) + .build(); + + Node node = Node.create("id", "node1.domain.tld", new Flavor(NodeResources.unspecified()), Node.State.dirty, NodeType.tenant) + .allocation(new Allocation(ProvisioningTester.applicationId(), + ClusterMembership.from("container/default/0/0", Version.fromString("1.2.3"), Optional.empty()), + NodeResources.unspecified(), + Generation.initial(), + false)) + .build(); + + tester.nodeRepository().database().addNodesInState(List.of(node), node.state(), Agent.system); + + Duration expiryTimeout = Duration.ofMinutes(30); + DirtyExpirer expirer = new DirtyExpirer(tester.nodeRepository(), expiryTimeout, new TestMetric()); + + assertEquals(Node.State.dirty, tester.nodeRepository().nodes().list().first().get().state()); + expirer.run(); + assertEquals(Node.State.dirty, tester.nodeRepository().nodes().list().first().get().state()); + + tester.clock().advance(expiryTimeout.plusSeconds(1)); + expirer.run(); + assertEquals(Node.State.failed, tester.nodeRepository().nodes().list().first().get().state()); + assertEquals(dynamicProvisioning, tester.nodeRepository().nodes().list().first().get().allocation().isEmpty()); + } + +}
\ No newline at end of file diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DynamicProvisioningMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DynamicProvisioningMaintainerTest.java index 337bbb0cbb4..076a0e24620 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DynamicProvisioningMaintainerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/DynamicProvisioningMaintainerTest.java @@ -12,7 +12,6 @@ import com.yahoo.config.provision.Flavor; import com.yahoo.config.provision.NodeFlavors; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; -import com.yahoo.config.provision.ParentHostUnavailableException; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.Zone; @@ -53,7 +52,6 @@ import static com.yahoo.vespa.hosted.provision.testutils.MockHostProvisioner.Beh import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; -import static org.junit.Assert.fail; /** * @author freva @@ -454,12 +452,8 @@ public class DynamicProvisioningMaintainerTest { Supplier<Node> hostToRemove = () -> tester.nodeRepository().nodes().node(hostnameToRemove).get(); Supplier<Node> nodeToRemove = () -> tester.nodeRepository().nodes().node(configNodes.childrenOf(hostnameToRemove).first().get().hostname()).get(); - // Retire and deprovision host + // Set want to retire and deprovision on host and children tester.nodeRepository().nodes().deprovision(hostToRemove.get(), Agent.system, tester.clock().instant()); - tester.nodeRepository().nodes().deallocate(hostToRemove.get(), Agent.system, getClass().getSimpleName()); - assertSame("Host moves to parked", Node.State.parked, hostToRemove.get().state()); - assertSame("Node remains active", Node.State.active, nodeToRemove.get().state()); - assertTrue("Node wants to retire", nodeToRemove.get().status().wantToRetire()); // Redeployment of config server application retires node tester.prepareAndActivateInfraApplication(configSrvApp, NodeType.config); @@ -477,6 +471,10 @@ public class DynamicProvisioningMaintainerTest { tester.nodeRepository().nodes().removeRecursively(inactiveConfigServer, true); assertEquals(2, tester.nodeRepository().nodes().list().nodeType(NodeType.config).size()); + // ExpiredRetirer moves host to inactive after child has moved to parked + tester.nodeRepository().nodes().deallocate(hostToRemove.get(), Agent.system, getClass().getSimpleName()); + assertSame("Host moves to parked", Node.State.parked, hostToRemove.get().state()); + // Host is removed dynamicProvisioningTester.maintainer.maintain(); assertEquals(2, tester.nodeRepository().nodes().list().nodeType(NodeType.confighost).size()); @@ -488,10 +486,9 @@ public class DynamicProvisioningMaintainerTest { // Deployment on another config server starts provisioning a new host and child HostName.setHostNameForTestingOnly("cfg3.example.com"); - try { - tester.prepareAndActivateInfraApplication(configSrvApp, NodeType.config); - fail("Expected provisioning to fail"); - } catch (ParentHostUnavailableException ignored) {} + assertEquals(0, tester.nodeRepository().nodes().list(Node.State.reserved).nodeType(NodeType.config).size()); + assertEquals(2, tester.prepareAndActivateInfraApplication(configSrvApp, NodeType.config).size()); + assertEquals(1, tester.nodeRepository().nodes().list(Node.State.reserved).nodeType(NodeType.config).size()); Node newNode = tester.nodeRepository().nodes().list(Node.State.reserved).nodeType(NodeType.config).first().get(); // Resume provisioning and activate host diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainerTest.java index e99f7740c29..5af787092d5 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/NodeMetricsDbMaintainerTest.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.hosted.provision.maintenance; import com.yahoo.config.provision.Capacity; import com.yahoo.config.provision.ClusterResources; import com.yahoo.config.provision.NodeResources; -import com.yahoo.vespa.hosted.provision.autoscale.MetricSnapshot; +import com.yahoo.vespa.hosted.provision.autoscale.NodeMetricSnapshot; import com.yahoo.vespa.hosted.provision.autoscale.MetricsDb; import com.yahoo.vespa.hosted.provision.autoscale.MetricsV2MetricsFetcher; import com.yahoo.vespa.hosted.provision.autoscale.NodeTimeseries; @@ -49,9 +49,9 @@ public class NodeMetricsDbMaintainerTest { List<NodeTimeseries> timeseriesList = db.getNodeTimeseries(Duration.ofDays(1), Set.of("host-1.yahoo.com", "host-2.yahoo.com")); assertEquals(2, timeseriesList.size()); - List<MetricSnapshot> allSnapshots = timeseriesList.stream() - .flatMap(timeseries -> timeseries.asList().stream()) - .collect(Collectors.toList()); + List<NodeMetricSnapshot> allSnapshots = timeseriesList.stream() + .flatMap(timeseries -> timeseries.asList().stream()) + .collect(Collectors.toList()); assertTrue(allSnapshots.stream().anyMatch(snapshot -> snapshot.inService())); assertTrue(allSnapshots.stream().anyMatch(snapshot -> ! snapshot.inService())); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainerTest.java index d5b7903b94c..88d39e887d3 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/ScalingSuggestionsMaintainerTest.java @@ -17,7 +17,7 @@ import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeList; import com.yahoo.vespa.hosted.provision.NodeRepository; import com.yahoo.vespa.hosted.provision.applications.Cluster; -import com.yahoo.vespa.hosted.provision.autoscale.MetricSnapshot; +import com.yahoo.vespa.hosted.provision.autoscale.NodeMetricSnapshot; import com.yahoo.vespa.hosted.provision.autoscale.MetricsDb; import com.yahoo.vespa.hosted.provision.autoscale.Resource; import com.yahoo.vespa.hosted.provision.provisioning.FlavorConfigBuilder; @@ -74,7 +74,7 @@ public class ScalingSuggestionsMaintainerTest { assertEquals("14 nodes with [vcpu: 6.9, memory: 5.1 Gb, disk 15.0 Gb, bandwidth: 0.1 Gbps]", suggestionOf(app1, cluster1, tester).get().resources().toString()); - assertEquals("8 nodes with [vcpu: 14.7, memory: 4.0 Gb, disk 11.8 Gb, bandwidth: 0.1 Gbps]", + assertEquals("9 nodes with [vcpu: 13.8, memory: 4.0 Gb, disk 10.3 Gb, bandwidth: 0.1 Gbps]", suggestionOf(app2, cluster2, tester).get().resources().toString()); // Utilization goes way down @@ -125,14 +125,14 @@ public class ScalingSuggestionsMaintainerTest { NodeList nodes = nodeRepository.nodes().list(Node.State.active).owner(applicationId); for (int i = 0; i < count; i++) { for (Node node : nodes) - db.add(List.of(new Pair<>(node.hostname(), new MetricSnapshot(nodeRepository.clock().instant(), - cpu, - memory, - disk, - generation, - true, - true, - 0.0)))); + db.addNodeMetrics(List.of(new Pair<>(node.hostname(), new NodeMetricSnapshot(nodeRepository.clock().instant(), + cpu, + memory, + disk, + generation, + true, + true, + 0.0)))); } } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisionTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisionTest.java index 4db1b86419b..a6e67f2747c 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisionTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/DynamicDockerProvisionTest.java @@ -1,4 +1,4 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.provisioning; import com.yahoo.component.Version; @@ -247,14 +247,14 @@ public class DynamicDockerProvisionTest { tester.activate(app1, cluster1, Capacity.from(resources(2, 1, 2, 20, 40), resources(4, 1, 2, 20, 40))); tester.assertNodes("Allocation specifies memory in the advertised amount", - 3, 1, 2, 20, 40, + 2, 1, 2, 20, 40, app1, cluster1); // Redeploy the same tester.activate(app1, cluster1, Capacity.from(resources(2, 1, 2, 20, 40), resources(4, 1, 2, 20, 40))); tester.assertNodes("Allocation specifies memory in the advertised amount", - 3, 1, 2, 20, 40, + 2, 1, 2, 20, 40, app1, cluster1); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java index c269b4642ea..0db5453c963 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTest.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.provision.provisioning; import com.yahoo.component.Version; @@ -440,25 +440,7 @@ public class ProvisioningTest { } @Test - public void test_node_limits_only_container() { - Flavor hostFlavor = new Flavor(new NodeResources(20, 40, 100, 4)); - ProvisioningTester tester = new ProvisioningTester.Builder().zone(new Zone(Environment.prod, RegionName.from("us-east"))) - .flavors(List.of(hostFlavor)) - .build(); - tester.makeReadyHosts(4, hostFlavor.resources()).activateTenantHosts(); - - ApplicationId app1 = ProvisioningTester.applicationId("app1"); - ClusterSpec cluster1 = ClusterSpec.request(ClusterSpec.Type.container, new ClusterSpec.Id("cluster1")).vespaVersion("7").build(); - - tester.activate(app1, cluster1, Capacity.from(new ClusterResources(2, 1, NodeResources.unspecified()), - new ClusterResources(4, 1, NodeResources.unspecified()))); - tester.assertNodes("Initial allocation at min with default resources", - 2, 1, 1.5, 8, 50, 0.3, - app1, cluster1); - } - - @Test - public void test_node_limits_only_content() { + public void test_node_limits() { Flavor hostFlavor = new Flavor(new NodeResources(20, 40, 100, 4)); ProvisioningTester tester = new ProvisioningTester.Builder().zone(new Zone(Environment.prod, RegionName.from("us-east"))) .flavors(List.of(hostFlavor)) @@ -471,7 +453,7 @@ public class ProvisioningTest { tester.activate(app1, cluster1, Capacity.from(new ClusterResources(2, 1, NodeResources.unspecified()), new ClusterResources(4, 1, NodeResources.unspecified()))); tester.assertNodes("Initial allocation at (allowable) min with default resources", - 3, 1, 1.5, 8, 50, 0.3, + 2, 1, 1.5, 8, 50, 0.3, app1, cluster1); } diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java index eefbd03ce4e..a30735df01c 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/provisioning/ProvisioningTester.java @@ -226,7 +226,7 @@ public class ProvisioningTester { } } - public void prepareAndActivateInfraApplication(ApplicationId application, NodeType nodeType, Version version) { + public List<HostSpec> prepareAndActivateInfraApplication(ApplicationId application, NodeType nodeType, Version version) { ClusterSpec cluster = ClusterSpec.request(ClusterSpec.Type.container, ClusterSpec.Id.from(nodeType.toString())) .vespaVersion(version) .stateful(nodeType == NodeType.config || nodeType == NodeType.controller) @@ -234,10 +234,11 @@ public class ProvisioningTester { Capacity capacity = Capacity.fromRequiredNodeType(nodeType); List<HostSpec> hostSpecs = prepare(application, cluster, capacity); activate(application, hostSpecs); + return hostSpecs; } - public void prepareAndActivateInfraApplication(ApplicationId application, NodeType nodeType) { - prepareAndActivateInfraApplication(application, nodeType, Version.fromString("6.42")); + public List<HostSpec> prepareAndActivateInfraApplication(ApplicationId application, NodeType nodeType) { + return prepareAndActivateInfraApplication(application, nodeType, Version.fromString("6.42")); } public void deactivate(ApplicationId applicationId) { diff --git a/parent/pom.xml b/parent/pom.xml index 12e82033e1a..043c0dc80dc 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -790,7 +790,7 @@ find zkfacade/src/main/java/org/apache/curator -name package-info.java | \ xargs perl -pi -e 's/major = [0-9]+, minor = [0-9]+, micro = [0-9]+/major = 4, minor = 3, micro = 0/g' --> - <curator.version>4.3.0</curator.version> + <curator.version>5.1.0</curator.version> <jna.version>4.5.2</jna.version> <commons.math3.version>3.6.1</commons.math3.version> <junit.version>5.7.0</junit.version> @@ -813,7 +813,7 @@ <protobuf.version>3.7.0</protobuf.version> <surefire.version>2.22.0</surefire.version> <tensorflow.version>1.12.0</tensorflow.version> - <zookeeper.client.version>3.5.8</zookeeper.client.version> + <zookeeper.client.version>3.6.2</zookeeper.client.version> <doclint>all</doclint> <test.hide>true</test.hide> diff --git a/searchcore/src/tests/proton/documentdb/documentbucketmover/bucketmover_common.h b/searchcore/src/tests/proton/documentdb/documentbucketmover/bucketmover_common.h index e955f017d67..f70a4bfad11 100644 --- a/searchcore/src/tests/proton/documentdb/documentbucketmover/bucketmover_common.h +++ b/searchcore/src/tests/proton/documentdb/documentbucketmover/bucketmover_common.h @@ -64,8 +64,13 @@ struct MyDocumentRetriever : public DocumentRetrieverBaseForTest { using DocumentVector = std::vector<Document::SP>; std::shared_ptr<const DocumentTypeRepo> _repo; DocumentVector _docs; + uint32_t _lid2Fail; - MyDocumentRetriever(std::shared_ptr<const DocumentTypeRepo> repo) : _repo(std::move(repo)), _docs() { + MyDocumentRetriever(std::shared_ptr<const DocumentTypeRepo> repo) + : _repo(std::move(repo)), + _docs(), + _lid2Fail(0) + { _docs.push_back(Document::UP()); // lid 0 invalid } @@ -76,9 +81,11 @@ struct MyDocumentRetriever : public DocumentRetrieverBaseForTest { DocumentMetaData getDocumentMetaData(const DocumentId &) const override { return DocumentMetaData(); } Document::UP getFullDocument(DocumentIdT lid) const override { - return Document::UP(_docs[lid]->clone()); + return (lid != _lid2Fail) ? Document::UP(_docs[lid]->clone()) : Document::UP(); } + void failRetrieveForLid(uint32_t lid) { _lid2Fail = lid; } + CachedSelect::SP parseSelect(const vespalib::string &) const override { return {}; } @@ -115,6 +122,10 @@ struct MySubDb { void insertDocs(const UserDocuments &docs_); + void failRetrieveForLid(uint32_t lid) { + _realRetriever->failRetrieveForLid(lid); + } + BucketId bucket(uint32_t userId) const { return _docs.getBucket(userId); } diff --git a/searchcore/src/tests/proton/documentdb/documentbucketmover/documentbucketmover_v2_test.cpp b/searchcore/src/tests/proton/documentdb/documentbucketmover/documentbucketmover_v2_test.cpp index 99692ec53bd..5223e20bd5c 100644 --- a/searchcore/src/tests/proton/documentdb/documentbucketmover/documentbucketmover_v2_test.cpp +++ b/searchcore/src/tests/proton/documentdb/documentbucketmover/documentbucketmover_v2_test.cpp @@ -6,6 +6,7 @@ #include <vespa/searchcore/proton/server/document_db_maintenance_config.h> #include <vespa/persistence/dummyimpl/dummy_bucket_executor.h> #include <vespa/vespalib/util/threadstackexecutor.h> +#include <vespa/vespalib/util/lambdatask.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/log/log.h> @@ -70,6 +71,14 @@ struct ControllerFixtureBase : public ::testing::Test _bucketHandler.notifyBucketStateChanged(bucket, BucketInfo::ActiveState::NOT_ACTIVE); return *this; } + void failRetrieveForLid(uint32_t lid) { + _ready.failRetrieveForLid(lid); + _notReady.failRetrieveForLid(lid); + } + void fixRetriever() { + _ready.failRetrieveForLid(0); + _notReady.failRetrieveForLid(0); + } const MoveOperationVector &docsMoved() const { return _moveHandler._moves; } @@ -92,6 +101,11 @@ struct ControllerFixtureBase : public ::testing::Test _master.sync(); _master.sync(); // Handle that master schedules onto master again } + template <typename FunctionType> + void masterExecute(FunctionType &&function) { + _master.execute(vespalib::makeLambdaTask(std::forward<FunctionType>(function))); + _master.sync(); + } }; ControllerFixtureBase::ControllerFixtureBase(const BlockableMaintenanceJobConfig &blockableConfig, bool storeMoveDoneContexts) @@ -155,8 +169,10 @@ TEST_F(ControllerFixture, require_that_nothing_is_moved_if_bucket_state_says_so) addReady(_ready.bucket(1)); addReady(_ready.bucket(2)); _bmj.recompute(); - EXPECT_TRUE(_bmj.scanAndMove(4, 3)); - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + EXPECT_TRUE(_bmj.scanAndMove(4, 3)); + EXPECT_TRUE(_bmj.done()); + }); EXPECT_TRUE(docsMoved().empty()); EXPECT_TRUE(bucketsModified().empty()); } @@ -171,9 +187,11 @@ TEST_F(ControllerFixture, require_that_not_ready_bucket_is_moved_to_ready_if_buc EXPECT_EQ(0, numPending()); _bmj.recompute(); EXPECT_EQ(1, numPending()); - EXPECT_FALSE(_bmj.done()); - EXPECT_TRUE(_bmj.scanAndMove(4, 3)); - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + EXPECT_FALSE(_bmj.done()); + EXPECT_TRUE(_bmj.scanAndMove(4, 3)); + EXPECT_TRUE(_bmj.done()); + }); sync(); EXPECT_EQ(0, numPending()); EXPECT_EQ(3u, docsMoved().size()); @@ -189,9 +207,37 @@ TEST_F(ControllerFixture, require_that_ready_bucket_is_moved_to_not_ready_if_buc // bucket 2 should be moved addReady(_ready.bucket(1)); _bmj.recompute(); + masterExecute([this]() { + EXPECT_FALSE(_bmj.done()); + EXPECT_TRUE(_bmj.scanAndMove(4, 3)); + EXPECT_TRUE(_bmj.done()); + }); + sync(); + EXPECT_EQ(2u, docsMoved().size()); + assertEqual(_ready.bucket(2), _ready.docs(2)[0], 1, 2, docsMoved()[0]); + assertEqual(_ready.bucket(2), _ready.docs(2)[1], 1, 2, docsMoved()[1]); + EXPECT_EQ(1u, bucketsModified().size()); + EXPECT_EQ(_ready.bucket(2), bucketsModified()[0]); +} + +TEST_F(ControllerFixture, require_that_bucket_is_moved_even_with_error) +{ + // bucket 2 should be moved + addReady(_ready.bucket(1)); + _bmj.recompute(); + failRetrieveForLid(5); + masterExecute([this]() { + EXPECT_FALSE(_bmj.done()); + EXPECT_TRUE(_bmj.scanAndMove(4, 3)); + EXPECT_TRUE(_bmj.done()); + }); + sync(); EXPECT_FALSE(_bmj.done()); - EXPECT_TRUE(_bmj.scanAndMove(4, 3)); - EXPECT_TRUE(_bmj.done()); + fixRetriever(); + masterExecute([this]() { + EXPECT_TRUE(_bmj.scanAndMove(4, 3)); + EXPECT_TRUE(_bmj.done()); + }); sync(); EXPECT_EQ(2u, docsMoved().size()); assertEqual(_ready.bucket(2), _ready.docs(2)[0], 1, 2, docsMoved()[0]); @@ -210,29 +256,37 @@ TEST_F(ControllerFixture, require_that_we_move_buckets_in_several_steps) _bmj.recompute(); EXPECT_EQ(3, numPending()); - EXPECT_FALSE(_bmj.done()); + masterExecute([this]() { + EXPECT_FALSE(_bmj.done()); - EXPECT_FALSE(_bmj.scanAndMove(1, 2)); - EXPECT_FALSE(_bmj.done()); + EXPECT_FALSE(_bmj.scanAndMove(1, 2)); + EXPECT_FALSE(_bmj.done()); + }); sync(); EXPECT_EQ(2, numPending()); EXPECT_EQ(2u, docsMoved().size()); - EXPECT_FALSE(_bmj.scanAndMove(1, 2)); - EXPECT_FALSE(_bmj.done()); + masterExecute([this]() { + EXPECT_FALSE(_bmj.scanAndMove(1, 2)); + EXPECT_FALSE(_bmj.done()); + }); sync(); EXPECT_EQ(2, numPending()); EXPECT_EQ(4u, docsMoved().size()); - EXPECT_FALSE(_bmj.scanAndMove(1, 2)); - EXPECT_FALSE(_bmj.done()); + masterExecute([this]() { + EXPECT_FALSE(_bmj.scanAndMove(1, 2)); + EXPECT_FALSE(_bmj.done()); + }); sync(); EXPECT_EQ(1, numPending()); EXPECT_EQ(6u, docsMoved().size()); // move bucket 4, docs 3 - EXPECT_TRUE(_bmj.scanAndMove(1,2)); - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + EXPECT_TRUE(_bmj.scanAndMove(1, 2)); + EXPECT_TRUE(_bmj.done()); + }); sync(); EXPECT_EQ(0, numPending()); EXPECT_EQ(7u, docsMoved().size()); @@ -249,15 +303,19 @@ TEST_F(ControllerFixture, require_that_last_bucket_is_moved_before_reporting_don addReady(_ready.bucket(2)); addReady(_notReady.bucket(4)); _bmj.recompute(); - EXPECT_FALSE(_bmj.done()); + masterExecute([this]() { + EXPECT_FALSE(_bmj.done()); - EXPECT_FALSE(_bmj.scanAndMove(1, 1)); - EXPECT_FALSE(_bmj.done()); + EXPECT_FALSE(_bmj.scanAndMove(1, 1)); + EXPECT_FALSE(_bmj.done()); + }); sync(); EXPECT_EQ(1u, docsMoved().size()); EXPECT_EQ(4u, calcAsked().size()); - EXPECT_TRUE(_bmj.scanAndMove(1, 2)); - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + EXPECT_TRUE(_bmj.scanAndMove(1, 2)); + EXPECT_TRUE(_bmj.done()); + }); sync(); EXPECT_EQ(3u, docsMoved().size()); EXPECT_EQ(4u, calcAsked().size()); @@ -271,16 +329,20 @@ TEST_F(ControllerFixture, require_that_active_bucket_is_not_moved_from_ready_to_ _bmj.recompute(); EXPECT_FALSE(_bmj.done()); activateBucket(_ready.bucket(1)); - EXPECT_TRUE(_bmj.scanAndMove(4, 3)); // scan all, delay active bucket 1 - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + EXPECT_TRUE(_bmj.scanAndMove(4, 3)); // scan all, delay active bucket 1 + EXPECT_TRUE(_bmj.done()); + }); sync(); EXPECT_EQ(0u, docsMoved().size()); EXPECT_EQ(0u, bucketsModified().size()); deactivateBucket(_ready.bucket(1)); - EXPECT_FALSE(_bmj.done()); - EXPECT_TRUE(_bmj.scanAndMove(4, 3)); // move delayed and de-activated bucket 1 - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + EXPECT_FALSE(_bmj.done()); + EXPECT_TRUE(_bmj.scanAndMove(4, 3)); // move delayed and de-activated bucket 1 + EXPECT_TRUE(_bmj.done()); + }); sync(); EXPECT_EQ(3u, docsMoved().size()); EXPECT_EQ(1u, bucketsModified().size()); @@ -291,37 +353,46 @@ TEST_F(ControllerFixture, require_that_current_bucket_moving_is_cancelled_when_w { // bucket 1 should be moved addReady(_ready.bucket(2)); - _bmj.recompute(); - _bmj.scanAndMove(1, 1); - EXPECT_FALSE(_bmj.done()); + + masterExecute([this]() { + _bmj.recompute(); + _bmj.scanAndMove(1, 1); + EXPECT_FALSE(_bmj.done()); + }); sync(); EXPECT_EQ(1u, docsMoved().size()); EXPECT_EQ(4u, calcAsked().size()); - changeCalc(); // Not cancelled, bucket 1 still moving to notReady - EXPECT_EQ(4u, calcAsked().size()); - EXPECT_EQ(_ready.bucket(1), calcAsked()[0]); - _calc->resetAsked(); - _bmj.scanAndMove(1, 1); - EXPECT_FALSE(_bmj.done()); + masterExecute([this]() { + changeCalc(); // Not cancelled, bucket 1 still moving to notReady + EXPECT_EQ(4u, calcAsked().size()); + EXPECT_EQ(_ready.bucket(1), calcAsked()[0]); + _calc->resetAsked(); + _bmj.scanAndMove(1, 1); + EXPECT_FALSE(_bmj.done()); + }); sync(); EXPECT_EQ(1u, docsMoved().size()); EXPECT_EQ(0u, calcAsked().size()); addReady(_ready.bucket(1)); - changeCalc(); // cancelled, bucket 1 no longer moving to notReady - EXPECT_EQ(4u, calcAsked().size()); - EXPECT_EQ(_ready.bucket(1), calcAsked()[0]); - _calc->resetAsked(); - remReady(_ready.bucket(1)); - _calc->resetAsked(); - changeCalc(); // not cancelled. No active bucket move - EXPECT_EQ(4u, calcAsked().size()); - _bmj.scanAndMove(1, 1); + masterExecute([this]() { + changeCalc(); // cancelled, bucket 1 no longer moving to notReady + EXPECT_EQ(4u, calcAsked().size()); + EXPECT_EQ(_ready.bucket(1), calcAsked()[0]); + _calc->resetAsked(); + remReady(_ready.bucket(1)); + _calc->resetAsked(); + changeCalc(); // not cancelled. No active bucket move + EXPECT_EQ(4u, calcAsked().size()); + _bmj.scanAndMove(1, 1); + }); sync(); EXPECT_EQ(1u, docsMoved().size()); EXPECT_EQ(4u, calcAsked().size()); EXPECT_EQ(_ready.bucket(2), calcAsked()[1]); EXPECT_EQ(_notReady.bucket(3), calcAsked()[2]); - _bmj.scanAndMove(2, 3); + masterExecute([this]() { + _bmj.scanAndMove(2, 3); + }); EXPECT_TRUE(_bmj.done()); sync(); EXPECT_EQ(3u, docsMoved().size()); @@ -335,16 +406,20 @@ TEST_F(ControllerFixture, require_that_de_activated_bucket_is_not_moved_if_new_c // bucket 1 should be moved addReady(_ready.bucket(2)); _bmj.recompute(); - activateBucket(_ready.bucket(1)); - _bmj.scanAndMove(4, 3); // scan all, delay active bucket 1 + masterExecute([this]() { + activateBucket(_ready.bucket(1)); + _bmj.scanAndMove(4, 3); // scan all, delay active bucket 1 + }); sync(); EXPECT_EQ(0u, docsMoved().size()); EXPECT_EQ(0u, bucketsModified().size()); - deactivateBucket(_ready.bucket(1)); - addReady(_ready.bucket(1)); - changeCalc(); - _bmj.scanAndMove(4, 3); // consider delayed bucket 3 + masterExecute([this]() { + deactivateBucket(_ready.bucket(1)); + addReady(_ready.bucket(1)); + changeCalc(); + _bmj.scanAndMove(4, 3); // consider delayed bucket 3 + }); sync(); EXPECT_EQ(0u, docsMoved().size()); EXPECT_EQ(0u, bucketsModified().size()); @@ -357,9 +432,11 @@ TEST_F(ControllerFixture, ready_bucket_not_moved_to_not_ready_if_node_is_marked_ _calc->setNodeRetired(true); // Bucket 2 would be moved from ready to not ready in a non-retired case, but not when retired. addReady(_ready.bucket(1)); - _bmj.recompute(); - _bmj.scanAndMove(4, 3); - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + _bmj.recompute(); + _bmj.scanAndMove(4, 3); + EXPECT_TRUE(_bmj.done()); + }); sync(); EXPECT_EQ(0u, docsMoved().size()); } @@ -372,9 +449,11 @@ TEST_F(ControllerFixture, inactive_not_ready_bucket_not_moved_to_ready_if_node_i addReady(_ready.bucket(1)); addReady(_ready.bucket(2)); addReady(_notReady.bucket(3)); - _bmj.recompute(); - _bmj.scanAndMove(4, 3); - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + _bmj.recompute(); + _bmj.scanAndMove(4, 3); + EXPECT_TRUE(_bmj.done()); + }); sync(); EXPECT_EQ(0u, docsMoved().size()); } @@ -386,9 +465,11 @@ TEST_F(ControllerFixture, explicitly_active_not_ready_bucket_can_be_moved_to_rea addReady(_ready.bucket(2)); addReady(_notReady.bucket(3)); _bmj.recompute(); - activateBucket(_notReady.bucket(3)); - _bmj.scanAndMove(4, 3); - EXPECT_TRUE(_bmj.done()); + masterExecute([this]() { + activateBucket(_notReady.bucket(3)); + _bmj.scanAndMove(4, 3); + EXPECT_TRUE(_bmj.done()); + }); sync(); ASSERT_EQ(2u, docsMoved().size()); assertEqual(_notReady.bucket(3), _notReady.docs(3)[0], 2, 1, docsMoved()[0]); @@ -412,9 +493,11 @@ TEST_F(ControllerFixture, require_that_notifyCreateBucket_causes_bucket_to_be_re EXPECT_TRUE(_bmj.done()); // move job still believes work done sync(); EXPECT_TRUE(bucketsModified().empty()); - _bmj.notifyCreateBucket(_bucketDB->takeGuard(), _notReady.bucket(3)); // reconsider bucket 3 - EXPECT_FALSE(_bmj.done()); - EXPECT_TRUE(bucketsModified().empty()); + masterExecute([this]() { + _bmj.notifyCreateBucket(_bucketDB->takeGuard(), _notReady.bucket(3)); // reconsider bucket 3 + EXPECT_FALSE(_bmj.done()); + EXPECT_TRUE(bucketsModified().empty()); + }); sync(); EXPECT_TRUE(bucketsModified().empty()); runLoop(); diff --git a/searchcore/src/tests/proton/documentdb/documentbucketmover/documentmover_test.cpp b/searchcore/src/tests/proton/documentdb/documentbucketmover/documentmover_test.cpp index f49d806b6d4..143d8f290c6 100644 --- a/searchcore/src/tests/proton/documentdb/documentbucketmover/documentmover_test.cpp +++ b/searchcore/src/tests/proton/documentdb/documentbucketmover/documentmover_test.cpp @@ -35,6 +35,7 @@ struct DocumentMoverTest : ::testing::Test test::UserDocumentsBuilder _builder; std::shared_ptr<bucketdb::BucketDBOwner> _bucketDB; MyMoveOperationLimiter _limiter; + //TODO When we retire old bucket move job me must make rewrite this test to use the BucketMover directly. DocumentBucketMover _mover; MySubDbTwoBuckets _source; bucketdb::BucketDBOwner _bucketDb; @@ -71,8 +72,10 @@ TEST_F(DocumentMoverTest, require_that_initial_bucket_mover_is_done) MyMoveOperationLimiter limiter; DocumentBucketMover mover(limiter, _bucketDb); EXPECT_TRUE(mover.bucketDone()); + EXPECT_FALSE(mover.needReschedule()); mover.moveDocuments(2); EXPECT_TRUE(mover.bucketDone()); + EXPECT_FALSE(mover.needReschedule()); } TEST_F(DocumentMoverTest, require_that_we_can_move_all_documents) @@ -136,4 +139,16 @@ TEST_F(DocumentMoverTest, require_that_we_can_move_documents_in_several_steps) EXPECT_EQ(5u, _handler._moves.size()); } +TEST_F(DocumentMoverTest, require_that_cancel_signal_rescheduling_need) { + setupForBucket(_source.bucket(1), 6, 9); + EXPECT_FALSE(_mover.bucketDone()); + EXPECT_FALSE(_mover.needReschedule()); + EXPECT_TRUE(moveDocuments(2)); + EXPECT_FALSE(_mover.bucketDone()); + EXPECT_FALSE(_mover.needReschedule()); + _mover.cancel(); + EXPECT_TRUE(_mover.bucketDone()); + EXPECT_TRUE(_mover.needReschedule()); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchcore/src/tests/proton/documentmetastore/documentmetastore_test.cpp b/searchcore/src/tests/proton/documentmetastore/documentmetastore_test.cpp index 8f2c82fb21b..4e04677f972 100644 --- a/searchcore/src/tests/proton/documentmetastore/documentmetastore_test.cpp +++ b/searchcore/src/tests/proton/documentmetastore/documentmetastore_test.cpp @@ -193,7 +193,7 @@ assertWhiteList(const SimpleResult &exp, Blueprint::UP whiteListBlueprint, bool void assertSearchResult(const SimpleResult &exp, const DocumentMetaStore &dms, - const vespalib::string &term, const QueryTermSimple::SearchTerm &termType, + const vespalib::string &term, const QueryTermSimple::Type &termType, bool strict, uint32_t docIdLimit = 100) { AttributeVector::SearchContext::UP sc = diff --git a/searchcore/src/tests/proton/matching/request_context/request_context_test.cpp b/searchcore/src/tests/proton/matching/request_context/request_context_test.cpp index 7b545344e9b..d8aa7e0ffa8 100644 --- a/searchcore/src/tests/proton/matching/request_context/request_context_test.cpp +++ b/searchcore/src/tests/proton/matching/request_context/request_context_test.cpp @@ -67,7 +67,7 @@ TEST_F(RequestContextTest, query_tensor_can_be_retrieved) { auto tensor = get_query_tensor("my_tensor"); ASSERT_TRUE(tensor); - EXPECT_TRUE(tensor->is_tensor()); + EXPECT_TRUE(tensor->type().has_dimensions()); EXPECT_EQ(expected_query_tensor(), spec_from_value(*tensor)); } diff --git a/searchcore/src/vespa/searchcore/proton/attribute/flushableattribute.cpp b/searchcore/src/vespa/searchcore/proton/attribute/flushableattribute.cpp index 956e9a1b430..ae3edc93f6d 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/flushableattribute.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/flushableattribute.cpp @@ -168,7 +168,8 @@ FlushableAttribute::FlushableAttribute(const AttributeVectorSP attr, _lastStats.setPathElementsToLog(8); auto &config = attr->getConfig(); if (config.basicType() == search::attribute::BasicType::Type::TENSOR && - config.tensorType().is_tensor() && config.tensorType().is_dense() && config.hnsw_index_params().has_value()) { + config.tensorType().is_dense() && config.hnsw_index_params().has_value()) + { _replay_operation_cost = 100.0; // replaying operations to hnsw index is 100 times more expensive than reading from tls } } diff --git a/searchcore/src/vespa/searchcore/proton/server/bucketmovejobv2.cpp b/searchcore/src/vespa/searchcore/proton/server/bucketmovejobv2.cpp index 84b693db4ba..1c1a77475fc 100644 --- a/searchcore/src/vespa/searchcore/proton/server/bucketmovejobv2.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/bucketmovejobv2.cpp @@ -53,8 +53,6 @@ blockedDueToClusterState(const std::shared_ptr<IBucketStateCalculator> &calc) return !(clusterUp && nodeUp && !nodeInitializing); } -constexpr BucketId RECOMPUTE_TOKEN; - } BucketMoveJobV2::BucketMoveJobV2(const std::shared_ptr<IBucketStateCalculator> &calc, @@ -88,7 +86,6 @@ BucketMoveJobV2::BucketMoveJobV2(const std::shared_ptr<IBucketStateCalculator> & _movers(), _bucketsInFlight(), _buckets2Move(), - _postponedUntilSafe(), _stopped(false), _startedCount(0), _executedCount(0), @@ -107,7 +104,7 @@ BucketMoveJobV2::BucketMoveJobV2(const std::shared_ptr<IBucketStateCalculator> & _clusterStateChangedNotifier.addClusterStateChangedHandler(this); _bucketStateChangedNotifier.addBucketStateChangedHandler(this); _diskMemUsageNotifier.addDiskMemUsageListener(this); - recompute(); + recompute(_ready.meta_store()->getBucketDB().takeGuard()); } BucketMoveJobV2::~BucketMoveJobV2() @@ -201,13 +198,8 @@ private: void BucketMoveJobV2::failOperation(BucketId bucketId) { IncOnDestruct countGuard(_executedCount); - if (_stopped.load(std::memory_order_relaxed)) return; _master.execute(makeLambdaTask([this, bucketId]() { if (_stopped.load(std::memory_order_relaxed)) return; - cancelBucket(bucketId); - if (_bucketsInFlight.contains(bucketId)) { - handleMoveResult(_bucketsInFlight[bucketId]); - } considerBucket(_ready.meta_store()->getBucketDB().takeGuard(), bucketId); })); } @@ -216,10 +208,9 @@ void BucketMoveJobV2::startMove(BucketMoverSP mover, size_t maxDocsToMove) { auto [keys, done] = mover->getKeysToMove(maxDocsToMove); if (done) { - mover->setBucketDone(); + mover->setAllScheduled(); } if (keys.empty()) return; - if (_stopped.load(std::memory_order_relaxed)) return; mover->updateLastValidGid(keys.back()._gid); Bucket spiBucket(document::Bucket(_bucketSpace, mover->getBucket())); auto bucketTask = std::make_unique<StartMove>(*this, std::move(mover), std::move(keys), getLimiter().beginOperation()); @@ -231,7 +222,6 @@ void BucketMoveJobV2::prepareMove(BucketMoverSP mover, std::vector<MoveKey> keys, IDestructorCallbackSP onDone) { IncOnDestruct countGuard(_executedCount); - if (_stopped.load(std::memory_order_relaxed)) return; auto moveOps = mover->createMoveOperations(std::move(keys)); _master.execute(makeLambdaTask([this, mover=std::move(mover), moveOps=std::move(moveOps), onDone=std::move(onDone)]() mutable { if (_stopped.load(std::memory_order_relaxed)) return; @@ -240,28 +230,40 @@ BucketMoveJobV2::prepareMove(BucketMoverSP mover, std::vector<MoveKey> keys, IDe } void -BucketMoveJobV2::completeMove(BucketMoverSP mover, std::vector<GuardedMoveOp> ops, IDestructorCallbackSP onDone) { - mover->moveDocuments(std::move(ops), std::move(onDone)); - handleMoveResult(std::move(mover)); +BucketMoveJobV2::completeMove(BucketMoverSP mover, GuardedMoveOps ops, IDestructorCallbackSP onDone) { + mover->moveDocuments(std::move(ops.success), std::move(onDone)); + ops.failed.clear(); + if (checkIfMoverComplete(*mover)) { + reconsiderBucket(_ready.meta_store()->getBucketDB().takeGuard(), mover->getBucket()); + } } -void -BucketMoveJobV2::handleMoveResult(BucketMoverSP mover) { - if (mover->bucketDone() && mover->inSync()) { - BucketId bucket = mover->getBucket(); - assert(_bucketsInFlight.contains(bucket)); - _modifiedHandler.notifyBucketModified(bucket); - _bucketsInFlight.erase(bucket); - updatePending(); - if (_postponedUntilSafe.contains(bucket)) { - _postponedUntilSafe.erase(bucket); - reconsiderBucket(_ready.meta_store()->getBucketDB().takeGuard(), bucket); - } - if (_bucketsInFlight.empty() && _postponedUntilSafe.contains(RECOMPUTE_TOKEN)) { - _postponedUntilSafe.erase(RECOMPUTE_TOKEN); - recompute(); +bool +BucketMoveJobV2::checkIfMoverComplete(const BucketMover & mover) { + bool bucketMoveComplete = mover.allScheduled() && mover.inSync(); + bool needReschedule = mover.needReschedule(); + if (bucketMoveComplete || needReschedule) { + BucketId bucket = mover.getBucket(); + auto found = _bucketsInFlight.find(bucket); + if (needReschedule) { + if ((found != _bucketsInFlight.end()) && (&mover == found->second.get())) { + //Prevent old disconnected mover from creating havoc. + _bucketsInFlight.erase(found); + _movers.erase(std::remove_if(_movers.begin(), _movers.end(), + [bucket](const BucketMoverSP &cand) { + return cand->getBucket() == bucket; + }), + _movers.end()); + return true; + } + } else { + assert(found != _bucketsInFlight.end()); + _bucketsInFlight.erase(found); + _modifiedHandler.notifyBucketModified(bucket); } } + updatePending(); + return false; } void @@ -269,21 +271,15 @@ BucketMoveJobV2::cancelBucket(BucketId bucket) { auto inFlight = _bucketsInFlight.find(bucket); if (inFlight != _bucketsInFlight.end()) { inFlight->second->cancel(); - _movers.erase(std::remove_if(_movers.begin(), _movers.end(), - [bucket](const BucketMoverSP &mover) { return mover->getBucket() == bucket; }), - _movers.end()); - handleMoveResult(inFlight->second); + checkIfMoverComplete(*inFlight->second); } } void BucketMoveJobV2::considerBucket(const bucketdb::Guard & guard, BucketId bucket) { cancelBucket(bucket); - if (_bucketsInFlight.contains(bucket)) { - _postponedUntilSafe.insert(bucket); - } else { - reconsiderBucket(guard, bucket); - } + assert( !_bucketsInFlight.contains(bucket)); + reconsiderBucket(guard, bucket); } void @@ -296,6 +292,7 @@ BucketMoveJobV2::reconsiderBucket(const bucketdb::Guard & guard, BucketId bucket } else { _buckets2Move.erase(bucket); } + updatePending(); considerRun(); } @@ -306,10 +303,10 @@ BucketMoveJobV2::notifyCreateBucket(const bucketdb::Guard & guard, const BucketI } BucketMoveJobV2::BucketMoveSet -BucketMoveJobV2::computeBuckets2Move() +BucketMoveJobV2::computeBuckets2Move(const bucketdb::Guard & guard) { BucketMoveJobV2::BucketMoveSet toMove; - for (ScanIterator itr(_ready.meta_store()->getBucketDB().takeGuard(), BucketId()); itr.valid(); ++itr) { + for (ScanIterator itr(guard, BucketId()); itr.valid(); ++itr) { auto [mustMove, wantReady] = needMove(itr); if (mustMove) { toMove[itr.getBucket()] = wantReady; @@ -348,9 +345,9 @@ BucketMoveJobV2::moveDocs(size_t maxDocsToMove) { const auto & mover = _movers[index]; //Move, or reduce movers as we are tailing off - if (!mover->bucketDone()) { + if (!mover->allScheduled()) { startMove(mover, maxDocsToMove); - if (mover->bucketDone()) { + if (mover->allScheduled()) { _movers.erase(_movers.begin() + index); } } @@ -366,7 +363,7 @@ BucketMoveJobV2::scanAndMove(size_t maxBuckets2Move, size_t maxDocsToMovePerBuck bool BucketMoveJobV2::done() const { - return _buckets2Move.empty() && _movers.empty() && _postponedUntilSafe.empty() && !isBlocked(); + return _buckets2Move.empty() && _movers.empty() && !isBlocked(); } bool @@ -389,7 +386,11 @@ BucketMoveJobV2::run() void BucketMoveJobV2::recompute() { - _buckets2Move = computeBuckets2Move(); + recompute(_ready.meta_store()->getBucketDB().takeGuard()); +} +void +BucketMoveJobV2::recompute(const bucketdb::Guard & guard) { + _buckets2Move = computeBuckets2Move(guard); updatePending(); } @@ -400,6 +401,7 @@ BucketMoveJobV2::backFillMovers() { auto mover = greedyCreateMover(); _movers.push_back(mover); auto bucketId = mover->getBucket(); + assert( ! _bucketsInFlight.contains(bucketId)); _bucketsInFlight[bucketId] = std::move(mover); } updatePending(); @@ -416,12 +418,8 @@ BucketMoveJobV2::notifyClusterStateChanged(const std::shared_ptr<IBucketStateCal unBlock(BlockedReason::CLUSTER_STATE); _movers.clear(); std::for_each(_bucketsInFlight.begin(), _bucketsInFlight.end(), [](auto & entry) { entry.second->cancel();}); - std::erase_if(_bucketsInFlight, [](auto & entry) { return entry.second->inSync(); }); - if (_bucketsInFlight.empty()) { - recompute(); - } else { - _postponedUntilSafe.insert(RECOMPUTE_TOKEN); - } + _bucketsInFlight.clear(); + recompute(_ready.meta_store()->getBucketDB().takeGuard()); } } diff --git a/searchcore/src/vespa/searchcore/proton/server/bucketmovejobv2.h b/searchcore/src/vespa/searchcore/proton/server/bucketmovejobv2.h index 0f1c567b1b3..620b76ac81c 100644 --- a/searchcore/src/vespa/searchcore/proton/server/bucketmovejobv2.h +++ b/searchcore/src/vespa/searchcore/proton/server/bucketmovejobv2.h @@ -9,7 +9,6 @@ #include "iclusterstatechangedhandler.h" #include <vespa/searchcore/proton/bucketdb/bucketscaniterator.h> #include <vespa/searchcore/proton/bucketdb/i_bucket_create_listener.h> -#include <vespa/vespalib/stllike/hash_set.h> namespace storage::spi { struct BucketExecutor; } namespace searchcorespi::index { struct IThreadService; } @@ -55,8 +54,7 @@ private: using Bucket2Mover = std::map<BucketId, BucketMoverSP>; using Movers = std::vector<BucketMoverSP>; using MoveKey = BucketMover::MoveKey; - using GuardedMoveOp = BucketMover::GuardedMoveOp; - using BucketSet = vespalib::hash_set<BucketId, BucketId::hash>; + using GuardedMoveOps = BucketMover::GuardedMoveOps; std::shared_ptr<IBucketStateCalculator> _calc; IDocumentMoveHandler &_moveHandler; IBucketModifiedHandler &_modifiedHandler; @@ -69,7 +67,6 @@ private: Movers _movers; Bucket2Mover _bucketsInFlight; BucketMoveSet _buckets2Move; - BucketSet _postponedUntilSafe; std::atomic<bool> _stopped; std::atomic<size_t> _startedCount; @@ -83,19 +80,21 @@ private: void startMove(BucketMoverSP mover, size_t maxDocsToMove); void prepareMove(BucketMoverSP mover, std::vector<MoveKey> keysToMove, IDestructorCallbackSP context); - void completeMove(BucketMoverSP mover, std::vector<GuardedMoveOp> keys, IDestructorCallbackSP context); - void handleMoveResult(BucketMoverSP mover); + void completeMove(BucketMoverSP mover, GuardedMoveOps moveOps, IDestructorCallbackSP context); + bool checkIfMoverComplete(const BucketMover & mover); + void checkForReschedule(const bucketdb::Guard & guard, BucketId bucket); void considerBucket(const bucketdb::Guard & guard, BucketId bucket); void reconsiderBucket(const bucketdb::Guard & guard, BucketId bucket); void updatePending(); void cancelBucket(BucketId bucket); // True if something to cancel NeedResult needMove(const ScanIterator &itr) const; - BucketMoveSet computeBuckets2Move(); + BucketMoveSet computeBuckets2Move(const bucketdb::Guard & guard); BucketMoverSP createMover(BucketId bucket, bool wantReady); BucketMoverSP greedyCreateMover(); void backFillMovers(); void moveDocs(size_t maxDocsToMove); void failOperation(BucketId bucket); + void recompute(const bucketdb::Guard & guard); friend class StartMove; public: BucketMoveJobV2(const std::shared_ptr<IBucketStateCalculator> &calc, @@ -117,7 +116,7 @@ public: bool scanAndMove(size_t maxBuckets2Move, size_t maxDocsToMovePerBucket); bool done() const; - void recompute(); + void recompute(); // Only for testing bool inSync() const; bool run() override; diff --git a/searchcore/src/vespa/searchcore/proton/server/documentbucketmover.cpp b/searchcore/src/vespa/searchcore/proton/server/documentbucketmover.cpp index 662e9e2e920..f08cd4d7ab7 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentbucketmover.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/documentbucketmover.cpp @@ -20,8 +20,8 @@ namespace proton::bucketdb { typedef IDocumentMetaStore::Iterator Iterator; -BucketMover::GuardedMoveOp -BucketMover::createMoveOperation(MoveKey &key) { +MoveOperation::UP +BucketMover::createMoveOperation(const MoveKey &key) { if (_source->lidNeedsCommit(key._lid)) return {}; const RawDocumentMetaData &metaNow = _source->meta_store()->getRawMetaData(key._lid); @@ -29,13 +29,13 @@ BucketMover::createMoveOperation(MoveKey &key) { if (metaNow.getTimestamp() != key._timestamp) return {}; Document::SP doc(_source->retriever()->getFullDocument(key._lid)); - if (!doc || doc->getId().getGlobalId() != key._gid) - return {}; // Failed to retrieve document, removed or changed identity + if (!doc || doc->getId().getGlobalId() != key._gid) { + // Failed to retrieve document, removed or changed identity + return {}; + } BucketId bucketId = _bucket.stripUnused(); - return BucketMover::GuardedMoveOp(std::make_unique<MoveOperation>(bucketId, key._timestamp, std::move(doc), - DbDocumentId(_source->sub_db_id(), key._lid), - _targetSubDbId), - std::move(key._guard)); + return std::make_unique<MoveOperation>(bucketId, key._timestamp, std::move(doc), + DbDocumentId(_source->sub_db_id(), key._lid), _targetSubDbId); } void @@ -60,7 +60,8 @@ BucketMover::BucketMover(const BucketId &bucket, const MaintenanceDocumentSubDB _targetSubDbId(targetSubDbId), _started(0), _completed(0), - _bucketDone(false), + _needReschedule(false), + _allScheduled(false), _lastGidValid(false), _lastGid() { } @@ -88,18 +89,26 @@ BucketMover::getKeysToMove(size_t maxDocsToMove) { return result; } -std::vector<BucketMover::GuardedMoveOp> +BucketMover::GuardedMoveOps BucketMover::createMoveOperations(std::vector<MoveKey> toMove) { - std::vector<GuardedMoveOp> successfulReads; - successfulReads.reserve(toMove.size()); + GuardedMoveOps moveOps; + moveOps.success.reserve(toMove.size()); for (MoveKey &key : toMove) { - auto moveOp = createMoveOperation(key); - if (!moveOp.first) { - break; + if (moveOps.failed.empty()) { + auto moveOp = createMoveOperation(key); + if (moveOp) { + moveOps.success.emplace_back(std::move(moveOp), std::move(key._guard)); + } else { + moveOps.failed.push_back(std::move(key._guard)); + } + } else { + moveOps.failed.push_back(std::move(key._guard)); } - successfulReads.push_back(std::move(moveOp)); } - return successfulReads; + if ( ! moveOps.failed.empty()) { + _needReschedule.store(true, std::memory_order_relaxed); + } + return moveOps; } void @@ -109,6 +118,12 @@ BucketMover::moveDocuments(std::vector<GuardedMoveOp> moveOps, IDestructorCallba } } +void +BucketMover::cancel() { + setAllScheduled(); + _needReschedule.store(true, std::memory_order_relaxed); +} + } namespace proton { @@ -137,21 +152,20 @@ DocumentBucketMover::moveDocuments(size_t maxDocsToMove) { bool DocumentBucketMover::moveDocuments(size_t maxDocsToMove, IMoveOperationLimiter &limiter) { - if (_impl->bucketDone()) { + if (_impl->allScheduled()) { return true; } auto [keys, done] = _impl->getKeysToMove(maxDocsToMove); - size_t numKeys = keys.size(); auto moveOps = _impl->createMoveOperations(std::move(keys)); - bool allOk = (numKeys == moveOps.size()); + bool allOk = moveOps.failed.empty(); if (done && allOk) { - _impl->setBucketDone(); + _impl->setAllScheduled(); } - if (moveOps.empty()) return allOk; + if (moveOps.success.empty()) return allOk; - _impl->updateLastValidGid(moveOps.back().first->getDocument()->getId().getGlobalId()); + _impl->updateLastValidGid(moveOps.success.back().first->getDocument()->getId().getGlobalId()); - for (auto & moveOp : moveOps) { + for (auto & moveOp : moveOps.success) { // We cache the bucket for the document we are going to move to avoid getting // inconsistent bucket info (getBucketInfo()) while moving between ready and not-ready // sub dbs as the bucket info is not updated atomically in this case. diff --git a/searchcore/src/vespa/searchcore/proton/server/documentbucketmover.h b/searchcore/src/vespa/searchcore/proton/server/documentbucketmover.h index 807151b0769..fc7760a4dc4 100644 --- a/searchcore/src/vespa/searchcore/proton/server/documentbucketmover.h +++ b/searchcore/src/vespa/searchcore/proton/server/documentbucketmover.h @@ -64,6 +64,12 @@ public: MoveGuard _guard; }; + using GuardedMoveOp = std::pair<MoveOperationUP, MoveGuard>; + struct GuardedMoveOps { + std::vector<GuardedMoveOp> success; + std::vector<MoveGuard> failed; + }; + BucketMover(const document::BucketId &bucket, const MaintenanceDocumentSubDB *source, uint32_t targetSubDbId, IDocumentMoveHandler &handler) noexcept; BucketMover(BucketMover &&) noexcept = delete; @@ -72,20 +78,20 @@ public: BucketMover & operator=(const BucketMover &) = delete; ~BucketMover(); - using GuardedMoveOp = std::pair<MoveOperationUP, MoveGuard>; /// Must be called in master thread std::pair<std::vector<MoveKey>, bool> getKeysToMove(size_t maxDocsToMove); /// Call from any thread - std::vector<GuardedMoveOp> createMoveOperations(std::vector<MoveKey> toMove); + GuardedMoveOps createMoveOperations(std::vector<MoveKey> toMove); /// Must be called in master thread void moveDocuments(std::vector<GuardedMoveOp> moveOps, IDestructorCallbackSP onDone); void moveDocument(MoveOperationUP moveOp, IDestructorCallbackSP onDone); const document::BucketId &getBucket() const { return _bucket; } - void cancel() { setBucketDone(); } - void setBucketDone() { _bucketDone = true; } + void cancel(); + void setAllScheduled() { _allScheduled = true; } /// Signals all documents have been scheduled for move - bool bucketDone() const { return _bucketDone; } + bool allScheduled() const { return _allScheduled; } + bool needReschedule() const { return _needReschedule.load(std::memory_order_relaxed); } const MaintenanceDocumentSubDB * getSource() const { return _source; } /// Must be called in master thread void updateLastValidGid(const document::GlobalId &gid) { @@ -103,10 +109,11 @@ private: std::atomic<uint32_t> _started; std::atomic<uint32_t> _completed; - bool _bucketDone; // All moves started, or operation has been cancelled + std::atomic<bool> _needReschedule; + bool _allScheduled; // All moves started, or operation has been cancelled bool _lastGidValid; document::GlobalId _lastGid; - GuardedMoveOp createMoveOperation(MoveKey & key); + MoveOperationUP createMoveOperation(const MoveKey & key); size_t pending() const { return _started.load(std::memory_order_relaxed) - _completed.load(std::memory_order_relaxed); } @@ -139,8 +146,9 @@ public: const document::BucketId &getBucket() const { return _impl->getBucket(); } bool moveDocuments(size_t maxDocsToMove); void cancel() { _impl->cancel(); } + bool needReschedule() { return _impl && _impl->needReschedule(); } bool bucketDone() const { - return !_impl || _impl->bucketDone(); + return !_impl || _impl->allScheduled(); } const MaintenanceDocumentSubDB * getSource() const { return _impl->getSource(); } }; diff --git a/searchcore/src/vespa/searchcore/proton/server/memory_flush_config_updater.cpp b/searchcore/src/vespa/searchcore/proton/server/memory_flush_config_updater.cpp index 1e60bb4f388..88e2096aa63 100644 --- a/searchcore/src/vespa/searchcore/proton/server/memory_flush_config_updater.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/memory_flush_config_updater.cpp @@ -128,14 +128,14 @@ MemoryFlushConfigUpdater::convertConfig(const ProtonConfig::Flush::Memory &confi const size_t hardMemoryLimit = getHardMemoryLimit(memory); size_t totalMaxMemory = config.maxmemory; if (totalMaxMemory > hardMemoryLimit) { - LOG(info, "flush.memory.maxmemory=%" PRId64 " cannot" + LOG(debug, "flush.memory.maxmemory=%" PRId64 " cannot" " be set above the hard limit of %ld so we cap it to the hard limit", config.maxmemory, hardMemoryLimit); totalMaxMemory = hardMemoryLimit; } size_t eachMaxMemory = config.each.maxmemory; if (eachMaxMemory > hardMemoryLimit) { - LOG(info, "flush.memory.each.maxmemory=%" PRId64 " cannot" + LOG(debug, "flush.memory.each.maxmemory=%" PRId64 " cannot" " be set above the hard limit of %ld so we cap it to the hard limit", config.maxmemory, hardMemoryLimit); eachMaxMemory = hardMemoryLimit; diff --git a/searchlib/src/tests/attribute/enum_comparator/enum_comparator_test.cpp b/searchlib/src/tests/attribute/enum_comparator/enum_comparator_test.cpp index 4e1e1c6d792..087968ff8d9 100644 --- a/searchlib/src/tests/attribute/enum_comparator/enum_comparator_test.cpp +++ b/searchlib/src/tests/attribute/enum_comparator/enum_comparator_test.cpp @@ -24,73 +24,108 @@ using TreeType = BTreeRoot<EnumIndex, BTreeNoLeafData, const vespalib::datastore::EntryComparatorWrapper>; using NodeAllocator = TreeType::NodeAllocatorType; -class Test : public vespalib::TestApp { -private: - void requireThatNumericComparatorIsWorking(); - void requireThatFloatComparatorIsWorking(); - void requireThatStringComparatorIsWorking(); - void requireThatComparatorWithTreeIsWorking(); - void requireThatFoldedComparatorIsWorking(); - -public: - Test() {} - int Main() override; -}; - -void -Test::requireThatNumericComparatorIsWorking() + +TEST("requireThatNumericLessIsWorking") +{ + NumericEnumStore es(false); + EnumIndex e1 = es.insert(10); + EnumIndex e2 = es.insert(30); + auto cmp1 = es.make_comparator(); + EXPECT_TRUE(cmp1.less(e1, e2)); + EXPECT_FALSE(cmp1.less(e2, e1)); + EXPECT_FALSE(cmp1.less(e1, e1)); + auto cmp2 = es.make_comparator(20); + EXPECT_TRUE(cmp2.less(EnumIndex(), e2)); + EXPECT_FALSE(cmp2.less(e2, EnumIndex())); +} + +TEST("requireThatNumericEqualIsWorking") { NumericEnumStore es(false); EnumIndex e1 = es.insert(10); EnumIndex e2 = es.insert(30); auto cmp1 = es.make_comparator(); - EXPECT_TRUE(cmp1(e1, e2)); - EXPECT_TRUE(!cmp1(e2, e1)); - EXPECT_TRUE(!cmp1(e1, e1)); + EXPECT_FALSE(cmp1.equal(e1, e2)); + EXPECT_FALSE(cmp1.equal(e2, e1)); + EXPECT_TRUE(cmp1.equal(e1, e1)); auto cmp2 = es.make_comparator(20); - EXPECT_TRUE(cmp2(EnumIndex(), e2)); - EXPECT_TRUE(!cmp2(e2, EnumIndex())); + EXPECT_FALSE(cmp2.equal(EnumIndex(), e2)); + EXPECT_FALSE(cmp2.equal(e2, EnumIndex())); + EXPECT_TRUE(cmp2.equal(EnumIndex(), EnumIndex())); } -void -Test::requireThatFloatComparatorIsWorking() +TEST("requireThatFloatLessIsWorking") { FloatEnumStore es(false); EnumIndex e1 = es.insert(10.5); EnumIndex e2 = es.insert(30.5); EnumIndex e3 = es.insert(std::numeric_limits<float>::quiet_NaN()); auto cmp1 = es.make_comparator(); - EXPECT_TRUE(cmp1(e1, e2)); - EXPECT_TRUE(!cmp1(e2, e1)); - EXPECT_TRUE(!cmp1(e1, e1)); - EXPECT_TRUE(cmp1(e3, e1)); // nan - EXPECT_TRUE(!cmp1(e1, e3)); // nan - EXPECT_TRUE(!cmp1(e3, e3)); // nan + EXPECT_TRUE(cmp1.less(e1, e2)); + EXPECT_FALSE(cmp1.less(e2, e1)); + EXPECT_FALSE(cmp1.less(e1, e1)); + EXPECT_TRUE(cmp1.less(e3, e1)); // nan + EXPECT_FALSE(cmp1.less(e1, e3)); // nan + EXPECT_FALSE(cmp1.less(e3, e3)); // nan auto cmp2 = es.make_comparator(20.5); - EXPECT_TRUE(cmp2(EnumIndex(), e2)); - EXPECT_TRUE(!cmp2(e2, EnumIndex())); + EXPECT_TRUE(cmp2.less(EnumIndex(), e2)); + EXPECT_FALSE(cmp2.less(e2, EnumIndex())); } -void -Test::requireThatStringComparatorIsWorking() +TEST("requireThatFloatEqualIsWorking") +{ + FloatEnumStore es(false); + EnumIndex e1 = es.insert(10.5); + EnumIndex e2 = es.insert(30.5); + EnumIndex e3 = es.insert(std::numeric_limits<float>::quiet_NaN()); + auto cmp1 = es.make_comparator(); + EXPECT_FALSE(cmp1.equal(e1, e2)); + EXPECT_FALSE(cmp1.equal(e2, e1)); + EXPECT_TRUE(cmp1.equal(e1, e1)); + EXPECT_FALSE(cmp1.equal(e3, e1)); // nan + EXPECT_FALSE(cmp1.equal(e1, e3)); // nan + EXPECT_TRUE(cmp1.equal(e3, e3)); // nan + auto cmp2 = es.make_comparator(20.5); + EXPECT_FALSE(cmp2.equal(EnumIndex(), e2)); + EXPECT_FALSE(cmp2.equal(e2, EnumIndex())); + EXPECT_TRUE(cmp2.equal(EnumIndex(), EnumIndex())); +} + +TEST("requireThatStringLessIsWorking") { StringEnumStore es(false); EnumIndex e1 = es.insert("Aa"); EnumIndex e2 = es.insert("aa"); EnumIndex e3 = es.insert("aB"); auto cmp1 = es.make_comparator(); - EXPECT_TRUE(cmp1(e1, e2)); // similar folded, fallback to regular - EXPECT_TRUE(!cmp1(e2, e1)); - EXPECT_TRUE(!cmp1(e1, e1)); - EXPECT_TRUE(cmp1(e2, e3)); // folded compare + EXPECT_TRUE(cmp1.less(e1, e2)); // similar folded, fallback to regular + EXPECT_FALSE(cmp1.less(e2, e1)); + EXPECT_FALSE(cmp1.less(e1, e1)); + EXPECT_TRUE(cmp1.less(e2, e3)); // folded compare EXPECT_TRUE(strcmp("aa", "aB") > 0); // regular auto cmp2 = es.make_comparator("AB"); - EXPECT_TRUE(cmp2(EnumIndex(), e3)); - EXPECT_TRUE(!cmp2(e3, EnumIndex())); + EXPECT_TRUE(cmp2.less(EnumIndex(), e3)); + EXPECT_FALSE(cmp2.less(e3, EnumIndex())); } -void -Test::requireThatComparatorWithTreeIsWorking() +TEST("requireThatStringEqualIsWorking") +{ + StringEnumStore es(false); + EnumIndex e1 = es.insert("Aa"); + EnumIndex e2 = es.insert("aa"); + EnumIndex e3 = es.insert("aB"); + auto cmp1 = es.make_comparator(); + EXPECT_FALSE(cmp1.equal(e1, e2)); // similar folded, fallback to regular + EXPECT_FALSE(cmp1.equal(e2, e1)); + EXPECT_TRUE(cmp1.equal(e1, e1)); + EXPECT_FALSE(cmp1.equal(e2, e3)); // folded compare + auto cmp2 = es.make_comparator("AB"); + EXPECT_FALSE(cmp2.equal(EnumIndex(), e3)); + EXPECT_FALSE(cmp2.equal(e3, EnumIndex())); + EXPECT_TRUE(cmp2.equal(EnumIndex(), EnumIndex())); +} + +TEST("requireThatComparatorWithTreeIsWorking") { NumericEnumStore es(false); vespalib::GenerationHandler g; @@ -98,7 +133,7 @@ Test::requireThatComparatorWithTreeIsWorking() NodeAllocator m; for (int32_t v = 100; v > 0; --v) { auto cmp = es.make_comparator(v); - EXPECT_TRUE(!t.find(EnumIndex(), m, cmp).valid()); + EXPECT_FALSE(t.find(EnumIndex(), m, cmp).valid()); EnumIndex idx = es.insert(v); t.insert(idx, BTreeNoLeafData(), m, cmp); } @@ -115,8 +150,7 @@ Test::requireThatComparatorWithTreeIsWorking() m.trimHoldLists(g.getFirstUsedGeneration()); } -void -Test::requireThatFoldedComparatorIsWorking() +TEST("requireThatFoldedLessIsWorking") { StringEnumStore es(false); EnumIndex e1 = es.insert("Aa"); @@ -124,33 +158,42 @@ Test::requireThatFoldedComparatorIsWorking() EnumIndex e3 = es.insert("aB"); EnumIndex e4 = es.insert("Folded"); auto cmp1 = es.make_folded_comparator(); - EXPECT_TRUE(!cmp1(e1, e2)); // similar folded - EXPECT_TRUE(!cmp1(e2, e1)); // similar folded - EXPECT_TRUE(cmp1(e2, e3)); // folded compare - EXPECT_TRUE(!cmp1(e3, e2)); // folded compare + EXPECT_FALSE(cmp1.less(e1, e2)); // similar folded + EXPECT_FALSE(cmp1.less(e2, e1)); // similar folded + EXPECT_TRUE(cmp1.less(e2, e3)); // folded compare + EXPECT_FALSE(cmp1.less(e3, e2)); // folded compare auto cmp2 = es.make_folded_comparator("fol", false); auto cmp3 = es.make_folded_comparator("fol", true); - EXPECT_TRUE(cmp2(EnumIndex(), e4)); - EXPECT_TRUE(!cmp2(e4, EnumIndex())); - EXPECT_TRUE(!cmp3(EnumIndex(), e4)); // similar when prefix - EXPECT_TRUE(!cmp3(e4, EnumIndex())); // similar when prefix + EXPECT_TRUE(cmp2.less(EnumIndex(), e4)); + EXPECT_FALSE(cmp2.less(e4, EnumIndex())); + EXPECT_FALSE(cmp3.less(EnumIndex(), e4)); // similar when prefix + EXPECT_FALSE(cmp3.less(e4, EnumIndex())); // similar when prefix } -int -Test::Main() +TEST("requireThatFoldedEqualIsWorking") { - TEST_INIT("comparator_test"); - - requireThatNumericComparatorIsWorking(); - requireThatFloatComparatorIsWorking(); - requireThatStringComparatorIsWorking(); - requireThatComparatorWithTreeIsWorking(); - requireThatFoldedComparatorIsWorking(); + StringEnumStore es(false); + EnumIndex e1 = es.insert("Aa"); + EnumIndex e2 = es.insert("aa"); + EnumIndex e3 = es.insert("aB"); + EnumIndex e4 = es.insert("Folded"); + auto cmp1 = es.make_folded_comparator(); + EXPECT_TRUE(cmp1.equal(e1, e1)); // similar folded + EXPECT_TRUE(cmp1.equal(e2, e1)); // similar folded + EXPECT_TRUE(cmp1.equal(e2, e1)); + EXPECT_FALSE(cmp1.equal(e2, e3)); // folded compare + EXPECT_FALSE(cmp1.equal(e3, e2)); // folded compare + auto cmp2 = es.make_folded_comparator("fol", false); + auto cmp3 = es.make_folded_comparator("fol", true); + EXPECT_FALSE(cmp2.equal(EnumIndex(), e4)); + EXPECT_FALSE(cmp2.equal(e4, EnumIndex())); + EXPECT_TRUE(cmp2.equal(EnumIndex(), EnumIndex())); + EXPECT_FALSE(cmp3.equal(EnumIndex(), e4)); // similar when prefix + EXPECT_FALSE(cmp3.equal(e4, EnumIndex())); // similar when prefix + EXPECT_TRUE(cmp3.equal(EnumIndex(), EnumIndex())); // similar when prefix - TEST_DONE(); } } -TEST_APPHOOK(search::Test); - +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/attribute/reference_attribute/reference_attribute_test.cpp b/searchlib/src/tests/attribute/reference_attribute/reference_attribute_test.cpp index b6b8e0a60e8..a8a34b0ba8a 100644 --- a/searchlib/src/tests/attribute/reference_attribute/reference_attribute_test.cpp +++ b/searchlib/src/tests/attribute/reference_attribute/reference_attribute_test.cpp @@ -476,7 +476,7 @@ struct ReferenceAttributeSearchTest : public ReferenceAttributeTest { } void expect_search_result(const std::string& term, const FakeResult& expected) { - auto ctx = _attr->getSearch(std::make_unique<QueryTermSimple>(term, QueryTermSimple::WORD), + auto ctx = _attr->getSearch(std::make_unique<QueryTermSimple>(term, QueryTermSimple::Type::WORD), SearchContextParams()); TermFieldMatchData tfmd; auto itr = ctx->createIterator(&tfmd, false); diff --git a/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp b/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp index 11da9701b92..1af70946fa2 100644 --- a/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp +++ b/searchlib/src/tests/attribute/searchcontext/searchcontext_test.cpp @@ -61,6 +61,7 @@ using fef::TermFieldMatchDataPosition; using queryeval::HitCollector; using queryeval::SearchIterator; using queryeval::SimpleResult; +using TermType = search::QueryTermSimple::Type; class DocSet : public std::set<uint32_t> { @@ -117,7 +118,7 @@ public: static void addReservedDoc(AttributeVector &ptr); static void addDocs(AttributeVector & ptr, uint32_t numDocs); template <typename V, typename T> - static SearchContextPtr getSearch(const V & vec, const T & term, QueryTermSimple::SearchTerm termType=QueryTermSimple::WORD); + static SearchContextPtr getSearch(const V & vec, const T & term, TermType termType=TermType::WORD); private: typedef std::map<vespalib::string, Config> ConfigMap; // Map of all config objects @@ -137,14 +138,14 @@ private: template <typename V, typename T> void fillPostingList(PostingList<V, T> & pl); static void buildTermQuery(std::vector<char> & buffer, const vespalib::string & index, const vespalib::string & term, - QueryTermSimple::SearchTerm termType=QueryTermSimple::WORD); + TermType termType=TermType::WORD); ResultSetPtr performSearch(SearchIterator & sb, uint32_t numDocs); template <typename V, typename T> - ResultSetPtr performSearch(const V & vec, const T & term, QueryTermSimple::SearchTerm termType=QueryTermSimple::WORD); + ResultSetPtr performSearch(const V & vec, const T & term, TermType termType=TermType::WORD); template <typename V> void performSearch(const V & vec, const vespalib::string & term, - const DocSet & expected, QueryTermSimple::SearchTerm termType); + const DocSet & expected, TermType termType); void checkResultSet(const ResultSet & rs, const DocSet & exp, bool bitVector); template<typename T, typename A> @@ -236,7 +237,7 @@ private: // test prefix search void performPrefixSearch(const StringAttribute & vec, const vespalib::string & term, - const DocSet & expected, QueryTermSimple::SearchTerm termType); + const DocSet & expected, TermType termType); void testPrefixSearch(const AttributePtr & ptr); void testPrefixSearch(); @@ -390,7 +391,7 @@ SearchContextTest::fillPostingList(PostingList<V, T> & pl) } void -SearchContextTest::buildTermQuery(std::vector<char> & buffer, const vespalib::string & index, const vespalib::string & term, QueryTermSimple::SearchTerm termType) +SearchContextTest::buildTermQuery(std::vector<char> & buffer, const vespalib::string & index, const vespalib::string & term, TermType termType) { uint32_t indexLen = index.size(); uint32_t termLen = term.size(); @@ -398,8 +399,8 @@ SearchContextTest::buildTermQuery(std::vector<char> & buffer, const vespalib::st uint32_t p = 0; buffer.resize(queryPacketSize); switch (termType) { - case QueryTermSimple::PREFIXTERM: buffer[p++] = ParseItem::ITEM_PREFIXTERM; break; - case QueryTermSimple::REGEXP: buffer[p++] = ParseItem::ITEM_REGEXP; break; + case TermType::PREFIXTERM: buffer[p++] = ParseItem::ITEM_PREFIXTERM; break; + case TermType::REGEXP: buffer[p++] = ParseItem::ITEM_REGEXP; break; default: buffer[p++] = ParseItem::ITEM_TERM; break; @@ -415,7 +416,7 @@ SearchContextTest::buildTermQuery(std::vector<char> & buffer, const vespalib::st template <typename V, typename T> SearchContextPtr -SearchContextTest::getSearch(const V & vec, const T & term, QueryTermSimple::SearchTerm termType) +SearchContextTest::getSearch(const V & vec, const T & term, TermType termType) { std::vector<char> query; vespalib::asciistream ss; @@ -441,7 +442,7 @@ SearchContextTest::performSearch(SearchIterator & sb, uint32_t numDocs) template <typename V, typename T> ResultSetPtr -SearchContextTest::performSearch(const V & vec, const T & term, QueryTermSimple::SearchTerm termType) +SearchContextTest::performSearch(const V & vec, const T & term, TermType termType) { TermFieldMatchData dummy; SearchContextPtr sc = getSearch(vec, term, termType); @@ -454,7 +455,7 @@ SearchContextTest::performSearch(const V & vec, const T & term, QueryTermSimple: template <typename V> void SearchContextTest::performSearch(const V & vec, const vespalib::string & term, - const DocSet & expected, QueryTermSimple::SearchTerm termType) + const DocSet & expected, TermType termType) { #if 0 std::cout << "performSearch[" << term << "]: {"; @@ -1100,7 +1101,7 @@ void SearchContextTest::performRangeSearch(const VectorType & vec, const vespalib::string & term, const DocSet & expected) { - performSearch(vec, term, expected, QueryTermSimple::WORD); + performSearch(vec, term, expected, TermType::WORD); } template <typename VectorType, typename ValueType> @@ -1307,7 +1308,7 @@ void SearchContextTest::performCaseInsensitiveSearch(const StringAttribute & vec, const vespalib::string & term, const DocSet & expected) { - performSearch(vec, term, expected, QueryTermSimple::WORD); + performSearch(vec, term, expected, TermType::WORD); } void @@ -1403,8 +1404,8 @@ SearchContextTest::testRegexSearch(const AttributePtr & ptr) } for (uint32_t i = 0; i < terms.size(); ++i) { - performSearch(vec, terms[i], expected[i], QueryTermSimple::REGEXP); - performSearch(vec, terms[i], empty, QueryTermSimple::WORD); + performSearch(vec, terms[i], expected[i], TermType::REGEXP); + performSearch(vec, terms[i], empty, TermType::WORD); } } @@ -1432,7 +1433,7 @@ SearchContextTest::testRegexSearch() void SearchContextTest::performPrefixSearch(const StringAttribute & vec, const vespalib::string & term, - const DocSet & expected, QueryTermSimple::SearchTerm termType) + const DocSet & expected, TermType termType) { performSearch(vec, term, expected, termType); } @@ -1477,11 +1478,11 @@ SearchContextTest::testPrefixSearch(const AttributePtr & ptr) for (uint32_t i = 0; i < 4; ++i) { for (uint32_t j = 0; j < 3; ++j) { if (j == 0 || ptr->getConfig().fastSearch()) { - performPrefixSearch(vec, terms[i][j], expected[i], QueryTermSimple::PREFIXTERM); - performPrefixSearch(vec, terms[i][j], empty, QueryTermSimple::WORD); + performPrefixSearch(vec, terms[i][j], expected[i], TermType::PREFIXTERM); + performPrefixSearch(vec, terms[i][j], empty, TermType::WORD); } else { - performPrefixSearch(vec, terms[i][j], empty, QueryTermSimple::PREFIXTERM); - performPrefixSearch(vec, terms[i][j], empty, QueryTermSimple::WORD); + performPrefixSearch(vec, terms[i][j], empty, TermType::PREFIXTERM); + performPrefixSearch(vec, terms[i][j], empty, TermType::WORD); } } } @@ -1779,15 +1780,15 @@ SearchContextTest::requireThatFlagAttributeHandlesTheByteRange() fa.append(5, 127, 1); fa.commit(true); - performSearch(fa, "-128", DocSet().put(1), QueryTermSimple::WORD); - performSearch(fa, "127", DocSet().put(5), QueryTermSimple::WORD); - performSearch(fa, ">-128", DocSet().put(2).put(3).put(4).put(5), QueryTermSimple::WORD); - performSearch(fa, "<127", DocSet().put(1).put(2).put(3).put(4), QueryTermSimple::WORD); - performSearch(fa, "[-128;-8]", DocSet().put(1).put(2), QueryTermSimple::WORD); - performSearch(fa, "[-8;8]", DocSet().put(2).put(3), QueryTermSimple::WORD); - performSearch(fa, "[8;127]", DocSet().put(3).put(4).put(5), QueryTermSimple::WORD); - performSearch(fa, "[-129;-8]", DocSet().put(1).put(2), QueryTermSimple::WORD); - performSearch(fa, "[8;128]", DocSet().put(3).put(4).put(5), QueryTermSimple::WORD); + performSearch(fa, "-128", DocSet().put(1), TermType::WORD); + performSearch(fa, "127", DocSet().put(5), TermType::WORD); + performSearch(fa, ">-128", DocSet().put(2).put(3).put(4).put(5), TermType::WORD); + performSearch(fa, "<127", DocSet().put(1).put(2).put(3).put(4), TermType::WORD); + performSearch(fa, "[-128;-8]", DocSet().put(1).put(2), TermType::WORD); + performSearch(fa, "[-8;8]", DocSet().put(2).put(3), TermType::WORD); + performSearch(fa, "[8;127]", DocSet().put(3).put(4).put(5), TermType::WORD); + performSearch(fa, "[-129;-8]", DocSet().put(1).put(2), TermType::WORD); + performSearch(fa, "[8;128]", DocSet().put(3).put(4).put(5), TermType::WORD); } void @@ -1838,7 +1839,7 @@ public: _attr.commit(); } search::AttributeVector::SearchContext::UP create_search_context(const std::string& term) const { - return _attr.getSearch(std::make_unique<search::QueryTermSimple>(term, search::QueryTermSimple::WORD), + return _attr.getSearch(std::make_unique<search::QueryTermSimple>(term, search::TermType::WORD), SearchContextParams().useBitVector(true)); } SimpleResult search_context(const std::string& term) const { diff --git a/searchlib/src/tests/attribute/searchcontextelementiterator/searchcontextelementiterator_test.cpp b/searchlib/src/tests/attribute/searchcontextelementiterator/searchcontextelementiterator_test.cpp index 22fd409320c..7f9577eb13f 100644 --- a/searchlib/src/tests/attribute/searchcontextelementiterator/searchcontextelementiterator_test.cpp +++ b/searchlib/src/tests/attribute/searchcontextelementiterator/searchcontextelementiterator_test.cpp @@ -109,7 +109,7 @@ TEST(ElementIteratorTest, require_that_searchcontext) fef::TermFieldMatchData tfmd; SearchContextParams params; - ISearchContext::UP sc = attribute->createSearchContext(std::make_unique<QueryTermSimple>("1", QueryTermSimple::SearchTerm::WORD), params); + ISearchContext::UP sc = attribute->createSearchContext(std::make_unique<QueryTermSimple>("1", QueryTermSimple::Type::WORD), params); SearchContextElementIterator elemIt(sc->createIterator(&tfmd, false), *sc); verifyElementIterator(elemIt); } diff --git a/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp b/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp index 87d9f081ffc..aaae2772687 100644 --- a/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp +++ b/searchlib/src/tests/attribute/stringattribute/stringattribute_test.cpp @@ -386,7 +386,7 @@ testSingleValue(Attribute & svsa, Config &cfg) TEST("testSingleValue") { EXPECT_EQUAL(24u, sizeof(AttributeVector::SearchContext)); - EXPECT_EQUAL(72u, sizeof(SingleValueStringAttribute::StringSingleImplSearchContext)); + EXPECT_EQUAL(56u, sizeof(SingleValueStringAttribute::StringSingleImplSearchContext)); { Config cfg(BasicType::STRING, CollectionType::SINGLE); SingleValueStringAttribute svsa("svsa", cfg); diff --git a/searchlib/src/tests/features/constant/constant_test.cpp b/searchlib/src/tests/features/constant/constant_test.cpp index 140c93125b0..9c8480c1da2 100644 --- a/searchlib/src/tests/features/constant/constant_test.cpp +++ b/searchlib/src/tests/features/constant/constant_test.cpp @@ -45,7 +45,7 @@ struct ExecFixture bool setup() { return test.setup(); } const Value &extractTensor(uint32_t docid) { Value::CREF value = test.resolveObjectFeature(docid); - ASSERT_TRUE(value.get().is_tensor()); + ASSERT_TRUE(value.get().type().has_dimensions()); return value.get(); } const Value &executeTensor(uint32_t docId = 1) { @@ -53,7 +53,7 @@ struct ExecFixture } double extractDouble(uint32_t docid) { Value::CREF value = test.resolveObjectFeature(docid); - ASSERT_TRUE(value.get().is_double()); + ASSERT_TRUE(value.get().type().is_double()); return value.get().as_double(); } double executeDouble(uint32_t docId = 1) { diff --git a/searchlib/src/tests/features/tensor/tensor_test.cpp b/searchlib/src/tests/features/tensor/tensor_test.cpp index 53049c4a385..5d7698822eb 100644 --- a/searchlib/src/tests/features/tensor/tensor_test.cpp +++ b/searchlib/src/tests/features/tensor/tensor_test.cpp @@ -152,7 +152,7 @@ struct ExecFixture } const Value &extractTensor(uint32_t docid) { Value::CREF value = test.resolveObjectFeature(docid); - ASSERT_TRUE(value.get().is_tensor()); + ASSERT_TRUE(value.get().type().has_dimensions()); return value.get(); } const Value &execute(uint32_t docId = 1) { diff --git a/searchlib/src/tests/query/streaming_query_test.cpp b/searchlib/src/tests/query/streaming_query_test.cpp index dbd186fdcb5..5ce34cfcc3f 100644 --- a/searchlib/src/tests/query/streaming_query_test.cpp +++ b/searchlib/src/tests/query/streaming_query_test.cpp @@ -11,6 +11,7 @@ using namespace search; using namespace search::query; using namespace search::streaming; +using TermType = QueryTerm::Type; void assertHit(const Hit & h, size_t expWordpos, size_t expContext, int32_t weight) { EXPECT_EQUAL(h.wordpos(), expWordpos); @@ -23,201 +24,255 @@ TEST("testQueryLanguage") { int64_t ia(0), ib(0); double da(0), db(0); - QueryTerm q(factory.create(), "7", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, 7); - EXPECT_EQUAL(ib, 7); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, 7); - EXPECT_EQUAL(db, 7); - - q = QueryTerm(factory.create(), "-7", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, -7); - EXPECT_EQUAL(ib, -7); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -7); - EXPECT_EQUAL(db, -7); - - q = QueryTerm(factory.create(), "7.5", "index", QueryTerm::WORD); - EXPECT_TRUE(!q.getAsIntegerTerm(ia, ib)); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, 7.5); - EXPECT_EQUAL(db, 7.5); - - q = QueryTerm(factory.create(), "-7.5", "index", QueryTerm::WORD); - EXPECT_TRUE(!q.getAsIntegerTerm(ia, ib)); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -7.5); - EXPECT_EQUAL(db, -7.5); - - q = QueryTerm(factory.create(), "<7", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); - EXPECT_EQUAL(ib, 6); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); - EXPECT_LESS(db, 7); - EXPECT_GREATER(db, 6.99); - - q = QueryTerm(factory.create(), "[;7]", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); - EXPECT_EQUAL(ib, 7); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); - EXPECT_EQUAL(db, 7); - - q = QueryTerm(factory.create(), ">7", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, 8); - EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_GREATER(da, 7); - EXPECT_LESS(da, 7.01); - EXPECT_EQUAL(db, std::numeric_limits<double>::max()); - - q = QueryTerm(factory.create(), "[7;]", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, 7); - EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, 7); - EXPECT_EQUAL(db, std::numeric_limits<double>::max()); - - q = QueryTerm(factory.create(), "[-7;7]", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, -7); - EXPECT_EQUAL(ib, 7); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -7); - EXPECT_EQUAL(db, 7); - - q = QueryTerm(factory.create(), "[-7.1;7.1]", "index", QueryTerm::WORD); - EXPECT_FALSE(q.getAsIntegerTerm(ia, ib)); // This is dubious and perhaps a regression. - EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); - EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -7.1); - EXPECT_EQUAL(db, 7.1); - - q = QueryTerm(factory.create(), "[500.0;1.7976931348623157E308]", "index", QueryTerm::WORD); - EXPECT_FALSE(q.getAsIntegerTerm(ia, ib)); // This is dubious and perhaps a regression. - EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); - EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, 500.0); - EXPECT_EQUAL(db, std::numeric_limits<double>::max()); + { + QueryTerm q(factory.create(), "7", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, 7); + EXPECT_EQUAL(ib, 7); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, 7); + EXPECT_EQUAL(db, 7); + } + + { + QueryTerm q(factory.create(), "-7", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, -7); + EXPECT_EQUAL(ib, -7); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -7); + EXPECT_EQUAL(db, -7); + } + + { + QueryTerm q(factory.create(), "7.5", "index", TermType::WORD); + EXPECT_TRUE(!q.getAsIntegerTerm(ia, ib)); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, 7.5); + EXPECT_EQUAL(db, 7.5); + } + + { + QueryTerm q(factory.create(), "-7.5", "index", TermType::WORD); + EXPECT_TRUE(!q.getAsIntegerTerm(ia, ib)); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -7.5); + EXPECT_EQUAL(db, -7.5); + } + + { + QueryTerm q(factory.create(), "<7", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); + EXPECT_EQUAL(ib, 6); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); + EXPECT_LESS(db, 7); + EXPECT_GREATER(db, 6.99); + } + + { + QueryTerm q(factory.create(), "[;7]", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); + EXPECT_EQUAL(ib, 7); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); + EXPECT_EQUAL(db, 7); + } + + { + QueryTerm q(factory.create(), ">7", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, 8); + EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_GREATER(da, 7); + EXPECT_LESS(da, 7.01); + EXPECT_EQUAL(db, std::numeric_limits<double>::max()); + } + + { + QueryTerm q(factory.create(), "[7;]", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, 7); + EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, 7); + EXPECT_EQUAL(db, std::numeric_limits<double>::max()); + } + + { + QueryTerm q(factory.create(), "[-7;7]", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, -7); + EXPECT_EQUAL(ib, 7); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -7); + EXPECT_EQUAL(db, 7); + } + + { + QueryTerm q(factory.create(), "[-7.1;7.1]", "index", TermType::WORD); + EXPECT_FALSE(q.getAsIntegerTerm(ia, ib)); // This is dubious and perhaps a regression. + EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); + EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -7.1); + EXPECT_EQUAL(db, 7.1); + } + + { + QueryTerm q(factory.create(), "[500.0;1.7976931348623157E308]", "index", TermType::WORD); + EXPECT_FALSE(q.getAsIntegerTerm(ia, ib)); // This is dubious and perhaps a regression. + EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); + EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, 500.0); + EXPECT_EQUAL(db, std::numeric_limits<double>::max()); + } const double minusSeven(-7), seven(7); - q = QueryTerm(factory.create(), "<-7;7]", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, -6); - EXPECT_EQUAL(ib, 7); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, std::nextafterf(minusSeven, seven)); - EXPECT_EQUAL(db, seven); - - q = QueryTerm(factory.create(), "<-7;7>", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, -6); - EXPECT_EQUAL(ib, 6); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, std::nextafterf(minusSeven, seven)); - EXPECT_EQUAL(db, std::nextafterf(seven, minusSeven)); - - q = QueryTerm(factory.create(), "<1;2>", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, 2); - EXPECT_EQUAL(ib, 1); - - q = QueryTerm(factory.create(), "[-7;7>", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, -7); - EXPECT_EQUAL(ib, 6); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, minusSeven); - EXPECT_EQUAL(db, std::nextafterf(seven, minusSeven)); - - q = QueryTerm(factory.create(), "<-7", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); - EXPECT_EQUAL(ib, -8); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); - EXPECT_LESS(db, -7); - EXPECT_GREATER(db, -7.01); - - q = QueryTerm(factory.create(), "[;-7]", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); - EXPECT_EQUAL(ib, -7); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); - EXPECT_EQUAL(db, -7); - - q = QueryTerm(factory.create(), "<;-7]", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); - EXPECT_EQUAL(ib, -7); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); - EXPECT_EQUAL(db, -7); - - q = QueryTerm(factory.create(), ">-7", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, -6); - EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_GREATER(da, -7); - EXPECT_LESS(da, -6.99); - EXPECT_EQUAL(db, std::numeric_limits<double>::max()); - - q = QueryTerm(factory.create(), "[-7;]", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, -7); - EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -7); - EXPECT_EQUAL(db, std::numeric_limits<double>::max()); - - q = QueryTerm(factory.create(), "[-7;>", "index", QueryTerm::WORD); - EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); - EXPECT_EQUAL(ia, -7); - EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); - EXPECT_TRUE(q.getAsDoubleTerm(da, db)); - EXPECT_EQUAL(da, -7); - EXPECT_EQUAL(db, std::numeric_limits<double>::max()); - - q = QueryTerm(factory.create(), "a", "index", QueryTerm::WORD); - EXPECT_TRUE(!q.getAsIntegerTerm(ia, ib)); - EXPECT_TRUE(!q.getAsDoubleTerm(da, db)); - - q = QueryTerm(factory.create(), "word", "index", QueryTerm::WORD); - EXPECT_TRUE(!q.isPrefix()); - EXPECT_TRUE(!q.isSubstring()); - EXPECT_TRUE(!q.isSuffix()); - - q = QueryTerm(factory.create(), "prefix", "index", QueryTerm::PREFIXTERM); - EXPECT_TRUE(q.isPrefix()); - EXPECT_TRUE(!q.isSubstring()); - EXPECT_TRUE(!q.isSuffix()); - - q = QueryTerm(factory.create(), "substring", "index", QueryTerm::SUBSTRINGTERM); - EXPECT_TRUE(!q.isPrefix()); - EXPECT_TRUE(q.isSubstring()); - EXPECT_TRUE(!q.isSuffix()); - - q = QueryTerm(factory.create(), "suffix", "index", QueryTerm::SUFFIXTERM); - EXPECT_TRUE(!q.isPrefix()); - EXPECT_TRUE(!q.isSubstring()); - EXPECT_TRUE(q.isSuffix()); - - q = QueryTerm(factory.create(), "regexp", "index", QueryTerm::REGEXP); - EXPECT_TRUE(!q.isPrefix()); - EXPECT_TRUE(!q.isSubstring()); - EXPECT_TRUE(!q.isSuffix()); - EXPECT_TRUE(q.isRegex()); + { + QueryTerm q(factory.create(), "<-7;7]", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, -6); + EXPECT_EQUAL(ib, 7); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, std::nextafterf(minusSeven, seven)); + EXPECT_EQUAL(db, seven); + } + + { + QueryTerm q(factory.create(), "<-7;7>", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, -6); + EXPECT_EQUAL(ib, 6); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, std::nextafterf(minusSeven, seven)); + EXPECT_EQUAL(db, std::nextafterf(seven, minusSeven)); + } + + { + QueryTerm q(factory.create(), "<1;2>", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, 2); + EXPECT_EQUAL(ib, 1); + } + + { + QueryTerm q(factory.create(), "[-7;7>", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, -7); + EXPECT_EQUAL(ib, 6); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, minusSeven); + EXPECT_EQUAL(db, std::nextafterf(seven, minusSeven)); + } + + { + QueryTerm q(factory.create(), "<-7", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); + EXPECT_EQUAL(ib, -8); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); + EXPECT_LESS(db, -7); + EXPECT_GREATER(db, -7.01); + } + + { + QueryTerm q(factory.create(), "[;-7]", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); + EXPECT_EQUAL(ib, -7); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); + EXPECT_EQUAL(db, -7); + } + + { + QueryTerm q(factory.create(), "<;-7]", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, std::numeric_limits<int64_t>::min()); + EXPECT_EQUAL(ib, -7); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -std::numeric_limits<double>::max()); + EXPECT_EQUAL(db, -7); + } + + { + QueryTerm q(factory.create(), ">-7", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, -6); + EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_GREATER(da, -7); + EXPECT_LESS(da, -6.99); + EXPECT_EQUAL(db, std::numeric_limits<double>::max()); + } + + { + QueryTerm q(factory.create(), "[-7;]", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, -7); + EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -7); + EXPECT_EQUAL(db, std::numeric_limits<double>::max()); + } + + { + QueryTerm q(factory.create(), "[-7;>", "index", TermType::WORD); + EXPECT_TRUE(q.getAsIntegerTerm(ia, ib)); + EXPECT_EQUAL(ia, -7); + EXPECT_EQUAL(ib, std::numeric_limits<int64_t>::max()); + EXPECT_TRUE(q.getAsDoubleTerm(da, db)); + EXPECT_EQUAL(da, -7); + EXPECT_EQUAL(db, std::numeric_limits<double>::max()); + } + + { + QueryTerm q(factory.create(), "a", "index", TermType::WORD); + EXPECT_TRUE(!q.getAsIntegerTerm(ia, ib)); + EXPECT_TRUE(!q.getAsDoubleTerm(da, db)); + } + + { + QueryTerm q(factory.create(), "word", "index", TermType::WORD); + EXPECT_TRUE(!q.isPrefix()); + EXPECT_TRUE(!q.isSubstring()); + EXPECT_TRUE(!q.isSuffix()); + } + + { + QueryTerm q(factory.create(), "prefix", "index", TermType::PREFIXTERM); + EXPECT_TRUE(q.isPrefix()); + EXPECT_TRUE(!q.isSubstring()); + EXPECT_TRUE(!q.isSuffix()); + } + + { + QueryTerm q(factory.create(), "substring", "index", TermType::SUBSTRINGTERM); + EXPECT_TRUE(!q.isPrefix()); + EXPECT_TRUE(q.isSubstring()); + EXPECT_TRUE(!q.isSuffix()); + } + + { + QueryTerm q(factory.create(), "suffix", "index", TermType::SUFFIXTERM); + EXPECT_TRUE(!q.isPrefix()); + EXPECT_TRUE(!q.isSubstring()); + EXPECT_TRUE(q.isSuffix()); + } + + { + QueryTerm q(factory.create(), "regexp", "index", TermType::REGEXP); + EXPECT_TRUE(!q.isPrefix()); + EXPECT_TRUE(!q.isSubstring()); + EXPECT_TRUE(!q.isSuffix()); + EXPECT_TRUE(q.isRegex()); + } } class AllowRewrite : public QueryNodeResultFactory @@ -426,7 +481,7 @@ TEST("testHit") { } void assertInt8Range(const std::string &term, bool expAdjusted, int64_t expLow, int64_t expHigh) { - QueryTermSimple q(term, QueryTermSimple::WORD); + QueryTermSimple q(term, TermType::WORD); QueryTermSimple::RangeResult<int8_t> res = q.getRange<int8_t>(); EXPECT_EQUAL(true, res.valid); EXPECT_EQUAL(expAdjusted, res.adjusted); @@ -435,7 +490,7 @@ void assertInt8Range(const std::string &term, bool expAdjusted, int64_t expLow, } void assertInt32Range(const std::string &term, bool expAdjusted, int64_t expLow, int64_t expHigh) { - QueryTermSimple q(term, QueryTermSimple::WORD); + QueryTermSimple q(term, TermType::WORD); QueryTermSimple::RangeResult<int32_t> res = q.getRange<int32_t>(); EXPECT_EQUAL(true, res.valid); EXPECT_EQUAL(expAdjusted, res.adjusted); @@ -444,7 +499,7 @@ void assertInt32Range(const std::string &term, bool expAdjusted, int64_t expLow, } void assertInt64Range(const std::string &term, bool expAdjusted, int64_t expLow, int64_t expHigh) { - QueryTermSimple q(term, QueryTermSimple::WORD); + QueryTermSimple q(term, TermType::WORD); QueryTermSimple::RangeResult<int64_t> res = q.getRange<int64_t>(); EXPECT_EQUAL(true, res.valid); EXPECT_EQUAL(expAdjusted, res.adjusted); @@ -547,7 +602,7 @@ TEST("require that ascending range can be specified with limit only") { double high_double = 0.0; QueryNodeResultFactory eqnr; - QueryTerm ascending_query(eqnr.create(), "[;;500]", "index", QueryTerm::WORD); + QueryTerm ascending_query(eqnr.create(), "[;;500]", "index", TermType::WORD); EXPECT_TRUE(ascending_query.getAsIntegerTerm(low_integer, high_integer)); EXPECT_TRUE(ascending_query.getAsDoubleTerm(low_double, high_double)); @@ -565,7 +620,7 @@ TEST("require that descending range can be specified with limit only") { double high_double = 0.0; QueryNodeResultFactory eqnr; - QueryTerm descending_query(eqnr.create(), "[;;-500]", "index", QueryTerm::WORD); + QueryTerm descending_query(eqnr.create(), "[;;-500]", "index", TermType::WORD); EXPECT_TRUE(descending_query.getAsIntegerTerm(low_integer, high_integer)); EXPECT_TRUE(descending_query.getAsDoubleTerm(low_double, high_double)); @@ -578,7 +633,7 @@ TEST("require that descending range can be specified with limit only") { TEST("require that correctly specified diversity can be parsed") { QueryNodeResultFactory eqnr; - QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78]", "index", QueryTerm::WORD); + QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78]", "index", TermType::WORD); EXPECT_TRUE(descending_query.isValid()); EXPECT_EQUAL(-500, descending_query.getRangeLimit()); EXPECT_EQUAL("ab56", descending_query.getDiversityAttribute()); @@ -589,7 +644,7 @@ TEST("require that correctly specified diversity can be parsed") { TEST("require that correctly specified diversity with cutoff groups can be parsed") { QueryNodeResultFactory eqnr; - QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;93]", "index", QueryTerm::WORD); + QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;93]", "index", TermType::WORD); EXPECT_TRUE(descending_query.isValid()); EXPECT_EQUAL(-500, descending_query.getRangeLimit()); EXPECT_EQUAL("ab56", descending_query.getDiversityAttribute()); @@ -600,7 +655,7 @@ TEST("require that correctly specified diversity with cutoff groups can be parse TEST("require that correctly specified diversity with cutoff groups can be parsed") { QueryNodeResultFactory eqnr; - QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;13]", "index", QueryTerm::WORD); + QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;13]", "index", TermType::WORD); EXPECT_TRUE(descending_query.isValid()); EXPECT_EQUAL(-500, descending_query.getRangeLimit()); EXPECT_EQUAL("ab56", descending_query.getDiversityAttribute()); @@ -611,7 +666,7 @@ TEST("require that correctly specified diversity with cutoff groups can be parse TEST("require that correctly specified diversity with incorrect cutoff groups can be parsed") { QueryNodeResultFactory eqnr; - QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;a13.9]", "index", QueryTerm::WORD); + QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;a13.9]", "index", TermType::WORD); EXPECT_TRUE(descending_query.isValid()); EXPECT_EQUAL(-500, descending_query.getRangeLimit()); EXPECT_EQUAL("ab56", descending_query.getDiversityAttribute()); @@ -622,7 +677,7 @@ TEST("require that correctly specified diversity with incorrect cutoff groups ca TEST("require that correctly specified diversity with cutoff strategy can be parsed") { QueryNodeResultFactory eqnr; - QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;93;anything but strict]", "index", QueryTerm::WORD); + QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;93;anything but strict]", "index", TermType::WORD); EXPECT_TRUE(descending_query.isValid()); EXPECT_EQUAL(-500, descending_query.getRangeLimit()); EXPECT_EQUAL("ab56", descending_query.getDiversityAttribute()); @@ -633,7 +688,7 @@ TEST("require that correctly specified diversity with cutoff strategy can be par TEST("require that correctly specified diversity with strict cutoff strategy can be parsed") { QueryNodeResultFactory eqnr; - QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;93;strict]", "index", QueryTerm::WORD); + QueryTerm descending_query(eqnr.create(), "[;;-500;ab56;78;93;strict]", "index", TermType::WORD); EXPECT_TRUE(descending_query.isValid()); EXPECT_EQUAL(-500, descending_query.getRangeLimit()); EXPECT_EQUAL("ab56", descending_query.getDiversityAttribute()); @@ -644,12 +699,12 @@ TEST("require that correctly specified diversity with strict cutoff strategy can TEST("require that incorrectly specified diversity can be parsed") { QueryNodeResultFactory eqnr; - QueryTerm descending_query(eqnr.create(), "[;;-500;ab56]", "index", QueryTerm::WORD); + QueryTerm descending_query(eqnr.create(), "[;;-500;ab56]", "index", TermType::WORD); EXPECT_FALSE(descending_query.isValid()); } TEST("require that we do not break the stack on bad query") { - QueryTermSimple term("<form><iframe+	 +src=\\\"javascript:alert(1)\\\" 	;>", QueryTerm::WORD); + QueryTermSimple term("<form><iframe+	 +src=\\\"javascript:alert(1)\\\" 	;>", TermType::WORD); EXPECT_FALSE(term.isValid()); } @@ -731,5 +786,10 @@ TEST("testSameElementEvaluate") { EXPECT_TRUE(sameElem->evaluate()); } +TEST("Control the size of query terms") { + EXPECT_EQUAL(104u, sizeof(QueryTermSimple)); + EXPECT_EQUAL(120u, sizeof(QueryTermUCS4)); + EXPECT_EQUAL(264u, sizeof(QueryTerm)); +} TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/searchlib/src/tests/queryeval/queryeval.cpp b/searchlib/src/tests/queryeval/queryeval.cpp index 23647b208a3..f82bfabb7c3 100644 --- a/searchlib/src/tests/queryeval/queryeval.cpp +++ b/searchlib/src/tests/queryeval/queryeval.cpp @@ -350,7 +350,7 @@ public: _a.update(docId, 1); } _a.commit(); - _sc = _a.getSearch(std::make_unique<search::QueryTermSimple>("1", search::QueryTermSimple::WORD), + _sc = _a.getSearch(std::make_unique<search::QueryTermSimple>("1", search::QueryTermSimple::Type::WORD), SearchContextParams().useBitVector(true)); } SearchIterator::UP diff --git a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp index 62b598d730b..45cffde8a8c 100644 --- a/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp +++ b/searchlib/src/vespa/searchlib/aggregation/aggregation.cpp @@ -91,7 +91,7 @@ AggregationResult::Configure::execute(vespalib::Identifiable &obj) AggregationResult & AggregationResult::setExpression(ExpressionNode::UP expr) { - _expressionTree.reset(new ExpressionTree(std::move(expr))); + _expressionTree = std::make_shared<ExpressionTree>(std::move(expr)); prepare(&_expressionTree->getResult(), false); return *this; } diff --git a/searchlib/src/vespa/searchlib/aggregation/group.h b/searchlib/src/vespa/searchlib/aggregation/group.h index f6b6bc732af..5b425de24e6 100644 --- a/searchlib/src/vespa/searchlib/aggregation/group.h +++ b/searchlib/src/vespa/searchlib/aggregation/group.h @@ -170,8 +170,8 @@ public: Group(); Group(const Group & rhs); Group & operator =(const Group & rhs); - Group(Group &&) = default; - Group & operator = (Group &&) = default; + Group(Group &&) noexcept = default; + Group & operator = (Group &&) noexcept = default; ~Group(); int cmpId(const Group &rhs) const { return _id->cmpFast(*rhs._id); } diff --git a/searchlib/src/vespa/searchlib/aggregation/groupinglevel.h b/searchlib/src/vespa/searchlib/aggregation/groupinglevel.h index 16d004f807d..ad53ff20fc2 100644 --- a/searchlib/src/vespa/searchlib/aggregation/groupinglevel.h +++ b/searchlib/src/vespa/searchlib/aggregation/groupinglevel.h @@ -77,8 +77,8 @@ private: vespalib::CloneablePtr<Grouper> _grouper; public: GroupingLevel(); - GroupingLevel(GroupingLevel &&) = default; - GroupingLevel & operator =(GroupingLevel &&) = default; + GroupingLevel(GroupingLevel &&) noexcept = default; + GroupingLevel & operator =(GroupingLevel &&) noexcept = default; GroupingLevel(const GroupingLevel &); GroupingLevel & operator =(const GroupingLevel &); ~GroupingLevel(); diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 70a59f1575a..5ba38d803c8 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -591,7 +591,7 @@ public: void visit(RangeTerm &n) override { const string stack = StackDumpCreator::create(n); const string term = queryeval::termAsString(n); - QueryTermSimple parsed_term(term, QueryTermSimple::WORD); + QueryTermSimple parsed_term(term, QueryTermSimple::Type::WORD); if (parsed_term.getMaxPerGroup() > 0) { const IAttributeVector *diversity(getRequestContext().getAttribute(parsed_term.getDiversityAttribute())); if (check_valid_diversity_attr(diversity)) { @@ -636,9 +636,9 @@ public: extractTerm(const query::Node &node, bool isInteger) { vespalib::string term = queryeval::termAsString(node); if (isInteger) { - return std::make_unique<QueryTermSimple>(term, QueryTermSimple::WORD); + return std::make_unique<QueryTermSimple>(term, QueryTermSimple::Type::WORD); } - return std::make_unique<QueryTermUCS4>(term, QueryTermSimple::WORD); + return std::make_unique<QueryTermUCS4>(term, QueryTermSimple::Type::WORD); } template <typename WS, typename NODE> diff --git a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp index f2e2f8271de..991f1f03ee7 100644 --- a/searchlib/src/vespa/searchlib/attribute/configconverter.cpp +++ b/searchlib/src/vespa/searchlib/attribute/configconverter.cpp @@ -11,6 +11,7 @@ namespace { using search::attribute::CollectionType; using search::attribute::BasicType; using vespalib::eval::ValueType; +using vespalib::eval::CellType; typedef std::map<AttributesConfig::Attribute::Datatype, BasicType::Type> DataTypeMap; typedef std::map<AttributesConfig::Attribute::Collectiontype, CollectionType::Type> CollectionTypeMap; @@ -102,7 +103,7 @@ ConfigConverter::convert(const AttributesConfig::Attribute & cfg) if (!cfg.tensortype.empty()) { retval.setTensorType(ValueType::from_spec(cfg.tensortype)); } else { - retval.setTensorType(ValueType::tensor_type({})); + retval.setTensorType(ValueType::double_type()); } } return retval; diff --git a/searchlib/src/vespa/searchlib/attribute/enum_store_dictionary.cpp b/searchlib/src/vespa/searchlib/attribute/enum_store_dictionary.cpp index 19d30317c7b..ed16dc2d8d8 100644 --- a/searchlib/src/vespa/searchlib/attribute/enum_store_dictionary.cpp +++ b/searchlib/src/vespa/searchlib/attribute/enum_store_dictionary.cpp @@ -122,7 +122,7 @@ EnumStoreDictionary<DictionaryT>::find_matching_enums(const vespalib::datastore: { std::vector<IEnumStore::EnumHandle> result; auto itr = this->_dict.getFrozenView().find(Index(), cmp); - while (itr.valid() && !cmp(Index(), itr.getKey())) { + while (itr.valid() && !cmp.less(Index(), itr.getKey())) { result.push_back(itr.getKey().ref()); ++itr; } @@ -169,7 +169,7 @@ UniqueStoreAddResult EnumStoreFoldedDictionary::add(const EntryComparator& comp, std::function<EntryRef(void)> insertEntry) { auto it = _dict.lowerBound(EntryRef(), comp); - if (it.valid() && !comp(EntryRef(), it.getKey())) { + if (it.valid() && !comp.less(EntryRef(), it.getKey())) { // Entry already exists return UniqueStoreAddResult(it.getKey(), false); } @@ -177,7 +177,7 @@ EnumStoreFoldedDictionary::add(const EntryComparator& comp, std::function<EntryR _dict.insert(it, newRef, EntryRef().ref()); // Maybe move posting list reference from next entry ++it; - if (it.valid() && EntryRef(it.getData()).valid() && !(*_folded_compare)(newRef, it.getKey())) { + if (it.valid() && EntryRef(it.getData()).valid() && !_folded_compare->less(newRef, it.getKey())) { EntryRef posting_list_ref(it.getData()); _dict.thaw(it); it.writeData(EntryRef().ref()); @@ -198,7 +198,7 @@ EnumStoreFoldedDictionary::remove(const EntryComparator& comp, EntryRef ref) _dict.remove(it); // Maybe copy posting list reference to next entry if (posting_list_ref.valid()) { - if (it.valid() && !EntryRef(it.getData()).valid() && !(*_folded_compare)(ref, it.getKey())) { + if (it.valid() && !EntryRef(it.getData()).valid() && !_folded_compare->less(ref, it.getKey())) { this->_dict.thaw(it); it.writeData(posting_list_ref.ref()); } else { diff --git a/searchlib/src/vespa/searchlib/attribute/enumattribute.hpp b/searchlib/src/vespa/searchlib/attribute/enumattribute.hpp index 58a220922fd..0e14c567345 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/enumattribute.hpp @@ -33,7 +33,7 @@ void EnumAttribute<B>::load_enum_store(LoadedVector& loaded) EnumIndex index = loader.insert(value.getValue(), value._pidx.ref()); for (size_t i(0), m(loaded.size()); i < m; ++i, loaded.next()) { value = loaded.read(); - if (!EnumStore::ComparatorType::equal(prev, value.getValue())) { + if (!EnumStore::ComparatorType::equal_helper(prev, value.getValue())) { loader.set_ref_count_for_last_value(prevRefCount); index = loader.insert(value.getValue(), value._pidx.ref()); prev = value.getValue(); diff --git a/searchlib/src/vespa/searchlib/attribute/enumcomparator.cpp b/searchlib/src/vespa/searchlib/attribute/enumcomparator.cpp index 115eaa90841..a428ac77d87 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumcomparator.cpp +++ b/searchlib/src/vespa/searchlib/attribute/enumcomparator.cpp @@ -26,10 +26,9 @@ EnumStoreComparator<EntryT>::EnumStoreComparator(const DataStoreType& data_store template <typename EntryT> bool -EnumStoreComparator<EntryT>::equal(const EntryT& lhs, const EntryT& rhs) +EnumStoreComparator<EntryT>::equal_helper(const EntryT& lhs, const EntryT& rhs) { - return !vespalib::datastore::UniqueStoreComparatorHelper<EntryT>::less(lhs, rhs) && - !vespalib::datastore::UniqueStoreComparatorHelper<EntryT>::less(rhs, lhs); + return vespalib::datastore::UniqueStoreComparatorHelper<EntryT>::equal(lhs, rhs); } EnumStoreStringComparator::EnumStoreStringComparator(const DataStoreType& data_store) diff --git a/searchlib/src/vespa/searchlib/attribute/enumcomparator.h b/searchlib/src/vespa/searchlib/attribute/enumcomparator.h index aaf47987a5d..0215053ba3a 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumcomparator.h +++ b/searchlib/src/vespa/searchlib/attribute/enumcomparator.h @@ -21,7 +21,7 @@ public: EnumStoreComparator(const DataStoreType& data_store, const EntryT& fallback_value, bool prefix = false); EnumStoreComparator(const DataStoreType& data_store); - static bool equal(const EntryT& lhs, const EntryT& rhs); + static bool equal_helper(const EntryT& lhs, const EntryT& rhs); }; /** @@ -51,9 +51,12 @@ public: return compare(lhs, rhs) == 0; } - bool operator() (const vespalib::datastore::EntryRef lhs, const vespalib::datastore::EntryRef rhs) const override { + bool less(const vespalib::datastore::EntryRef lhs, const vespalib::datastore::EntryRef rhs) const override { return compare(get(lhs), get(rhs)) < 0; } + bool equal(const vespalib::datastore::EntryRef lhs, const vespalib::datastore::EntryRef rhs) const override { + return compare(get(lhs), get(rhs)) == 0; + } }; @@ -95,12 +98,15 @@ public: return compare_folded(lhs, rhs) == 0; } - bool operator() (const vespalib::datastore::EntryRef lhs, const vespalib::datastore::EntryRef rhs) const override { + bool less(const vespalib::datastore::EntryRef lhs, const vespalib::datastore::EntryRef rhs) const override { if (use_prefix()) { return compare_folded_prefix(get(lhs), get(rhs), _prefix_len) < 0; } return compare_folded(get(lhs), get(rhs)) < 0; } + bool equal(const vespalib::datastore::EntryRef lhs, const vespalib::datastore::EntryRef rhs) const override { + return compare_folded(get(lhs), get(rhs)) == 0; + } }; extern template class EnumStoreComparator<int8_t>; diff --git a/searchlib/src/vespa/searchlib/attribute/enumstore.cpp b/searchlib/src/vespa/searchlib/attribute/enumstore.cpp index 0ad8a7d7c5b..ecd55138df1 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumstore.cpp +++ b/searchlib/src/vespa/searchlib/attribute/enumstore.cpp @@ -35,7 +35,7 @@ EnumStoreT<const char*>::load_unique_value(const void* src, if (prev_idx.valid()) { auto cmp = make_comparator(value); - assert(cmp(prev_idx, Index())); + assert(cmp.less(prev_idx, Index())); } return sz; } diff --git a/searchlib/src/vespa/searchlib/attribute/enumstore.hpp b/searchlib/src/vespa/searchlib/attribute/enumstore.hpp index 54c756ee437..c1098f079e6 100644 --- a/searchlib/src/vespa/searchlib/attribute/enumstore.hpp +++ b/searchlib/src/vespa/searchlib/attribute/enumstore.hpp @@ -66,7 +66,7 @@ EnumStoreT<EntryT>::load_unique_value(const void* src, size_t available, Index& if (prev_idx.valid()) { auto cmp = make_comparator(*value); - assert(cmp(prev_idx, Index())); + assert(cmp.less(prev_idx, Index())); } return sizeof(EntryType); } @@ -159,8 +159,8 @@ bool EnumStoreT<EntryT>::is_folded_change(Index idx1, Index idx2) const { auto cmp = make_folded_comparator(); - assert(!cmp(idx2, idx1)); - return cmp(idx1, idx2); + assert(!cmp.less(idx2, idx1)); + return cmp.less(idx1, idx2); } template <typename EntryT> diff --git a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp index 1fd1cd09bea..ed19e6ae0ba 100644 --- a/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multinumericpostattribute.hpp @@ -103,7 +103,7 @@ MultiValueNumericPostingAttribute<B, M>::DocumentWeightAttributeAdapter::lookup( auto comp = self._enumStore.make_comparator(int_term); dictItr.lower_bound(dictionary_snapshot, EnumIndex(), comp); - if (dictItr.valid() && !comp(EnumIndex(), dictItr.getKey())) { + if (dictItr.valid() && !comp.less(EnumIndex(), dictItr.getKey())) { vespalib::datastore::EntryRef pidx(dictItr.getData()); if (pidx.valid()) { const PostingList &plist = self.getPostingList(); diff --git a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp index 25d7858ea81..f97a4e281a8 100644 --- a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp @@ -115,7 +115,7 @@ MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::lookup( auto comp = self._enumStore.make_folded_comparator(term.c_str()); dictItr.lower_bound(dictionary_snapshot, enumstore::Index(), comp); - if (dictItr.valid() && !comp(enumstore::Index(), dictItr.getKey())) { + if (dictItr.valid() && !comp.less(enumstore::Index(), dictItr.getKey())) { vespalib::datastore::EntryRef pidx(dictItr.getData()); if (pidx.valid()) { const PostingList &plist = self.getPostingList(); @@ -134,7 +134,7 @@ MultiValueStringPostingAttributeT<B, T>::DocumentWeightAttributeAdapter::collect Dictionary::ConstIterator dictItr(vespalib::btree::BTreeNode::Ref(), dictionary.getAllocator()); auto comp = self._enumStore.make_folded_comparator(); dictItr.lower_bound(dictionary_snapshot, enum_idx, comp); - while (dictItr.valid() && !comp(enum_idx, dictItr.getKey())) { + while (dictItr.valid() && !comp.less(enum_idx, dictItr.getKey())) { callback(dictItr.getKey()); ++dictItr; } diff --git a/searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp b/searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp index f5b4accfedc..02e94b8281e 100644 --- a/searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp +++ b/searchlib/src/vespa/searchlib/attribute/postinglistattribute.cpp @@ -232,7 +232,7 @@ handle_load_posting_lists(LoadedVector& loaded) LoadedValueType prev = value.getValue(); for (size_t i(0), m(loaded.size()); i < m; i++, loaded.next()) { value = loaded.read(); - if (FoldedComparatorType::equal(prev, value.getValue())) { + if (FoldedComparatorType::equal_helper(prev, value.getValue())) { // for single value attributes loaded[numDocs] is used // for default value but we don't want to add an // invalid docId to the posting list. diff --git a/searchlib/src/vespa/searchlib/attribute/postinglistattribute.h b/searchlib/src/vespa/searchlib/attribute/postinglistattribute.h index 5f2eb02ecd2..eab8d1576fd 100644 --- a/searchlib/src/vespa/searchlib/attribute/postinglistattribute.h +++ b/searchlib/src/vespa/searchlib/attribute/postinglistattribute.h @@ -28,7 +28,7 @@ public: _cmp(cmp) { } - bool operator<(const EnumPostingPair &rhs) const { return (*_cmp)(_idx, rhs._idx); } + bool operator<(const EnumPostingPair &rhs) const { return _cmp->less(_idx, rhs._idx); } IEnumStore::Index getEnumIdx() const { return _idx; } }; diff --git a/searchlib/src/vespa/searchlib/attribute/postinglistsearchcontext.cpp b/searchlib/src/vespa/searchlib/attribute/postinglistsearchcontext.cpp index 972baa267ce..1c9b8dbf7b2 100644 --- a/searchlib/src/vespa/searchlib/attribute/postinglistsearchcontext.cpp +++ b/searchlib/src/vespa/searchlib/attribute/postinglistsearchcontext.cpp @@ -46,7 +46,7 @@ PostingListSearchContext::lookupTerm(const vespalib::datastore::EntryComparator { _lowerDictItr.lower_bound(_frozenDictionary.getRoot(), EnumIndex(), comp); _upperDictItr = _lowerDictItr; - if (_upperDictItr.valid() && !comp(EnumIndex(), _upperDictItr.getKey())) { + if (_upperDictItr.valid() && !comp.less(EnumIndex(), _upperDictItr.getKey())) { ++_upperDictItr; _uniqueValues = 1u; } @@ -59,7 +59,7 @@ PostingListSearchContext::lookupRange(const vespalib::datastore::EntryComparator { _lowerDictItr.lower_bound(_frozenDictionary.getRoot(), EnumIndex(), low); _upperDictItr = _lowerDictItr; - if (_upperDictItr.valid() && !high(EnumIndex(), _upperDictItr.getKey())) { + if (_upperDictItr.valid() && !high.less(EnumIndex(), _upperDictItr.getKey())) { _upperDictItr.seekPast(EnumIndex(), high); } _uniqueValues = _upperDictItr - _lowerDictItr; diff --git a/searchlib/src/vespa/searchlib/attribute/reference.h b/searchlib/src/vespa/searchlib/attribute/reference.h index 8d7e37f585b..426ce9ea314 100644 --- a/searchlib/src/vespa/searchlib/attribute/reference.h +++ b/searchlib/src/vespa/searchlib/attribute/reference.h @@ -29,14 +29,21 @@ public: _revMapIdx() { } - bool operator<(const Reference &rhs) const { + bool operator < (const Reference &rhs) const { return _gid < rhs._gid; } + bool operator == (const Reference &rhs) const { + return _gid == rhs._gid; + } const GlobalId &gid() const { return _gid; } uint32_t lid() const { return _lid; } EntryRef revMapIdx() const { return _revMapIdx; } void setLid(uint32_t targetLid) const { _lid = targetLid; } void setRevMapIdx(EntryRef newRevMapIdx) const { _revMapIdx = newRevMapIdx; } + size_t hash() const noexcept { + GlobalId::hash hasher; + return hasher(_gid); + } }; } diff --git a/searchlib/src/vespa/searchlib/attribute/stringbase.cpp b/searchlib/src/vespa/searchlib/attribute/stringbase.cpp index d64e03c67a4..56a644a68b1 100644 --- a/searchlib/src/vespa/searchlib/attribute/stringbase.cpp +++ b/searchlib/src/vespa/searchlib/attribute/stringbase.cpp @@ -225,13 +225,15 @@ StringAttribute::StringSearchContext::StringSearchContext(QueryTermSimple::UP qT const StringAttribute & toBeSearched) : SearchContext(toBeSearched), _queryTerm(static_cast<QueryTermUCS4 *>(qTerm.release())), - _termUCS4(queryTerm()->getUCS4Term()), + _termUCS4(nullptr), _regex(), _isPrefix(_queryTerm->isPrefix()), _isRegex(_queryTerm->isRegex()) { if (isRegex()) { _regex = vespalib::Regex::from_pattern(_queryTerm->getTerm(), vespalib::Regex::Options::IgnoreCase); + } else { + _queryTerm->term(_termUCS4); } } @@ -261,16 +263,6 @@ StringAttribute::clearDoc(DocId doc) return removed; } -namespace { - -class DirectAccessor { -public: - DirectAccessor() { } - const char * get(const char * v) const { return v; } -}; - -} - bool StringAttribute::applyWeight(DocId doc, const FieldValue & fv, const ArithmeticValueUpdate & wAdjust) { diff --git a/searchlib/src/vespa/searchlib/attribute/stringbase.h b/searchlib/src/vespa/searchlib/attribute/stringbase.h index d72f7002086..b8fef783d58 100644 --- a/searchlib/src/vespa/searchlib/attribute/stringbase.h +++ b/searchlib/src/vespa/searchlib/attribute/stringbase.h @@ -157,7 +157,7 @@ protected: const vespalib::Regex & getRegex() const { return _regex; } private: std::unique_ptr<QueryTermUCS4> _queryTerm; - std::vector<ucs4_t> _termUCS4; + const ucs4_t *_termUCS4; vespalib::Regex _regex; bool _isPrefix; bool _isRegex; diff --git a/searchlib/src/vespa/searchlib/expression/expressiontree.h b/searchlib/src/vespa/searchlib/expression/expressiontree.h index 89ab4de879b..057a7801637 100644 --- a/searchlib/src/vespa/searchlib/expression/expressiontree.h +++ b/searchlib/src/vespa/searchlib/expression/expressiontree.h @@ -10,11 +10,10 @@ namespace document { class DocumentType; class Document; } -namespace search { -namespace attribute { class IAttributeContext; } +namespace search::attribute { class IAttributeContext; } -namespace expression { +namespace search::expression { class AttributeNode; class DocumentAccessorNode; @@ -45,11 +44,11 @@ public: ExpressionTree(const ExpressionNode & root); ExpressionTree(ExpressionNode::UP root); ExpressionTree(const ExpressionTree & rhs); - ExpressionTree(ExpressionTree &&) = default; + ExpressionTree(ExpressionTree &&) noexcept = default; ~ExpressionTree(); ExpressionTree & operator = (ExpressionNode::UP rhs); ExpressionTree & operator = (const ExpressionTree & rhs); - ExpressionTree & operator = (ExpressionTree &&) = default; + ExpressionTree & operator = (ExpressionTree &&) noexcept = default; bool execute(DocId docId, HitRank rank) const; bool execute(const document::Document & doc, HitRank rank) const; @@ -79,5 +78,4 @@ private: ArrayAtLookupList _arrayAtLookupNodes; }; -} // namespace expression -} // namespace search +} diff --git a/searchlib/src/vespa/searchlib/features/attributefeature.cpp b/searchlib/src/vespa/searchlib/features/attributefeature.cpp index 80d9a305ef4..2c1ae7c557a 100644 --- a/searchlib/src/vespa/searchlib/features/attributefeature.cpp +++ b/searchlib/src/vespa/searchlib/features/attributefeature.cpp @@ -530,7 +530,7 @@ AttributeBlueprint::setup(const fef::IIndexEnvironment & env, "the given key of a weighted set attribute, or" "the tensor of a tensor attribute", output_type); const fef::FieldInfo * fInfo = env.getFieldByName(_attrName); - if (_tensorType.is_tensor() || isSingleValueBoolField(*fInfo)) { + if (_tensorType.has_dimensions() || isSingleValueBoolField(*fInfo)) { _numOutputs = 1; } else { describeOutput("weight", "The weight associated with the given key in a weighted set attribute."); @@ -558,7 +558,7 @@ fef::FeatureExecutor & AttributeBlueprint::createExecutor(const fef::IQueryEnvironment &env, vespalib::Stash &stash) const { const IAttributeVector * attribute = lookupAttribute(_attrKey, _attrName, env); - if (_tensorType.is_tensor()) { + if (_tensorType.has_dimensions()) { return createTensorAttributeExecutor(attribute, _attrName, _tensorType, stash); } else { return createAttributeExecutor(_numOutputs, attribute, _attrName, _extra, stash); diff --git a/searchlib/src/vespa/searchlib/features/queryfeature.cpp b/searchlib/src/vespa/searchlib/features/queryfeature.cpp index c6196fcbc7f..60bd77e4883 100644 --- a/searchlib/src/vespa/searchlib/features/queryfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/queryfeature.cpp @@ -137,7 +137,7 @@ createTensorExecutor(const IQueryEnvironment &env, FeatureExecutor & QueryBlueprint::createExecutor(const IQueryEnvironment &env, vespalib::Stash &stash) const { - if (_valueType.is_tensor()) { + if (_valueType.has_dimensions()) { return createTensorExecutor(env, _key, _valueType, stash); } else { std::vector<feature_t> values; diff --git a/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h b/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h index f4a5b0b8d0a..475075671cd 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h +++ b/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h @@ -8,6 +8,7 @@ #include <vespa/vespalib/stllike/string.h> using vespalib::eval::FastValueBuilderFactory; +using vespalib::eval::CellType; namespace search::features { @@ -29,7 +30,7 @@ public: TensorFromAttributeExecutor(const search::attribute::IAttributeVector *attribute, const vespalib::string &dimension) : _attribute(attribute), - _type(vespalib::eval::ValueType::tensor_type({{dimension}})), + _type(vespalib::eval::ValueType::make_type(CellType::DOUBLE, {{dimension}})), _attrBuffer(), _addr_ref(), _tensor() diff --git a/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp b/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp index e4f0a010ae2..76a6e908fcb 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp @@ -20,6 +20,7 @@ using search::attribute::WeightedConstCharContent; using search::attribute::WeightedStringContent; using vespalib::eval::FastValueBuilderFactory; using vespalib::eval::ValueType; +using vespalib::eval::CellType; using search::fef::FeatureType; namespace search { @@ -45,7 +46,7 @@ TensorFromLabelsBlueprint::setup(const search::fef::IIndexEnvironment &env, } describeOutput("tensor", "The tensor created from the given array source (attribute field or query parameter)", - FeatureType::object(ValueType::tensor_type({{_dimension}}))); + FeatureType::object(ValueType::make_type(CellType::DOUBLE, {{_dimension}}))); return validSource; } @@ -60,13 +61,13 @@ createAttributeExecutor(const search::fef::IQueryEnvironment &env, if (attribute == NULL) { LOG(warning, "The attribute vector '%s' was not found in the attribute manager." " Returning empty tensor.", attrName.c_str()); - return ConstantTensorExecutor::createEmpty(ValueType::tensor_type({{dimension}}), stash); + return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash); } if (attribute->getCollectionType() != search::attribute::CollectionType::ARRAY || attribute->isFloatingPointType()) { LOG(warning, "The attribute vector '%s' is NOT of type array of string or integer." " Returning empty tensor.", attrName.c_str()); - return ConstantTensorExecutor::createEmpty(ValueType::tensor_type({{dimension}}), stash); + return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash); } // Note that for array attribute vectors the default weight is 1.0 for all values. // This means we can get the attribute content as weighted content and build @@ -86,7 +87,7 @@ createQueryExecutor(const search::fef::IQueryEnvironment &env, const vespalib::string &queryKey, const vespalib::string &dimension, vespalib::Stash &stash) { - ValueType type = ValueType::tensor_type({{dimension}}); + ValueType type = ValueType::make_type(CellType::DOUBLE, {{dimension}}); search::fef::Property prop = env.getProperties().lookup(queryKey); if (prop.found() && !prop.get().empty()) { std::vector<vespalib::string> vector; @@ -115,7 +116,7 @@ TensorFromLabelsBlueprint::createExecutor(const search::fef::IQueryEnvironment & } else if (_sourceType == QUERY_SOURCE) { return createQueryExecutor(env, _sourceParam, _dimension, stash); } - return ConstantTensorExecutor::createEmpty(ValueType::tensor_type({{_dimension}}), stash); + return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{_dimension}}), stash); } } // namespace features diff --git a/searchlib/src/vespa/searchlib/features/tensor_from_weighted_set_feature.cpp b/searchlib/src/vespa/searchlib/features/tensor_from_weighted_set_feature.cpp index 88309120882..50fab518402 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_from_weighted_set_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/tensor_from_weighted_set_feature.cpp @@ -21,6 +21,7 @@ using search::attribute::WeightedConstCharContent; using search::attribute::WeightedStringContent; using vespalib::eval::FastValueBuilderFactory; using vespalib::eval::ValueType; +using vespalib::eval::CellType; using search::fef::FeatureType; namespace search { @@ -58,7 +59,7 @@ TensorFromWeightedSetBlueprint::setup(const search::fef::IIndexEnvironment &env, } describeOutput("tensor", "The tensor created from the given weighted set source (attribute field or query parameter)", - FeatureType::object(ValueType::tensor_type({{_dimension}}))); + FeatureType::object(ValueType::make_type(CellType::DOUBLE, {{_dimension}}))); return validSource; } @@ -74,13 +75,13 @@ createAttributeExecutor(const search::fef::IQueryEnvironment &env, if (attribute == NULL) { LOG(warning, "The attribute vector '%s' was not found in the attribute manager." " Returning empty tensor.", attrName.c_str()); - return ConstantTensorExecutor::createEmpty(ValueType::tensor_type({{dimension}}), stash); + return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash); } if (attribute->getCollectionType() != search::attribute::CollectionType::WSET || attribute->isFloatingPointType()) { LOG(warning, "The attribute vector '%s' is NOT of type weighted set of string or integer." " Returning empty tensor.", attrName.c_str()); - return ConstantTensorExecutor::createEmpty(ValueType::tensor_type({{dimension}}), stash); + return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash); } if (attribute->isIntegerType()) { // Using WeightedStringContent ensures that the integer values are converted @@ -97,7 +98,7 @@ createQueryExecutor(const search::fef::IQueryEnvironment &env, const vespalib::string &queryKey, const vespalib::string &dimension, vespalib::Stash &stash) { - ValueType type = ValueType::tensor_type({{dimension}}); + ValueType type = ValueType::make_type(CellType::DOUBLE, {{dimension}}); search::fef::Property prop = env.getProperties().lookup(queryKey); if (prop.found() && !prop.get().empty()) { WeightedStringVector vector; @@ -127,7 +128,7 @@ TensorFromWeightedSetBlueprint::createExecutor(const search::fef::IQueryEnvironm } else if (_sourceType == QUERY_SOURCE) { return createQueryExecutor(env, _sourceParam, _dimension, stash); } - return ConstantTensorExecutor::createEmpty(ValueType::tensor_type({{_dimension}}), stash); + return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{_dimension}}), stash); } } // namespace features diff --git a/searchlib/src/vespa/searchlib/query/query_term_simple.cpp b/searchlib/src/vespa/searchlib/query/query_term_simple.cpp index 46f2c6198e5..2c5e977928c 100644 --- a/searchlib/src/vespa/searchlib/query/query_term_simple.cpp +++ b/searchlib/src/vespa/searchlib/query/query_term_simple.cpp @@ -25,7 +25,7 @@ void QueryTermSimple::visitMembers(vespalib::ObjectVisitor & visitor) const { visit(visitor, "term", _term); - visit(visitor, "type", _type); + visit(visitor, "type", static_cast<uint32_t>(_type)); } template <typename N> @@ -189,17 +189,6 @@ bool QueryTermSimple::getAsDoubleTerm(double & lower, double & upper) const return getAsNumericTerm(lower, upper, DoubleDecoder()); } -QueryTermSimple::QueryTermSimple() : - _type(WORD), - _rangeLimit(0), - _maxPerGroup(0), - _diversityCutoffGroups(std::numeric_limits<uint32_t>::max()), - _diversityCutoffStrict(false), - _valid(true), - _term(), - _diversityAttribute() -{ } - QueryTermSimple::~QueryTermSimple() = default; namespace { @@ -213,15 +202,15 @@ bool isFullRange(vespalib::stringref s) { } -QueryTermSimple::QueryTermSimple(const string & term_, SearchTerm type) : - _type(type), - _rangeLimit(0), - _maxPerGroup(0), - _diversityCutoffGroups(std::numeric_limits<uint32_t>::max()), - _diversityCutoffStrict(false), - _valid(true), - _term(term_), - _diversityAttribute() +QueryTermSimple::QueryTermSimple(const string & term_, Type type) + : _rangeLimit(0), + _maxPerGroup(0), + _diversityCutoffGroups(std::numeric_limits<uint32_t>::max()), + _type(type), + _diversityCutoffStrict(false), + _valid(true), + _term(term_), + _diversityAttribute() { if (isFullRange(_term)) { stringref rest(_term.c_str() + 1, _term.size() - 2); @@ -272,7 +261,7 @@ QueryTermSimple::getAsNumericTerm(T & lower, T & upper, D d) const bool valid(empty()); size_t sz(_term.size()); if (sz) { - char *err(NULL); + char *err(nullptr); T low(lower); T high(upper); const char * q = _term.c_str(); @@ -320,8 +309,8 @@ QueryTermSimple::getClassName() const } -void visit(vespalib::ObjectVisitor &self, const vespalib::string &name, - const search::QueryTermSimple *obj) +void +visit(vespalib::ObjectVisitor &self, const vespalib::string &name, const search::QueryTermSimple *obj) { if (obj != 0) { self.openStruct(name, obj->getClassName()); @@ -332,8 +321,8 @@ void visit(vespalib::ObjectVisitor &self, const vespalib::string &name, } } -void visit(vespalib::ObjectVisitor &self, const vespalib::string &name, - const search::QueryTermSimple &obj) +void +visit(vespalib::ObjectVisitor &self, const vespalib::string &name, const search::QueryTermSimple &obj) { visit(self, name, &obj); } diff --git a/searchlib/src/vespa/searchlib/query/query_term_simple.h b/searchlib/src/vespa/searchlib/query/query_term_simple.h index 11dbdefdaab..93b19212926 100644 --- a/searchlib/src/vespa/searchlib/query/query_term_simple.h +++ b/searchlib/src/vespa/searchlib/query/query_term_simple.h @@ -15,13 +15,13 @@ public: typedef std::unique_ptr<QueryTermSimple> UP; typedef vespalib::string string; typedef vespalib::stringref stringref; - enum SearchTerm { - WORD, - PREFIXTERM, - SUBSTRINGTERM, - EXACTSTRINGTERM, - SUFFIXTERM, - REGEXP + enum class Type : uint8_t { + WORD = 0, + PREFIXTERM = 1, + SUBSTRINGTERM = 2, + EXACTSTRINGTERM = 3, + SUFFIXTERM = 4, + REGEXP = 5 }; template <typename N> @@ -34,12 +34,11 @@ public: bool isEqual() const { return low == high; } }; - QueryTermSimple(const QueryTermSimple &) = default; - QueryTermSimple & operator = (const QueryTermSimple &) = default; - QueryTermSimple(QueryTermSimple &&) = default; - QueryTermSimple & operator = (QueryTermSimple &&) = default; - QueryTermSimple(); - QueryTermSimple(const string & term_, SearchTerm type); + QueryTermSimple(const QueryTermSimple &) = delete; + QueryTermSimple & operator = (const QueryTermSimple &) = delete; + QueryTermSimple(QueryTermSimple &&) = delete; + QueryTermSimple & operator = (QueryTermSimple &&) = delete; + QueryTermSimple(const string & term_, Type type); virtual ~QueryTermSimple(); /** * Extracts the content of this query term as a range with low and high values. @@ -54,12 +53,12 @@ public: bool getAsIntegerTerm(int64_t & lower, int64_t & upper) const; bool getAsDoubleTerm(double & lower, double & upper) const; const char * getTerm() const { return _term.c_str(); } - bool isPrefix() const { return (_type == PREFIXTERM); } - bool isSubstring() const { return (_type == SUBSTRINGTERM); } - bool isExactstring() const { return (_type == EXACTSTRINGTERM); } - bool isSuffix() const { return (_type == SUFFIXTERM); } - bool isWord() const { return (_type == WORD); } - bool isRegex() const { return (_type == REGEXP); } + bool isPrefix() const { return (_type == Type::PREFIXTERM); } + bool isSubstring() const { return (_type == Type::SUBSTRINGTERM); } + bool isExactstring() const { return (_type == Type::EXACTSTRINGTERM); } + bool isSuffix() const { return (_type == Type::SUFFIXTERM); } + bool isWord() const { return (_type == Type::WORD); } + bool isRegex() const { return (_type == Type::REGEXP); } bool empty() const { return _term.empty(); } virtual void visitMembers(vespalib::ObjectVisitor &visitor) const; vespalib::string getClassName() const; @@ -72,10 +71,10 @@ private: RangeResult<N> getIntegerRange() const; template <typename N> RangeResult<N> getFloatRange() const; - SearchTerm _type; int _rangeLimit; uint32_t _maxPerGroup; uint32_t _diversityCutoffGroups; + Type _type; bool _diversityCutoffStrict; bool _valid; string _term; diff --git a/searchlib/src/vespa/searchlib/query/query_term_ucs4.cpp b/searchlib/src/vespa/searchlib/query/query_term_ucs4.cpp index 86cda7e6786..be0398e1a50 100644 --- a/searchlib/src/vespa/searchlib/query/query_term_ucs4.cpp +++ b/searchlib/src/vespa/searchlib/query/query_term_ucs4.cpp @@ -3,37 +3,25 @@ #include "query_term_ucs4.h" #include <vespa/vespalib/objects/visit.h> #include <vespa/vespalib/text/utf8.h> +#include <mutex> namespace search { -QueryTermUCS4::UCS4StringT -QueryTermUCS4::getUCS4Term() const { - UCS4StringT ucs4; - const string & term = getTermString(); - ucs4.reserve(term.size() + 1); - vespalib::Utf8Reader r(term); - while (r.hasMore()) { - ucs4_t u = r.getChar(); - ucs4.push_back(u); - } - ucs4.push_back(0); - return ucs4; +namespace { + std::mutex _globalMutex; } -QueryTermUCS4::QueryTermUCS4() : - QueryTermSimple(), - _cachedTermLen(0), - _termUCS4() -{ - _termUCS4.push_back(0); +QueryTermUCS4::~QueryTermUCS4() { + ucs4_t * ucs4 = _termUCS4.load(std::memory_order_relaxed); + if (ucs4 != nullptr) { + delete [] ucs4; + } } -QueryTermUCS4::~QueryTermUCS4() = default; - -QueryTermUCS4::QueryTermUCS4(const string & termS, SearchTerm type) : +QueryTermUCS4::QueryTermUCS4(const string & termS, Type type) : QueryTermSimple(termS, type), - _cachedTermLen(0), - _termUCS4() + _termUCS4(nullptr), + _cachedTermLen(0) { vespalib::Utf8Reader r(termS); while (r.hasMore()) { @@ -43,6 +31,27 @@ QueryTermUCS4::QueryTermUCS4(const string & termS, SearchTerm type) : } } +const ucs4_t * +QueryTermUCS4::fillUCS4() { + /* + * Double checked locking...... + * This is a 'dirty' optimisation, but this is done to avoid writing a lot of data and blow the cpu caches with something + * you do not really need most of the time. That matters when qps is very high and query is wide, and hits are few. + */ + std::lock_guard guard(_globalMutex); + ucs4_t * ucs4 = _termUCS4.load(std::memory_order_relaxed); + if (ucs4 != nullptr) return ucs4; + ucs4 = new ucs4_t[_cachedTermLen + 1]; + vespalib::Utf8Reader r(getTermString()); + uint32_t i(0); + while (r.hasMore()) { + ucs4[i++] = r.getChar(); + } + ucs4[_cachedTermLen] = 0; + _termUCS4.store(ucs4); + return ucs4; +} + void QueryTermUCS4::visitMembers(vespalib::ObjectVisitor & visitor) const { diff --git a/searchlib/src/vespa/searchlib/query/query_term_ucs4.h b/searchlib/src/vespa/searchlib/query/query_term_ucs4.h index 8a270d47777..00ac59d729e 100644 --- a/searchlib/src/vespa/searchlib/query/query_term_ucs4.h +++ b/searchlib/src/vespa/searchlib/query/query_term_ucs4.h @@ -2,10 +2,8 @@ #pragma once #include "query_term_simple.h" -#include <vespa/vespalib/util/memory.h> -#include <vespa/vespalib/objects/objectvisitor.h> #include <vespa/fastlib/text/unicodeutil.h> -#include <vector> +#include <atomic> namespace search { @@ -14,29 +12,27 @@ namespace search { */ class QueryTermUCS4 : public QueryTermSimple { public: - typedef std::vector<ucs4_t> UCS4StringT; typedef std::unique_ptr<QueryTermUCS4> UP; - QueryTermUCS4(const QueryTermUCS4 &) = default; - QueryTermUCS4 & operator = (const QueryTermUCS4 &) = default; - QueryTermUCS4(QueryTermUCS4 &&) = default; - QueryTermUCS4 & operator = (QueryTermUCS4 &&) = default; - QueryTermUCS4(); - QueryTermUCS4(const string & term_, SearchTerm type); - ~QueryTermUCS4(); - size_t getTermLen() const { return _cachedTermLen; } - size_t term(const char * & t) const { t = getTerm(); return _cachedTermLen; } - UCS4StringT getUCS4Term() const; + QueryTermUCS4(const QueryTermUCS4 &) = delete; + QueryTermUCS4 & operator = (const QueryTermUCS4 &) = delete; + QueryTermUCS4(QueryTermUCS4 &&) = delete; + QueryTermUCS4 & operator = (QueryTermUCS4 &&) = delete; + QueryTermUCS4(const string & term_, Type type); + ~QueryTermUCS4() override; + uint32_t getTermLen() const { return _cachedTermLen; } + uint32_t term(const char * & t) const { t = getTerm(); return _cachedTermLen; } void visitMembers(vespalib::ObjectVisitor &visitor) const override; - size_t term(const ucs4_t * & t) { - if (_termUCS4.empty()) { - _termUCS4 = getUCS4Term(); + uint32_t term(const ucs4_t * & t) { + t = _termUCS4.load(std::memory_order_relaxed); + if (t == nullptr) { + t = fillUCS4(); } - t = &_termUCS4[0]; return _cachedTermLen; } private: - size_t _cachedTermLen; - UCS4StringT _termUCS4; + const ucs4_t * fillUCS4(); + std::atomic<ucs4_t *> _termUCS4; + uint32_t _cachedTermLen; }; } diff --git a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp index 66466b030d0..ec1b26ec143 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/querynode.cpp @@ -91,22 +91,23 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor index = parent->getIndex() + "." + index; } vespalib::stringref term = queryRep.getTerm(); - QueryTerm::SearchTerm sTerm(QueryTerm::WORD); + using TermType = QueryTerm::Type; + TermType sTerm(TermType::WORD); switch (type) { case ParseItem::ITEM_REGEXP: - sTerm = QueryTerm::REGEXP; + sTerm = TermType::REGEXP; break; case ParseItem::ITEM_PREFIXTERM: - sTerm = QueryTerm::PREFIXTERM; + sTerm = TermType::PREFIXTERM; break; case ParseItem::ITEM_SUBSTRINGTERM: - sTerm = QueryTerm::SUBSTRINGTERM; + sTerm = TermType::SUBSTRINGTERM; break; case ParseItem::ITEM_EXACTSTRINGTERM: - sTerm = QueryTerm::EXACTSTRINGTERM; + sTerm = TermType::EXACTSTRINGTERM; break; case ParseItem::ITEM_SUFFIXTERM: - sTerm = QueryTerm::SUFFIXTERM; + sTerm = TermType::SUFFIXTERM; break; default: break; @@ -118,16 +119,16 @@ QueryNode::Build(const QueryNode * parent, const QueryNodeResultFactory & factor // But it will do for now as only correct sddocname queries are sent down. qn.reset(new TrueNode()); } else { - std::unique_ptr<QueryTerm> qt(new QueryTerm(factory.create(), ssTerm, ssIndex, sTerm)); + auto qt = std::make_unique<QueryTerm>(factory.create(), ssTerm, ssIndex, sTerm); qt->setWeight(queryRep.GetWeight()); qt->setUniqueId(queryRep.getUniqueId()); if ( qt->encoding().isBase10Integer() || ! qt->encoding().isFloat() || ! factory.getRewriteFloatTerms() || !allowRewrite || (ssTerm.find('.') == vespalib::string::npos)) { qn = std::move(qt); } else { - std::unique_ptr<PhraseQueryNode> phrase(new PhraseQueryNode()); - phrase->push_back(UP(new QueryTerm(factory.create(), ssTerm.substr(0, ssTerm.find('.')), ssIndex, QueryTerm::WORD))); - phrase->push_back(UP(new QueryTerm(factory.create(), ssTerm.substr(ssTerm.find('.') + 1), ssIndex, QueryTerm::WORD))); - std::unique_ptr<EquivQueryNode> orqn(new EquivQueryNode()); + auto phrase = std::make_unique<PhraseQueryNode>(); + phrase->push_back(std::make_unique<QueryTerm>(factory.create(), ssTerm.substr(0, ssTerm.find('.')), ssIndex, TermType::WORD)); + phrase->push_back(std::make_unique<QueryTerm>(factory.create(), ssTerm.substr(ssTerm.find('.') + 1), ssIndex, TermType::WORD)); + auto orqn = std::make_unique<EquivQueryNode>(); orqn->push_back(std::move(qt)); orqn->push_back(std::move(phrase)); qn.reset(orqn.release()); diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp index 943920c9dc6..69250d84cab 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.cpp @@ -40,22 +40,6 @@ static CharInfo _G_charTable; namespace search::streaming { -QueryTerm::QueryTerm() : - QueryTermUCS4(), - _index(), - _encoding(), - _result(), - _hitList(), - _weight(100), - _uniqueId(0), - _fieldInfo() -{ } - -QueryTerm::QueryTerm(const QueryTerm &) = default; -QueryTerm & QueryTerm::operator = (const QueryTerm &) = default; -QueryTerm::QueryTerm(QueryTerm &&) noexcept = default; -QueryTerm & QueryTerm::operator = (QueryTerm &&) noexcept = default; - QueryTerm::~QueryTerm() = default; void @@ -70,7 +54,7 @@ QueryTerm::visitMembers(vespalib::ObjectVisitor & visitor) const visit(visitor, "uniqueid", _uniqueId); } -QueryTerm::QueryTerm(std::unique_ptr<QueryNodeResultBase> org, const string & termS, const string & indexS, SearchTerm type) : +QueryTerm::QueryTerm(std::unique_ptr<QueryNodeResultBase> org, const string & termS, const string & indexS, Type type) : QueryTermUCS4(termS, type), _index(indexS), _encoding(0x01), diff --git a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h index 65e966ca0f2..134945e36d6 100644 --- a/searchlib/src/vespa/searchlib/query/streaming/queryterm.h +++ b/searchlib/src/vespa/searchlib/query/streaming/queryterm.h @@ -26,7 +26,7 @@ public: class EncodingBitMap { public: - EncodingBitMap(unsigned bm=0) : _enc(bm) { } + EncodingBitMap(uint8_t bm=0) : _enc(bm) { } bool isFloat() const { return _enc & Float; } bool isBase10Integer() const { return _enc & Base10Integer; } bool isAscii7Bit() const { return _enc & Ascii7Bit; } @@ -35,7 +35,7 @@ public: void setFloat(bool v) { if (v) _enc |= Float; else _enc &= ~Float; } private: enum { Ascii7Bit=0x01, Base10Integer=0x02, Float=0x04 }; - unsigned _enc; + uint8_t _enc; }; class FieldInfo { public: @@ -53,12 +53,11 @@ public: uint32_t _hitCount; uint32_t _fieldLength; }; - QueryTerm(); - QueryTerm(std::unique_ptr<QueryNodeResultBase> resultBase, const string & term, const string & index, SearchTerm type); - QueryTerm(const QueryTerm &); - QueryTerm & operator = (const QueryTerm &); - QueryTerm(QueryTerm &&) noexcept; - QueryTerm & operator = (QueryTerm &&) noexcept; + QueryTerm(std::unique_ptr<QueryNodeResultBase> resultBase, const string & term, const string & index, Type type); + QueryTerm(const QueryTerm &) = delete; + QueryTerm & operator = (const QueryTerm &) = delete; + QueryTerm(QueryTerm &&) = delete; + QueryTerm & operator = (QueryTerm &&) = delete; ~QueryTerm(); bool evaluate() const override; const HitList & evaluateHits(HitList & hl) const override; diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp index 362e1b45266..cc43a694a69 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute_saver.cpp @@ -62,13 +62,11 @@ void DenseTensorAttributeSaver::save_tensor_store(BufferWriter& writer) const { const uint32_t docIdLimit(_refs.size()); - const uint32_t cellSize = _tensorStore.getCellSize(); for (uint32_t lid = 0; lid < docIdLimit; ++lid) { if (_refs[lid].valid()) { auto raw = _tensorStore.getRawBuffer(_refs[lid]); writer.write(&tensorIsPresent, sizeof(tensorIsPresent)); - size_t numCells = _tensorStore.getNumCells(); - size_t rawLen = numCells * cellSize; + size_t rawLen = _tensorStore.getBufSize(); writer.write(static_cast<const char *>(raw), rawLen); } else { writer.write(&tensorIsNotPresent, sizeof(tensorIsNotPresent)); diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp index e99ba196224..13796d35dec 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp @@ -6,9 +6,10 @@ #include <vespa/vespalib/util/memory_allocator.h> using vespalib::datastore::Handle; +using vespalib::eval::CellType; +using vespalib::eval::CellTypeUtils; using vespalib::eval::Value; using vespalib::eval::ValueType; -using CellType = vespalib::eval::CellType; namespace search::tensor { @@ -17,14 +18,6 @@ namespace { constexpr size_t MIN_BUFFER_ARRAYS = 1024; constexpr size_t DENSE_TENSOR_ALIGNMENT = 32; -size_t size_of(CellType type) { - switch (type) { - case CellType::DOUBLE: return sizeof(double); - case CellType::FLOAT: return sizeof(float); - } - abort(); -} - size_t my_align(size_t size, size_t alignment) { size += alignment - 1; return (size - (size % alignment)); @@ -34,7 +27,7 @@ size_t my_align(size_t size, size_t alignment) { DenseTensorStore::TensorSizeCalc::TensorSizeCalc(const ValueType &type) : _numCells(1u), - _cellSize(size_of(type.cell_type())) + _cell_type(type.cell_type()) { for (const auto &dim: type.dimensions()) { _numCells *= dim.size; diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h index aa5a7993eaf..dad28642e67 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h @@ -24,10 +24,12 @@ public: struct TensorSizeCalc { size_t _numCells; // product of dimension sizes - uint32_t _cellSize; // size of a cell (e.g. double => 8, float => 4) + vespalib::eval::CellType _cell_type; TensorSizeCalc(const ValueType &type); - size_t bufSize() const { return (_numCells * _cellSize); } + size_t bufSize() const { + return vespalib::eval::CellTypeUtils::mem_size(_cell_type, _numCells); + } size_t alignedSize() const; }; @@ -60,7 +62,6 @@ public: const ValueType &type() const { return _type; } size_t getNumCells() const { return _tensorSizeCalc._numCells; } - uint32_t getCellSize() const { return _tensorSizeCalc._cellSize; } size_t getBufSize() const { return _tensorSizeCalc.bufSize(); } const void *getRawBuffer(RefType ref) const; vespalib::datastore::Handle<char> allocRawBuffer(); diff --git a/searchlib/src/vespa/searchlib/test/imported_attribute_fixture.cpp b/searchlib/src/vespa/searchlib/test/imported_attribute_fixture.cpp index 68d2dc5472b..61eaae40e90 100644 --- a/searchlib/src/vespa/searchlib/test/imported_attribute_fixture.cpp +++ b/searchlib/src/vespa/searchlib/test/imported_attribute_fixture.cpp @@ -2,6 +2,7 @@ #include "imported_attribute_fixture.h" #include "mock_gid_to_lid_mapping.h" +#include <vespa/searchlib/query/query_term_ucs4.h> #include <vespa/vespalib/util/stringfmt.h> #include <future> @@ -55,7 +56,7 @@ GlobalId dummy_gid(uint32_t doc_index) { } std::unique_ptr<QueryTermSimple> word_term(vespalib::stringref term) { - return std::make_unique<QueryTermSimple>(term, QueryTermSimple::WORD); + return std::make_unique<QueryTermUCS4>(term, QueryTermSimple::Type::WORD); } diff --git a/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java b/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java index d4e74e22e40..4d6b160db18 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/AutoReloadingX509KeyManager.java @@ -12,6 +12,8 @@ import javax.net.ssl.X509ExtendedKeyManager; import java.io.IOException; import java.io.UncheckedIOException; import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.nio.file.Path; import java.security.KeyStore; import java.security.Principal; @@ -72,8 +74,8 @@ public class AutoReloadingX509KeyManager extends X509ExtendedKeyManager implemen return KeyStoreBuilder.withType(KeyStoreType.PKCS12) .withKeyEntry( CERTIFICATE_ALIAS, - KeyUtils.fromPemEncodedPrivateKey(com.yahoo.vespa.jdk8compat.Files.readString(privateKey)), - X509CertificateUtils.certificateListFromPem(com.yahoo.vespa.jdk8compat.Files.readString(certificateChain))) + KeyUtils.fromPemEncodedPrivateKey(new String(Files.readAllBytes(privateKey), StandardCharsets.UTF_8)), + X509CertificateUtils.certificateListFromPem(new String(Files.readAllBytes(certificateChain), StandardCharsets.UTF_8))) .build(); } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/security-utils/src/main/java/com/yahoo/security/tls/ConfigFileBasedTlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/ConfigFileBasedTlsContext.java index acc70d50d6a..bc1f1dcc6f6 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/ConfigFileBasedTlsContext.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/ConfigFileBasedTlsContext.java @@ -14,9 +14,12 @@ import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; import java.io.IOException; import java.io.UncheckedIOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; import java.nio.file.Path; import java.security.KeyStore; import java.time.Duration; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Set; @@ -82,7 +85,7 @@ public class ConfigFileBasedTlsContext implements TlsContext { private static KeyStore loadTruststore(Path caCertificateFile) { try { return KeyStoreBuilder.withType(KeyStoreType.PKCS12) - .withCertificateEntries("cert", X509CertificateUtils.certificateListFromPem(com.yahoo.vespa.jdk8compat.Files.readString(caCertificateFile))) + .withCertificateEntries("cert", X509CertificateUtils.certificateListFromPem(new String(Files.readAllBytes(caCertificateFile), StandardCharsets.UTF_8))) .build(); } catch (IOException e) { throw new UncheckedIOException(e); @@ -94,8 +97,8 @@ public class ConfigFileBasedTlsContext implements TlsContext { return KeyStoreBuilder.withType(KeyStoreType.PKCS12) .withKeyEntry( "default", - KeyUtils.fromPemEncodedPrivateKey(com.yahoo.vespa.jdk8compat.Files.readString(privateKeyFile)), - X509CertificateUtils.certificateListFromPem(com.yahoo.vespa.jdk8compat.Files.readString(certificatesFile))) + KeyUtils.fromPemEncodedPrivateKey(new String(Files.readAllBytes(privateKeyFile), StandardCharsets.UTF_8)), + X509CertificateUtils.certificateListFromPem(new String(Files.readAllBytes(certificatesFile), StandardCharsets.UTF_8))) .build(); } catch (IOException e) { throw new UncheckedIOException(e); @@ -111,7 +114,7 @@ public class ConfigFileBasedTlsContext implements TlsContext { HostnameVerification hostnameVerification = options.isHostnameValidationDisabled() ? HostnameVerification.DISABLED : HostnameVerification.ENABLED; PeerAuthorizerTrustManager authorizerTrustManager = options.getAuthorizedPeers() .map(authorizedPeers -> new PeerAuthorizerTrustManager(authorizedPeers, mode, hostnameVerification, mutableTrustManager)) - .orElseGet(() -> new PeerAuthorizerTrustManager(new AuthorizedPeers(com.yahoo.vespa.jdk8compat.Set.of()), AuthorizationMode.DISABLE, hostnameVerification, mutableTrustManager)); + .orElseGet(() -> new PeerAuthorizerTrustManager(new AuthorizedPeers(Collections.emptySet()), AuthorizationMode.DISABLE, hostnameVerification, mutableTrustManager)); SSLContext sslContext = new SslContextBuilder() .withKeyManager(mutableKeyManager) .withTrustManager(authorizerTrustManager) diff --git a/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java index 250596628ee..56f2ecb8efc 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/DefaultTlsContext.java @@ -11,6 +11,7 @@ import javax.net.ssl.SSLParameters; import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Set; import java.util.logging.Level; @@ -61,7 +62,7 @@ public class DefaultTlsContext implements TlsContext { String.format("None of the accepted ciphers are supported (supported=%s, accepted=%s)", supportedCiphers, acceptedCiphers)); } - log.log(Level.FINE, () -> String.format("Allowed cipher suites that are supported: %s", com.yahoo.vespa.jdk8compat.List.of(allowedCiphers))); + log.log(Level.FINE, () -> String.format("Allowed cipher suites that are supported: %s", Arrays.asList(allowedCiphers))); return allowedCiphers; } @@ -139,7 +140,7 @@ public class DefaultTlsContext implements TlsContext { builder.withTrustManagerFactory(truststore -> new PeerAuthorizerTrustManager(authorizedPeers, mode, hostnameVerification, truststore)); } else { builder.withTrustManagerFactory(truststore -> new PeerAuthorizerTrustManager( - new AuthorizedPeers(com.yahoo.vespa.jdk8compat.Set.of()), AuthorizationMode.DISABLE, hostnameVerification, truststore)); + new AuthorizedPeers(Collections.emptySet()), AuthorizationMode.DISABLE, hostnameVerification, truststore)); } return builder.build(); } diff --git a/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java b/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java index c60f13f9729..a3b438fcc65 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/KeyManagerUtils.java @@ -30,7 +30,7 @@ public class KeyManagerUtils { .filter(manager -> manager instanceof X509ExtendedKeyManager) .map(X509ExtendedKeyManager.class::cast) .findFirst() - .orElseThrow(() -> new RuntimeException("No X509ExtendedKeyManager in " + com.yahoo.vespa.jdk8compat.List.of(keyManagers))); + .orElseThrow(() -> new RuntimeException("No X509ExtendedKeyManager in " + Arrays.asList(keyManagers))); } catch (GeneralSecurityException e) { throw new RuntimeException(e); } diff --git a/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java b/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java index eef05d4f4f2..1f78dc9d481 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/TlsContext.java @@ -5,6 +5,8 @@ import javax.net.ssl.SSLContext; import javax.net.ssl.SSLEngine; import javax.net.ssl.SSLParameters; import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; import java.util.Set; import static java.util.stream.Collectors.toSet; @@ -23,7 +25,7 @@ public interface TlsContext extends AutoCloseable { * For TLSv1.3 we allow the DEFAULT group ciphers. * Note that we _only_ allow AEAD ciphers for either TLS version. */ - Set<String> ALLOWED_CIPHER_SUITES = com.yahoo.vespa.jdk8compat.Set.of( + Set<String> ALLOWED_CIPHER_SUITES = Collections.unmodifiableSet(new HashSet<>(Arrays.asList( "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", // Java 12 @@ -32,10 +34,10 @@ public interface TlsContext extends AutoCloseable { "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "TLS_AES_128_GCM_SHA256", // TLSv1.3 "TLS_AES_256_GCM_SHA384", // TLSv1.3 - "TLS_CHACHA20_POLY1305_SHA256"); // TLSv1.3, Java 12 + "TLS_CHACHA20_POLY1305_SHA256"))); // TLSv1.3, Java 12 // TODO Enable TLSv1.3 after upgrading to JDK 17 - Set<String> ALLOWED_PROTOCOLS = com.yahoo.vespa.jdk8compat.Set.of("TLSv1.2"); + Set<String> ALLOWED_PROTOCOLS = Collections.singleton("TLSv1.2"); String SSL_CONTEXT_VERSION = "TLS"; // Use SSLContext implementations that supports all TLS versions /** diff --git a/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java b/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java index 17f56011261..cb8c6e53555 100644 --- a/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java +++ b/security-utils/src/main/java/com/yahoo/security/tls/TrustManagerUtils.java @@ -29,7 +29,7 @@ public class TrustManagerUtils { .filter(manager -> manager instanceof X509ExtendedTrustManager) .map(X509ExtendedTrustManager.class::cast) .findFirst() - .orElseThrow(() -> new RuntimeException("No X509ExtendedTrustManager in " + com.yahoo.vespa.jdk8compat.List.of(trustManagers))); + .orElseThrow(() -> new RuntimeException("No X509ExtendedTrustManager in " + Arrays.asList(trustManagers))); } catch (GeneralSecurityException e) { throw new RuntimeException(e); } diff --git a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Collection.java b/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Collection.java deleted file mode 100644 index fbfea01b2c7..00000000000 --- a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Collection.java +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.jdk8compat; - -import java.util.function.IntFunction; - -/** - * Backport of new {@link java.util.Collection} methods added after JDK8 - * - * @author bjorncs - */ -public interface Collection { - static <T> T[] toArray(java.util.Collection<T> collection, IntFunction<T[]> generator) { - return collection.toArray(generator.apply(collection.size())); - } - -} diff --git a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Files.java b/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Files.java deleted file mode 100644 index cc3bd698cd5..00000000000 --- a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Files.java +++ /dev/null @@ -1,24 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.jdk8compat; - -import java.io.IOException; -import java.nio.charset.StandardCharsets; -import java.nio.file.OpenOption; -import java.nio.file.Path; - -/** - * Backport of new {@link java.nio.file.Files} methods added after JDK8 - * - * @author bjorncs - */ -public interface Files { - - static String readString(Path path) throws IOException { - byte[] bytes = java.nio.file.Files.readAllBytes(path); - return new String(bytes, StandardCharsets.UTF_8); - } - - static Path writeString(Path path, CharSequence string, OpenOption... options) throws IOException { - return java.nio.file.Files.write(path, string.toString().getBytes(), options); - } -} diff --git a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/List.java b/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/List.java deleted file mode 100644 index f57834e93cb..00000000000 --- a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/List.java +++ /dev/null @@ -1,17 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.jdk8compat; - -import java.util.Arrays; - -/** - * Backport of new {@link java.util.List} methods added after JDK8 - * - * @author bjorncs - */ -public interface List { - @SafeVarargs - @SuppressWarnings("varargs") - static <E> java.util.List<E> of(E... elements) { - return Arrays.asList(elements); - } -} diff --git a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Set.java b/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Set.java deleted file mode 100644 index b2c998bb716..00000000000 --- a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/Set.java +++ /dev/null @@ -1,18 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.jdk8compat; - -import java.util.Arrays; -import java.util.HashSet; - -/** - * Backport of new {@link java.util.Set} methods added after JDK8 - * - * @author bjorncs - */ -public interface Set { - @SafeVarargs - @SuppressWarnings("varargs") - static <E> java.util.Set<E> of(E... elements) { - return new HashSet<>(Arrays.asList(elements)); - } -} diff --git a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/package-info.java b/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/package-info.java deleted file mode 100644 index 40d74321438..00000000000 --- a/security-utils/src/main/java/com/yahoo/vespa/jdk8compat/package-info.java +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -/** - * JDK8 port of types and methods added in later JDK versions. - * TODO Remove this package once vespa-http-client/security-utils no longer targets JDK8 - * - * @author bjorncs - */ -package com.yahoo.vespa.jdk8compat;
\ No newline at end of file diff --git a/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java index 22710e7f393..024149a7282 100644 --- a/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java +++ b/security-utils/src/test/java/com/yahoo/security/tls/AutoReloadingX509KeyManagerTest.java @@ -15,6 +15,7 @@ import org.mockito.Mockito; import javax.security.auth.x500.X500Principal; import java.io.IOException; import java.math.BigInteger; +import java.nio.file.Files; import java.nio.file.Path; import java.security.KeyPair; import java.security.Principal; @@ -41,12 +42,12 @@ public class AutoReloadingX509KeyManagerTest { public void crypto_material_is_reloaded_when_scheduler_task_is_executed() throws IOException { KeyPair keyPair = KeyUtils.generateKeypair(KeyAlgorithm.EC); Path privateKeyFile = tempDirectory.newFile().toPath(); - com.yahoo.vespa.jdk8compat.Files.writeString(privateKeyFile, KeyUtils.toPem(keyPair.getPrivate())); + Files.write(privateKeyFile, KeyUtils.toPem(keyPair.getPrivate()).getBytes()); Path certificateFile = tempDirectory.newFile().toPath(); BigInteger serialNumberInitialCertificate = BigInteger.ONE; X509Certificate initialCertificate = generateCertificate(keyPair, serialNumberInitialCertificate); - com.yahoo.vespa.jdk8compat.Files.writeString(certificateFile, X509CertificateUtils.toPem(initialCertificate)); + Files.write(certificateFile, X509CertificateUtils.toPem(initialCertificate).getBytes()); ScheduledExecutorService scheduler = Mockito.mock(ScheduledExecutorService.class); ArgumentCaptor<Runnable> updaterTaskCaptor = ArgumentCaptor.forClass(Runnable.class); @@ -61,7 +62,7 @@ public class AutoReloadingX509KeyManagerTest { BigInteger serialNumberUpdatedCertificate = BigInteger.TEN; X509Certificate updatedCertificate = generateCertificate(keyPair, serialNumberUpdatedCertificate); - com.yahoo.vespa.jdk8compat.Files.writeString(certificateFile, X509CertificateUtils.toPem(updatedCertificate)); + Files.write(certificateFile, X509CertificateUtils.toPem(updatedCertificate).getBytes()); updaterTaskCaptor.getValue().run(); // run update task in ReloadingX509KeyManager diff --git a/security-utils/src/test/java/com/yahoo/security/tls/ConfigFileBasedTlsContextTest.java b/security-utils/src/test/java/com/yahoo/security/tls/ConfigFileBasedTlsContextTest.java index 54a1e3847f9..f1c8acbaf3b 100644 --- a/security-utils/src/test/java/com/yahoo/security/tls/ConfigFileBasedTlsContextTest.java +++ b/security-utils/src/test/java/com/yahoo/security/tls/ConfigFileBasedTlsContextTest.java @@ -35,17 +35,17 @@ public class ConfigFileBasedTlsContextTest { public void can_create_sslcontext_from_credentials() throws IOException, InterruptedException { KeyPair keyPair = KeyUtils.generateKeypair(EC); Path privateKeyFile = tempDirectory.newFile().toPath(); - com.yahoo.vespa.jdk8compat.Files.writeString(privateKeyFile, KeyUtils.toPem(keyPair.getPrivate())); + Files.write(privateKeyFile, KeyUtils.toPem(keyPair.getPrivate()).getBytes()); X509Certificate certificate = X509CertificateBuilder .fromKeypair(keyPair, new X500Principal("CN=dummy"), EPOCH, EPOCH.plus(1, DAYS), SHA256_WITH_ECDSA, BigInteger.ONE) .build(); Path certificateChainFile = tempDirectory.newFile().toPath(); String certificatePem = X509CertificateUtils.toPem(certificate); - com.yahoo.vespa.jdk8compat.Files.writeString(certificateChainFile, certificatePem); + Files.write(certificateChainFile, certificatePem.getBytes()); Path caCertificatesFile = tempDirectory.newFile().toPath(); - com.yahoo.vespa.jdk8compat.Files.writeString(caCertificatesFile, certificatePem); + Files.write(caCertificatesFile, certificatePem.getBytes()); TransportSecurityOptions options = new TransportSecurityOptions.Builder() .withCertificates(certificateChainFile, privateKeyFile) diff --git a/security-utils/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java b/security-utils/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java index 8fd2ca065c7..43389ade275 100644 --- a/security-utils/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java +++ b/security-utils/src/test/java/com/yahoo/security/tls/TransportSecurityOptionsTest.java @@ -8,6 +8,7 @@ import java.nio.charset.StandardCharsets; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Arrays; import java.util.Collections; import static org.junit.Assert.assertEquals; @@ -21,7 +22,7 @@ public class TransportSecurityOptionsTest { private static final TransportSecurityOptions OPTIONS = new TransportSecurityOptions.Builder() .withCertificates(Paths.get("certs.pem"), Paths.get("myhost.key")) .withCaCertificates(Paths.get("my_cas.pem")) - .withAcceptedCiphers(com.yahoo.vespa.jdk8compat.List.of("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" , "TLS_AES_256_GCM_SHA384")) + .withAcceptedCiphers(Arrays.asList("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", "TLS_AES_256_GCM_SHA384")) .withAcceptedProtocols(Collections.singletonList("TLSv1.2")) .withHostnameValidationDisabled(true) .build(); diff --git a/security-utils/src/test/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializerTest.java b/security-utils/src/test/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializerTest.java index e14b3d99212..35fd25b6a62 100644 --- a/security-utils/src/test/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializerTest.java +++ b/security-utils/src/test/java/com/yahoo/security/tls/json/TransportSecurityOptionsJsonSerializerTest.java @@ -68,7 +68,7 @@ public class TransportSecurityOptionsJsonSerializerTest { TransportSecurityOptions options = new TransportSecurityOptions.Builder() .withCertificates(Paths.get("certs.pem"), Paths.get("myhost.key")) .withCaCertificates(Paths.get("my_cas.pem")) - .withAcceptedCiphers(com.yahoo.vespa.jdk8compat.List.of("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" , "TLS_AES_256_GCM_SHA384")) + .withAcceptedCiphers(Arrays.asList("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384" , "TLS_AES_256_GCM_SHA384")) .withAcceptedProtocols(Collections.singletonList("TLSv1.2")) .withHostnameValidationDisabled(true) .build(); diff --git a/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp b/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp index b0cb058d762..de8a6c707e5 100644 --- a/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/hitcollector.cpp @@ -162,7 +162,7 @@ HitCollector::getFeatureSet(IRankProgram &rankProgram, for (uint32_t j = 0; j < names.size(); ++j) { if (resolver.is_object(j)) { auto obj = resolver.resolve(j).as_object(docId); - if (! obj.get().is_double()) { + if (! obj.get().type().is_double()) { vespalib::nbostream buf; encode_value(obj.get(), buf); f[j].set_data(vespalib::Memory(buf.peek(), buf.size())); diff --git a/vespalib/CMakeLists.txt b/vespalib/CMakeLists.txt index c51e42176dc..cc82091568e 100644 --- a/vespalib/CMakeLists.txt +++ b/vespalib/CMakeLists.txt @@ -42,6 +42,7 @@ vespa_define_module( src/tests/datastore/array_store_config src/tests/datastore/buffer_type src/tests/datastore/datastore + src/tests/datastore/simple_hash_map src/tests/datastore/unique_store src/tests/datastore/unique_store_dictionary src/tests/datastore/unique_store_string_allocator @@ -131,9 +132,11 @@ vespa_define_module( src/tests/tutorial/threads src/tests/typify src/tests/util/bfloat16 + src/tests/util/file_area_freelist src/tests/util/generationhandler src/tests/util/generationhandler_stress src/tests/util/md5 + src/tests/util/mmap_file_allocator src/tests/util/mmap_file_allocator_factory src/tests/util/rcuvector src/tests/util/reusable_set diff --git a/vespalib/src/tests/alloc/alloc_test.cpp b/vespalib/src/tests/alloc/alloc_test.cpp index d37abb15c2f..52b11e7dbc0 100644 --- a/vespalib/src/tests/alloc/alloc_test.cpp +++ b/vespalib/src/tests/alloc/alloc_test.cpp @@ -2,7 +2,7 @@ #include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/util/alloc.h> -#include <vespa/vespalib/util/mmap_file_allocator.h> +#include <vespa/vespalib/util/memory_allocator.h> #include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/util/size_literals.h> #include <cstddef> @@ -297,30 +297,4 @@ TEST("auto alloced mmap alloc can not be shrinked below HUGEPAGE_SIZE/2 + 1 ") { EXPECT_EQUAL(SZ, buf.size()); } -TEST("mmap file allocator works") -{ - MmapFileAllocator allocator("mmap-file-allocator-dir"); - auto alloc = Alloc::alloc_with_allocator(&allocator); - auto buf = alloc.create(0); - EXPECT_EQUAL(0u, allocator.get_end_offset()); - EXPECT_EQUAL(0u, buf.size()); - EXPECT_TRUE(buf.get() == nullptr); - buf = alloc.create(4); - EXPECT_LESS_EQUAL(4u, buf.size()); - EXPECT_TRUE(buf.get() != nullptr); - memcpy(buf.get(), "1st", 4); - auto buf2 = alloc.create(5); - EXPECT_LESS_EQUAL(5u, buf2.size()); - EXPECT_TRUE(buf2.get() != nullptr); - EXPECT_TRUE(buf.get() != buf2.get()); - memcpy(buf2.get(), "fine", 5); - EXPECT_FALSE(buf.resize_inplace(5)); - EXPECT_FALSE(buf.resize_inplace(3)); - EXPECT_NOT_EQUAL(0u, allocator.get_end_offset()); - int result = msync(buf.get(), buf.size(), MS_SYNC); - EXPECT_EQUAL(0, result); - result = msync(buf2.get(), buf2.size(), MS_SYNC); - EXPECT_EQUAL(0, result); -} - TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/tests/datastore/simple_hash_map/CMakeLists.txt b/vespalib/src/tests/datastore/simple_hash_map/CMakeLists.txt new file mode 100644 index 00000000000..c790481ebbc --- /dev/null +++ b/vespalib/src/tests/datastore/simple_hash_map/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_simple_hash_map_test_app + SOURCES + simple_hash_map_test.cpp + DEPENDS + vespalib + GTest::GTest +) +vespa_add_test(NAME vespalib_simple_hash_map_test_app COMMAND vespalib_simple_hash_map_test_app) diff --git a/vespalib/src/tests/datastore/simple_hash_map/simple_hash_map_test.cpp b/vespalib/src/tests/datastore/simple_hash_map/simple_hash_map_test.cpp new file mode 100644 index 00000000000..ad4ac1518b2 --- /dev/null +++ b/vespalib/src/tests/datastore/simple_hash_map/simple_hash_map_test.cpp @@ -0,0 +1,232 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/datastore/simple_hash_map.h> +#include <vespa/vespalib/datastore/unique_store_allocator.h> +#include <vespa/vespalib/datastore/unique_store_comparator.h> + +#include <vespa/vespalib/util/lambdatask.h> +#include <vespa/vespalib/util/rand48.h> +#include <vespa/vespalib/util/size_literals.h> +#include <vespa/vespalib/util/threadstackexecutor.h> +#include <vespa/vespalib/gtest/gtest.h> + +#include <vespa/vespalib/datastore/unique_store_allocator.hpp> + +#include <vespa/log/log.h> +LOG_SETUP("vespalib_datastore_simple_hash_test"); + +using vespalib::datastore::EntryRef; +using RefT = vespalib::datastore::EntryRefT<22>; +using MyAllocator = vespalib::datastore::UniqueStoreAllocator<uint32_t, RefT>; +using MyDataStore = vespalib::datastore::DataStoreT<RefT>; +using MyCompare = vespalib::datastore::UniqueStoreComparator<uint32_t, RefT>; +using MyHashMap = vespalib::datastore::SimpleHashMap; +using GenerationHandler = vespalib::GenerationHandler; +using vespalib::makeLambdaTask; + +struct DataStoreSimpleHashTest : public ::testing::Test +{ + GenerationHandler _generationHandler; + MyAllocator _allocator; + MyDataStore& _store; + MyHashMap _hash_map; + vespalib::ThreadStackExecutor _writer; // 1 write thread + vespalib::ThreadStackExecutor _readers; // multiple reader threads + vespalib::Rand48 _rnd; + uint32_t _keyLimit; + std::atomic<long> _read_seed; + std::atomic<long> _done_write_work; + std::atomic<long> _done_read_work; + std::atomic<long> _found_count; + std::atomic<int> _stop_read; + bool _report_work; + + DataStoreSimpleHashTest(); + ~DataStoreSimpleHashTest(); + void commit(); + void insert(uint32_t key); + void remove(uint32_t key); + + void read_work(uint32_t cnt); + void read_work(); + void write_work(uint32_t cnt); +}; + + +DataStoreSimpleHashTest::DataStoreSimpleHashTest() + : _generationHandler(), + _allocator(), + _store(_allocator.get_data_store()), + _hash_map(std::make_unique<MyCompare>(_store)), + _writer(1, 128_Ki), + _readers(4, 128_Ki), + _rnd(), + _keyLimit(1000000), + _read_seed(50), + _done_write_work(0), + _done_read_work(0), + _found_count(0), + _stop_read(0), + _report_work(false) +{ + _rnd.srand48(32); +} + + +DataStoreSimpleHashTest::~DataStoreSimpleHashTest() +{ + _readers.sync(); + _readers.shutdown(); + _writer.sync(); + _writer.shutdown(); + commit(); + if (_report_work) { + LOG(info, + "read_work=%ld, write_work=%ld, found_count=%ld", + _done_read_work.load(), _done_write_work.load(), _found_count.load()); + } +} + + +void +DataStoreSimpleHashTest::commit() +{ + _store.transferHoldLists(_generationHandler.getCurrentGeneration()); + _hash_map.transfer_hold_lists(_generationHandler.getCurrentGeneration()); + _generationHandler.incGeneration(); + _store.trimHoldLists(_generationHandler.getFirstUsedGeneration()); + _hash_map.trim_hold_lists(_generationHandler.getFirstUsedGeneration()); +} + +void +DataStoreSimpleHashTest::insert(uint32_t key) +{ + MyCompare comp(_store, key); +std::function<EntryRef(void)> insert_entry([this, key]() -> EntryRef { return _allocator.allocate(key); }); + auto& result = _hash_map.add(comp, insert_entry); + auto ref = result.first.load_relaxed(); + auto &wrapped_entry = _allocator.get_wrapped(ref); + EXPECT_EQ(key, wrapped_entry.value()); +} + +void +DataStoreSimpleHashTest::remove(uint32_t key) +{ + MyCompare comp(_store, key); + auto result = _hash_map.remove(comp, EntryRef()); + if (result != nullptr) { + auto ref = result->first.load_relaxed(); + auto &wrapped_entry = _allocator.get_wrapped(ref); + EXPECT_EQ(key, wrapped_entry.value()); + _allocator.hold(ref); + } +} + + +void +DataStoreSimpleHashTest::read_work(uint32_t cnt) +{ + vespalib::Rand48 rnd; + long found = 0; + rnd.srand48(++_read_seed); + uint32_t i; + for (i = 0; i < cnt && _stop_read.load() == 0; ++i) { + auto guard = _generationHandler.takeGuard(); + uint32_t key = rnd.lrand48() % (_keyLimit + 1); + MyCompare comp(_store, key); + auto result = _hash_map.find(comp, EntryRef()); + if (result != nullptr) { + auto ref = result->first.load_relaxed(); + auto &wrapped_entry = _allocator.get_wrapped(ref); + EXPECT_EQ(key, wrapped_entry.value()); + ++found; + } + } + _done_read_work += i; + _found_count += found; + LOG(info, "done %u read work", i); +} + + +void +DataStoreSimpleHashTest::read_work() +{ + read_work(std::numeric_limits<uint32_t>::max()); +} + + +void +DataStoreSimpleHashTest::write_work(uint32_t cnt) +{ + vespalib::Rand48 &rnd(_rnd); + for (uint32_t i = 0; i < cnt; ++i) { + uint32_t key = rnd.lrand48() % _keyLimit; + if ((rnd.lrand48() & 1) == 0) { + insert(key); + } else { + remove(key); + } + commit(); + } + _done_write_work += cnt; + _stop_read = 1; + LOG(info, "done %u write work", cnt); +} + + +TEST_F(DataStoreSimpleHashTest, smoke_test) +{ + EXPECT_EQ(0, _hash_map.size()); + insert(1); + EXPECT_EQ(1, _hash_map.size()); + remove(2); + EXPECT_EQ(1, _hash_map.size()); + insert(1); + EXPECT_EQ(1, _hash_map.size()); + insert(5); + EXPECT_EQ(2, _hash_map.size()); + insert(4); + EXPECT_EQ(3, _hash_map.size()); + remove(3); + EXPECT_EQ(3, _hash_map.size()); + remove(5); + EXPECT_EQ(2, _hash_map.size()); + commit(); + MyCompare comp3(_store, 3); + auto result3 = _hash_map.find(comp3, EntryRef()); + EXPECT_TRUE(result3 == nullptr); + MyCompare comp4(_store, 4); + auto result4 = _hash_map.find(comp4, EntryRef()); + EXPECT_TRUE(result4 != nullptr); + auto ref4 = result4->first.load_relaxed(); + auto& wrapped_entry4 = _allocator.get_wrapped(ref4); + EXPECT_EQ(4, wrapped_entry4.value()); +} + +TEST_F(DataStoreSimpleHashTest, single_threaded_reader_without_updates) +{ + _report_work = true; + write_work(10); + _stop_read = 0; + read_work(10); +} + +TEST_F(DataStoreSimpleHashTest, single_threaded_reader_during_updates) +{ + uint32_t cnt = 1000000; + _report_work = true; + _writer.execute(makeLambdaTask([this, cnt]() { write_work(cnt); })); + _readers.execute(makeLambdaTask([this]() { read_work(); })); +} + +TEST_F(DataStoreSimpleHashTest, multi_threaded_reader_during_updates) +{ + uint32_t cnt = 1000000; + _report_work = true; + _writer.execute(makeLambdaTask([this, cnt]() { write_work(cnt); })); + for (size_t i = 0; i < 4; ++i) { + _readers.execute(makeLambdaTask([this]() { read_work(); })); + } +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/datastore/unique_store_dictionary/unique_store_dictionary_test.cpp b/vespalib/src/tests/datastore/unique_store_dictionary/unique_store_dictionary_test.cpp index 6a9215c3eb9..71cb1864ce7 100644 --- a/vespalib/src/tests/datastore/unique_store_dictionary/unique_store_dictionary_test.cpp +++ b/vespalib/src/tests/datastore/unique_store_dictionary/unique_store_dictionary_test.cpp @@ -25,9 +25,15 @@ public: Comparator(uint32_t to_find) : _to_find(to_find) {} - bool operator()(const EntryRef lhs, const EntryRef rhs) const override { + bool less(const EntryRef lhs, const EntryRef rhs) const override { return resolve(lhs).ref() < resolve(rhs).ref(); } + bool equal(const EntryRef lhs, const EntryRef rhs) const override { + return resolve(lhs).ref() == resolve(rhs).ref(); + } + size_t hash(const EntryRef rhs) const override { + return rhs.ref(); + } }; struct DictionaryReadTest : public ::testing::Test { diff --git a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp index 4e4129feb78..5d70c95b1d9 100644 --- a/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp +++ b/vespalib/src/tests/util/bfloat16/bfloat16_test.cpp @@ -34,6 +34,18 @@ TEST(BFloat16Test, normal_usage) { EXPECT_EQ(float(b2), 0x110); } +TEST(BFloat16Test, has_range_of_int_8) { + for (int i = -128; i < 128; ++i) { + int8_t byte = i; + float flt = byte; + EXPECT_EQ(byte, i); + EXPECT_EQ(flt, i); + BFloat16 value = flt; + float recover = value; + EXPECT_EQ(recover, flt); + } +} + TEST(BFloat16Test, with_nbostream) { nbostream buf; for (BFloat16 value : simple_values) { diff --git a/vespalib/src/tests/util/file_area_freelist/CMakeLists.txt b/vespalib/src/tests/util/file_area_freelist/CMakeLists.txt new file mode 100644 index 00000000000..a0961adbba8 --- /dev/null +++ b/vespalib/src/tests/util/file_area_freelist/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_file_area_freelist_test_app TEST + SOURCES + file_area_freelist_test.cpp + DEPENDS + vespalib + GTest::GTest +) +vespa_add_test(NAME vespalib_file_area_freelist_test_app COMMAND vespalib_file_area_freelist_test_app) diff --git a/vespalib/src/tests/util/file_area_freelist/file_area_freelist_test.cpp b/vespalib/src/tests/util/file_area_freelist/file_area_freelist_test.cpp new file mode 100644 index 00000000000..aecad4e12cf --- /dev/null +++ b/vespalib/src/tests/util/file_area_freelist/file_area_freelist_test.cpp @@ -0,0 +1,73 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/util/file_area_freelist.h> +#include <vespa/vespalib/gtest/gtest.h> + +using vespalib::alloc::FileAreaFreeList; + +class FileAreaFreeListTest : public ::testing::Test +{ +protected: + FileAreaFreeList _freelist; + static constexpr auto bad_offset = FileAreaFreeList::bad_offset; + +public: + FileAreaFreeListTest(); + ~FileAreaFreeListTest(); +}; + +FileAreaFreeListTest::FileAreaFreeListTest() + : _freelist() +{ +} + +FileAreaFreeListTest::~FileAreaFreeListTest() = default; + + +TEST_F(FileAreaFreeListTest, empty_freelist_is_ok) +{ + EXPECT_EQ(bad_offset, _freelist.alloc(1)); +} + +TEST_F(FileAreaFreeListTest, can_reuse_free_area) +{ + _freelist.free(4, 1); + EXPECT_EQ(4, _freelist.alloc(1)); + EXPECT_EQ(bad_offset, _freelist.alloc(1)); +} + +TEST_F(FileAreaFreeListTest, merge_area_with_next_area) +{ + _freelist.free(5, 1); + _freelist.free(4, 1); + EXPECT_EQ(4, _freelist.alloc(2)); + EXPECT_EQ(bad_offset, _freelist.alloc(1)); +} + +TEST_F(FileAreaFreeListTest, merge_area_with_previous_area) +{ + _freelist.free(3, 1); + _freelist.free(4, 1); + EXPECT_EQ(3, _freelist.alloc(2)); + EXPECT_EQ(bad_offset, _freelist.alloc(1)); +} + +TEST_F(FileAreaFreeListTest, merge_area_with_previous_and_next_area) +{ + _freelist.free(5, 1); + _freelist.free(3, 1); + _freelist.free(4, 1); + EXPECT_EQ(3, _freelist.alloc(3)); + EXPECT_EQ(bad_offset, _freelist.alloc(1)); +} + +TEST_F(FileAreaFreeListTest, can_use_part_of_free_area) +{ + _freelist.free(4, 2); + EXPECT_EQ(4, _freelist.alloc(1)); + EXPECT_EQ(5, _freelist.alloc(1)); + EXPECT_EQ(bad_offset, _freelist.alloc(1)); +} + + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/util/mmap_file_allocator/.gitignore b/vespalib/src/tests/util/mmap_file_allocator/.gitignore new file mode 100644 index 00000000000..a18e9aac589 --- /dev/null +++ b/vespalib/src/tests/util/mmap_file_allocator/.gitignore @@ -0,0 +1 @@ +/mmap-file-allocator-factory-dir diff --git a/vespalib/src/tests/util/mmap_file_allocator/CMakeLists.txt b/vespalib/src/tests/util/mmap_file_allocator/CMakeLists.txt new file mode 100644 index 00000000000..00ce5f52fd7 --- /dev/null +++ b/vespalib/src/tests/util/mmap_file_allocator/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(vespalib_mmap_file_allocator_test_app TEST + SOURCES + mmap_file_allocator_test.cpp + DEPENDS + vespalib + GTest::GTest +) +vespa_add_test(NAME vespalib_mmap_file_allocator_test_app COMMAND vespalib_mmap_file_allocator_test_app) diff --git a/vespalib/src/tests/util/mmap_file_allocator/mmap_file_allocator_test.cpp b/vespalib/src/tests/util/mmap_file_allocator/mmap_file_allocator_test.cpp new file mode 100644 index 00000000000..0d6e7718b86 --- /dev/null +++ b/vespalib/src/tests/util/mmap_file_allocator/mmap_file_allocator_test.cpp @@ -0,0 +1,94 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/util/mmap_file_allocator.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <sys/mman.h> + +using vespalib::alloc::MemoryAllocator; +using vespalib::alloc::MmapFileAllocator; + +namespace { + +vespalib::string basedir("mmap-file-allocator-dir"); +vespalib::string hello("hello"); + +struct MyAlloc +{ + const MemoryAllocator& allocator; + void* data; + size_t size; + + MyAlloc(MemoryAllocator& allocator_in, MemoryAllocator::PtrAndSize buf) + : allocator(allocator_in), + data(buf.first), + size(buf.second) + { + } + + ~MyAlloc() + { + allocator.free(data, size); + } + + MemoryAllocator::PtrAndSize asPair() const noexcept { return std::make_pair(data, size); } +}; + +} + +class MmapFileAllocatorTest : public ::testing::Test +{ +protected: + MmapFileAllocator _allocator; + +public: + MmapFileAllocatorTest(); + ~MmapFileAllocatorTest(); +}; + +MmapFileAllocatorTest::MmapFileAllocatorTest() + : _allocator(basedir) +{ +} + +MmapFileAllocatorTest::~MmapFileAllocatorTest() = default; + +TEST_F(MmapFileAllocatorTest, zero_sized_allocation_is_handled) +{ + MyAlloc buf(_allocator, _allocator.alloc(0)); + EXPECT_EQ(nullptr, buf.data); + EXPECT_EQ(0u, buf.size); +} + +TEST_F(MmapFileAllocatorTest, mmap_file_allocator_works) +{ + MyAlloc buf(_allocator, _allocator.alloc(4)); + EXPECT_LE(4u, buf.size); + EXPECT_TRUE(buf.data != nullptr); + memcpy(buf.data, "1st", 4); + MyAlloc buf2(_allocator, _allocator.alloc(5)); + EXPECT_LE(5u, buf2.size); + EXPECT_TRUE(buf2.data != nullptr); + EXPECT_TRUE(buf.data != buf2.data); + memcpy(buf2.data, "fine", 5); + EXPECT_EQ(0u, _allocator.resize_inplace(buf.asPair(), 5)); + EXPECT_EQ(0u, _allocator.resize_inplace(buf.asPair(), 3)); + EXPECT_NE(0u, _allocator.get_end_offset()); + int result = msync(buf.data, buf.size, MS_SYNC); + EXPECT_EQ(0, result); + result = msync(buf2.data, buf2.size, MS_SYNC); + EXPECT_EQ(0, result); +} + +TEST_F(MmapFileAllocatorTest, reuse_file_offset_works) +{ + { + MyAlloc buf(_allocator, _allocator.alloc(hello.size() + 1)); + memcpy(buf.data, hello.c_str(), hello.size() + 1); + } + { + MyAlloc buf(_allocator, _allocator.alloc(hello.size() + 1)); + EXPECT_EQ(0, memcmp(buf.data, hello.c_str(), hello.size() + 1)); + } +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/vespa/vespalib/datastore/CMakeLists.txt b/vespalib/src/vespa/vespalib/datastore/CMakeLists.txt index 16c5d3973e8..1ee945f1b8f 100644 --- a/vespalib/src/vespa/vespalib/datastore/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/datastore/CMakeLists.txt @@ -8,6 +8,8 @@ vespa_add_library(vespalib_vespalib_datastore OBJECT datastore.cpp datastorebase.cpp entryref.cpp + fixed_size_hash_map.cpp + simple_hash_map.cpp unique_store.cpp unique_store_string_allocator.cpp DEPENDS diff --git a/vespalib/src/vespa/vespalib/datastore/atomic_entry_ref.h b/vespalib/src/vespa/vespalib/datastore/atomic_entry_ref.h index ed45723b6ef..3ec2d6b163e 100644 --- a/vespalib/src/vespa/vespalib/datastore/atomic_entry_ref.h +++ b/vespalib/src/vespa/vespalib/datastore/atomic_entry_ref.h @@ -37,6 +37,9 @@ public: EntryRef load_acquire() const noexcept { return EntryRef(_ref.load(std::memory_order_acquire)); } + EntryRef load_relaxed() const noexcept { + return EntryRef(_ref.load(std::memory_order_relaxed)); + } }; } diff --git a/vespalib/src/vespa/vespalib/datastore/entry_comparator.h b/vespalib/src/vespa/vespalib/datastore/entry_comparator.h index d0b5b307a9e..bd0eb318b18 100644 --- a/vespalib/src/vespa/vespalib/datastore/entry_comparator.h +++ b/vespalib/src/vespa/vespalib/datastore/entry_comparator.h @@ -19,7 +19,9 @@ public: /** * Returns true if the value represented by lhs ref is less than the value represented by rhs ref. */ - virtual bool operator()(const EntryRef lhs, const EntryRef rhs) const = 0; + virtual bool less(const EntryRef lhs, const EntryRef rhs) const = 0; + virtual bool equal(const EntryRef lhs, const EntryRef rhs) const = 0; + virtual size_t hash(const EntryRef rhs) const = 0; }; } diff --git a/vespalib/src/vespa/vespalib/datastore/entry_comparator_wrapper.h b/vespalib/src/vespa/vespalib/datastore/entry_comparator_wrapper.h index 199d074b453..2856103b3e1 100644 --- a/vespalib/src/vespa/vespalib/datastore/entry_comparator_wrapper.h +++ b/vespalib/src/vespa/vespalib/datastore/entry_comparator_wrapper.h @@ -16,7 +16,7 @@ public: : _comp(comp) { } bool operator()(const EntryRef &lhs, const EntryRef &rhs) const { - return _comp(lhs, rhs); + return _comp.less(lhs, rhs); } }; diff --git a/vespalib/src/vespa/vespalib/datastore/fixed_size_hash_map.cpp b/vespalib/src/vespa/vespalib/datastore/fixed_size_hash_map.cpp new file mode 100644 index 00000000000..d852cd40b78 --- /dev/null +++ b/vespalib/src/vespa/vespalib/datastore/fixed_size_hash_map.cpp @@ -0,0 +1,174 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "fixed_size_hash_map.h" +#include "entry_comparator.h" +#include <vespa/vespalib/util/array.hpp> +#include <cassert> +#include <stdexcept> + +namespace vespalib::datastore { + +FixedSizeHashMap::Node::Node(Node&&) +{ + throw std::runtime_error("vespalib::datastore::FixedSizeHashMap::Node move constructor should never be called"); +} + +void +FixedSizeHashMap::Node::on_free() +{ + _kv = std::make_pair(AtomicEntryRef(), AtomicEntryRef()); +} + +FixedSizeHashMap::FixedSizeHashMap(uint32_t modulo, uint32_t capacity, uint32_t num_stripes) + : _chain_heads(modulo), + _nodes(), + _modulo(modulo), + _count(0u), + _free_head(no_node_idx), + _free_count(0u), + _hold_count(0u), + _hold_1_list(), + _hold_2_list(), + _num_stripes(num_stripes) +{ + _nodes.reserve(capacity); +} + +FixedSizeHashMap::FixedSizeHashMap(uint32_t modulo, uint32_t capacity, uint32_t num_stripes, const FixedSizeHashMap &orig, const EntryComparator& comp) + : FixedSizeHashMap(modulo, capacity, num_stripes) +{ + for (const auto &chain_head : orig._chain_heads) { + for (uint32_t node_idx = chain_head.load_relaxed(); node_idx != no_node_idx;) { + auto& node = orig._nodes[node_idx]; + force_add(comp, node.get_kv()); + node_idx = node.get_next().load(std::memory_order_relaxed); + } + } +} + +FixedSizeHashMap::~FixedSizeHashMap() = default; + +void +FixedSizeHashMap::force_add(const EntryComparator& comp, const KvType& kv) +{ + size_t hash_idx = comp.hash(kv.first.load_relaxed()) / _num_stripes; + hash_idx %= _modulo; + auto& chain_head = _chain_heads[hash_idx]; + assert(_nodes.size() < _nodes.capacity()); + uint32_t node_idx = _nodes.size(); + new (_nodes.push_back_fast()) Node(kv, chain_head.load_relaxed()); + chain_head.set(node_idx); + ++_count; +} + +FixedSizeHashMap::KvType& +FixedSizeHashMap::add(const EntryComparator& comp, std::function<EntryRef(void)>& insert_entry) +{ + size_t hash_idx = comp.hash(EntryRef()) / _num_stripes; + hash_idx %= _modulo; + auto& chain_head = _chain_heads[hash_idx]; + uint32_t node_idx = chain_head.load_relaxed(); + while (node_idx != no_node_idx) { + auto& node = _nodes[node_idx]; + if (comp.equal(EntryRef(), node.get_kv().first.load_relaxed())) { + return node.get_kv(); + } + node_idx = node.get_next().load(std::memory_order_relaxed); + } + if (_free_head != no_node_idx) { + node_idx = _free_head; + auto& node = _nodes[node_idx]; + _free_head = node.get_next().load(std::memory_order_relaxed); + --_free_count; + node.get_kv().first.store_release(insert_entry()); + node.get_next().store(chain_head.load_relaxed()); + chain_head.set(node_idx); + ++_count; + return node.get_kv(); + } + assert(_nodes.size() < _nodes.capacity()); + node_idx = _nodes.size(); + new (_nodes.push_back_fast()) Node(std::make_pair(AtomicEntryRef(insert_entry()), AtomicEntryRef()), chain_head.load_relaxed()); + chain_head.set(node_idx); + ++_count; + return _nodes[node_idx].get_kv(); +} + +void +FixedSizeHashMap::transfer_hold_lists_slow(generation_t generation) +{ + auto &hold_2_list = _hold_2_list; + for (uint32_t node_idx : _hold_1_list) { + hold_2_list.push_back(std::make_pair(generation, node_idx)); + } + _hold_1_list.clear(); + +} + + +void +FixedSizeHashMap::trim_hold_lists_slow(generation_t usedGen) +{ + while (!_hold_2_list.empty()) { + auto& first = _hold_2_list.front(); + if (static_cast<sgeneration_t>(first.first - usedGen) >= 0) { + break; + } + uint32_t node_idx = first.second; + auto& node = _nodes[node_idx]; + node.get_next().store(_free_head, std::memory_order_relaxed); + _free_head = node_idx; + ++_free_count; + --_hold_count; + node.on_free(); + _hold_2_list.erase(_hold_2_list.begin()); + } +} + +FixedSizeHashMap::KvType* +FixedSizeHashMap::remove(const EntryComparator& comp, EntryRef key_ref) +{ + size_t hash_idx = comp.hash(key_ref) / _num_stripes; + hash_idx %= _modulo; + auto& chain_head = _chain_heads[hash_idx]; + uint32_t node_idx = chain_head.load_relaxed(); + uint32_t prev_node_idx = no_node_idx; + while (node_idx != no_node_idx) { + auto &node = _nodes[node_idx]; + uint32_t next = node.get_next().load(std::memory_order_relaxed); + if (comp.equal(key_ref, node.get_kv().first.load_relaxed())) { + if (prev_node_idx != no_node_idx) { + _nodes[prev_node_idx].get_next().store(next, std::memory_order_release); + } else { + chain_head.set(next); + } + --_count; + ++_hold_count; + _hold_1_list.push_back(node_idx); + return &_nodes[node_idx].get_kv(); + } + node_idx = next; + } + return nullptr; +} + +const FixedSizeHashMap::KvType* +FixedSizeHashMap::find(const EntryComparator& comp, EntryRef key_ref) const +{ + size_t hash_idx = comp.hash(key_ref) / _num_stripes; + hash_idx %= _modulo; + auto& chain_head = _chain_heads[hash_idx]; + uint32_t node_idx = chain_head.load_acquire(); + while (node_idx != no_node_idx) { + auto &node = _nodes[node_idx]; + EntryRef node_key_ref = node.get_kv().first.load_acquire(); + if (node_key_ref.valid() && comp.equal(key_ref, node_key_ref)) { + return &_nodes[node_idx].get_kv(); + } + uint32_t next = node.get_next().load(std::memory_order_acquire); + node_idx = next; + } + return nullptr; +} + +} diff --git a/vespalib/src/vespa/vespalib/datastore/fixed_size_hash_map.h b/vespalib/src/vespa/vespalib/datastore/fixed_size_hash_map.h new file mode 100644 index 00000000000..bafcf642a8d --- /dev/null +++ b/vespalib/src/vespa/vespalib/datastore/fixed_size_hash_map.h @@ -0,0 +1,119 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "atomic_entry_ref.h" +#include <vespa/vespalib/util/array.h> +#include <vespa/vespalib/util/arrayref.h> +#include <vespa/vespalib/util/generationhandler.h> +#include <limits> +#include <atomic> +#include <deque> +#include <functional> + +namespace vespalib { class GenerationHolder; } +namespace vespalib::datastore { + +class EntryComparator; + +/* + * Fixed sized hash map over keys in data store, meant to support a faster + * dictionary for unique store with relation to lookups. + * + * Currently hardcoded key and data types, where key references an entry + * in a UniqueStore and value references a posting list + * (cf. search::attribute::PostingStore). + * + * This structure supports one writer and many readers. + * + * A reader must own an appropriate GenerationHandler::Guard to ensure + * that memory is held while it can be accessed by reader. + * + * The writer must update generation and call transfer_hold_lists and + * trim_hold_lists as needed to free up memory no longer needed by any + * readers. + */ +class FixedSizeHashMap { +public: + static constexpr uint32_t no_node_idx = std::numeric_limits<uint32_t>::max(); + using KvType = std::pair<AtomicEntryRef, AtomicEntryRef>; + using generation_t = GenerationHandler::generation_t; + using sgeneration_t = GenerationHandler::sgeneration_t; + +private: + class ChainHead { + std::atomic<uint32_t> _node_idx; + + public: + ChainHead() + : _node_idx(no_node_idx) + { + } + // Writer thread + uint32_t load_relaxed() const noexcept { return _node_idx.load(std::memory_order_relaxed); } + void set(uint32_t node_idx) { _node_idx.store(node_idx, std::memory_order_release); } + + // Reader thread + uint32_t load_acquire() const noexcept { return _node_idx.load(std::memory_order_acquire); } + }; + class Node { + KvType _kv; + std::atomic<uint32_t> _next; + public: + Node() + : Node(std::make_pair(AtomicEntryRef(), AtomicEntryRef()), no_node_idx) + { + } + Node(KvType kv, uint32_t next) + : _kv(kv), + _next(next) + { + } + Node(Node &&rhs); // Must be defined, but must never be used. + void on_free(); + std::atomic<uint32_t>& get_next() noexcept { return _next; } + const std::atomic<uint32_t>& get_next() const noexcept { return _next; } + KvType& get_kv() noexcept { return _kv; } + const KvType& get_kv() const noexcept { return _kv; } + }; + + Array<ChainHead> _chain_heads; + Array<Node> _nodes; + uint32_t _modulo; + uint32_t _count; + uint32_t _free_head; + uint32_t _free_count; + uint32_t _hold_count; + Array<uint32_t> _hold_1_list; + std::deque<std::pair<generation_t, uint32_t>> _hold_2_list; + uint32_t _num_stripes; + + void transfer_hold_lists_slow(generation_t generation); + void trim_hold_lists_slow(generation_t usedGen); + void force_add(const EntryComparator& comp, const KvType& kv); +public: + FixedSizeHashMap(uint32_t module, uint32_t capacity, uint32_t num_stripes); + FixedSizeHashMap(uint32_t module, uint32_t capacity, uint32_t num_stripes, const FixedSizeHashMap &orig, const EntryComparator& comp); + ~FixedSizeHashMap(); + + KvType& add(const EntryComparator& comp, std::function<EntryRef(void)>& insert_entry); + KvType* remove(const EntryComparator& comp, EntryRef key_ref); + const KvType* find(const EntryComparator& comp, EntryRef key_ref) const; + + void transfer_hold_lists(generation_t generation) { + if (!_hold_1_list.empty()) { + transfer_hold_lists_slow(generation); + } + } + + void trim_hold_lists(generation_t usedGen) { + if (!_hold_2_list.empty() && static_cast<sgeneration_t>(_hold_2_list.front().first - usedGen) < 0) { + trim_hold_lists_slow(usedGen); + } + } + + bool full() const noexcept { return _nodes.size() == _nodes.capacity() && _free_count == 0u; } + size_t size() const noexcept { return _count; } +}; + +} diff --git a/vespalib/src/vespa/vespalib/datastore/simple_hash_map.cpp b/vespalib/src/vespa/vespalib/datastore/simple_hash_map.cpp new file mode 100644 index 00000000000..90e1bc60e06 --- /dev/null +++ b/vespalib/src/vespa/vespalib/datastore/simple_hash_map.cpp @@ -0,0 +1,139 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "simple_hash_map.h" +#include "fixed_size_hash_map.h" +#include "entry_comparator.h" + +namespace vespalib::datastore { + +class SimpleHashMapStripeHeld : public GenerationHeldBase +{ + std::unique_ptr<const FixedSizeHashMap> _data; +public: + SimpleHashMapStripeHeld(size_t size, std::unique_ptr<const FixedSizeHashMap> data); + ~SimpleHashMapStripeHeld(); +}; + +SimpleHashMapStripeHeld::SimpleHashMapStripeHeld(size_t size, std::unique_ptr<const FixedSizeHashMap> data) + : GenerationHeldBase(size), + _data(std::move(data)) +{ +} + +SimpleHashMapStripeHeld::~SimpleHashMapStripeHeld() = default; + +SimpleHashMap::SimpleHashMap(std::unique_ptr<const EntryComparator> comp) + : _gen_holder(), + _maps(), + _comp(std::move(comp)) +{ +} + +SimpleHashMap::~SimpleHashMap() +{ + for (size_t i = 0; i < num_stripes; ++i) { + auto map = _maps[i].load(std::memory_order_relaxed); + delete map; + } +} + +size_t +SimpleHashMap::get_stripe(const EntryComparator& comp, EntryRef key_ref) const +{ + return comp.hash(key_ref) % num_stripes; +} + +void +SimpleHashMap::alloc_stripe(size_t stripe) +{ + auto map = _maps[stripe].load(std::memory_order_relaxed); + if (map == nullptr) { + auto umap = std::make_unique<FixedSizeHashMap>(2u, 3u, num_stripes); + _maps[stripe].store(umap.release(), std::memory_order_release); + } else { + auto umap = std::make_unique<FixedSizeHashMap>(map->size() * 2 + 2, map->size() * 3 + 3, num_stripes, *map, *_comp); + _maps[stripe].store(umap.release(), std::memory_order_release); + hold_stripe(std::unique_ptr<const FixedSizeHashMap>(map)); + } +} + +void +SimpleHashMap::hold_stripe(std::unique_ptr<const FixedSizeHashMap> map) +{ + // TODO: Provider proper held size + auto hold = std::make_unique<SimpleHashMapStripeHeld>(0, std::move(map)); + _gen_holder.hold(std::move(hold)); +} + +SimpleHashMap::KvType& +SimpleHashMap::add(const EntryComparator& comp, std::function<EntryRef(void)>& insert_entry) +{ + size_t stripe = get_stripe(comp, EntryRef()); + auto map = _maps[stripe].load(std::memory_order_relaxed); + if (map == nullptr || map->full()) { + alloc_stripe(stripe); + map = _maps[stripe].load(std::memory_order_relaxed); + } + return map->add(comp, insert_entry); +} + +SimpleHashMap::KvType* +SimpleHashMap::remove(const EntryComparator& comp, EntryRef key_ref) +{ + size_t stripe = get_stripe(comp, key_ref); + auto map = _maps[stripe].load(std::memory_order_relaxed); + if (map == nullptr) { + return nullptr; + } + return map->remove(comp, key_ref); +} + +const SimpleHashMap::KvType* +SimpleHashMap::find(const EntryComparator& comp, EntryRef key_ref) const +{ + size_t stripe = get_stripe(comp, key_ref); + auto map = _maps[stripe].load(std::memory_order_relaxed); + if (map == nullptr) { + return nullptr; + } + return map->find(comp, key_ref); +} + +void +SimpleHashMap::transfer_hold_lists(generation_t generation) +{ + for (size_t i = 0; i < num_stripes; ++i) { + auto map = _maps[i].load(std::memory_order_relaxed); + if (map != nullptr) { + map->transfer_hold_lists(generation); + } + } + _gen_holder.transferHoldLists(generation); +} + +void +SimpleHashMap::trim_hold_lists(generation_t used_gen) +{ + for (size_t i = 0; i < num_stripes; ++i) { + auto map = _maps[i].load(std::memory_order_relaxed); + if (map != nullptr) { + map->trim_hold_lists(used_gen); + } + } + _gen_holder.trimHoldLists(used_gen); +} + +size_t +SimpleHashMap::size() const noexcept +{ + size_t result = 0; + for (size_t i = 0; i < num_stripes; ++i) { + auto map = _maps[i].load(std::memory_order_relaxed); + if (map != nullptr) { + result += map->size(); + } + } + return result; +} + +} diff --git a/vespalib/src/vespa/vespalib/datastore/simple_hash_map.h b/vespalib/src/vespa/vespalib/datastore/simple_hash_map.h new file mode 100644 index 00000000000..506c1a3ea3f --- /dev/null +++ b/vespalib/src/vespa/vespalib/datastore/simple_hash_map.h @@ -0,0 +1,57 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "atomic_entry_ref.h" +#include <atomic> +#include <vespa/vespalib/util/generationholder.h> +#include <functional> + +namespace vespalib::datastore { + +class FixedSizeHashMap; +class EntryComparator; + +/* + * Hash map over keys in data store, meant to support a faster + * dictionary for unique store with relation to lookups. + * + * Currently hardcoded key and data types, where key references an entry + * in a UniqueStore and value references a posting list + * (cf. search::attribute::PostingStore). + * + * This structure supports one writer and many readers. + * + * A reader must own an appropriate GenerationHandler::Guard to ensure + * that memory is held while it can be accessed by reader. + * + * The writer must update generation and call transfer_hold_lists and + * trim_hold_lists as needed to free up memory no longer needed by any + * readers. + */ +class SimpleHashMap { +public: + using KvType = std::pair<AtomicEntryRef, AtomicEntryRef>; + using generation_t = GenerationHandler::generation_t; + using sgeneration_t = GenerationHandler::sgeneration_t; +private: + GenerationHolder _gen_holder; + static constexpr size_t num_stripes = 1; + std::atomic<FixedSizeHashMap *> _maps[num_stripes]; + std::unique_ptr<const EntryComparator> _comp; + + size_t get_stripe(const EntryComparator& comp, EntryRef key_ref) const; + void alloc_stripe(size_t stripe); + void hold_stripe(std::unique_ptr<const FixedSizeHashMap> map); +public: + SimpleHashMap(std::unique_ptr<const EntryComparator> comp); + ~SimpleHashMap(); + KvType& add(const EntryComparator& comp, std::function<EntryRef(void)> &insert_entry); + KvType* remove(const EntryComparator& comp, EntryRef key_ref); + const KvType* find(const EntryComparator& comp, EntryRef key_ref) const; + void transfer_hold_lists(generation_t generation); + void trim_hold_lists(generation_t used_gen); + size_t size() const noexcept; +}; + +} diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_comparator.h b/vespalib/src/vespa/vespalib/datastore/unique_store_comparator.h index 3226c4563cc..5280b8712e5 100644 --- a/vespalib/src/vespa/vespalib/datastore/unique_store_comparator.h +++ b/vespalib/src/vespa/vespalib/datastore/unique_store_comparator.h @@ -5,6 +5,7 @@ #include "entry_comparator.h" #include "unique_store_entry.h" #include "datastore.h" +#include <vespa/vespalib/stllike/hash_fun.h> #include <cmath> namespace vespalib::datastore { @@ -18,6 +19,13 @@ public: static bool less(const EntryT& lhs, const EntryT& rhs) { return lhs < rhs; } + static bool equal(const EntryT& lhs, const EntryT& rhs) { + return lhs == rhs; + } + static size_t hash(const EntryT& rhs) { + vespalib::hash<EntryT> hasher; + return hasher(rhs); + } }; /** @@ -37,6 +45,25 @@ public: return (lhs < rhs); } } + static bool equal(EntryT lhs, const EntryT rhs) { + if (std::isnan(lhs)) { + return std::isnan(rhs); + } else if (std::isnan(rhs)) { + return false; + } else { + return (lhs == rhs); + } + } + static size_t hash(EntryT rhs) { + if (std::isnan(rhs)) { + return 0; + } else { + union U { EntryT f; std::conditional_t<std::is_same_v<double, EntryT>, uint64_t, uint32_t> i; }; + U t; + t.f = rhs; + return t.i; + } + } }; /** @@ -93,11 +120,20 @@ public: { } - bool operator()(const EntryRef lhs, const EntryRef rhs) const override { + bool less(const EntryRef lhs, const EntryRef rhs) const override { const EntryType &lhsValue = get(lhs); const EntryType &rhsValue = get(rhs); return UniqueStoreComparatorHelper<EntryT>::less(lhsValue, rhsValue); } + bool equal(const EntryRef lhs, const EntryRef rhs) const override { + const EntryType &lhsValue = get(lhs); + const EntryType &rhsValue = get(rhs); + return UniqueStoreComparatorHelper<EntryT>::equal(lhsValue, rhsValue); + } + size_t hash(const EntryRef rhs) const override { + const EntryType &rhsValue = get(rhs); + return UniqueStoreComparatorHelper<EntryT>::hash(rhsValue); + } }; } diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_dictionary.hpp b/vespalib/src/vespa/vespalib/datastore/unique_store_dictionary.hpp index 8ecf71d08c7..6fcf15e69c1 100644 --- a/vespalib/src/vespa/vespalib/datastore/unique_store_dictionary.hpp +++ b/vespalib/src/vespa/vespalib/datastore/unique_store_dictionary.hpp @@ -29,7 +29,7 @@ UniqueStoreDictionary<DictionaryT, ParentT>:: ReadSnapshotImpl::count(const EntryComparator& comp) const { auto itr = _frozen_view.lowerBound(EntryRef(), comp); - if (itr.valid() && !comp(EntryRef(), itr.getKey())) { + if (itr.valid() && !comp.less(EntryRef(), itr.getKey())) { return 1u; } return 0u; @@ -43,7 +43,7 @@ ReadSnapshotImpl::count_in_range(const EntryComparator& low, { auto low_itr = _frozen_view.lowerBound(EntryRef(), low); auto high_itr = low_itr; - if (high_itr.valid() && !high(EntryRef(), high_itr.getKey())) { + if (high_itr.valid() && !high.less(EntryRef(), high_itr.getKey())) { high_itr.seekPast(EntryRef(), high); } return high_itr - low_itr; @@ -94,7 +94,7 @@ UniqueStoreDictionary<DictionaryT, ParentT>::add(const EntryComparator &comp, std::function<EntryRef(void)> insertEntry) { auto itr = _dict.lowerBound(EntryRef(), comp); - if (itr.valid() && !comp(EntryRef(), itr.getKey())) { + if (itr.valid() && !comp.less(EntryRef(), itr.getKey())) { return UniqueStoreAddResult(itr.getKey(), false); } else { @@ -109,7 +109,7 @@ EntryRef UniqueStoreDictionary<DictionaryT, ParentT>::find(const EntryComparator &comp) { auto itr = _dict.lowerBound(EntryRef(), comp); - if (itr.valid() && !comp(EntryRef(), itr.getKey())) { + if (itr.valid() && !comp.less(EntryRef(), itr.getKey())) { return itr.getKey(); } else { return EntryRef(); diff --git a/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h b/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h index 140e38dbef1..9acacc0073f 100644 --- a/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h +++ b/vespalib/src/vespa/vespalib/datastore/unique_store_string_comparator.h @@ -44,11 +44,21 @@ public: { } - bool operator()(const EntryRef lhs, const EntryRef rhs) const override { + bool less(const EntryRef lhs, const EntryRef rhs) const override { const char *lhs_value = get(lhs); const char *rhs_value = get(rhs); return (strcmp(lhs_value, rhs_value) < 0); } + bool equal(const EntryRef lhs, const EntryRef rhs) const override { + const char *lhs_value = get(lhs); + const char *rhs_value = get(rhs); + return (strcmp(lhs_value, rhs_value) == 0); + } + size_t hash(const EntryRef rhs) const override { + const char *rhs_value = get(rhs); + vespalib::hash<const char *> hasher; + return hasher(rhs_value); + } }; } diff --git a/vespalib/src/vespa/vespalib/util/CMakeLists.txt b/vespalib/src/vespa/vespalib/util/CMakeLists.txt index 62d642b76b2..934f3d00d6f 100644 --- a/vespalib/src/vespa/vespalib/util/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/util/CMakeLists.txt @@ -22,6 +22,7 @@ vespa_add_library(vespalib_vespalib_util OBJECT error.cpp exception.cpp exceptions.cpp + file_area_freelist.cpp gencnt.cpp generationhandler.cpp generationholder.cpp diff --git a/vespalib/src/vespa/vespalib/util/file_area_freelist.cpp b/vespalib/src/vespa/vespalib/util/file_area_freelist.cpp new file mode 100644 index 00000000000..5edebcac1ad --- /dev/null +++ b/vespalib/src/vespa/vespalib/util/file_area_freelist.cpp @@ -0,0 +1,104 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "file_area_freelist.h" +#include <cassert> + +namespace vespalib::alloc { + +FileAreaFreeList::FileAreaFreeList() + : _free_areas(), + _free_sizes() +{ +} + +FileAreaFreeList::~FileAreaFreeList() = default; + +void +FileAreaFreeList::remove_from_size_set(uint64_t offset, size_t size) +{ + auto itr = _free_sizes.find(size); + assert(itr != _free_sizes.end()); + auto &offsets = itr->second; + auto erased_count = offsets.erase(offset); + assert(erased_count != 0u); + if (offsets.empty()) { + _free_sizes.erase(itr); + } +} + +std::pair<uint64_t, size_t> +FileAreaFreeList::prepare_reuse_area(size_t size) +{ + auto itr = _free_sizes.lower_bound(size); + if (itr == _free_sizes.end()) { + return std::make_pair(bad_offset, 0); // No free areas of sufficient size + } + auto old_size = itr->first; + assert(old_size >= size); + auto &offsets = itr->second; + assert(!offsets.empty()); + auto oitr = offsets.begin(); + auto offset = *oitr; + offsets.erase(oitr); + if (offsets.empty()) { + _free_sizes.erase(itr); + } + // Note: Caller must update _free_areas + return std::make_pair(offset, old_size); +} + +uint64_t +FileAreaFreeList::alloc(size_t size) +{ + auto reuse_candidate = prepare_reuse_area(size); + auto offset = reuse_candidate.first; + if (offset == bad_offset) { + return bad_offset; // No free areas of sufficient size + } + auto fa_itr = _free_areas.find(offset); + assert(fa_itr != _free_areas.end()); + fa_itr = _free_areas.erase(fa_itr); + auto old_size = reuse_candidate.second; + if (old_size > size) { + // Old area beyond what we reuse should still be a free area. + auto ins_res = _free_sizes[old_size - size].insert(offset + size); + assert(ins_res.second); + _free_areas.emplace_hint(fa_itr, offset + size, old_size - size); + } + return offset; +} + +void +FileAreaFreeList::free(uint64_t offset, size_t size) +{ + auto itr = _free_areas.lower_bound(offset); + if (itr != _free_areas.end() && itr->first <= offset + size) { + // Merge with next free area + assert(itr->first == offset + size); + remove_from_size_set(itr->first, itr->second); + size += itr->second; + itr = _free_areas.erase(itr); + } + bool adjusted_prev_area = false; + if (itr != _free_areas.begin()) { + --itr; + if (itr->first + itr->second >= offset) { + // Merge with previous free area + assert(itr->first + itr->second == offset); + remove_from_size_set(itr->first, itr->second); + offset = itr->first; + size += itr->second; + itr->second = size; + adjusted_prev_area = true; + } else { + ++itr; + } + } + if (!adjusted_prev_area) { + _free_areas.emplace_hint(itr, offset, size); + } + auto ins_res = _free_sizes[size].insert(offset); + assert(ins_res.second); +} + +} diff --git a/vespalib/src/vespa/vespalib/util/file_area_freelist.h b/vespalib/src/vespa/vespalib/util/file_area_freelist.h new file mode 100644 index 00000000000..f19dfdd7de3 --- /dev/null +++ b/vespalib/src/vespa/vespalib/util/file_area_freelist.h @@ -0,0 +1,29 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <cstddef> +#include <cstdint> +#include <limits> +#include <map> +#include <set> + +namespace vespalib::alloc { + +/* + * Class that tracks free areas in a file. + */ +class FileAreaFreeList { + std::map<uint64_t, size_t> _free_areas; // map from offset to size + std::map<size_t, std::set<uint64_t>> _free_sizes; // map from size to set of offsets + void remove_from_size_set(uint64_t offset, size_t size); + std::pair<uint64_t, size_t> prepare_reuse_area(size_t size); +public: + FileAreaFreeList(); + ~FileAreaFreeList(); + uint64_t alloc(size_t size); + void free(uint64_t offset, size_t size); + static constexpr uint64_t bad_offset = std::numeric_limits<uint64_t>::max(); +}; + +} diff --git a/vespalib/src/vespa/vespalib/util/mmap_file_allocator.cpp b/vespalib/src/vespa/vespalib/util/mmap_file_allocator.cpp index 469a5c366b3..ff11caf31e9 100644 --- a/vespalib/src/vespa/vespalib/util/mmap_file_allocator.cpp +++ b/vespalib/src/vespa/vespalib/util/mmap_file_allocator.cpp @@ -12,7 +12,9 @@ namespace vespalib::alloc { MmapFileAllocator::MmapFileAllocator(const vespalib::string& dir_name) : _dir_name(dir_name), _file(_dir_name + "/swapfile"), - _end_offset(0) + _end_offset(0), + _allocations(), + _freelist() { mkdir(_dir_name, true); _file.open(O_RDWR | O_CREAT | O_TRUNC, false); @@ -26,16 +28,27 @@ MmapFileAllocator::~MmapFileAllocator() rmdir(_dir_name, true); } +uint64_t +MmapFileAllocator::alloc_area(size_t sz) const +{ + uint64_t offset = _freelist.alloc(sz); + if (offset != FileAreaFreeList::bad_offset) { + return offset; + } + offset = _end_offset; + _end_offset += sz; + _file.resize(_end_offset); + return offset; +} + MmapFileAllocator::PtrAndSize MmapFileAllocator::alloc(size_t sz) const { if (sz == 0) { return PtrAndSize(nullptr, 0); // empty allocation } - uint64_t offset = _end_offset; sz = round_up_to_page_size(sz); - _end_offset += sz; - _file.resize(_end_offset); + uint64_t offset = alloc_area(sz); void *buf = mmap(nullptr, sz, PROT_READ | PROT_WRITE, MAP_SHARED, @@ -44,7 +57,7 @@ MmapFileAllocator::alloc(size_t sz) const assert(buf != MAP_FAILED); assert(buf != nullptr); // Register allocation - auto ins_res = _allocations.insert(std::make_pair(buf, sz)); + auto ins_res = _allocations.insert(std::make_pair(buf, SizeAndOffset(sz, offset))); assert(ins_res.second); int retval = madvise(buf, sz, MADV_RANDOM); assert(retval == 0); @@ -64,12 +77,14 @@ MmapFileAllocator::free(PtrAndSize alloc) const auto itr = _allocations.find(alloc.first); assert(itr != _allocations.end()); assert(itr->first == alloc.first); - assert(itr->second == alloc.second); + assert(itr->second.size == alloc.second); + auto offset = itr->second.offset; _allocations.erase(itr); int retval = madvise(alloc.first, alloc.second, MADV_DONTNEED); assert(retval == 0); retval = munmap(alloc.first, alloc.second); assert(retval == 0); + _freelist.free(offset, alloc.second); } size_t diff --git a/vespalib/src/vespa/vespalib/util/mmap_file_allocator.h b/vespalib/src/vespa/vespalib/util/mmap_file_allocator.h index 308513153c1..0d459bc2134 100644 --- a/vespalib/src/vespa/vespalib/util/mmap_file_allocator.h +++ b/vespalib/src/vespa/vespalib/util/mmap_file_allocator.h @@ -3,6 +3,7 @@ #pragma once #include "memory_allocator.h" +#include "file_area_freelist.h" #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/stllike/hash_map.h> #include <vespa/vespalib/stllike/string.h> @@ -16,10 +17,25 @@ namespace vespalib::alloc { * have been freed. */ class MmapFileAllocator : public MemoryAllocator { + struct SizeAndOffset { + size_t size; + uint64_t offset; + SizeAndOffset() + : SizeAndOffset(0u, 0u) + { + } + SizeAndOffset(size_t size_in, uint64_t offset_in) + : size(size_in), + offset(offset_in) + { + } + }; vespalib::string _dir_name; mutable File _file; mutable uint64_t _end_offset; - mutable hash_map<void *, size_t> _allocations; + mutable hash_map<void *, SizeAndOffset> _allocations; + mutable FileAreaFreeList _freelist; + uint64_t alloc_area(size_t sz) const; public: MmapFileAllocator(const vespalib::string& dir_name); ~MmapFileAllocator(); diff --git a/vsm/src/tests/searcher/searcher_test.cpp b/vsm/src/tests/searcher/searcher_test.cpp index 87bbefa6c06..2a57f01be83 100644 --- a/vsm/src/tests/searcher/searcher_test.cpp +++ b/vsm/src/tests/searcher/searcher_test.cpp @@ -21,6 +21,7 @@ using search::streaming::HitList; using search::streaming::QueryNodeResultFactory; using search::streaming::QueryTerm; using search::streaming::QueryTermList; +using TermType = QueryTerm::Type; using namespace vsm; template <typename T> @@ -58,17 +59,17 @@ private: for (size_t i = 0; i < terms.size(); ++i) { ParsedQueryTerm pqt = parseQueryTerm(terms[i]); ParsedTerm pt = parseTerm(pqt.second); - qtv.push_back(QueryTerm(eqnr.create(), pt.first, pqt.first.empty() ? "index" : pqt.first, pt.second)); + qtv.push_back(std::make_unique<QueryTerm>(eqnr.create(), pt.first, pqt.first.empty() ? "index" : pqt.first, pt.second)); } for (size_t i = 0; i < qtv.size(); ++i) { - qtl.push_back(&qtv[i]); + qtl.push_back(qtv[i].get()); } } public: typedef std::pair<std::string, std::string> ParsedQueryTerm; - typedef std::pair<std::string, QueryTerm::SearchTerm> ParsedTerm; + typedef std::pair<std::string, TermType> ParsedTerm; QueryNodeResultFactory eqnr; - std::vector<QueryTerm> qtv; + std::vector<QueryTerm::UP> qtv; QueryTermList qtl; Query(const StringList & terms); ~Query(); @@ -81,13 +82,13 @@ public: } static ParsedTerm parseTerm(const std::string & term) { if (term[0] == '*' && term[term.size() - 1] == '*') { - return std::make_pair(term.substr(1, term.size() - 2), QueryTerm::SUBSTRINGTERM); + return std::make_pair(term.substr(1, term.size() - 2), TermType::SUBSTRINGTERM); } else if (term[0] == '*') { - return std::make_pair(term.substr(1, term.size() - 1), QueryTerm::SUFFIXTERM); + return std::make_pair(term.substr(1, term.size() - 1), TermType::SUFFIXTERM); } else if (term[term.size() - 1] == '*') { - return std::make_pair(term.substr(0, term.size() - 1), QueryTerm::PREFIXTERM); + return std::make_pair(term.substr(0, term.size() - 1), TermType::PREFIXTERM); } else { - return std::make_pair(term, QueryTerm::WORD); + return std::make_pair(term, TermType::WORD); } } }; @@ -95,7 +96,7 @@ public: Query::Query(const StringList & terms) : eqnr(), qtv(), qtl() { setupQuery(terms); } -Query::~Query() {} +Query::~Query() = default; struct SnippetModifierSetup { @@ -127,7 +128,7 @@ void assertSnippetModifier(const StringList &query, const std::string &fv, const void assertSnippetModifier(SnippetModifierSetup &setup, const FieldValue &fv, const std::string &exp); void assertQueryTerms(const SnippetModifierManager &man, FieldIdT fId, const StringList &terms); void assertNumeric(FieldSearcher &fs, const StringList &query, const FieldValue &fv, const BoolList &exp); -std::vector<QueryTerm> performSearch(FieldSearcher &fs, const StringList &query, const FieldValue &fv); +std::vector<QueryTerm::UP> performSearch(FieldSearcher &fs, const StringList &query, const FieldValue &fv); void assertSearch(FieldSearcher &fs, const StringList &query, const FieldValue &fv, const HitsList &exp); bool assertCountWords(size_t numWords, const std::string &field); bool assertFieldInfo(FieldSearcher &fs, const StringList &query, const FieldValue &fv, const FieldInfoList &exp); @@ -284,8 +285,8 @@ bool assertMatchTermSuffix(const std::string & term, const std::string & word) { QueryNodeResultFactory eqnr; - QueryTerm qa(eqnr.create(), term, "index", QueryTerm::WORD); - QueryTerm qb(eqnr.create(), word, "index", QueryTerm::WORD); + QueryTerm qa(eqnr.create(), term, "index", TermType::WORD); + QueryTerm qb(eqnr.create(), word, "index", TermType::WORD); const ucs4_t * a; size_t alen = qa.term(a); const ucs4_t * b; @@ -303,7 +304,7 @@ assertNumeric(FieldSearcher & fs, const StringList & query, const FieldValue & f assertSearch(fs, query, fv, hl); } -std::vector<QueryTerm> +std::vector<QueryTerm::UP> performSearch(FieldSearcher & fs, const StringList & query, const FieldValue & fv) { Query q(query); @@ -319,17 +320,17 @@ performSearch(FieldSearcher & fs, const StringList & query, const FieldValue & f doc.setField(0, document::FieldValue::UP(fv.clone())); fs.search(doc); - return q.qtv; + return std::move(q.qtv); } void assertSearch(FieldSearcher & fs, const StringList & query, const FieldValue & fv, const HitsList & exp) { - std::vector<QueryTerm> qtv = performSearch(fs, query, fv); + auto qtv = performSearch(fs, query, fv); EXPECT_EQUAL(qtv.size(), exp.size()); ASSERT_TRUE(qtv.size() == exp.size()); for (size_t i = 0; i < qtv.size(); ++i) { - const HitList & hl = qtv[i].getHitList(); + const HitList & hl = qtv[i]->getHitList(); EXPECT_EQUAL(hl.size(), exp[i].size()); ASSERT_TRUE(hl.size() == exp[i].size()); for (size_t j = 0; j < hl.size(); ++j) { @@ -342,13 +343,13 @@ bool assertFieldInfo(FieldSearcher & fs, const StringList & query, const FieldValue & fv, const FieldInfoList & exp) { - std::vector<QueryTerm> qtv = performSearch(fs, query, fv); + auto qtv = performSearch(fs, query, fv); if (!EXPECT_EQUAL(qtv.size(), exp.size())) return false; bool retval = true; for (size_t i = 0; i < qtv.size(); ++i) { - if (!EXPECT_EQUAL(qtv[i].getFieldInfo(0).getHitOffset(), exp[i].getHitOffset())) retval = false; - if (!EXPECT_EQUAL(qtv[i].getFieldInfo(0).getHitCount(), exp[i].getHitCount())) retval = false; - if (!EXPECT_EQUAL(qtv[i].getFieldInfo(0).getFieldLength(), exp[i].getFieldLength())) retval = false; + if (!EXPECT_EQUAL(qtv[i]->getFieldInfo(0).getHitOffset(), exp[i].getHitOffset())) retval = false; + if (!EXPECT_EQUAL(qtv[i]->getFieldInfo(0).getHitCount(), exp[i].getHitCount())) retval = false; + if (!EXPECT_EQUAL(qtv[i]->getFieldInfo(0).getFieldLength(), exp[i].getFieldLength())) retval = false; } return retval; } @@ -467,13 +468,13 @@ testStrChrFieldSearcher(StrChrFieldSearcher & fs) ASSERT_TRUE(Query::parseQueryTerm("term").first == ""); ASSERT_TRUE(Query::parseQueryTerm("term").second == "term"); ASSERT_TRUE(Query::parseTerm("*substr*").first == "substr"); - ASSERT_TRUE(Query::parseTerm("*substr*").second == QueryTerm::SUBSTRINGTERM); + ASSERT_TRUE(Query::parseTerm("*substr*").second == TermType::SUBSTRINGTERM); ASSERT_TRUE(Query::parseTerm("*suffix").first == "suffix"); - ASSERT_TRUE(Query::parseTerm("*suffix").second == QueryTerm::SUFFIXTERM); + ASSERT_TRUE(Query::parseTerm("*suffix").second == TermType::SUFFIXTERM); ASSERT_TRUE(Query::parseTerm("prefix*").first == "prefix"); - ASSERT_TRUE(Query::parseTerm("prefix*").second == QueryTerm::PREFIXTERM); + ASSERT_TRUE(Query::parseTerm("prefix*").second == TermType::PREFIXTERM); ASSERT_TRUE(Query::parseTerm("term").first == "term"); - ASSERT_TRUE(Query::parseTerm("term").second == QueryTerm::WORD); + ASSERT_TRUE(Query::parseTerm("term").second == TermType::WORD); } TEST("suffix matching") { diff --git a/zkfacade/pom.xml b/zkfacade/pom.xml index 179856f053f..d9bd377ffa1 100644 --- a/zkfacade/pom.xml +++ b/zkfacade/pom.xml @@ -65,6 +65,21 @@ <artifactId>zookeeper</artifactId> <version>${zookeeper.client.version}</version> </dependency> + <!-- snappy-java and metrics-core are included here + to be able to work with ZooKeeper 3.6.2 due to + class loading issues --> + <dependency> + <groupId>io.dropwizard.metrics</groupId> + <artifactId>metrics-core</artifactId> + <scope>compile</scope> + <version>3.2.5</version> + </dependency> + <dependency> + <groupId>org.xerial.snappy</groupId> + <artifactId>snappy-java</artifactId> + <scope>compile</scope> + <version>1.1.7</version> + </dependency> <dependency> <groupId>org.mockito</groupId> <artifactId>mockito-core</artifactId> diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java index adfd9bd051f..4cbb6c95cb4 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/Curator.java @@ -10,6 +10,7 @@ import com.yahoo.text.Utf8; import com.yahoo.vespa.curator.api.VespaCurator; import com.yahoo.vespa.curator.recipes.CuratorCounter; import com.yahoo.vespa.defaults.Defaults; +import com.yahoo.vespa.zookeeper.VespaSslContextProvider; import com.yahoo.vespa.zookeeper.VespaZooKeeperServer; import org.apache.curator.RetryPolicy; import org.apache.curator.framework.CuratorFramework; @@ -124,9 +125,15 @@ public class Curator implements VespaCurator, AutoCloseable { private static ZKClientConfig createClientConfig(Optional<File> clientConfigFile) { if (clientConfigFile.isPresent()) { boolean useSecureClient = Boolean.parseBoolean(getEnvironmentVariable("VESPA_USE_TLS_FOR_ZOOKEEPER_CLIENT").orElse("false")); - String config = "zookeeper.client.secure=" + useSecureClient + "\n"; + StringBuilder configBuilder = new StringBuilder("zookeeper.client.secure=").append(useSecureClient).append("\n"); + if (useSecureClient) { + configBuilder.append("zookeeper.ssl.context.supplier.class=").append(VespaSslContextProvider.class.getName()).append("\n") + .append("zookeeper.ssl.enabledProtocols=").append(VespaSslContextProvider.enabledTlsProtocolConfigValue()).append("\n") + .append("zookeeper.ssl.ciphersuites=").append(VespaSslContextProvider.enabledTlsCiphersConfigValue()).append("\n") + .append("zookeeper.ssl.clientAuth=NEED\n"); + } clientConfigFile.get().getParentFile().mkdirs(); - IOUtils.writeFile(clientConfigFile.get(), Utf8.toBytes(config)); + IOUtils.writeFile(clientConfigFile.get(), Utf8.toBytes(configBuilder.toString())); try { return new ZKClientConfig(clientConfigFile.get()); } catch (QuorumPeerConfig.ConfigException e) { diff --git a/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCuratorFramework.java b/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCuratorFramework.java index d538f7ce6e2..f968b2b5098 100644 --- a/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCuratorFramework.java +++ b/zkfacade/src/main/java/com/yahoo/vespa/curator/mock/MockCuratorFramework.java @@ -53,6 +53,7 @@ import org.apache.curator.framework.api.UnhandledErrorListener; import org.apache.curator.framework.api.VersionPathAndBytesable; import org.apache.curator.framework.api.WatchPathable; import org.apache.curator.framework.api.Watchable; +import org.apache.curator.framework.api.WatchesBuilder; import org.apache.curator.framework.api.transaction.CuratorMultiTransaction; import org.apache.curator.framework.api.transaction.CuratorTransaction; import org.apache.curator.framework.api.transaction.CuratorTransactionBridge; @@ -222,6 +223,9 @@ public class MockCuratorFramework implements CuratorFramework { public RemoveWatchesBuilder watches() { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @Override + public WatchesBuilder watchers() { throw new UnsupportedOperationException("Not implemented in MockCurator"); } + + @Override public WatcherRemoveCuratorFramework newWatcherRemoveCuratorFramework() { class MockWatcherRemoveCuratorFramework extends MockCuratorFramework implements WatcherRemoveCuratorFramework { @@ -245,7 +249,7 @@ public class MockCuratorFramework implements CuratorFramework { public SchemaSet getSchemaSet() { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @Override - public boolean isZk34CompatibilityMode() { return false; } + public CompletableFuture<Void> postSafeNotify(Object monitorHolder) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } @Override public CompletableFuture<Void> runSafe(Runnable runnable) { throw new UnsupportedOperationException("Not implemented in MockCurator"); } diff --git a/zkfacade/src/main/java/org/apache/curator/framework/api/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/api/package-info.java index e3da4ab3efa..3bf1d2192bd 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/api/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/api/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework.api; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/api/transaction/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/api/transaction/package-info.java index 94f8b12894e..487d9b87e4e 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/api/transaction/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/api/transaction/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework.api.transaction; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/listen/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/listen/package-info.java index 71ee8ccfff0..076022b2240 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/listen/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/listen/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework.listen; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/package-info.java index 2999456bc9d..bfb643627d2 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/recipes/atomic/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/recipes/atomic/package-info.java index dd1dd7a1899..d40a238a5c6 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/recipes/atomic/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/recipes/atomic/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework.recipes.atomic; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/recipes/barriers/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/recipes/barriers/package-info.java index 4e2aea367de..58c56692adb 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/recipes/barriers/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/recipes/barriers/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework.recipes.barriers; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/recipes/cache/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/recipes/cache/package-info.java index ad6913d6381..7ed48f808c0 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/recipes/cache/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/recipes/cache/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework.recipes.cache; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/recipes/locks/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/recipes/locks/package-info.java index 4307c09e30a..834f7eb11f5 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/recipes/locks/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/recipes/locks/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework.recipes.locks; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/framework/state/package-info.java b/zkfacade/src/main/java/org/apache/curator/framework/state/package-info.java index 4a10e20318d..22c3075161a 100644 --- a/zkfacade/src/main/java/org/apache/curator/framework/state/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/framework/state/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.framework.state; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/package-info.java b/zkfacade/src/main/java/org/apache/curator/package-info.java index 232a5fd46f3..7248986dcde 100644 --- a/zkfacade/src/main/java/org/apache/curator/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zkfacade/src/main/java/org/apache/curator/retry/package-info.java b/zkfacade/src/main/java/org/apache/curator/retry/package-info.java index f45a0d927a5..cebabfd75a0 100644 --- a/zkfacade/src/main/java/org/apache/curator/retry/package-info.java +++ b/zkfacade/src/main/java/org/apache/curator/retry/package-info.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -@ExportPackage(version = @Version(major = 4, minor = 3, micro = 0)) +@ExportPackage(version = @Version(major = 5, minor = 1, micro = 0)) package org.apache.curator.retry; import com.yahoo.osgi.annotation.ExportPackage; import com.yahoo.osgi.annotation.Version; diff --git a/zookeeper-server/zookeeper-server-common/src/main/java/com/yahoo/vespa/zookeeper/Configurator.java b/zookeeper-server/zookeeper-server-common/src/main/java/com/yahoo/vespa/zookeeper/Configurator.java index ba79969469a..ebf5032a4a7 100644 --- a/zookeeper-server/zookeeper-server-common/src/main/java/com/yahoo/vespa/zookeeper/Configurator.java +++ b/zookeeper-server/zookeeper-server-common/src/main/java/com/yahoo/vespa/zookeeper/Configurator.java @@ -7,14 +7,11 @@ import com.yahoo.security.KeyStoreBuilder; import com.yahoo.security.KeyStoreType; import com.yahoo.security.KeyStoreUtils; import com.yahoo.security.KeyUtils; -import com.yahoo.security.SslContextBuilder; import com.yahoo.security.X509CertificateUtils; -import com.yahoo.security.tls.TlsContext; import com.yahoo.security.tls.TransportSecurityOptions; import com.yahoo.text.Utf8; import com.yahoo.vespa.defaults.Defaults; -import javax.net.ssl.SSLContext; import java.io.FileWriter; import java.io.IOException; import java.nio.file.Files; @@ -24,8 +21,6 @@ import java.security.PrivateKey; import java.security.cert.X509Certificate; import java.util.List; import java.util.Optional; -import java.util.Set; -import java.util.TreeSet; import java.util.logging.Level; import java.util.stream.Collectors; @@ -95,9 +90,8 @@ public class Configurator { sb.append("metricsProvider.className=org.apache.zookeeper.metrics.impl.NullMetricsProvider\n"); ensureThisServerIsRepresented(config.myid(), config.server()); config.server().forEach(server -> addServerToCfg(sb, server, config.clientPort())); - SSLContext sslContext = new SslContextBuilder().build(); - sb.append(new TlsQuorumConfig(sslContext, jksKeyStoreFilePath).createConfig(config, transportSecurityOptions)); - sb.append(new TlsClientServerConfig(sslContext, jksKeyStoreFilePath).createConfig(config, transportSecurityOptions)); + sb.append(new TlsQuorumConfig(jksKeyStoreFilePath).createConfig(config, transportSecurityOptions)); + sb.append(new TlsClientServerConfig(jksKeyStoreFilePath).createConfig(config, transportSecurityOptions)); return sb.toString(); } @@ -178,10 +172,6 @@ public class Configurator { } private interface TlsConfig { - default Set<String> allowedCiphers(SSLContext sslContext) { return new TreeSet<>(TlsContext.getAllowedCipherSuites(sslContext)); } - - default Set<String> allowedProtocols(SSLContext sslContext) { return new TreeSet<>(TlsContext.getAllowedProtocols(sslContext)); } - default Optional<String> getEnvironmentVariable(String variableName) { return Optional.ofNullable(System.getenv().get(variableName)) .filter(var -> !var.isEmpty()); @@ -196,8 +186,6 @@ public class Configurator { Path jksKeyStoreFilePath(); - SSLContext sslContext(); - default String createCommonKeyStoreTrustStoreOptions(Optional<TransportSecurityOptions> transportSecurityOptions) { StringBuilder sb = new StringBuilder(); transportSecurityOptions.ifPresent(options -> { @@ -215,10 +203,9 @@ public class Configurator { StringBuilder sb = new StringBuilder(); sb.append(configFieldPrefix()).append(".hostnameVerification=false\n"); sb.append(configFieldPrefix()).append(".clientAuth=NEED\n"); - sb.append(configFieldPrefix()).append(".ciphersuites=").append(String.join(",", allowedCiphers(sslContext()))).append("\n"); - sb.append(configFieldPrefix()).append(".enabledProtocols=").append(String.join(",", allowedProtocols(sslContext()))).append("\n"); - sb.append(configFieldPrefix()).append(".protocol=").append(sslContext().getProtocol()).append("\n"); - + sb.append(configFieldPrefix()).append(".ciphersuites=").append(VespaSslContextProvider.enabledTlsCiphersConfigValue()).append("\n"); + sb.append(configFieldPrefix()).append(".enabledProtocols=").append(VespaSslContextProvider.enabledTlsProtocolConfigValue()).append("\n"); + sb.append(configFieldPrefix()).append(".protocol=").append(VespaSslContextProvider.sslContextVersion()).append("\n"); return sb.toString(); } @@ -226,11 +213,9 @@ public class Configurator { static class TlsClientServerConfig implements TlsConfig { - private final SSLContext sslContext; private final Path jksKeyStoreFilePath; - TlsClientServerConfig(SSLContext sslContext, Path jksKeyStoreFilePath) { - this.sslContext = sslContext; + TlsClientServerConfig(Path jksKeyStoreFilePath) { this.jksKeyStoreFilePath = jksKeyStoreFilePath; } @@ -269,19 +254,13 @@ public class Configurator { return jksKeyStoreFilePath; } - @Override - public SSLContext sslContext() { - return sslContext; - } } static class TlsQuorumConfig implements TlsConfig { - private final SSLContext sslContext; private final Path jksKeyStoreFilePath; - TlsQuorumConfig(SSLContext sslContext, Path jksKeyStoreFilePath) { - this.sslContext = sslContext; + TlsQuorumConfig(Path jksKeyStoreFilePath) { this.jksKeyStoreFilePath = jksKeyStoreFilePath; } @@ -329,11 +308,6 @@ public class Configurator { return jksKeyStoreFilePath; } - @Override - public SSLContext sslContext() { - return sslContext; - } - } } diff --git a/zookeeper-server/zookeeper-server-common/src/main/java/com/yahoo/vespa/zookeeper/VespaSslContextProvider.java b/zookeeper-server/zookeeper-server-common/src/main/java/com/yahoo/vespa/zookeeper/VespaSslContextProvider.java new file mode 100644 index 00000000000..dc254ade071 --- /dev/null +++ b/zookeeper-server/zookeeper-server-common/src/main/java/com/yahoo/vespa/zookeeper/VespaSslContextProvider.java @@ -0,0 +1,43 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.zookeeper; + +import com.yahoo.security.tls.TlsContext; +import com.yahoo.security.tls.TransportSecurityUtils; + +import javax.net.ssl.SSLContext; +import java.util.Collection; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; + +/** + * Provider for Vespa {@link SSLContext} instance to Zookeeper + misc utility methods for providing Vespa TLS specific ZK configuration. + * + * @author bjorncs + */ +public class VespaSslContextProvider implements Supplier<SSLContext> { + + private static final TlsContext tlsContext = TransportSecurityUtils.getSystemTlsContext().orElse(null); + + @Override + public SSLContext get() { + if (!tlsEnabled()) throw new IllegalStateException("Vespa TLS is not enabled"); + return tlsContext.context(); + } + + public static boolean tlsEnabled() { return tlsContext != null; } + + public static String enabledTlsProtocolConfigValue() { + // Fallback to all allowed protocols if we cannot determine which are actually supported by runtime + Collection<String> enabledProtocols = tlsEnabled() ? List.of(tlsContext.parameters().getProtocols()) : TlsContext.ALLOWED_PROTOCOLS; + return enabledProtocols.stream().sorted().collect(Collectors.joining(",")); + } + + public static String enabledTlsCiphersConfigValue() { + // Fallback to all allowed ciphers if we cannot determine which are actually supported by runtime + Collection<String> enabledCiphers = tlsEnabled() ? List.of(tlsContext.parameters().getCipherSuites()) : TlsContext.ALLOWED_CIPHER_SUITES; + return enabledCiphers.stream().sorted().collect(Collectors.joining(",")); + } + + public static String sslContextVersion() { return tlsEnabled() ? tlsContext.context().getProtocol() : TlsContext.SSL_CONTEXT_VERSION; } +} diff --git a/zookeeper-server/zookeeper-server-common/src/test/java/com/yahoo/vespa/zookeeper/ConfiguratorTest.java b/zookeeper-server/zookeeper-server-common/src/test/java/com/yahoo/vespa/zookeeper/ConfiguratorTest.java index 0f43fb45d9d..a7994531b93 100644 --- a/zookeeper-server/zookeeper-server-common/src/test/java/com/yahoo/vespa/zookeeper/ConfiguratorTest.java +++ b/zookeeper-server/zookeeper-server-common/src/test/java/com/yahoo/vespa/zookeeper/ConfiguratorTest.java @@ -218,7 +218,10 @@ public class ConfiguratorTest { private String commonTlsQuorumConfig() { return "ssl.quorum.hostnameVerification=false\n" + "ssl.quorum.clientAuth=NEED\n" + - "ssl.quorum.ciphersuites=TLS_AES_128_GCM_SHA256,TLS_AES_256_GCM_SHA384,TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384\n" + + "ssl.quorum.ciphersuites=TLS_AES_128_GCM_SHA256,TLS_AES_256_GCM_SHA384,TLS_CHACHA20_POLY1305_SHA256," + + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384," + + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256," + + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256\n" + "ssl.quorum.enabledProtocols=TLSv1.2\n" + "ssl.quorum.protocol=TLS\n"; } @@ -226,7 +229,10 @@ public class ConfiguratorTest { private String commonTlsClientServerConfig() { return "ssl.hostnameVerification=false\n" + "ssl.clientAuth=NEED\n" + - "ssl.ciphersuites=TLS_AES_128_GCM_SHA256,TLS_AES_256_GCM_SHA384,TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384\n" + + "ssl.ciphersuites=TLS_AES_128_GCM_SHA256,TLS_AES_256_GCM_SHA384,TLS_CHACHA20_POLY1305_SHA256," + + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384," + + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256," + + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256\n" + "ssl.enabledProtocols=TLSv1.2\n" + "ssl.protocol=TLS\n"; } |