diff options
39 files changed, 768 insertions, 132 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java index 1b6b43625f7..f66e987bea6 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/admin/monitoring/VespaMetricSet.java @@ -196,6 +196,10 @@ public class VespaMetricSet { metrics.add(new Metric("mem.direct.used.average")); metrics.add(new Metric("mem.direct.used.max")); metrics.add(new Metric("mem.direct.count.max")); + metrics.add(new Metric("mem.native.total.average")); + metrics.add(new Metric("mem.native.free.average")); + metrics.add(new Metric("mem.native.used.average")); + metrics.add(new Metric("mem.native.used.max")); metrics.add(new Metric("jdisc.gc.count.average")); metrics.add(new Metric("jdisc.gc.count.max")); diff --git a/config/src/vespa/config/subscription/configsubscriptionset.cpp b/config/src/vespa/config/subscription/configsubscriptionset.cpp index 1d5db778dda..48e2bdc2615 100644 --- a/config/src/vespa/config/subscription/configsubscriptionset.cpp +++ b/config/src/vespa/config/subscription/configsubscriptionset.cpp @@ -42,7 +42,7 @@ ConfigSubscriptionSet::acquireSnapshot(duration timeout, bool ignoreChange) steady_time now = steady_clock::now(); const steady_time deadline = now + timeout; - int64_t lastGeneration = _currentGeneration; + int64_t lastGeneration = getGeneration(); bool inSync = false; LOG(spam, "Going into nextConfig loop, time left is %f", vespalib::to_s(deadline - now)); @@ -55,7 +55,7 @@ ConfigSubscriptionSet::acquireSnapshot(duration timeout, bool ignoreChange) // Run nextUpdate on all subscribers to get them in sync. for (const auto & subscription : _subscriptionList) { - if (!subscription->nextUpdate(_currentGeneration, deadline) && !subscription->hasGenerationChanged()) { + if (!subscription->nextUpdate(getGeneration(), deadline) && !subscription->hasGenerationChanged()) { subscription->reset(); continue; } @@ -68,7 +68,7 @@ ConfigSubscriptionSet::acquireSnapshot(duration timeout, bool ignoreChange) LOG(spam, "Config subscription did not change, id(%s), defname(%s)", key.getConfigId().c_str(), key.getDefName().c_str()); } LOG(spam, "Previous generation is %" PRId64 ", updates is %" PRId64, lastGeneration, subscription->getGeneration()); - if (isGenerationNewer(subscription->getGeneration(), _currentGeneration)) { + if (isGenerationNewer(subscription->getGeneration(), getGeneration())) { numGenerationChanged++; } if (generation < 0) { @@ -88,10 +88,10 @@ ConfigSubscriptionSet::acquireSnapshot(duration timeout, bool ignoreChange) } } - bool updated = inSync && isGenerationNewer(lastGeneration, _currentGeneration); + bool updated = inSync && isGenerationNewer(lastGeneration, getGeneration()); if (updated) { - LOG(spam, "Config was updated from %" PRId64 " to %" PRId64, _currentGeneration, lastGeneration); - _currentGeneration = lastGeneration; + LOG(spam, "Config was updated from %" PRId64 " to %" PRId64, getGeneration(), lastGeneration); + _currentGeneration.store(lastGeneration, std::memory_order_relaxed); _state = CONFIGURED; for (const auto & subscription : _subscriptionList) { const ConfigKey & key(subscription->getKey()); diff --git a/config/src/vespa/config/subscription/configsubscriptionset.h b/config/src/vespa/config/subscription/configsubscriptionset.h index 4b6d970770d..8daf7ae91ea 100644 --- a/config/src/vespa/config/subscription/configsubscriptionset.h +++ b/config/src/vespa/config/subscription/configsubscriptionset.h @@ -39,7 +39,7 @@ public: * @return generation number */ int64_t getGeneration() const noexcept { - return _currentGeneration; + return _currentGeneration.load(std::memory_order_relaxed); } /** @@ -69,7 +69,7 @@ private: const vespalib::duration _maxNapTime; std::shared_ptr<IConfigContext> _context; // Context to keep alive managers. IConfigManager & _mgr; // The config manager that we use. - int64_t _currentGeneration; // Holds the current config generation. + std::atomic<int64_t> _currentGeneration; // Holds the current config generation. SubscriptionList _subscriptionList; // List of current subscriptions. std::atomic<SubscriberState> _state; // Current state of this subscriber. }; diff --git a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java index 62982d66978..4659c2acc36 100644 --- a/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java +++ b/container-disc/src/main/java/com/yahoo/container/jdisc/metric/MetricUpdater.java @@ -5,6 +5,8 @@ import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.jdisc.Metric; import com.yahoo.jdisc.statistics.ContainerWatchdogMetrics; +import com.yahoo.nativec.NativeHeap; + import java.lang.management.BufferPoolMXBean; import java.lang.management.ManagementFactory; import java.nio.file.DirectoryStream; @@ -24,6 +26,9 @@ import java.util.TimerTask; */ public class MetricUpdater extends AbstractComponent { + private static final String NATIVE_FREE_MEMORY_BYTES = "mem.native.free"; + private static final String NATIVE_USED_MEMORY_BYTES = "mem.native.used"; + private static final String NATIVE_TOTAL_MEMORY_BYTES = "mem.native.total"; private static final String HEAP_FREE_MEMORY_BYTES = "mem.heap.free"; private static final String HEAP_USED_MEMORY_BYTES = "mem.heap.used"; private static final String HEAP_TOTAL_MEMORY_BYTES = "mem.heap.total"; @@ -116,6 +121,13 @@ public class MetricUpdater extends AbstractComponent { metric.set(DIRECT_COUNT, count, null); } + private void nativeHeapUsed() { + NativeHeap nativeHeap = NativeHeap.sample(); + metric.set(NATIVE_FREE_MEMORY_BYTES, nativeHeap.availableSize(), null); + metric.set(NATIVE_USED_MEMORY_BYTES, nativeHeap.usedSize(), null); + metric.set(NATIVE_TOTAL_MEMORY_BYTES, nativeHeap.totalSize(), null); + } + @Override public void run() { long freeMemory = runtime.freeMemory(); @@ -127,6 +139,7 @@ public class MetricUpdater extends AbstractComponent { metric.set(MEMORY_MAPPINGS_COUNT, count_mappings(), null); metric.set(OPEN_FILE_DESCRIPTORS, count_open_files(), null); directMemoryUsed(); + nativeHeapUsed(); containerWatchdogMetrics.emitMetrics(metric); garbageCollectionMetrics.emitMetrics(metric); diff --git a/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java index 68d73a7914a..f49ccf2c2f6 100644 --- a/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java +++ b/container-disc/src/test/java/com/yahoo/container/jdisc/metric/MetricUpdaterTest.java @@ -27,7 +27,7 @@ public class MetricUpdaterTest { ContainerWatchdogMetrics containerWatchdogMetrics = mock(ContainerWatchdogMetrics.class); new MetricUpdater(new MockScheduler(), metric, containerWatchdogMetrics); verify(containerWatchdogMetrics, times(1)).emitMetrics(any()); - verify(metric, times(9 + 2 * gcCount)).set(anyString(), any(), any()); + verify(metric, times(12 + 2 * gcCount)).set(anyString(), any(), any()); } private static class MockScheduler implements MetricUpdater.Scheduler { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/billing/BillingApiHandlerV2.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/billing/BillingApiHandlerV2.java index 612b584c7c0..0f3e5b7f76b 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/billing/BillingApiHandlerV2.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/billing/BillingApiHandlerV2.java @@ -28,12 +28,9 @@ import com.yahoo.vespa.hosted.controller.tenant.Tenant; import javax.ws.rs.BadRequestException; import java.math.BigDecimal; import java.time.Clock; -import java.time.Instant; import java.time.LocalDate; import java.time.ZoneOffset; -import java.time.chrono.ChronoZonedDateTime; import java.time.format.DateTimeFormatter; -import java.time.temporal.ChronoUnit; import java.util.Comparator; import java.util.List; import java.util.Optional; @@ -181,8 +178,7 @@ public class BillingApiHandlerV2 extends RestApiRequestHandler<BillingApiHandler var tenantName = TenantName.from(requestContext.pathParameters().getStringOrThrow("tenant")); var tenant = tenants.require(tenantName, CloudTenant.class); var untilAt = untilParameter(requestContext); - var usage = billing.createUncommittedBill(tenant.name(), untilAt.atZone(ZoneOffset.UTC).toLocalDate()); - + var usage = billing.createUncommittedBill(tenant.name(), untilAt); var slime = new Slime(); usageToSlime(slime.setObject(), usage); return slime; @@ -192,7 +188,7 @@ public class BillingApiHandlerV2 extends RestApiRequestHandler<BillingApiHandler private Slime accountant(RestApi.RequestContext requestContext) { var untilAt = untilParameter(requestContext); - var usagePerTenant = billing.createUncommittedBills(untilAt.atZone(ZoneOffset.UTC).toLocalDate()); + var usagePerTenant = billing.createUncommittedBills(untilAt); var response = new Slime(); var tenantsResponse = response.setObject().setArray("tenants"); @@ -214,7 +210,7 @@ public class BillingApiHandlerV2 extends RestApiRequestHandler<BillingApiHandler var tenant = tenants.require(tenantName, CloudTenant.class); var untilAt = untilParameter(requestContext); - var usage = billing.createUncommittedBill(tenant.name(), untilAt.atZone(ZoneOffset.UTC).toLocalDate()); + var usage = billing.createUncommittedBill(tenant.name(), untilAt); var slime = new Slime(); toSlime(slime.setObject(), usage); @@ -320,21 +316,15 @@ public class BillingApiHandlerV2 extends RestApiRequestHandler<BillingApiHandler // ---------- END INVOICE RENDERING ---------- - private Instant untilParameter(RestApi.RequestContext ctx) { + private LocalDate untilParameter(RestApi.RequestContext ctx) { return ctx.queryParameters().getString("until") .map(LocalDate::parse) .map(date -> date.plusDays(1)) - .map(date -> date.atStartOfDay(ZoneOffset.UTC)) - .map(ChronoZonedDateTime::toInstant) - .orElseGet(this::startOfDayTomorrowUTC); - } - - private Instant startOfDayTodayUTC() { - return LocalDate.now(clock.withZone(ZoneOffset.UTC)).atStartOfDay(ZoneOffset.UTC).toInstant(); + .orElseGet(this::tomorrow); } - private Instant startOfDayTomorrowUTC() { - return startOfDayTodayUTC().plus(1, ChronoUnit.DAYS); + private LocalDate tomorrow() { + return LocalDate.now(clock).plusDays(1); } private static String getInspectorFieldOrThrow(Inspector inspector, String field) { diff --git a/searchcore/src/vespa/searchcore/proton/documentmetastore/raw_document_meta_data.h b/searchcore/src/vespa/searchcore/proton/documentmetastore/raw_document_meta_data.h index e5f7dcc7192..b5e512fcd9e 100644 --- a/searchcore/src/vespa/searchcore/proton/documentmetastore/raw_document_meta_data.h +++ b/searchcore/src/vespa/searchcore/proton/documentmetastore/raw_document_meta_data.h @@ -5,6 +5,8 @@ #include <vespa/document/base/globalid.h> #include <vespa/document/bucket/bucketid.h> #include <persistence/spi/types.h> +#include <algorithm> +#include <atomic> #include <cassert> namespace proton { @@ -18,35 +20,41 @@ struct RawDocumentMetaData using BucketId = document::BucketId; using Timestamp = storage::spi::Timestamp; GlobalId _gid; - uint8_t _bucketUsedBits; - uint8_t _docSizeLow; - uint16_t _docSizeHigh; - Timestamp _timestamp; + std::atomic<uint32_t> _bucket_used_bits_and_doc_size; + std::atomic<uint64_t> _timestamp; + + static uint32_t capped_doc_size(uint32_t doc_size) { return std::min(0xffffffu, doc_size); } RawDocumentMetaData() noexcept : _gid(), - _bucketUsedBits(BucketId::minNumBits), - _docSizeLow(0), - _docSizeHigh(0), - _timestamp() + _bucket_used_bits_and_doc_size(BucketId::minNumBits), + _timestamp(0) { } RawDocumentMetaData(const GlobalId &gid, const BucketId &bucketId, const Timestamp ×tamp, uint32_t docSize) noexcept : _gid(gid), - _bucketUsedBits(bucketId.getUsedBits()), - _docSizeLow(docSize), - _docSizeHigh(docSize >> 8), + _bucket_used_bits_and_doc_size(bucketId.getUsedBits() | (capped_doc_size(docSize) << 8)), _timestamp(timestamp) { assert(bucketId.valid()); BucketId verId(gid.convertToBucketId()); - verId.setUsedBits(_bucketUsedBits); + verId.setUsedBits(bucketId.getUsedBits()); assert(bucketId.getRawId() == verId.getRawId() || bucketId.getRawId() == verId.getId()); - if (docSize >= (1u << 24)) { - _docSizeLow = 0xff; - _docSizeHigh = 0xffff; - } + } + + RawDocumentMetaData(const RawDocumentMetaData& rhs) + : _gid(rhs._gid), + _bucket_used_bits_and_doc_size(rhs._bucket_used_bits_and_doc_size.load(std::memory_order_relaxed)), + _timestamp(rhs._timestamp.load(std::memory_order_relaxed)) + { + } + + RawDocumentMetaData& operator=(const RawDocumentMetaData& rhs) { + _gid = rhs._gid; + _bucket_used_bits_and_doc_size.store(rhs._bucket_used_bits_and_doc_size.load(std::memory_order_relaxed), std::memory_order_relaxed); + _timestamp.store(rhs._timestamp.load(std::memory_order_relaxed), std::memory_order_relaxed); + return *this; } bool operator<(const GlobalId &rhs) const noexcept { return _gid < rhs; } @@ -57,17 +65,17 @@ struct RawDocumentMetaData const GlobalId &getGid() const { return _gid; } GlobalId &getGid() { return _gid; } void setGid(const GlobalId &rhs) { _gid = rhs; } - uint8_t getBucketUsedBits() const { return _bucketUsedBits; } + uint8_t getBucketUsedBits() const { return _bucket_used_bits_and_doc_size.load(std::memory_order_relaxed) & 0xffu; } BucketId getBucketId() const { BucketId ret(_gid.convertToBucketId()); - ret.setUsedBits(_bucketUsedBits); + ret.setUsedBits(getBucketUsedBits()); return ret; } void setBucketUsedBits(uint8_t bucketUsedBits) { assert(BucketId::validUsedBits(bucketUsedBits)); - _bucketUsedBits = bucketUsedBits; + _bucket_used_bits_and_doc_size.store((_bucket_used_bits_and_doc_size.load(std::memory_order_relaxed) & ~0xffu) | bucketUsedBits, std::memory_order_relaxed); } void setBucketId(const BucketId &bucketId) { @@ -77,15 +85,16 @@ struct RawDocumentMetaData verId.setUsedBits(bucketUsedBits); assert(bucketId.getRawId() == verId.getRawId() || bucketId.getRawId() == verId.getId()); - _bucketUsedBits = bucketUsedBits; + setBucketUsedBits(bucketUsedBits); } - Timestamp getTimestamp() const { return _timestamp; } + Timestamp getTimestamp() const { return Timestamp(_timestamp.load(std::memory_order_relaxed)); } + + void setTimestamp(const Timestamp ×tamp) { _timestamp.store(timestamp.getValue(), std::memory_order_relaxed); } - void setTimestamp(const Timestamp ×tamp) { _timestamp = timestamp; } + uint32_t getDocSize() const { return _bucket_used_bits_and_doc_size.load(std::memory_order_relaxed) >> 8; } + void setDocSize(uint32_t docSize) { _bucket_used_bits_and_doc_size.store((_bucket_used_bits_and_doc_size.load(std::memory_order_relaxed) & 0xffu) | (capped_doc_size(docSize) << 8), std::memory_order_relaxed); } - uint32_t getDocSize() const { return _docSizeLow + (static_cast<uint32_t>(_docSizeHigh) << 8); } - void setDocSize(uint32_t docSize) { _docSizeLow = docSize; _docSizeHigh = docSize >> 8; } }; } // namespace proton diff --git a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp index 756af216988..381915f53d4 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/matcher.cpp @@ -128,8 +128,7 @@ Matcher::getStats() { std::lock_guard<std::mutex> guard(_statsLock); MatchingStats stats = std::move(_stats); - _stats = MatchingStats(); - _stats.softDoomFactor(stats.softDoomFactor()); + _stats = MatchingStats(stats.softDoomFactor()); return stats; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/matching_stats.cpp b/searchcore/src/vespa/searchcore/proton/matching/matching_stats.cpp index 2c826094ddf..bbb9dba9a30 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matching_stats.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/matching_stats.cpp @@ -19,7 +19,7 @@ constexpr double MAX_CHANGE_FACTOR = 5; } // namespace proton::matching::<unnamed> -MatchingStats::MatchingStats() +MatchingStats::MatchingStats(double prev_soft_doom_factor) : _queries(0), _limited_queries(0), _docidSpaceCovered(0), @@ -28,7 +28,7 @@ MatchingStats::MatchingStats() _docsReRanked(0), _softDoomed(0), _doomOvertime(), - _softDoomFactor(INITIAL_SOFT_DOOM_FACTOR), + _softDoomFactor(prev_soft_doom_factor), _queryCollateralTime(), // TODO: Remove in Vespa 8 _querySetupTime(), _queryLatency(), @@ -89,16 +89,18 @@ MatchingStats::updatesoftDoomFactor(vespalib::duration hardLimit, vespalib::dura // It is merely a safety measure to avoid overflow on bad input as can happen with time senstive stuff // in any soft real time system. if ((hardLimit >= MIN_TIMEOUT) && (softLimit >= MIN_TIMEOUT)) { + double factor = softDoomFactor(); double diff = vespalib::to_s(softLimit - duration)/vespalib::to_s(hardLimit); if (duration < softLimit) { // Since softdoom factor can become very small, allow a minimum change of some size - diff = std::min(diff, _softDoomFactor*MAX_CHANGE_FACTOR); - _softDoomFactor += 0.01*diff; + diff = std::min(diff, factor*MAX_CHANGE_FACTOR); + factor += 0.01*diff; } else { - diff = std::max(diff, -_softDoomFactor*MAX_CHANGE_FACTOR); - _softDoomFactor += 0.02*diff; + diff = std::max(diff, -factor*MAX_CHANGE_FACTOR); + factor += 0.02*diff; } - _softDoomFactor = std::max(_softDoomFactor, 0.01); // Never go below 1% + factor = std::max(factor, 0.01); // Never go below 1% + softDoomFactor(factor); } return *this; } diff --git a/searchcore/src/vespa/searchcore/proton/matching/matching_stats.h b/searchcore/src/vespa/searchcore/proton/matching/matching_stats.h index 047c6fcaf13..eafa14870fa 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/matching_stats.h +++ b/searchcore/src/vespa/searchcore/proton/matching/matching_stats.h @@ -5,6 +5,7 @@ #include <vector> #include <cstddef> #include <vespa/vespalib/util/time.h> +#include <vespa/vespalib/datastore/atomic_value_wrapper.h> namespace proton::matching { @@ -124,7 +125,8 @@ private: size_t _docsReRanked; size_t _softDoomed; Avg _doomOvertime; - double _softDoomFactor; + using SoftDoomFactor = vespalib::datastore::AtomicValueWrapper<double>; + SoftDoomFactor _softDoomFactor; Avg _queryCollateralTime; // TODO: Remove in Vespa 8 Avg _querySetupTime; Avg _queryLatency; @@ -139,7 +141,7 @@ public: MatchingStats & operator = (const MatchingStats &) = delete; MatchingStats(MatchingStats &&) = default; MatchingStats & operator = (MatchingStats &&) = default; - MatchingStats(); + MatchingStats(double prev_soft_doom_factor = INITIAL_SOFT_DOOM_FACTOR); ~MatchingStats(); MatchingStats &queries(size_t value) { _queries = value; return *this; } @@ -165,8 +167,8 @@ public: vespalib::duration doomOvertime() const { return vespalib::from_s(_doomOvertime.max()); } - MatchingStats &softDoomFactor(double value) { _softDoomFactor = value; return *this; } - double softDoomFactor() const { return _softDoomFactor; } + MatchingStats &softDoomFactor(double value) { _softDoomFactor.store_relaxed(value); return *this; } + double softDoomFactor() const { return _softDoomFactor.load_relaxed(); } MatchingStats &updatesoftDoomFactor(vespalib::duration hardLimit, vespalib::duration softLimit, vespalib::duration duration); // TODO: Remove in Vespa 8 diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.cpp b/searchlib/src/vespa/searchlib/attribute/attributevector.cpp index 1edd5d9fe76..73ff91f135c 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributevector.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attributevector.cpp @@ -528,7 +528,7 @@ AttributeVector::compactLidSpace(uint32_t wantedLidLimit) { } commit(); _committedDocIdLimit.store(wantedLidLimit, std::memory_order_release); - _compactLidSpaceGeneration = _genHandler.getCurrentGeneration(); + _compactLidSpaceGeneration.store(_genHandler.getCurrentGeneration(), std::memory_order_relaxed); incGeneration(); } @@ -536,7 +536,7 @@ AttributeVector::compactLidSpace(uint32_t wantedLidLimit) { bool AttributeVector::canShrinkLidSpace() const { return wantShrinkLidSpace() && - _compactLidSpaceGeneration < getFirstUsedGeneration(); + _compactLidSpaceGeneration.load(std::memory_order_relaxed) < getFirstUsedGeneration(); } diff --git a/searchlib/src/vespa/searchlib/attribute/attributevector.h b/searchlib/src/vespa/searchlib/attribute/attributevector.h index 5f336ab921f..b40f36ad6bd 100644 --- a/searchlib/src/vespa/searchlib/attribute/attributevector.h +++ b/searchlib/src/vespa/searchlib/attribute/attributevector.h @@ -501,7 +501,7 @@ private: std::atomic<uint32_t> _committedDocIdLimit; // docid limit for search uint32_t _uncommittedDocIdLimit; // based on queued changes uint64_t _createSerialNum; - uint64_t _compactLidSpaceGeneration; + std::atomic<uint64_t> _compactLidSpaceGeneration; bool _hasEnum; bool _loaded; bool _isUpdateableInMemoryOnly; diff --git a/searchlib/src/vespa/searchlib/common/packets.h b/searchlib/src/vespa/searchlib/common/packets.h index 07c942a997b..f13cbe24ce4 100644 --- a/searchlib/src/vespa/searchlib/common/packets.h +++ b/searchlib/src/vespa/searchlib/common/packets.h @@ -5,6 +5,7 @@ #include <vespa/vespalib/util/compressionconfig.h> #include <vespa/vespalib/util/memory.h> #include <vector> +#include <atomic> class FNET_DataBuffer; @@ -15,21 +16,21 @@ using vespalib::string; class FS4PersistentPacketStreamer { using CompressionConfig = vespalib::compression::CompressionConfig; - unsigned int _compressionLimit; - unsigned int _compressionLevel; - CompressionConfig::Type _compressionType; + std::atomic<unsigned int> _compressionLimit; + std::atomic<unsigned int> _compressionLevel; + std::atomic<CompressionConfig::Type> _compressionType; public: static FS4PersistentPacketStreamer Instance; FS4PersistentPacketStreamer(); - void SetCompressionLimit(unsigned int limit) { _compressionLimit = limit; } - void SetCompressionLevel(unsigned int level) { _compressionLevel = level; } - void SetCompressionType(CompressionConfig::Type compressionType) { _compressionType = compressionType; } - CompressionConfig::Type getCompressionType() const { return _compressionType; } - uint32_t getCompressionLimit() const { return _compressionLimit; } - uint32_t getCompressionLevel() const { return _compressionLevel; } + void SetCompressionLimit(unsigned int limit) { _compressionLimit.store(limit, std::memory_order_relaxed); } + void SetCompressionLevel(unsigned int level) { _compressionLevel.store(level, std::memory_order_relaxed); } + void SetCompressionType(CompressionConfig::Type compressionType) { _compressionType.store(compressionType, std::memory_order_relaxed); } + CompressionConfig::Type getCompressionType() const { return _compressionType.load(std::memory_order_relaxed); } + uint32_t getCompressionLimit() const { return _compressionLimit.load(std::memory_order_relaxed); } + uint32_t getCompressionLevel() const { return _compressionLevel.load(std::memory_order_relaxed); } }; //========================================================================== diff --git a/vespajlib/src/main/java/com/yahoo/nativec/GLibcVersion.java b/vespajlib/src/main/java/com/yahoo/nativec/GLibcVersion.java index 67ae30c84f5..2dfa4f6d11b 100644 --- a/vespajlib/src/main/java/com/yahoo/nativec/GLibcVersion.java +++ b/vespajlib/src/main/java/com/yahoo/nativec/GLibcVersion.java @@ -1,5 +1,10 @@ package com.yahoo.nativec; +/** + * Gives access to the C library version. + * + * @author baldersheim + */ public class GLibcVersion { private final static Throwable initException = NativeC.loadLibrary(GLibcVersion.class); public static Throwable init() { diff --git a/vespajlib/src/main/java/com/yahoo/nativec/MallInfo.java b/vespajlib/src/main/java/com/yahoo/nativec/MallInfo.java index a4f5486ccf1..eda6c7d1af7 100644 --- a/vespajlib/src/main/java/com/yahoo/nativec/MallInfo.java +++ b/vespajlib/src/main/java/com/yahoo/nativec/MallInfo.java @@ -2,7 +2,12 @@ package com.yahoo.nativec; import com.sun.jna.Structure; -public class MallInfo { +/** + * Gives access to the information provided by the C library mallinfo() function. + * + * @author baldersheim + */ +public class MallInfo extends NativeHeap { private final static Throwable initException = NativeC.loadLibrary(MallInfo.class); public static Throwable init() { return initException; @@ -23,8 +28,27 @@ public class MallInfo { public int keepcost; /* Top-most, releasable space (bytes) */ } private static native MallInfoStruct.ByValue mallinfo(); + + private final MallInfoStruct mallinfo; public MallInfo() { mallinfo = mallinfo(); } - private final MallInfoStruct mallinfo; + + @Override + public long usedSize() { + long v = mallinfo.uordblks; + return v << 20; // Due to too few bits in ancient mallinfo vespamalloc reports in 1M units + } + + @Override + public long totalSize() { + long v = mallinfo.arena; + return v << 20; // Due to too few bits in ancient mallinfo vespamalloc reports in 1M units + } + + @Override + public long availableSize() { + long v = mallinfo.fordblks; + return v << 20; // Due to too few bits in ancient mallinfo vespamalloc reports in 1M units + } } diff --git a/vespajlib/src/main/java/com/yahoo/nativec/MallInfo2.java b/vespajlib/src/main/java/com/yahoo/nativec/MallInfo2.java index 1ae3bc590e2..ea735046843 100644 --- a/vespajlib/src/main/java/com/yahoo/nativec/MallInfo2.java +++ b/vespajlib/src/main/java/com/yahoo/nativec/MallInfo2.java @@ -2,7 +2,12 @@ package com.yahoo.nativec; import com.sun.jna.Structure; -public class MallInfo2 { +/** + * Gives access to the information provided by the C library mallinfo2() function. + * + * @author baldersheim + */ +public class MallInfo2 extends NativeHeap { private final static Throwable initException = NativeC.loadLibrary(MallInfo2.class); public static Throwable init() { return initException; @@ -23,8 +28,24 @@ public class MallInfo2 { public long keepcost; /* Top-most, releasable space (bytes) */ } private static native MallInfo2Struct.ByValue mallinfo2(); + private final MallInfo2Struct mallinfo; + public MallInfo2() { mallinfo = mallinfo2(); } - private final MallInfo2Struct mallinfo; + + @Override + public long usedSize() { + return mallinfo.uordblks; + } + + @Override + public long totalSize() { + return mallinfo.arena; + } + + @Override + public long availableSize() { + return mallinfo.fordblks; + } } diff --git a/vespajlib/src/main/java/com/yahoo/nativec/NativeHeap.java b/vespajlib/src/main/java/com/yahoo/nativec/NativeHeap.java new file mode 100644 index 00000000000..ddff2e33230 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/nativec/NativeHeap.java @@ -0,0 +1,24 @@ +package com.yahoo.nativec; + +import com.sun.jna.Platform; + +/** + * Represents the native C heap if accessible + * + * @author baldersheim + */ +public class NativeHeap { + public long usedSize() { return 0; } + public long totalSize() { return 0; } + public long availableSize() { return 0; } + public static NativeHeap sample() { + if (Platform.isLinux()) { + GLibcVersion version = new GLibcVersion(); + if ((version.major() >= 3) || ((version.major() == 2) && (version.minor() >= 33))) { + return new MallInfo2(); + } + return new MallInfo(); + } + return new NativeHeap(); + } +} diff --git a/vespajlib/src/main/java/com/yahoo/nativec/PosixFAdvise.java b/vespajlib/src/main/java/com/yahoo/nativec/PosixFAdvise.java index 3e2c26d2ef2..0fdcbca5f14 100644 --- a/vespajlib/src/main/java/com/yahoo/nativec/PosixFAdvise.java +++ b/vespajlib/src/main/java/com/yahoo/nativec/PosixFAdvise.java @@ -2,6 +2,11 @@ package com.yahoo.nativec; import com.sun.jna.LastErrorException; +/** + * Gives access to the C library posix_fadvise() function. + * + * @author baldersheim + */ public class PosixFAdvise { public static final int POSIX_FADV_DONTNEED = 4; // See /usr/include/linux/fadvise.h private final static Throwable initException = NativeC.loadLibrary(PosixFAdvise.class); diff --git a/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp b/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp index 45e9c92343e..e20cd30c597 100644 --- a/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp +++ b/vespalib/src/tests/net/tls/openssl_impl/openssl_impl_test.cpp @@ -535,40 +535,40 @@ struct CertFixture : Fixture { CertFixture::~CertFixture() = default; struct PrintingCertificateCallback : CertificateVerificationCallback { - bool verify(const PeerCredentials& peer_creds) const override { + AuthorizationResult verify(const PeerCredentials& peer_creds) const override { if (!peer_creds.common_name.empty()) { fprintf(stderr, "Got a CN: %s\n", peer_creds.common_name.c_str()); } for (auto& dns : peer_creds.dns_sans) { fprintf(stderr, "Got a DNS SAN entry: %s\n", dns.c_str()); } - return true; + return AuthorizationResult::make_authorized_for_all_roles(); } }; // Single-use mock verifier struct MockCertificateCallback : CertificateVerificationCallback { mutable PeerCredentials creds; // only used in single thread testing context - bool verify(const PeerCredentials& peer_creds) const override { + AuthorizationResult verify(const PeerCredentials& peer_creds) const override { creds = peer_creds; - return true; + return AuthorizationResult::make_authorized_for_all_roles(); } }; struct AlwaysFailVerifyCallback : CertificateVerificationCallback { - bool verify([[maybe_unused]] const PeerCredentials& peer_creds) const override { + AuthorizationResult verify([[maybe_unused]] const PeerCredentials& peer_creds) const override { fprintf(stderr, "Rejecting certificate, none shall pass!\n"); - return false; + return AuthorizationResult::make_not_authorized(); } }; struct ExceptionThrowingCallback : CertificateVerificationCallback { - bool verify([[maybe_unused]] const PeerCredentials& peer_creds) const override { + AuthorizationResult verify([[maybe_unused]] const PeerCredentials& peer_creds) const override { throw std::runtime_error("oh no what is going on"); } }; -TEST_F("Certificate verification callback returning false breaks handshake", CertFixture) { +TEST_F("Certificate verification callback returning unauthorized breaks handshake", CertFixture) { auto ck = f.create_ca_issued_peer_cert({"hello.world.example.com"}, {}); f.reset_client_with_cert_opts(ck, std::make_shared<PrintingCertificateCallback>()); @@ -602,8 +602,40 @@ TEST_F("Certificate verification callback observes CN, DNS SANs and URI SANs", C ASSERT_EQUAL(2u, creds.dns_sans.size()); EXPECT_EQUAL("crash.wile.example.com", creds.dns_sans[0]); EXPECT_EQUAL("burn.wile.example.com", creds.dns_sans[1]); - ASSERT_EQUAL(1u, server_cb->creds.uri_sans.size()); - EXPECT_EQUAL("foo://bar.baz/zoid", server_cb->creds.uri_sans[0]); + ASSERT_EQUAL(1u, creds.uri_sans.size()); + EXPECT_EQUAL("foo://bar.baz/zoid", creds.uri_sans[0]); +} + +TEST_F("Peer credentials are propagated to CryptoCodec", CertFixture) { + auto cli_cert = f.create_ca_issued_peer_cert( + {{"rockets.wile.example.com"}}, + {{"DNS:crash.wile.example.com"}, {"DNS:burn.wile.example.com"}, + {"URI:foo://bar.baz/zoid"}}); + auto serv_cert = f.create_ca_issued_peer_cert( + {{"birdseed.roadrunner.example.com"}}, + {{"DNS:fake.tunnel.example.com"}}); + f.reset_client_with_cert_opts(cli_cert, std::make_shared<PrintingCertificateCallback>()); + auto server_cb = std::make_shared<MockCertificateCallback>(); + f.reset_server_with_cert_opts(serv_cert, server_cb); + ASSERT_TRUE(f.handshake()); + + auto& client_creds = f.server->peer_credentials(); + auto& server_creds = f.client->peer_credentials(); + + fprintf(stderr, "Client credentials (observed by server): %s\n", to_string(client_creds).c_str()); + fprintf(stderr, "Server credentials (observed by client): %s\n", to_string(server_creds).c_str()); + + EXPECT_EQUAL("rockets.wile.example.com", client_creds.common_name); + ASSERT_EQUAL(2u, client_creds.dns_sans.size()); + EXPECT_EQUAL("crash.wile.example.com", client_creds.dns_sans[0]); + EXPECT_EQUAL("burn.wile.example.com", client_creds.dns_sans[1]); + ASSERT_EQUAL(1u, client_creds.uri_sans.size()); + EXPECT_EQUAL("foo://bar.baz/zoid", client_creds.uri_sans[0]); + + EXPECT_EQUAL("birdseed.roadrunner.example.com", server_creds.common_name); + ASSERT_EQUAL(1u, server_creds.dns_sans.size()); + EXPECT_EQUAL("fake.tunnel.example.com", server_creds.dns_sans[0]); + ASSERT_EQUAL(0u, server_creds.uri_sans.size()); } TEST_F("Last occurring CN is given to verification callback if multiple CNs are present", CertFixture) { diff --git a/vespalib/src/tests/net/tls/policy_checking_certificate_verifier/policy_checking_certificate_verifier_test.cpp b/vespalib/src/tests/net/tls/policy_checking_certificate_verifier/policy_checking_certificate_verifier_test.cpp index 812d06868fd..8c9e50f17b4 100644 --- a/vespalib/src/tests/net/tls/policy_checking_certificate_verifier/policy_checking_certificate_verifier_test.cpp +++ b/vespalib/src/tests/net/tls/policy_checking_certificate_verifier/policy_checking_certificate_verifier_test.cpp @@ -124,7 +124,12 @@ PeerCredentials creds_with_cn(vespalib::stringref cn) { bool verify(AuthorizedPeers authorized_peers, const PeerCredentials& peer_creds) { auto verifier = create_verify_callback_from(std::move(authorized_peers)); - return verifier->verify(peer_creds); + return verifier->verify(peer_creds).success(); +} + +AssumedRoles verify_roles(AuthorizedPeers authorized_peers, const PeerCredentials& peer_creds) { + auto verifier = create_verify_callback_from(std::move(authorized_peers)); + return verifier->verify(peer_creds).steal_assumed_roles(); } TEST("Default-constructed AuthorizedPeers does not allow all authenticated peers") { @@ -137,6 +142,16 @@ TEST("Specially constructed set of policies allows all authenticated peers") { EXPECT_TRUE(verify(allow_all, creds_with_dns_sans({{"anything.goes"}}))); } +TEST("specially constructed set of policies returns wildcard role set") { + auto allow_all = AuthorizedPeers::allow_all_authenticated(); + EXPECT_EQUAL(verify_roles(allow_all, creds_with_dns_sans({{"anything.goes"}})), AssumedRoles::make_wildcard_role()); +} + +TEST("policy without explicit role set implicitly returns wildcard role set") { + auto authorized = authorized_peers({policy_with({required_san_dns("yolo.swag")})}); + EXPECT_EQUAL(verify_roles(authorized, creds_with_dns_sans({{"yolo.swag"}})), AssumedRoles::make_wildcard_role()); +} + TEST("Non-empty policies do not allow all authenticated peers") { auto allow_not_all = authorized_peers({policy_with({required_san_dns("hello.world")})}); EXPECT_FALSE(allow_not_all.allows_all_authenticated()); @@ -231,10 +246,11 @@ struct MultiPolicyMatchFixture { }; MultiPolicyMatchFixture::MultiPolicyMatchFixture() - : authorized(authorized_peers({policy_with({required_san_dns("hello.world")}), - policy_with({required_san_dns("foo.bar")}), - policy_with({required_san_dns("zoid.berg")}), - policy_with({required_san_uri("zoid://be.rg/")})})) + : authorized(authorized_peers({policy_with({required_san_dns("hello.world")}, assumed_roles({"r1"})), + policy_with({required_san_dns("foo.bar")}, assumed_roles({"r2"})), + policy_with({required_san_dns("zoid.berg")}, assumed_roles({"r2", "r3"})), + policy_with({required_san_dns("secret.sauce")}, AssumedRoles::make_wildcard_role()), + policy_with({required_san_uri("zoid://be.rg/")}, assumed_roles({"r4"}))})) {} MultiPolicyMatchFixture::~MultiPolicyMatchFixture() = default; @@ -246,14 +262,34 @@ TEST_F("peer verifies if it matches at least 1 policy of multiple", MultiPolicyM EXPECT_TRUE(verify(f.authorized, creds_with_uri_sans({{"zoid://be.rg/"}}))); } +TEST_F("role set is returned for single matched policy", MultiPolicyMatchFixture) { + EXPECT_EQUAL(verify_roles(f.authorized, creds_with_dns_sans({{"hello.world"}})), assumed_roles({"r1"})); + EXPECT_EQUAL(verify_roles(f.authorized, creds_with_dns_sans({{"foo.bar"}})), assumed_roles({"r2"})); + EXPECT_EQUAL(verify_roles(f.authorized, creds_with_dns_sans({{"zoid.berg"}})), assumed_roles({"r2", "r3"})); + EXPECT_EQUAL(verify_roles(f.authorized, creds_with_dns_sans({{"secret.sauce"}})), AssumedRoles::make_wildcard_role()); + EXPECT_EQUAL(verify_roles(f.authorized, creds_with_uri_sans({{"zoid://be.rg/"}})), assumed_roles({"r4"})); +} + TEST_F("peer verifies if it matches multiple policies", MultiPolicyMatchFixture) { EXPECT_TRUE(verify(f.authorized, creds_with_dns_sans({{"hello.world"}, {"zoid.berg"}}))); } +TEST_F("union role set is returned if multiple policies match", MultiPolicyMatchFixture) { + EXPECT_EQUAL(verify_roles(f.authorized, creds_with_dns_sans({{"hello.world"}, {"foo.bar"}, {"zoid.berg"}})), + assumed_roles({"r1", "r2", "r3"})); + // Wildcard role is tracked as a distinct role string + EXPECT_EQUAL(verify_roles(f.authorized, creds_with_dns_sans({{"hello.world"}, {"foo.bar"}, {"secret.sauce"}})), + assumed_roles({"r1", "r2", "*"})); +} + TEST_F("peer must match at least 1 of multiple policies", MultiPolicyMatchFixture) { EXPECT_FALSE(verify(f.authorized, creds_with_dns_sans({{"does.not.exist"}}))); } +TEST_F("empty role set is returned if no policies match", MultiPolicyMatchFixture) { + EXPECT_EQUAL(verify_roles(f.authorized, creds_with_dns_sans({{"does.not.exist"}})), AssumedRoles::make_empty()); +} + TEST("CN requirement without glob pattern is matched as exact string") { auto authorized = authorized_peers({policy_with({required_cn("hello.world")})}); EXPECT_TRUE(verify(authorized, creds_with_cn("hello.world"))); @@ -272,6 +308,64 @@ TEST("CN requirement can include glob wildcards") { EXPECT_FALSE(verify(authorized, creds_with_cn("world"))); } +TEST("AssumedRoles by default contains no roles") { + AssumedRoles roles; + EXPECT_TRUE(roles.empty()); + EXPECT_FALSE(roles.can_assume_role("foo")); + auto empty = AssumedRoles::make_empty(); + EXPECT_EQUAL(roles, empty); +} + +TEST("AssumedRoles can be constructed with an explicit set of roles") { + auto roles = AssumedRoles::make_for_roles({"foo", "bar"}); + EXPECT_TRUE(roles.can_assume_role("foo")); + EXPECT_TRUE(roles.can_assume_role("bar")); + EXPECT_FALSE(roles.can_assume_role("baz")); +} + +TEST("AssumedRoles wildcard role can assume any role") { + auto roles = AssumedRoles::make_wildcard_role(); + EXPECT_TRUE(roles.can_assume_role("foo")); + EXPECT_TRUE(roles.can_assume_role("bar")); +} + +TEST("AssumedRolesBuilder builds union set of added roles") { + AssumedRolesBuilder builder; + builder.add_union(AssumedRoles::make_for_roles({"hello", "world"})); + builder.add_union(AssumedRoles::make_for_roles({"hello", "moon"})); + builder.add_union(AssumedRoles::make_for_roles({"goodbye", "moon"})); + auto roles = builder.build_with_move(); + EXPECT_EQUAL(roles, AssumedRoles::make_for_roles({"hello", "goodbye", "moon", "world"})); +} + +TEST("AuthorizationResult is not authorized by default") { + AuthorizationResult result; + EXPECT_FALSE(result.success()); + EXPECT_TRUE(result.assumed_roles().empty()); +} + +TEST("AuthorizationResult can be explicitly created as not authorzed") { + auto result = AuthorizationResult::make_not_authorized(); + EXPECT_FALSE(result.success()); + EXPECT_TRUE(result.assumed_roles().empty()); +} + +TEST("AuthorizationResult can be pre-authorized for all roles") { + auto result = AuthorizationResult::make_authorized_for_all_roles(); + EXPECT_TRUE(result.success()); + EXPECT_FALSE(result.assumed_roles().empty()); + EXPECT_TRUE(result.assumed_roles().can_assume_role("foo")); +} + +TEST("AuthorizationResult can be pre-authorized for an explicit set of roles") { + auto result = AuthorizationResult::make_authorized_for_roles(AssumedRoles::make_for_roles({"elden", "ring"})); + EXPECT_TRUE(result.success()); + EXPECT_FALSE(result.assumed_roles().empty()); + EXPECT_TRUE(result.assumed_roles().can_assume_role("elden")); + EXPECT_TRUE(result.assumed_roles().can_assume_role("ring")); + EXPECT_FALSE(result.assumed_roles().can_assume_role("O you don't have the right")); +} + // TODO test CN _and_ SAN TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt b/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt index f1e64241533..424c2bd672f 100644 --- a/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/net/tls/CMakeLists.txt @@ -1,7 +1,9 @@ # Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. vespa_add_library(vespalib_vespalib_net_tls OBJECT SOURCES + assumed_roles.cpp authorization_mode.cpp + authorization_result.cpp auto_reloading_tls_crypto_engine.cpp crypto_codec.cpp crypto_codec_adapter.cpp diff --git a/vespalib/src/vespa/vespalib/net/tls/assumed_roles.cpp b/vespalib/src/vespa/vespalib/net/tls/assumed_roles.cpp new file mode 100644 index 00000000000..672458d0024 --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/assumed_roles.cpp @@ -0,0 +1,95 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include "assumed_roles.h" +#include <vespa/vespalib/stllike/asciistream.h> +#include <algorithm> +#include <ostream> + +namespace vespalib::net::tls { + +const string AssumedRoles::WildcardRole("*"); + +AssumedRoles::AssumedRoles() = default; + +AssumedRoles::AssumedRoles(RoleSet assumed_roles) + : _assumed_roles(std::move(assumed_roles)) +{} + +AssumedRoles::AssumedRoles(const AssumedRoles&) = default; +AssumedRoles& AssumedRoles::operator=(const AssumedRoles&) = default; +AssumedRoles::AssumedRoles(AssumedRoles&&) noexcept = default; +AssumedRoles& AssumedRoles::operator=(AssumedRoles&&) noexcept = default; +AssumedRoles::~AssumedRoles() = default; + +bool AssumedRoles::can_assume_role(const string& role) const noexcept { + return (_assumed_roles.contains(role) || _assumed_roles.contains(WildcardRole)); +} + +std::vector<string> AssumedRoles::ordered_roles() const { + std::vector<string> roles; + for (const auto& r : _assumed_roles) { + roles.emplace_back(r); + } + std::sort(roles.begin(), roles.end()); + return roles; +} + +bool AssumedRoles::operator==(const AssumedRoles& rhs) const noexcept { + return (_assumed_roles == rhs._assumed_roles); +} + +void AssumedRoles::print(asciistream& os) const { + os << "AssumedRoles(roles: ["; + auto roles = ordered_roles(); + for (size_t i = 0; i < roles.size(); ++i) { + if (i > 0) { + os << ", "; + } + os << roles[i]; + } + os << "])"; +} + +asciistream& operator<<(asciistream& os, const AssumedRoles& res) { + res.print(os); + return os; +} + +std::ostream& operator<<(std::ostream& os, const AssumedRoles& res) { + os << to_string(res); + return os; +} + +string to_string(const AssumedRoles& res) { + asciistream os; + os << res; + return os.str(); +} + +AssumedRoles AssumedRoles::make_for_roles(RoleSet assumed_roles) { + return AssumedRoles(std::move(assumed_roles)); +} + +AssumedRoles AssumedRoles::make_wildcard_role() { + return AssumedRoles(RoleSet({WildcardRole})); +} + +AssumedRoles AssumedRoles::make_empty() { + return {}; +} + +AssumedRolesBuilder::AssumedRolesBuilder() = default; +AssumedRolesBuilder::~AssumedRolesBuilder() = default; + +void AssumedRolesBuilder::add_union(const AssumedRoles& roles) { + // TODO fix hash_set iterator range insert() + for (const auto& role : roles.unordered_roles()) { + _wip_roles.insert(role); + } +} + +AssumedRoles AssumedRolesBuilder::build_with_move() { + return AssumedRoles::make_for_roles(std::move(_wip_roles)); +} + +} + diff --git a/vespalib/src/vespa/vespalib/net/tls/assumed_roles.h b/vespalib/src/vespa/vespalib/net/tls/assumed_roles.h new file mode 100644 index 00000000000..00d800916fd --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/assumed_roles.h @@ -0,0 +1,80 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include <vespa/vespalib/stllike/hash_set.h> +#include <vespa/vespalib/stllike/string.h> +#include <vector> +#include <iosfwd> + +namespace vespalib { class asciistream; } + +namespace vespalib::net::tls { + +/** + * Encapsulates a set of roles that requests over a particular authenticated + * connection can assume, based on the authorization rules it matched during mTLS + * handshaking. + * + * If at least one role is a wildcard ('*') role, the connection can assume _any_ + * possible role. This is the default when no role constraints are specified in + * the TLS configuration file (legacy behavior). However, a default-constructed + * AssumedRoles instance does not allow any roles to be assumed. + */ +class AssumedRoles { +public: + using RoleSet = hash_set<string>; +private: + RoleSet _assumed_roles; + + static const string WildcardRole; + + explicit AssumedRoles(RoleSet assumed_roles); +public: + AssumedRoles(); + AssumedRoles(const AssumedRoles&); + AssumedRoles& operator=(const AssumedRoles&); + AssumedRoles(AssumedRoles&&) noexcept; + AssumedRoles& operator=(AssumedRoles&&) noexcept; + ~AssumedRoles(); + + [[nodiscard]] bool empty() const noexcept { + return _assumed_roles.empty(); + } + + /** + * Returns true iff `role` is present in the role set OR the role set contains + * the special wildcard role. + */ + [[nodiscard]] bool can_assume_role(const string& role) const noexcept; + + [[nodiscard]] const RoleSet& unordered_roles() const noexcept { + return _assumed_roles; + } + + [[nodiscard]] std::vector<string> ordered_roles() const; + + bool operator==(const AssumedRoles& rhs) const noexcept; + + void print(asciistream& os) const; + + static AssumedRoles make_for_roles(RoleSet assumed_roles); + static AssumedRoles make_wildcard_role(); // Allows assuming _all_ possible roles + static AssumedRoles make_empty(); // Matches _no_ possible roles +}; + +asciistream& operator<<(asciistream&, const AssumedRoles&); +std::ostream& operator<<(std::ostream&, const AssumedRoles&); +string to_string(const AssumedRoles&); + +class AssumedRolesBuilder { + AssumedRoles::RoleSet _wip_roles; +public: + AssumedRolesBuilder(); + ~AssumedRolesBuilder(); + + void add_union(const AssumedRoles& roles); + [[nodiscard]] bool empty() const noexcept { return _wip_roles.empty(); } + [[nodiscard]] AssumedRoles build_with_move(); +}; + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/authorization_result.cpp b/vespalib/src/vespa/vespalib/net/tls/authorization_result.cpp new file mode 100644 index 00000000000..069e971833c --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/authorization_result.cpp @@ -0,0 +1,62 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "authorization_result.h" +#include <vespa/vespalib/stllike/asciistream.h> +#include <ostream> + +namespace vespalib::net::tls { + +AuthorizationResult::AuthorizationResult() = default; + +AuthorizationResult::AuthorizationResult(AssumedRoles assumed_roles) + : _assumed_roles(std::move(assumed_roles)) +{} + +AuthorizationResult::AuthorizationResult(const AuthorizationResult&) = default; +AuthorizationResult& AuthorizationResult::operator=(const AuthorizationResult&) = default; +AuthorizationResult::AuthorizationResult(AuthorizationResult&&) noexcept = default; +AuthorizationResult& AuthorizationResult::operator=(AuthorizationResult&&) noexcept = default; +AuthorizationResult::~AuthorizationResult() = default; + +void AuthorizationResult::print(asciistream& os) const { + os << "AuthorizationResult("; + if (!success()) { + os << "NOT AUTHORIZED"; + } else { + os << _assumed_roles; + } + os << ')'; +} + +AuthorizationResult +AuthorizationResult::make_authorized_for_roles(AssumedRoles assumed_roles) { + return AuthorizationResult(std::move(assumed_roles)); +} + +AuthorizationResult +AuthorizationResult::make_authorized_for_all_roles() { + return AuthorizationResult(AssumedRoles::make_wildcard_role()); +} + +AuthorizationResult +AuthorizationResult::make_not_authorized() { + return {}; +} + +asciistream& operator<<(asciistream& os, const AuthorizationResult& res) { + res.print(os); + return os; +} + +std::ostream& operator<<(std::ostream& os, const AuthorizationResult& res) { + os << to_string(res); + return os; +} + +string to_string(const AuthorizationResult& res) { + asciistream os; + os << res; + return os.str(); +} + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/authorization_result.h b/vespalib/src/vespa/vespalib/net/tls/authorization_result.h new file mode 100644 index 00000000000..b92bbbca9dd --- /dev/null +++ b/vespalib/src/vespa/vespalib/net/tls/authorization_result.h @@ -0,0 +1,55 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#pragma once + +#include "assumed_roles.h" +#include <vespa/vespalib/stllike/string.h> +#include <iosfwd> + +namespace vespalib { class asciistream; } + +namespace vespalib::net::tls { + +/** + * The result of evaluating configured mTLS authorization rules against the + * credentials presented by a successfully authenticated peer certificate. + * + * This result contains the union set of all roles specified by the matching + * authorization rules. If no rules matched, the set will be empty. The role + * set will also be empty for a default-constructed instance. + */ +class AuthorizationResult { + AssumedRoles _assumed_roles; + + explicit AuthorizationResult(AssumedRoles assumed_roles); +public: + AuthorizationResult(); + AuthorizationResult(const AuthorizationResult&); + AuthorizationResult& operator=(const AuthorizationResult&); + AuthorizationResult(AuthorizationResult&&) noexcept; + AuthorizationResult& operator=(AuthorizationResult&&) noexcept; + ~AuthorizationResult(); + + // Returns true iff at least one assumed role has been granted. + [[nodiscard]] bool success() const noexcept { + return !_assumed_roles.empty(); + } + + [[nodiscard]] const AssumedRoles& assumed_roles() const noexcept { + return _assumed_roles; + } + [[nodiscard]] AssumedRoles steal_assumed_roles() noexcept { + return std::move(_assumed_roles); + } + + void print(asciistream& os) const; + + static AuthorizationResult make_authorized_for_roles(AssumedRoles assumed_roles); + static AuthorizationResult make_authorized_for_all_roles(); + static AuthorizationResult make_not_authorized(); +}; + +asciistream& operator<<(asciistream&, const AuthorizationResult&); +std::ostream& operator<<(std::ostream&, const AuthorizationResult&); +string to_string(const AuthorizationResult&); + +} diff --git a/vespalib/src/vespa/vespalib/net/tls/certificate_verification_callback.h b/vespalib/src/vespa/vespalib/net/tls/certificate_verification_callback.h index dec00486dcd..0c18ba1a789 100644 --- a/vespalib/src/vespa/vespalib/net/tls/certificate_verification_callback.h +++ b/vespalib/src/vespa/vespalib/net/tls/certificate_verification_callback.h @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include "authorization_result.h" #include "peer_credentials.h" namespace vespalib::net::tls { @@ -13,15 +14,15 @@ struct CertificateVerificationCallback { virtual ~CertificateVerificationCallback() = default; // Return true iff the peer credentials pass verification, false otherwise. // Must be thread safe. - virtual bool verify(const PeerCredentials& peer_creds) const = 0; + [[nodiscard]] virtual AuthorizationResult verify(const PeerCredentials& peer_creds) const = 0; }; // Simplest possible certificate verification callback which accepts the certificate // iff all its pre-verification by OpenSSL has passed. This means its chain is valid // and it is signed by a trusted CA. struct AcceptAllPreVerifiedCertificates : CertificateVerificationCallback { - bool verify([[maybe_unused]] const PeerCredentials& peer_creds) const override { - return true; // yolo + AuthorizationResult verify([[maybe_unused]] const PeerCredentials& peer_creds) const override { + return AuthorizationResult::make_authorized_for_all_roles(); // yolo } }; diff --git a/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h index 7448bf49799..86ccaf3eb64 100644 --- a/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h +++ b/vespalib/src/vespa/vespalib/net/tls/crypto_codec.h @@ -53,6 +53,8 @@ struct DecodeResult { }; struct TlsContext; +class PeerCredentials; +class AssumedRoles; // TODO move to different namespace, not dependent on TLS? @@ -175,6 +177,18 @@ public: */ virtual EncodeResult half_close(char* ciphertext, size_t ciphertext_size) noexcept = 0; + /** + * Credentials of the _remote peer_ as observed during certificate exchange. E.g. + * if this is a client codec, peer_credentials() returns the _server_ credentials + * and vice versa. + */ + [[nodiscard]] virtual const PeerCredentials& peer_credentials() const noexcept = 0; + + /** + * Union set of all assumed roles in the peer policy rules that fully matched the peer's credentials. + */ + [[nodiscard]] virtual const AssumedRoles& assumed_roles() const noexcept = 0; + /* * Creates an implementation defined CryptoCodec that provides at least TLSv1.2 * compliant handshaking and full duplex data transfer. diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h index 34ca31a8f6c..5be2146b349 100644 --- a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_crypto_codec_impl.h @@ -4,8 +4,10 @@ #include <vespa/vespalib/crypto/openssl_typedefs.h> #include <vespa/vespalib/net/socket_address.h> #include <vespa/vespalib/net/socket_spec.h> -#include <vespa/vespalib/net/tls/transport_security_options.h> +#include <vespa/vespalib/net/tls/assumed_roles.h> #include <vespa/vespalib/net/tls/crypto_codec.h> +#include <vespa/vespalib/net/tls/peer_credentials.h> +#include <vespa/vespalib/net/tls/transport_security_options.h> #include <memory> #include <optional> @@ -55,6 +57,8 @@ class OpenSslCryptoCodecImpl : public CryptoCodec { Mode _mode; std::optional<DeferredHandshakeParams> _deferred_handshake_params; std::optional<HandshakeResult> _deferred_handshake_result; + PeerCredentials _peer_credentials; + AssumedRoles _assumed_roles; public: ~OpenSslCryptoCodecImpl() override; @@ -95,6 +99,14 @@ public: char* plaintext, size_t plaintext_size) noexcept override; EncodeResult half_close(char* ciphertext, size_t ciphertext_size) noexcept override; + [[nodiscard]] const PeerCredentials& peer_credentials() const noexcept override { + return _peer_credentials; + } + + [[nodiscard]] const AssumedRoles& assumed_roles() const noexcept override { + return _assumed_roles; + } + const SocketAddress& peer_address() const noexcept { return _peer_address; } /* * If a client has sent a SNI extension field as part of the handshake, @@ -102,6 +114,15 @@ public: * call this for codecs in server mode. */ std::optional<vespalib::string> client_provided_sni_extension() const; + + // Only used by code bridging OpenSSL certificate verification callbacks and + // evaluation of custom authorization rules. + void set_peer_credentials(PeerCredentials peer_credentials) { + _peer_credentials = std::move(peer_credentials); + } + void set_assumed_roles(AssumedRoles assumed_roles) { + _assumed_roles = std::move(assumed_roles); + } private: OpenSslCryptoCodecImpl(std::shared_ptr<OpenSslTlsContextImpl> ctx, const SocketSpec& peer_spec, diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp index 40e4e1adbcf..3810140854b 100644 --- a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.cpp @@ -451,14 +451,14 @@ int OpenSslTlsContextImpl::verify_cb_wrapper(int preverified_ok, ::X509_STORE_CT auto* self = static_cast<OpenSslTlsContextImpl*>(SSL_CTX_get_app_data(ssl_ctx)); LOG_ASSERT(self != nullptr); - if (self->verify_trusted_certificate(store_ctx, codec_impl->peer_address())) { + if (self->verify_trusted_certificate(store_ctx, *codec_impl)) { return 1; } ConnectionStatistics::get(SSL_in_accept_init(ssl) != 0).inc_invalid_peer_credentials(); return 0; } -bool OpenSslTlsContextImpl::verify_trusted_certificate(::X509_STORE_CTX* store_ctx, const SocketAddress& peer_address) { +bool OpenSslTlsContextImpl::verify_trusted_certificate(::X509_STORE_CTX* store_ctx, OpenSslCryptoCodecImpl& codec_impl) { const auto authz_mode = authorization_mode(); // TODO consider if we want to fill in peer credentials even if authorization is disabled if (authz_mode == AuthorizationMode::Disable) { @@ -477,18 +477,22 @@ bool OpenSslTlsContextImpl::verify_trusted_certificate(::X509_STORE_CTX* store_c return false; } try { - const bool verified_by_cb = _cert_verify_callback->verify(creds); - if (!verified_by_cb) { + auto authz_result = _cert_verify_callback->verify(creds); + if (!authz_result.success()) { // Buffer warnings on peer IP address to avoid log flooding. - LOGBT(warning, peer_address.ip_address(), + LOGBT(warning, codec_impl.peer_address().ip_address(), "Certificate verification of peer '%s' failed with %s", - peer_address.spec().c_str(), to_string(creds).c_str()); + codec_impl.peer_address().spec().c_str(), to_string(creds).c_str()); return (authz_mode != AuthorizationMode::Enforce); } + // Store away credentials and role set for later use by requests that arrive over this connection. + // TODO encapsulate as const shared_ptr to immutable object to better facilitate sharing? + codec_impl.set_peer_credentials(std::move(creds)); + codec_impl.set_assumed_roles(authz_result.steal_assumed_roles()); } catch (std::exception& e) { - LOGBT(error, peer_address.ip_address(), + LOGBT(error, codec_impl.peer_address().ip_address(), "Got exception during certificate verification callback for peer '%s': %s", - peer_address.spec().c_str(), e.what()); + codec_impl.peer_address().spec().c_str(), e.what()); return false; } // we don't expect any non-std::exception derived exceptions, so let them terminate the process. return true; diff --git a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h index b84a599dead..d9e161a7b0f 100644 --- a/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h +++ b/vespalib/src/vespa/vespalib/net/tls/impl/openssl_tls_context_impl.h @@ -12,6 +12,8 @@ namespace vespalib::net::tls::impl { +class OpenSslCryptoCodecImpl; + class OpenSslTlsContextImpl : public TlsContext { crypto::SslCtxPtr _ctx; AuthorizationMode _authorization_mode; @@ -47,7 +49,7 @@ private: void set_ssl_ctx_self_reference(); void set_accepted_cipher_suites(const std::vector<vespalib::string>& ciphers); - bool verify_trusted_certificate(::X509_STORE_CTX* store_ctx, const SocketAddress& peer_address); + bool verify_trusted_certificate(::X509_STORE_CTX* store_ctx, OpenSslCryptoCodecImpl& codec_impl); static int verify_cb_wrapper(int preverified_ok, ::X509_STORE_CTX* store_ctx); }; diff --git a/vespalib/src/vespa/vespalib/net/tls/peer_credentials.cpp b/vespalib/src/vespa/vespalib/net/tls/peer_credentials.cpp index e00d4804fbe..9a001e24fea 100644 --- a/vespalib/src/vespa/vespalib/net/tls/peer_credentials.cpp +++ b/vespalib/src/vespa/vespalib/net/tls/peer_credentials.cpp @@ -2,12 +2,15 @@ #include "peer_credentials.h" #include <vespa/vespalib/stllike/asciistream.h> -#include <iostream> -#include <sstream> +#include <ostream> namespace vespalib::net::tls { PeerCredentials::PeerCredentials() = default; +PeerCredentials::PeerCredentials(const PeerCredentials&) = default; +PeerCredentials& PeerCredentials::operator=(const PeerCredentials&) = default; +PeerCredentials::PeerCredentials(PeerCredentials&&) noexcept = default; +PeerCredentials& PeerCredentials::operator=(PeerCredentials&&) noexcept = default; PeerCredentials::~PeerCredentials() = default; std::ostream& operator<<(std::ostream& os, const PeerCredentials& creds) { @@ -15,17 +18,40 @@ std::ostream& operator<<(std::ostream& os, const PeerCredentials& creds) { return os; } -vespalib::string to_string(const PeerCredentials& creds) { - vespalib::asciistream os; - os << "PeerCredentials(CN '" << creds.common_name - << "', DNS SANs ["; - for (size_t i = 0; i < creds.dns_sans.size(); ++i) { +namespace { +void emit_comma_separated_string_list(asciistream& os, stringref title, + const std::vector<string>& strings, bool prefix_comma) +{ + if (prefix_comma) { + os << ", "; + } + os << title << " ["; + for (size_t i = 0; i < strings.size(); ++i) { if (i != 0) { os << ", "; } - os << '\'' << creds.dns_sans[i] << '\''; + os << '\'' << strings[i] << '\''; + } + os << ']'; +} +} + +vespalib::string to_string(const PeerCredentials& creds) { + asciistream os; + os << "PeerCredentials("; + bool emit_comma = false; + if (!creds.common_name.empty()) { + os << "CN '" << creds.common_name << "'"; + emit_comma = true; + } + if (!creds.dns_sans.empty()) { + emit_comma_separated_string_list(os, "DNS SANs", creds.dns_sans, emit_comma); + emit_comma = true; + } + if (!creds.uri_sans.empty()) { + emit_comma_separated_string_list(os, "URI SANs", creds.uri_sans, emit_comma); } - os << "])"; + os << ')'; return os.str(); } diff --git a/vespalib/src/vespa/vespalib/net/tls/peer_credentials.h b/vespalib/src/vespa/vespalib/net/tls/peer_credentials.h index 636d643a62f..b81772d2bce 100644 --- a/vespalib/src/vespa/vespalib/net/tls/peer_credentials.h +++ b/vespalib/src/vespa/vespalib/net/tls/peer_credentials.h @@ -18,6 +18,10 @@ struct PeerCredentials { std::vector<vespalib::string> uri_sans; PeerCredentials(); + PeerCredentials(const PeerCredentials&); + PeerCredentials& operator=(const PeerCredentials&); + PeerCredentials(PeerCredentials&&) noexcept; + PeerCredentials& operator=(PeerCredentials&&) noexcept; ~PeerCredentials(); }; diff --git a/vespalib/src/vespa/vespalib/net/tls/peer_policies.cpp b/vespalib/src/vespa/vespalib/net/tls/peer_policies.cpp index a476e23e6cb..a4e651f3f19 100644 --- a/vespalib/src/vespa/vespalib/net/tls/peer_policies.cpp +++ b/vespalib/src/vespa/vespalib/net/tls/peer_policies.cpp @@ -119,6 +119,21 @@ RequiredPeerCredential::RequiredPeerCredential(Field field, vespalib::string mus RequiredPeerCredential::~RequiredPeerCredential() = default; +PeerPolicy::PeerPolicy() = default; + +PeerPolicy::PeerPolicy(std::vector<RequiredPeerCredential> required_peer_credentials) + : _required_peer_credentials(std::move(required_peer_credentials)), + _assumed_roles(AssumedRoles::make_wildcard_role()) +{} + +PeerPolicy::PeerPolicy(std::vector<RequiredPeerCredential> required_peer_credentials, + AssumedRoles assumed_roles) + : _required_peer_credentials(std::move(required_peer_credentials)), + _assumed_roles(std::move(assumed_roles)) +{} + +PeerPolicy::~PeerPolicy() = default; + namespace { template <typename Collection> void print_joined(std::ostream& os, const Collection& coll, const char* sep) { diff --git a/vespalib/src/vespa/vespalib/net/tls/peer_policies.h b/vespalib/src/vespa/vespalib/net/tls/peer_policies.h index 4166efc4312..6eab8c2c9b2 100644 --- a/vespalib/src/vespa/vespalib/net/tls/peer_policies.h +++ b/vespalib/src/vespa/vespalib/net/tls/peer_policies.h @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include "assumed_roles.h" #include <vespa/vespalib/stllike/string.h> #include <memory> #include <vector> @@ -49,18 +50,27 @@ public: class PeerPolicy { // _All_ credentials must match for the policy itself to match. std::vector<RequiredPeerCredential> _required_peer_credentials; + AssumedRoles _assumed_roles; public: - PeerPolicy() = default; - explicit PeerPolicy(std::vector<RequiredPeerCredential> required_peer_credentials_) - : _required_peer_credentials(std::move(required_peer_credentials_)) - {} + PeerPolicy(); + // This policy is created with a wildcard role set, i.e. full access. + explicit PeerPolicy(std::vector<RequiredPeerCredential> required_peer_credentials); + + PeerPolicy(std::vector<RequiredPeerCredential> required_peer_credentials, + AssumedRoles assumed_roles); - bool operator==(const PeerPolicy& rhs) const { - return (_required_peer_credentials == rhs._required_peer_credentials); + ~PeerPolicy(); + + bool operator==(const PeerPolicy& rhs) const noexcept { + return ((_required_peer_credentials == rhs._required_peer_credentials) && + (_assumed_roles == rhs._assumed_roles)); } - const std::vector<RequiredPeerCredential>& required_peer_credentials() const noexcept { + [[nodiscard]] const std::vector<RequiredPeerCredential>& required_peer_credentials() const noexcept { return _required_peer_credentials; } + [[nodiscard]] const AssumedRoles& assumed_roles() const noexcept { + return _assumed_roles; + } }; class AuthorizedPeers { diff --git a/vespalib/src/vespa/vespalib/net/tls/policy_checking_certificate_verifier.cpp b/vespalib/src/vespa/vespalib/net/tls/policy_checking_certificate_verifier.cpp index e2c45ad7358..65e14434ff1 100644 --- a/vespalib/src/vespa/vespalib/net/tls/policy_checking_certificate_verifier.cpp +++ b/vespalib/src/vespa/vespalib/net/tls/policy_checking_certificate_verifier.cpp @@ -61,7 +61,7 @@ public: ~PolicyConfiguredCertificateVerifier() override; - bool verify(const PeerCredentials& peer_creds) const override; + AuthorizationResult verify(const PeerCredentials& peer_creds) const override; }; PolicyConfiguredCertificateVerifier::PolicyConfiguredCertificateVerifier(AuthorizedPeers authorized_peers) noexcept @@ -69,16 +69,21 @@ PolicyConfiguredCertificateVerifier::PolicyConfiguredCertificateVerifier(Authori PolicyConfiguredCertificateVerifier::~PolicyConfiguredCertificateVerifier() = default; -bool PolicyConfiguredCertificateVerifier::verify(const PeerCredentials& peer_creds) const { +AuthorizationResult PolicyConfiguredCertificateVerifier::verify(const PeerCredentials& peer_creds) const { if (_authorized_peers.allows_all_authenticated()) { - return true; + return AuthorizationResult::make_authorized_for_all_roles(); } + AssumedRolesBuilder roles_builder; for (const auto& policy : _authorized_peers.peer_policies()) { if (matches_all_policy_requirements(peer_creds, policy)) { - return true; + roles_builder.add_union(policy.assumed_roles()); } } - return false; + if (!roles_builder.empty()) { + return AuthorizationResult::make_authorized_for_roles(roles_builder.build_with_move()); + } else { + return AuthorizationResult::make_not_authorized(); + } } std::shared_ptr<CertificateVerificationCallback> create_verify_callback_from(AuthorizedPeers authorized_peers) { diff --git a/vespalib/src/vespa/vespalib/test/peer_policy_utils.cpp b/vespalib/src/vespa/vespalib/test/peer_policy_utils.cpp index 724efa63e47..82d7b9ea07b 100644 --- a/vespalib/src/vespa/vespalib/test/peer_policy_utils.cpp +++ b/vespalib/src/vespa/vespalib/test/peer_policy_utils.cpp @@ -16,10 +16,23 @@ RequiredPeerCredential required_san_uri(vespalib::stringref pattern) { return {RequiredPeerCredential::Field::SAN_URI, pattern}; } +AssumedRoles assumed_roles(const std::vector<string>& roles) { + // TODO fix hash_set iterator range ctor to make this a one-liner + AssumedRoles::RoleSet role_set; + for (const auto& role : roles) { + role_set.insert(role); + } + return AssumedRoles::make_for_roles(std::move(role_set)); +} + PeerPolicy policy_with(std::vector<RequiredPeerCredential> creds) { return PeerPolicy(std::move(creds)); } +PeerPolicy policy_with(std::vector<RequiredPeerCredential> creds, AssumedRoles roles) { + return {std::move(creds), std::move(roles)}; +} + AuthorizedPeers authorized_peers(std::vector<PeerPolicy> peer_policies) { return AuthorizedPeers(std::move(peer_policies)); } diff --git a/vespalib/src/vespa/vespalib/test/peer_policy_utils.h b/vespalib/src/vespa/vespalib/test/peer_policy_utils.h index fe382f01b50..72e9fde20de 100644 --- a/vespalib/src/vespa/vespalib/test/peer_policy_utils.h +++ b/vespalib/src/vespa/vespalib/test/peer_policy_utils.h @@ -8,7 +8,9 @@ namespace vespalib::net::tls { RequiredPeerCredential required_cn(vespalib::stringref pattern); RequiredPeerCredential required_san_dns(vespalib::stringref pattern); RequiredPeerCredential required_san_uri(vespalib::stringref pattern); +AssumedRoles assumed_roles(const std::vector<string>& roles); PeerPolicy policy_with(std::vector<RequiredPeerCredential> creds); +PeerPolicy policy_with(std::vector<RequiredPeerCredential> creds, AssumedRoles roles); AuthorizedPeers authorized_peers(std::vector<PeerPolicy> peer_policies); } diff --git a/vespalib/src/vespa/vespalib/util/generationhandler.cpp b/vespalib/src/vespa/vespalib/util/generationhandler.cpp index a4b3dd6f5e6..7797978d187 100644 --- a/vespalib/src/vespa/vespalib/util/generationhandler.cpp +++ b/vespalib/src/vespa/vespalib/util/generationhandler.cpp @@ -125,7 +125,7 @@ GenerationHandler::updateFirstUsedGeneration() toFree->_next = _free; _free = toFree; } - _firstUsedGeneration = _first->_generation; + _firstUsedGeneration.store(_first->_generation, std::memory_order_relaxed); } GenerationHandler::GenerationHandler() @@ -215,7 +215,7 @@ GenerationHandler::getGenerationRefCount(generation_t gen) const { if (static_cast<sgeneration_t>(gen - _generation) > 0) return 0u; - if (static_cast<sgeneration_t>(_firstUsedGeneration - gen) > 0) + if (static_cast<sgeneration_t>(getFirstUsedGeneration() - gen) > 0) return 0u; for (GenerationHold *hold = _first; hold != nullptr; hold = hold->_next) { if (hold->_generation.load(std::memory_order_relaxed) == gen) diff --git a/vespalib/src/vespa/vespalib/util/generationhandler.h b/vespalib/src/vespa/vespalib/util/generationhandler.h index 0c4b49a2d5b..2aeb4c2f886 100644 --- a/vespalib/src/vespa/vespalib/util/generationhandler.h +++ b/vespalib/src/vespa/vespalib/util/generationhandler.h @@ -73,7 +73,7 @@ public: private: generation_t _generation; - generation_t _firstUsedGeneration; + std::atomic<generation_t> _firstUsedGeneration; std::atomic<GenerationHold *> _last; // Points to "current generation" entry GenerationHold *_first; // Points to "firstUsedGeneration" entry GenerationHold *_free; // List of free entries @@ -109,7 +109,7 @@ public: * if writer hasn't updated first used generation after last reader left. */ generation_t getFirstUsedGeneration() const { - return _firstUsedGeneration; + return _firstUsedGeneration.load(std::memory_order_relaxed); } /** |