diff options
author | Harald Musum <musum@verizonmedia.com> | 2020-06-19 11:16:02 +0200 |
---|---|---|
committer | Harald Musum <musum@verizonmedia.com> | 2020-06-19 11:16:02 +0200 |
commit | 5b6d5916c999069ac5097e18ef755f713e98f411 (patch) | |
tree | 79165b521e45d934b2541c2da44b5ed574900ef0 | |
parent | 5f5df5235de6f7c2532c19833e64adc113ffe1ed (diff) | |
parent | a74a8703830bb4d8656ea626a085c6966de5c06d (diff) |
Merge branch 'master' into hmusum/configserver-refactoring-13
78 files changed, 1632 insertions, 172 deletions
diff --git a/bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java b/bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java index 0626c786822..35db7b5fef3 100644 --- a/bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java +++ b/bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java @@ -49,11 +49,14 @@ public class AnalyzeBundle { PublicPackages pp = publicPackages(jarFile); exports.addAll(pp.exports); globals.addAll(pp.globals); + + // TODO: remove and clean up everything related to global packages. + if (! pp.globals.isEmpty()) throw new RuntimeException("Found global packages in bundle " + jarFile.getAbsolutePath()); } return new PublicPackages(exports, globals); } - public static PublicPackages publicPackages(File jarFile) { + static PublicPackages publicPackages(File jarFile) { try { Optional<Manifest> jarManifest = JarFiles.getManifest(jarFile); if (jarManifest.isPresent()) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java index 651572b3e36..0fb1407830a 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java @@ -155,7 +155,7 @@ public class Deployment implements com.yahoo.config.provision.Deployment { } session.waitUntilActivated(timeoutBudget); - log.log(Level.INFO, session.logPre() + "activated successfully using " + + log.log(Level.INFO, session.logPre() + "Session " + session.getSessionId() + " activated successfully using " + hostProvisioner.map(provisioner -> provisioner.getClass().getSimpleName()).orElse("no host provisioner") + ". Config generation " + session.getMetaData().getGeneration() + (previousActiveSession != null ? ". Based on previous active session " + previousActiveSession.getSessionId() : "") + diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java index b7b5c6380ef..543f9c2e303 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java @@ -49,12 +49,15 @@ public class ApplicationPackageMaintainer extends ConfigServerMaintainer { @Override protected void maintain() { + log.fine(() -> "Running"); if (! distributeApplicationPackage.value()) return; try (var fileDownloader = new FileDownloader(createConnectionPool(configserverConfig), downloadDirectory)){ for (var applicationId : applicationRepository.listApplications()) { + log.fine(() -> "Verifying application package for " + applicationId); RemoteSession session = applicationRepository.getActiveSession(applicationId); FileReference applicationPackage = session.getApplicationPackageReference(); + log.fine(() -> "Verifying application package file reference " + applicationPackage + " for session " + session.getSessionId()); if (applicationPackage != null && missingOnDisk(applicationPackage)) { log.fine(() -> "Downloading missing application package for application " + applicationId + " - session " + session.getSessionId()); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java index 2397dba6b5e..aa851a95335 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java @@ -103,7 +103,9 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P modelVersion, wantedNodeVespaVersion); - log.log(Level.FINE, "Create and validate model " + modelVersion + " for " + applicationId); + + log.log(properties.zone().system().isCd() ? Level.INFO : Level.FINE, + "Create and validate model " + modelVersion + " for " + applicationId + ", previous model is " + modelOf(modelVersion)); ValidationParameters validationParameters = new ValidationParameters(params.ignoreValidationErrors() ? IgnoreValidationErrors.TRUE : IgnoreValidationErrors.FALSE); ModelCreateResult result = modelFactory.createAndValidateModel(modelContext, validationParameters); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java index a29b105f43f..2e101762fc4 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java @@ -84,7 +84,7 @@ public final class PrepareParams { private boolean verbose = false; private boolean isBootstrap = false; private ApplicationId applicationId = ApplicationId.defaultId(); - private TimeoutBudget timeoutBudget = new TimeoutBudget(Clock.systemUTC(), Duration.ofSeconds(30)); + private TimeoutBudget timeoutBudget = new TimeoutBudget(Clock.systemUTC(), Duration.ofSeconds(60)); private Optional<Version> vespaVersion = Optional.empty(); private List<ContainerEndpoint> containerEndpoints = null; private Optional<String> tlsSecretsKeyName = Optional.empty(); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java index 3c83263f781..4542b8267e8 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java @@ -126,10 +126,10 @@ public class SessionPreparer { Preparation preparation = new Preparation(hostValidator, logger, params, currentActiveApplicationSet, tenantPath, serverDbSessionDir, applicationPackage, sessionZooKeeperClient); - // Note: Done before pre-processing, requires that to be done by users of the distributed package + preparation.preprocess(); + var distributedApplicationPackage = preparation.distributeApplicationPackage(); - preparation.preprocess(); try { AllocatedHosts allocatedHosts = preparation.buildModels(now); preparation.makeResult(allocatedHosts); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java index dee11bcd332..0a8d6cdde69 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java @@ -51,7 +51,6 @@ public class SessionStateWatcher { private void sessionChanged(Session.Status status) { long sessionId = remoteSession.getSessionId(); - log.log(Level.FINE, remoteSession.logPre() + "Session change: Session " + remoteSession.getSessionId() + " changed status to " + status); // valid for NEW -> PREPARE transitions, not ACTIVATE -> PREPARE. if (status.equals(Session.Status.PREPARE)) { @@ -91,6 +90,8 @@ public class SessionStateWatcher { ChildData node = fileCache.getCurrentData(); if (node != null) { newStatus = Session.Status.parse(Utf8.toString(node.getData())); + log.log(Level.FINE, remoteSession.logPre() + "Session change: Remote session " + remoteSession.getSessionId() + + " changed status from " + currentStatus.name() + " to " + newStatus.name()); sessionChanged(newStatus); } } catch (Exception e) { diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java index 222910d5bc6..0b5f2538892 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.hosted.controller.api.integration; import com.yahoo.vespa.hosted.controller.api.integration.aws.ApplicationRoleService; import com.yahoo.vespa.hosted.controller.api.integration.aws.AwsEventFetcher; import com.yahoo.vespa.hosted.controller.api.integration.aws.ResourceTagger; +import com.yahoo.vespa.hosted.controller.api.integration.billing.PlanController; import com.yahoo.vespa.hosted.controller.api.integration.certificates.EndpointCertificateProvider; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServer; import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationStore; @@ -76,4 +77,6 @@ public interface ServiceRegistry { SystemMonitor systemMonitor(); + PlanController planController(); + } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/CostCalculator.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/CostCalculator.java new file mode 100644 index 00000000000..628beec8450 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/CostCalculator.java @@ -0,0 +1,19 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.integration.billing; + +import com.yahoo.config.provision.NodeResources; +import com.yahoo.vespa.hosted.controller.api.integration.resource.CostInfo; + + +/** + * @author ogronnesby + */ +public interface CostCalculator { + + /** Calculate the cost for the given usage */ + CostInfo calculate(ResourceUsage usage); + + /** Estimate the cost for the given resources */ + double calculate(NodeResources resources); + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/Plan.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/Plan.java new file mode 100644 index 00000000000..75a88136c45 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/Plan.java @@ -0,0 +1,23 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.integration.billing; + +/** + * A Plan decides two different things: + * + * - How to map from usage to a sum of money that is owed. + * - Limits on how much resources can be used. + * + * @author ogronnesby + */ +public interface Plan { + + /** The ID of the plan as used in APIs and storage systems */ + String id(); + + /** The calculator used to calculate a bill for usage */ + CostCalculator calculator(); + + /** The quota limits associated with the plan */ + Object quota(); + +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/PlanController.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/PlanController.java new file mode 100644 index 00000000000..f13c251d212 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/PlanController.java @@ -0,0 +1,10 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.integration.billing; + +import com.yahoo.config.provision.TenantName; + +public interface PlanController { + + Plan getPlan(TenantName tenant); + +}
\ No newline at end of file diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/ResourceUsage.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/ResourceUsage.java new file mode 100644 index 00000000000..cbfd2b6ff50 --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/ResourceUsage.java @@ -0,0 +1,54 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.hosted.controller.api.integration.billing; + +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.zone.ZoneId; + +import java.math.BigDecimal; + +/** + * @author olaa + */ +public class ResourceUsage { + + private final ApplicationId applicationId; + private final ZoneId zoneId; + private final Plan plan; + private final BigDecimal cpuMillis; + private final BigDecimal memoryMillis; + private final BigDecimal diskMillis; + + public ResourceUsage(ApplicationId applicationId, ZoneId zoneId, Plan plan, + BigDecimal cpuMillis, BigDecimal memoryMillis, BigDecimal diskMillis) { + this.applicationId = applicationId; + this.zoneId = zoneId; + this.cpuMillis = cpuMillis; + this.memoryMillis = memoryMillis; + this.diskMillis = diskMillis; + this.plan = plan; + } + + public ApplicationId getApplicationId() { + return applicationId; + } + + public ZoneId getZoneId() { + return zoneId; + } + + public BigDecimal getCpuMillis() { + return cpuMillis; + } + + public BigDecimal getMemoryMillis() { + return memoryMillis; + } + + public BigDecimal getDiskMillis() { + return diskMillis; + } + + public Plan getPlan() { + return plan; + } +} diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/package-info.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/package-info.java new file mode 100644 index 00000000000..ae31f4a782d --- /dev/null +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/package-info.java @@ -0,0 +1,5 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@ExportPackage +package com.yahoo.vespa.hosted.controller.api.integration.billing; + +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java index 68dff26529f..2fdf442dbe0 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java @@ -219,8 +219,10 @@ enum PathGroup { /** Paths used for invoice management */ hostedAccountant(PathPrefix.api, "/billing/v1/invoice/{*}", - "/billing/v1/billing"); + "/billing/v1/billing"), + /** Path used for listing endpoint certificate request info */ + endpointCertificateRequestInfo(PathPrefix.none, "/certificateRequests/"); final List<String> pathSpecs; final PathPrefix prefix; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java index 9a5a0ad0e77..83adba6f59b 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java @@ -166,7 +166,12 @@ enum Policy { /** Invoice management */ hostedAccountant(Privilege.grant(Action.all()) .on(PathGroup.hostedAccountant) - .in(SystemName.PublicCd)); + .in(SystemName.PublicCd)), + + /** Listing endpoint certificate request info */ + endpointCertificateRequestInfo(Privilege.grant(Action.read) + .on(PathGroup.endpointCertificateRequestInfo) + .in(SystemName.all())); private final Set<Privilege> privileges; diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java index bf5ba4001fa..b9d534019db 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java @@ -61,6 +61,7 @@ public enum RoleDefinition { /** Admin — the administrative function for user management etc. */ administrator(Policy.tenantUpdate, Policy.tenantManager, + Policy.tenantDelete, Policy.applicationManager, Policy.paymentInstrumentRead, Policy.paymentInstrumentUpdate, diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java index 64549825b04..425364f6741 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java @@ -6,7 +6,6 @@ import com.google.common.io.BaseEncoding; import com.yahoo.config.application.api.DeploymentInstanceSpec; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ClusterSpec; -import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.zone.RoutingMethod; import com.yahoo.config.provision.zone.ZoneApi; @@ -51,7 +50,6 @@ import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; import java.util.stream.Collectors; -import java.util.stream.Stream; /** * Looks up stored endpoint certificate metadata, provisions new certificates if none is found, @@ -323,10 +321,10 @@ public class EndpointCertificateManager { } /** Create a common name based on a hash of the ApplicationId. This should always be less than 64 characters long. */ + @SuppressWarnings("UnstableApiUsage") private static String commonNameHashOf(ApplicationId application, SystemName system) { var hashCode = Hashing.sha1().hashString(application.serializedForm(), Charset.defaultCharset()); var base32encoded = BaseEncoding.base32().omitPadding().lowerCase().encode(hashCode.asBytes()); return 'v' + base32encoded + Endpoint.dnsSuffix(system); } - } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java index ee5d50414c1..786547d4a67 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java @@ -1,7 +1,6 @@ // Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.controller.dns; -import java.util.logging.Level; import com.yahoo.vespa.hosted.controller.api.integration.dns.NameService; import java.util.ArrayList; @@ -10,6 +9,8 @@ import java.util.Collections; import java.util.List; import java.util.Objects; import java.util.concurrent.LinkedBlockingDeque; +import java.util.function.UnaryOperator; +import java.util.logging.Level; import java.util.logging.Logger; /** @@ -39,12 +40,14 @@ public class NameServiceQueue { return Collections.unmodifiableCollection(requests); } - /** Returns a copy of this containing only the n most recent requests */ + /** Returns a copy of this containing the last n requests */ public NameServiceQueue last(int n) { - requireNonNegative(n); - if (requests.size() <= n) return this; - List<NameServiceRequest> requests = new ArrayList<>(this.requests); - return new NameServiceQueue(requests.subList(requests.size() - n, requests.size())); + return resize(n, (requests) -> requests.subList(requests.size() - n, requests.size())); + } + + /** Returns a copy of this containing the first n requests */ + public NameServiceQueue first(int n) { + return resize(n, (requests) -> requests.subList(0, n)); } /** Returns a copy of this with given request queued according to priority */ @@ -58,6 +61,7 @@ public class NameServiceQueue { return queue; } + /** Returns a copy of this with given request added */ public NameServiceQueue with(NameServiceRequest request) { return with(request, Priority.normal); } @@ -91,6 +95,13 @@ public class NameServiceQueue { return requests.toString(); } + private NameServiceQueue resize(int n, UnaryOperator<List<NameServiceRequest>> resizer) { + requireNonNegative(n); + if (requests.size() <= n) return this; + List<NameServiceRequest> requests = new ArrayList<>(this.requests); + return new NameServiceQueue(resizer.apply(requests)); + } + private static void requireNonNegative(int n) { if (n < 0) throw new IllegalArgumentException("n must be >= 0, got " + n); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java index 73daef7c2b0..e7eaf083a57 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java @@ -45,7 +45,7 @@ public class NameServiceDispatcher extends ControllerMaintainer { var remaining = queue.dispatchTo(nameService, requestCount); if (queue == remaining) return; // Queue unchanged - var dispatched = queue.last(requestCount); + var dispatched = queue.first(requestCount); if (!dispatched.requests().isEmpty()) { log.log(Level.INFO, "Dispatched name service request(s) in " + Duration.between(instant, clock.instant()) + diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java index dc3dbabcc07..bd0143ef879 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java @@ -8,6 +8,8 @@ import com.yahoo.vespa.flags.FetchVector; import com.yahoo.vespa.flags.FlagSource; import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.hosted.controller.Application; +import com.yahoo.vespa.hosted.controller.api.integration.ServiceRegistry; +import com.yahoo.vespa.hosted.controller.api.integration.billing.PlanController; import com.yahoo.vespa.hosted.controller.api.integration.organization.BillingInfo; import com.yahoo.vespa.hosted.controller.api.integration.user.Roles; import com.yahoo.vespa.hosted.controller.api.integration.user.UserId; @@ -34,11 +36,13 @@ public class CloudAccessControl implements AccessControl { private final UserManagement userManagement; private final BooleanFlag enablePublicSignup; + private final PlanController planController; @Inject - public CloudAccessControl(UserManagement userManagement, FlagSource flagSource) { + public CloudAccessControl(UserManagement userManagement, FlagSource flagSource, ServiceRegistry serviceRegistry) { this.userManagement = userManagement; this.enablePublicSignup = Flags.ENABLE_PUBLIC_SIGNUP_FLOW.bindTo(flagSource); + planController = serviceRegistry.planController(); } @Override @@ -97,12 +101,17 @@ public class CloudAccessControl implements AccessControl { @Override public void deleteTenant(TenantName tenant, Credentials credentials) { - // TODO: allow only if 0 resources, 0 balance + if(!(allowedByPrivilegedRole((Auth0Credentials) credentials) || isTrial(tenant))) + throw new ForbiddenException("Please contact the Vespa team for assistance in deleting non-trial tenants"); for (TenantRole role : Roles.tenantRoles(tenant)) userManagement.deleteRole(role); } + private boolean isTrial(TenantName tenant) { + return planController.getPlan(tenant).id().equals("trial"); + } + @Override public void createApplication(TenantAndApplicationId id, Credentials credentials) { for (Role role : Roles.applicationRoles(id.tenant(), id.application())) diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java index d0362ae98b8..30ed9b5432c 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java @@ -74,13 +74,21 @@ public class NameServiceQueueTest { assertTrue(queue.requests().isEmpty()); assertTrue("Removed " + r1, nameService.findRecords(Record.Type.CNAME, r1.name()).isEmpty()); - // Keep n most recent requests + // Keep n last requests queue = queue.with(req1).with(req2).with(req3).with(req4).with(req6) .last(2); assertEquals(List.of(req4, req6), List.copyOf(queue.requests())); assertSame(queue, queue.last(2)); assertSame(queue, queue.last(10)); assertTrue(queue.last(0).requests().isEmpty()); + + // Keep n first requests + queue = NameServiceQueue.EMPTY.with(req1).with(req2).with(req3).with(req4).with(req6) + .first(3); + assertEquals(List.of(req1, req2, req3), List.copyOf(queue.requests())); + assertSame(queue, queue.first(3)); + assertSame(queue, queue.first(10)); + assertTrue(queue.first(0).requests().isEmpty()); } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java index 4f81e443d9c..b7e7c9814e3 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java @@ -12,6 +12,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.aws.MockAwsEventFetcher import com.yahoo.vespa.hosted.controller.api.integration.aws.MockResourceTagger; import com.yahoo.vespa.hosted.controller.api.integration.aws.NoopApplicationRoleService; import com.yahoo.vespa.hosted.controller.api.integration.aws.ResourceTagger; +import com.yahoo.vespa.hosted.controller.api.integration.billing.PlanController; import com.yahoo.vespa.hosted.controller.api.integration.certificates.EndpointCertificateMock; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServer; import com.yahoo.vespa.hosted.controller.api.integration.dns.MemoryNameService; @@ -59,6 +60,7 @@ public class ServiceRegistryMock extends AbstractComponent implements ServiceReg private final MockRunDataStore mockRunDataStore = new MockRunDataStore(); private final MockResourceTagger mockResourceTagger = new MockResourceTagger(); private final ApplicationRoleService applicationRoleService = new NoopApplicationRoleService(); + private final PlanController planController = (tenantName) -> null; public ServiceRegistryMock(SystemName system) { this.zoneRegistryMock = new ZoneRegistryMock(system); @@ -201,4 +203,9 @@ public class ServiceRegistryMock extends AbstractComponent implements ServiceReg return endpointCertificateMock; } + @Override + public PlanController planController() { + return planController; + } + } diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index 3a9aabc83ba..3e81521550a 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -19,6 +19,7 @@ vespa_define_module( src/tests/eval/gbdt src/tests/eval/inline_operation src/tests/eval/interpreted_function + src/tests/eval/multiply_add src/tests/eval/node_tools src/tests/eval/node_types src/tests/eval/param_usage diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp index 7d48307b786..df1c06593cb 100644 --- a/eval/src/apps/tensor_conformance/generate.cpp +++ b/eval/src/apps/tensor_conformance/generate.cpp @@ -100,6 +100,8 @@ void generate_tensor_map(TestBuilder &dst) { generate_op1_map("relu(a)", operation::Relu::f, Sub2(Div16(N())), dst); generate_op1_map("sigmoid(a)", operation::Sigmoid::f, Sub2(Div16(N())), dst); generate_op1_map("elu(a)", operation::Elu::f, Sub2(Div16(N())), dst); + // TODO(havardpe): add erf when supported by Java + // generate_op1_map("erf(a)", operation::Erf::f, Sub2(Div16(N())), dst); generate_op1_map("a in [1,5,7,13,42]", MyIn::f, N(), dst); generate_map_expr("map(a,f(a)((a+1)*2))", MyOp::f, Div16(N()), dst); } diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp index fe9396398da..de5a3fbf395 100644 --- a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp +++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp @@ -65,6 +65,7 @@ TEST(InlineOperationTest, op1_lambdas_are_recognized) { EXPECT_EQ(as_op1("relu(a)"), &Relu::f); EXPECT_EQ(as_op1("sigmoid(a)"), &Sigmoid::f); EXPECT_EQ(as_op1("elu(a)"), &Elu::f); + EXPECT_EQ(as_op1("erf(a)"), &Erf::f); //------------------------------------------- EXPECT_EQ(as_op1("1/a"), &Inv::f); EXPECT_EQ(as_op1("1.0/a"), &Inv::f); diff --git a/eval/src/tests/eval/multiply_add/CMakeLists.txt b/eval/src/tests/eval/multiply_add/CMakeLists.txt new file mode 100644 index 00000000000..c50aa4f50a2 --- /dev/null +++ b/eval/src/tests/eval/multiply_add/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_multiply_add_test_app TEST + SOURCES + multiply_add_test.cpp + DEPENDS + vespaeval + gtest +) +vespa_add_test(NAME eval_multiply_add_test_app COMMAND eval_multiply_add_test_app) diff --git a/eval/src/tests/eval/multiply_add/multiply_add_test.cpp b/eval/src/tests/eval/multiply_add/multiply_add_test.cpp new file mode 100644 index 00000000000..35cab0a6030 --- /dev/null +++ b/eval/src/tests/eval/multiply_add/multiply_add_test.cpp @@ -0,0 +1,44 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/function.h> +#include <vespa/eval/eval/llvm/compiled_function.h> +#include <vespa/eval/eval/interpreted_function.h> +#include <vespa/eval/tensor/default_tensor_engine.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; + +using Engine = vespalib::tensor::DefaultTensorEngine; + +double gcc_fun(double a, double b) { + return (a * 3) + b; +} + +TEST(MultiplyAddTest, multiply_add_gives_same_result) { + auto fun = Function::parse("a*3+b"); + CompiledFunction cfun(*fun, PassParams::ARRAY); + NodeTypes node_types = NodeTypes(*fun, {ValueType::double_type(), ValueType::double_type()}); + InterpretedFunction ifun(Engine::ref(), *fun, node_types); + auto llvm_fun = cfun.get_function(); + //------------------------------------------------------------------------- + double a = -1.0/3.0; + double b = 1.0; + std::vector<double> ab({a, b}); + SimpleParams params(ab); + InterpretedFunction::Context ictx(ifun); + //------------------------------------------------------------------------- + const Value &result_value = ifun.eval(ictx, params); + double ifun_res = result_value.as_double(); + double llvm_res = llvm_fun(&ab[0]); + double gcc_res = gcc_fun(a, b); + fprintf(stderr, "ifun_res: %a\n", ifun_res); + fprintf(stderr, "llvm_res: %a\n", llvm_res); + fprintf(stderr, "gcc_res: %a\n", gcc_res); + EXPECT_EQ(ifun_res, llvm_res); + EXPECT_DOUBLE_EQ(llvm_res + 1.0, gcc_res + 1.0); + if (llvm_res != gcc_res) { + fprintf(stderr, "WARNING: diverging results caused by fused multiply add\n"); + } +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/eval/node_tools/node_tools_test.cpp b/eval/src/tests/eval/node_tools/node_tools_test.cpp index ca89650127e..13185065f57 100644 --- a/eval/src/tests/eval/node_tools/node_tools_test.cpp +++ b/eval/src/tests/eval/node_tools/node_tools_test.cpp @@ -99,6 +99,7 @@ TEST("require that call node types can be copied") { TEST_DO(verify_copy("relu(a)")); TEST_DO(verify_copy("sigmoid(a)")); TEST_DO(verify_copy("elu(a)")); + TEST_DO(verify_copy("erf(a)")); } TEST("require that tensor node types can NOT be copied (yet)") { diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index 7912ec213bc..f595c58ef29 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -191,6 +191,7 @@ TEST("require that various operations resolve appropriate type") { TEST_DO(verify_op1("relu(%s)")); // Relu TEST_DO(verify_op1("sigmoid(%s)")); // Sigmoid TEST_DO(verify_op1("elu(%s)")); // Elu + TEST_DO(verify_op1("erf(%s)")); // Erf } TEST("require that map resolves correct type") { diff --git a/eval/src/vespa/eval/eval/call_nodes.cpp b/eval/src/vespa/eval/eval/call_nodes.cpp index 69a9151a2bb..2fc25bdbc77 100644 --- a/eval/src/vespa/eval/eval/call_nodes.cpp +++ b/eval/src/vespa/eval/eval/call_nodes.cpp @@ -42,6 +42,7 @@ CallRepo::CallRepo() : _map() { add(nodes::Relu()); add(nodes::Sigmoid()); add(nodes::Elu()); + add(nodes::Erf()); } } // namespace vespalib::eval::nodes diff --git a/eval/src/vespa/eval/eval/call_nodes.h b/eval/src/vespa/eval/eval/call_nodes.h index 8210616750e..c5a41756005 100644 --- a/eval/src/vespa/eval/eval/call_nodes.h +++ b/eval/src/vespa/eval/eval/call_nodes.h @@ -138,6 +138,7 @@ struct IsNan : CallHelper<IsNan> { IsNan() : Helper("isNan", 1) {} }; struct Relu : CallHelper<Relu> { Relu() : Helper("relu", 1) {} }; struct Sigmoid : CallHelper<Sigmoid> { Sigmoid() : Helper("sigmoid", 1) {} }; struct Elu : CallHelper<Elu> { Elu() : Helper("elu", 1) {} }; +struct Erf : CallHelper<Erf> { Erf() : Helper("erf", 1) {} }; //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index cd31f92f96a..31167be5fe1 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -85,6 +85,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { void visit(const Relu &) override { add_byte(59); } void visit(const Sigmoid &) override { add_byte(60); } void visit(const Elu &) override { add_byte(61); } + void visit(const Erf &) override { add_byte(62); } // traverse bool open(const Node &node) override { node.accept(*this); return true; } diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 89f1789e97b..6f9bee025c9 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -644,6 +644,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const Elu &) override { make_call_1("vespalib_eval_elu"); } + void visit(const Erf &) override { + make_call_1("erf"); + } }; FunctionBuilder::~FunctionBuilder() { } diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index e80633b5c41..c5b5ca59401 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -344,6 +344,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const Elu &node) override { make_map(node, operation::Elu::f); } + void visit(const Erf &node) override { + make_map(node, operation::Erf::f); + } //------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp index 7bbe095c060..1c194736138 100644 --- a/eval/src/vespa/eval/eval/node_tools.cpp +++ b/eval/src/vespa/eval/eval/node_tools.cpp @@ -180,6 +180,7 @@ struct CopyNode : NodeTraverser, NodeVisitor { void visit(const Relu &node) override { copy_call(node); } void visit(const Sigmoid &node) override { copy_call(node); } void visit(const Elu &node) override { copy_call(node); } + void visit(const Erf &node) override { copy_call(node); } // traverse nodes bool open(const Node &) override { return !error; } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 468b9a58655..cbc96e719e0 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -274,6 +274,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { void visit(const Relu &node) override { resolve_op1(node); } void visit(const Sigmoid &node) override { resolve_op1(node); } void visit(const Elu &node) override { resolve_op1(node); } + void visit(const Erf &node) override { resolve_op1(node); } //------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index d3e066c8f53..95a5bec8be7 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -83,6 +83,7 @@ struct NodeVisitor { virtual void visit(const nodes::Relu &) = 0; virtual void visit(const nodes::Sigmoid &) = 0; virtual void visit(const nodes::Elu &) = 0; + virtual void visit(const nodes::Erf &) = 0; virtual ~NodeVisitor() {} }; @@ -150,6 +151,7 @@ struct EmptyNodeVisitor : NodeVisitor { void visit(const nodes::Relu &) override {} void visit(const nodes::Sigmoid &) override {} void visit(const nodes::Elu &) override {} + void visit(const nodes::Erf &) override {} }; } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp index fa8be4d20bc..b97ac3f2261 100644 --- a/eval/src/vespa/eval/eval/operation.cpp +++ b/eval/src/vespa/eval/eval/operation.cpp @@ -49,6 +49,7 @@ double IsNan::f(double a) { return std::isnan(a) ? 1.0 : 0.0; } double Relu::f(double a) { return std::max(a, 0.0); } double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); } double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; } +double Erf::f(double a) { return std::erf(a); } //----------------------------------------------------------------------------- double Inv::f(double a) { return (1.0 / a); } double Square::f(double a) { return (a * a); } @@ -106,6 +107,7 @@ std::map<vespalib::string,op1_t> make_op1_map() { add_op1(map, "relu(a)", Relu::f); add_op1(map, "sigmoid(a)", Sigmoid::f); add_op1(map, "elu(a)", Elu::f); + add_op1(map, "erf(a)", Erf::f); //------------------------------------- add_op1(map, "1/a", Inv::f); add_op1(map, "a*a", Square::f); diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h index 02d3322f867..3170c868214 100644 --- a/eval/src/vespa/eval/eval/operation.h +++ b/eval/src/vespa/eval/eval/operation.h @@ -48,6 +48,7 @@ struct IsNan { static double f(double a); }; struct Relu { static double f(double a); }; struct Sigmoid { static double f(double a); }; struct Elu { static double f(double a); }; +struct Erf { static double f(double a); }; //----------------------------------------------------------------------------- struct Inv { static double f(double a); }; struct Square { static double f(double a); }; diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp index 709234a1a2c..b1dfa6d3c9c 100644 --- a/eval/src/vespa/eval/eval/test/eval_spec.cpp +++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp @@ -151,6 +151,7 @@ EvalSpec::add_function_call_cases() { add_rule({"a", -1.0, 1.0}, "relu(a)", [](double a){ return std::max(a, 0.0); }); add_rule({"a", -1.0, 1.0}, "sigmoid(a)", [](double a){ return 1.0 / (1.0 + std::exp(-1.0 * a)); }); add_rule({"a", -1.0, 1.0}, "elu(a)", [](double a){ return (a < 0) ? std::exp(a)-1 : a; }); + add_rule({"a", -1.0, 1.0}, "erf(a)", [](double a){ return std::erf(a); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "atan2(a,b)", [](double a, double b){ return std::atan2(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "ldexp(a,b)", [](double a, double b){ return std::ldexp(a, b); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "pow(a,b)", [](double a, double b){ return std::pow(a, b); }); @@ -162,11 +163,11 @@ EvalSpec::add_function_call_cases() { void EvalSpec::add_tensor_operation_cases() { add_rule({"a", -1.0, 1.0}, "map(a,f(x)(sin(x)))", [](double x){ return std::sin(x); }); - add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x+x*3))", [](double x){ return (x + (x * 3)); }); + add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x*x*3))", [](double x){ return ((x * x) * 3); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); }); - add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); }); + add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x*y*3))", [](double x, double y){ return ((x * y) * 3); }); add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); }); - add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); }); + add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x*y*3))", [](double x, double y){ return ((x * y) * 3); }); add_rule({"a", -1.0, 1.0}, "reduce(a,avg)", [](double a){ return a; }); add_rule({"a", -1.0, 1.0}, "reduce(a,count)", [](double){ return 1.0; }); add_rule({"a", -1.0, 1.0}, "reduce(a,prod)", [](double a){ return a; }); diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 41c6dd21e24..95e720cd1a2 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -412,6 +412,7 @@ struct TestContext { TEST_DO(test_map_op("relu(a)", operation::Relu::f, Sub2(Div16(N())))); TEST_DO(test_map_op("sigmoid(a)", operation::Sigmoid::f, Sub2(Div16(N())))); TEST_DO(test_map_op("elu(a)", operation::Elu::f, Sub2(Div16(N())))); + TEST_DO(test_map_op("erf(a)", operation::Erf::f, Sub2(Div16(N())))); TEST_DO(test_map_op("a in [1,5,7,13,42]", MyIn::f, N())); TEST_DO(test_map_op("(a+1)*2", MyOp::f, Div16(N()))); } diff --git a/eval/src/vespa/eval/eval/visit_stuff.cpp b/eval/src/vespa/eval/eval/visit_stuff.cpp index 821e609ebd0..9306a720837 100644 --- a/eval/src/vespa/eval/eval/visit_stuff.cpp +++ b/eval/src/vespa/eval/eval/visit_stuff.cpp @@ -35,6 +35,7 @@ vespalib::string name_of(map_fun_t fun) { if (fun == operation::Relu::f) return "relu"; if (fun == operation::Sigmoid::f) return "sigmoid"; if (fun == operation::Elu::f) return "elu"; + if (fun == operation::Erf::f) return "erf"; return "[other map function]"; } diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java b/jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java index 748c2951a6a..1662ed5b46a 100644 --- a/jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java +++ b/jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jdisc.handler; +import java.io.IOException; import java.io.InputStream; import java.nio.ByteBuffer; import java.util.Objects; @@ -19,6 +20,8 @@ public class UnsafeContentInputStream extends InputStream { private final ReadableContentChannel content; private ByteBuffer buf = ByteBuffer.allocate(0); + private byte [] marked; + private int readSinceMarked; /** * <p>Constructs a new ContentInputStream that reads from the given {@link ReadableContentChannel}.</p> @@ -37,7 +40,15 @@ public class UnsafeContentInputStream extends InputStream { if (buf == null) { return -1; } - return ((int)buf.get()) & 0xFF; + byte b = buf.get(); + if (marked != null) { + if (readSinceMarked < marked.length) { + marked[readSinceMarked++] = b; + } else { + marked = null; + } + } + return ((int)b) & 0xFF; } @Override @@ -79,4 +90,28 @@ public class UnsafeContentInputStream extends InputStream { } } + + @Override + public synchronized void mark(int readlimit) { + marked = new byte[readlimit]; + readSinceMarked = 0; + } + + @Override + public synchronized void reset() throws IOException { + if (marked == null) { + throw new IOException("mark has not been called, or too much has been read since marked."); + } + ByteBuffer newBuf = ByteBuffer.allocate(readSinceMarked + buf.remaining()); + newBuf.put(marked, 0, readSinceMarked); + newBuf.put(buf); + newBuf.flip(); + buf = newBuf; + marked = null; + } + + @Override + public boolean markSupported() { + return true; + } } diff --git a/jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java b/jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java index c00fab6cb56..c96450c1bd2 100644 --- a/jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java +++ b/jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.jdisc.handler; +import com.yahoo.text.Utf8; import org.junit.Test; import java.io.BufferedReader; @@ -8,11 +9,11 @@ import java.io.IOException; import java.io.InputStream; import java.io.InputStreamReader; import java.nio.ByteBuffer; -import java.util.concurrent.Future; import static org.junit.Assert.assertTrue; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; /** * @author Simon Thoresen Hult @@ -32,7 +33,37 @@ public class UnsafeContentInputStreamTestCase { assertNull(reader.readLine()); } - @SuppressWarnings("deprecation") + @Test + public void testMark() throws IOException { + BufferedContentChannel channel = new BufferedContentChannel(); + FastContentWriter writer = new FastContentWriter(channel); + writer.write("Hello "); + writer.write("World!"); + writer.close(); + + InputStream stream = asInputStream(channel); + assertTrue(stream.markSupported()); + int first = stream.read(); + assertEquals('H', first); + stream.mark(10); + byte [] buf = new byte[8]; + stream.read(buf); + assertEquals("ello Wor", Utf8.toString(buf)); + stream.reset(); + stream.mark(5); + buf = new byte [9]; + stream.read(buf); + assertEquals("ello Worl", Utf8.toString(buf)); + try { + stream.reset(); + fail("UnsafeContentInputStream.reset expected to fail when your read past readLimit."); + } catch (IOException e) { + assertEquals("mark has not been called, or too much has been read since marked.", e.getMessage()); + } catch (Throwable t) { + fail("Did not expect " + t); + } + } + @Test public void requireThatCompletionsAreCalledWithDeprecatedContentWriter() throws IOException { BufferedContentChannel channel = new BufferedContentChannel(); @@ -63,7 +94,6 @@ public class UnsafeContentInputStreamTestCase { assertTrue(writer.isDone()); } - @SuppressWarnings("deprecation") @Test public void requireThatCloseDrainsStreamWithDeprecatedContentWriter() { BufferedContentChannel channel = new BufferedContentChannel(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index d14ad033a69..3d36b1bfffc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,11 +2,15 @@ package ai.vespa.rankingexpression.importer.onnx; +import ai.vespa.rankingexpression.importer.operations.ConstantOfShape; +import ai.vespa.rankingexpression.importer.operations.Expand; import ai.vespa.rankingexpression.importer.operations.Gather; +import ai.vespa.rankingexpression.importer.operations.OnnxConstant; import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; import ai.vespa.rankingexpression.importer.operations.ConcatReduce; import ai.vespa.rankingexpression.importer.operations.OnnxConcat; +import ai.vespa.rankingexpression.importer.operations.Range; import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Slice; @@ -81,11 +85,15 @@ class GraphImporter { case "cast": return new OnnxCast(modelName, nodeName, inputs, attributes); case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes); + case "constant": return new OnnxConstant(modelName, nodeName, inputs, attributes); + case "constantofshape": return new ConstantOfShape(modelName, nodeName, inputs, attributes); case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble())); + case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.erf()); case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); + case "expand": return new Expand(modelName, nodeName, inputs); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); case "gather": return new Gather(modelName, nodeName, inputs, attributes); case "gemm": return new Gemm(modelName, nodeName, inputs, attributes); @@ -100,6 +108,7 @@ class GraphImporter { case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); + case "range": return new Range(modelName, nodeName, inputs); case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null); case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java new file mode 100644 index 00000000000..887e350b430 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java @@ -0,0 +1,83 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class ConstantOfShape extends IntermediateOperation { + + private final AttributeMap attributeMap; + + private TensorType.Value valueTypeOfTensor = TensorType.Value.DOUBLE; + private double valueToFillWith = 0.0; + + + public ConstantOfShape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + + Optional<Value> value = attributeMap.get("value"); + if (value.isPresent()) { + Tensor t = value.get().asTensor(); + valueTypeOfTensor = t.type().valueType(); + valueToFillWith = t.valueIterator().next(); + } + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(1)) return null; + + IntermediateOperation input = inputs.get(0); + if (input.getConstantValue().isEmpty()) { + throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a constant."); + } + Tensor shape = input.getConstantValue().get().asTensor(); + if (shape.type().dimensions().size() > 1) { + throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a tensor with 0 or 1 dimensions."); + } + + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueTypeOfTensor); + Iterator<Double> iter = shape.valueIterator(); + for (int i = 0; iter.hasNext(); i++) { + builder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, iter.next().longValue())); + } + return builder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputTypesPresent(1)) return null; + ExpressionNode valueExpr = new ConstantNode(new DoubleValue(valueToFillWith)); + TensorFunction function = Generate.bound(type.type(), wrapScalar(valueExpr)); + return function; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public ConstantOfShape withInputs(List<IntermediateOperation> inputs) { + return new ConstantOfShape(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "ConstantOfShape"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java new file mode 100644 index 00000000000..30a7bc3bbad --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java @@ -0,0 +1,122 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class Expand extends IntermediateOperation { + + public Expand(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) return null; + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + Optional<Value> shapeValue = inputs.get(1).getConstantValue(); + if (shapeValue.isEmpty()) + throw new IllegalArgumentException("Expand " + name + ": shape must be a constant."); + + Tensor shape = shapeValue.get().asTensor(); + if (shape.type().rank() != 1) + throw new IllegalArgumentException("Expand " + name + ": shape must be a 1-d tensor."); + + OrderedTensorType inputType = inputs.get(0).type().get(); + + int inputRank = inputType.rank(); + int shapeSize = shape.type().dimensions().get(0).size().get().intValue(); + int sizeDiff = shapeSize - inputRank; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(inputType.type().valueType()); + Iterator<Double> iter = shape.valueIterator(); + + // Add any extra dimensions + for (int i = 0; i < sizeDiff; ++i) { + typeBuilder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, iter.next().intValue())); + } + + // Dimensions are matched innermost + for (int i = sizeDiff; i < shapeSize; i++) { + int shapeDimSize = iter.next().intValue(); + int inputDimSize = inputType.dimensions().get(i - sizeDiff).size().get().intValue(); + if (shapeDimSize != inputDimSize && shapeDimSize != 1 && inputDimSize != 1) { + throw new IllegalArgumentException("Expand " + name + ": dimension sizes of input and shape " + + "are not compatible. Either they must be equal or one must be of size 1."); + } + int dimSize = Math.max(shapeDimSize, inputDimSize); + typeBuilder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, dimSize)); + } + + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(2)) return null; + + IntermediateOperation input = inputs.get(0); + OrderedTensorType inputType = input.type().get(); + OrderedTensorType type = type().get(); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + int sizeDiff = type().get().rank() - inputType.rank(); + for (int i = sizeDiff; i < type().get().rank(); ++i) { + String inputDimensionName = inputType.dimensions().get(i - sizeDiff).name(); + String typeDimensionName = type.dimensionNames().get(i); + long inputDimensionSize = inputType.dimensions().get(i - sizeDiff).size().get(); + + ExpressionNode index; + if (inputDimensionSize == 1) { + index = new ConstantNode(new DoubleValue(0.0)); + } else { + index = new EmbracedNode(new ReferenceNode(typeDimensionName)); + } + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(index))); + } + + TensorFunction<Reference> externalRef = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(externalRef, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + return Generate.bound(type.type(), wrapScalar(sliceExpression)); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public Expand withInputs(List<IntermediateOperation> inputs) { + return new Expand(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Expand"; } + +} + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java index 2a34ae53d5e..91ff5d9cdd8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java @@ -105,6 +105,9 @@ public class Gather extends IntermediateOperation { private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) { TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(referenceName)); + if (dimensionValues.isEmpty()) { + return new TensorFunctionNode(inputIndices); + } Slice<Reference> sliceIndices = new Slice<>(inputIndices, dimensionValues); return new TensorFunctionNode(sliceIndices); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java new file mode 100644 index 00000000000..3c5ddf48cfc --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java @@ -0,0 +1,91 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; + +public class OnnxConstant extends IntermediateOperation { + + private final AttributeMap attributeMap; + private final Value value; + + public OnnxConstant(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + this.value = value(); + setConstantValueFunction(type -> new TensorValue(this.value.asTensor())); + } + + @Override + protected OrderedTensorType lazyGetType() { + OrderedTensorType type; + if (value instanceof TensorValue) { + type = OrderedTensorType.fromSpec(value.type().toString()).rename(vespaName() + "_"); + } else { + type = OrderedTensorType.fromDimensionList(TensorType.Value.DOUBLE, Collections.emptyList()); + } + return type; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public Optional<Value> getConstantValue() { + return Optional.of(new TensorValue(value.asTensor().withType(type().get().type()))); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public boolean isConstant() { + return true; + } + + @Override + public OnnxConstant withInputs(List<IntermediateOperation> inputs) { + return new OnnxConstant(modelName(), name(), inputs, attributeMap); + } + + @Override + public String operationName() { return "Constant"; } + + @Override + public String toString() { + return "Constant(" + type + ")"; + } + + @Override + public String toFullString() { + return "\t" + type + ":\tConstant(" + type + ")"; + } + + private Value value() { + Optional<Value> value = attributeMap.get("value"); + if (value.isEmpty()) { + value = attributeMap.get("value_float"); + if (value.isEmpty()) { + value = attributeMap.get("value_int"); + } + } + if (value.isEmpty()) { + throw new IllegalArgumentException("Node '" + name + "' of type " + + "constant has missing or non-supported 'value' attribute"); + } + return value.get(); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java new file mode 100644 index 00000000000..6df686cf910 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java @@ -0,0 +1,86 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class Range extends IntermediateOperation { + + private double start; + private double limit; + private double delta; + private long elements; + + public Range(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + private double getConstantInput(int index, String name) { + IntermediateOperation input = inputs.get(index); + if (input.getConstantValue().isEmpty()) { + throw new IllegalArgumentException("Range: " + name + " input must be a constant."); + } + Tensor value = input.getConstantValue().get().asTensor(); + if ( ! input.getConstantValue().get().hasDouble()) { + throw new IllegalArgumentException("Range: " + name + " input must be a scalar."); + } + return value.asDouble(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if ( ! allInputTypesPresent(3)) return null; + + start = getConstantInput(0, "start"); // must be constant because we need to know type + limit = getConstantInput(1, "limit"); + delta = getConstantInput(2, "delta"); + elements = (long) Math.ceil((limit - start) / delta); + + OrderedTensorType type = new OrderedTensorType.Builder() + .add(TensorType.Dimension.indexed(vespaName(), elements)) + .build(); + return type; + } + + @Override + protected TensorFunction lazyGetFunction() { + if ( ! allInputTypesPresent(3)) return null; + String dimensionName = type().get().dimensionNames().get(0); + ExpressionNode startExpr = new ConstantNode(new DoubleValue(start)); + ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta)); + ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName)); + ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr); + ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr); + TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr)); + return function; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + addConstraintsFrom(type, renamer); + } + + @Override + public Range withInputs(List<IntermediateOperation> inputs) { + return new Range(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Range"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java index 8696d0f1858..69283f10711 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -78,14 +78,12 @@ public class Select extends IntermediateOperation { List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions(); - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); - // These tensors should have the same dimension names - renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this); - renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this); + for (int i = 0; i < aDimensions.size(); ++i) { + String aDim = aDimensions.get(i).name(); + String bDim = bDimensions.get(i).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this); + } } @Override diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index d5dff7fb1b7..7b9868d71f5 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; @@ -66,6 +67,9 @@ public class OnnxOperationsTestCase { assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x)); assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f)); + x = evaluate("tensor(d0[7]):[-40.0, -0.5, -0.1, 0.0, 0.1, 0.5, 40.0]"); + assertEval("erf", x, evaluate("erf(x)", x)); + x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]"); assertEval("log", x, evaluate("log(x)", x)); assertEval("sqrt", x, evaluate("sqrt(x)", x)); @@ -405,9 +409,14 @@ public class OnnxOperationsTestCase { @Test public void testGather1() throws ParseException { - // 1 dim input, 1 dim indices + // 1 dim input, 0 dim indices Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); - Tensor y = evaluate("tensor(d0[3]):[0,2,4]"); + Tensor y = evaluate("tensor():[0]"); + assertEval("gather", x, y, evaluate("tensor():[1]")); + + // 1 dim input, 1 dim indices + x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + y = evaluate("tensor(d0[3]):[0,2,4]"); assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,3,5]")); // 2 dim input, 1 dim indices - axis 0 @@ -533,6 +542,43 @@ public class OnnxOperationsTestCase { assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2); } + @Test + public void testRange11() throws ParseException { + Tensor start = evaluate("tensor():[3]"); + Tensor limit = evaluate("tensor():[9]"); + Tensor delta = evaluate("tensor():[3]"); + assertEval("range", start, limit, delta, evaluate("tensor(d0[2]):[3,6]")); + + start = evaluate("tensor():[10]"); + limit = evaluate("tensor():[4]"); + delta = evaluate("tensor():[-2]"); + assertEval("range", start, limit, delta, evaluate("tensor(d0[3]):[10,8,6]")); + assertEval("range", start, limit, delta, evaluate("tensor(d0[3]):[10,8,6]")); + } + + @Test + public void testConstant12() throws ParseException { + assertEval("constant", evaluate("tensor(d0[3]):[1,2,3]"), createAttribute("value", evaluate("tensor(d0[3]):[1,2,3]"))); + assertEval("constant", evaluate("tensor<float>():[313.0]"), createAttribute("value_float", 313.0f)); + assertEval("constant", evaluate("tensor():[42]"), createAttribute("value_int", 42)); + } + + @Test + public void testConstantOfShape9() throws ParseException { + Tensor shape = evaluate("tensor(d0[3]):[1,2,3]"); + assertEval("constantofshape", shape, evaluate("tensor(d0[1],d1[2],d2[3]):[0,0,0,0,0,0]")); + assertEval("constantofshape", shape, evaluate("tensor<float>(d0[1],d1[2],d2[3]):[1,1,1,1,1,1]"), createAttribute("value", evaluate("tensor<float>(d0[1]):[1]"))); + } + + @Test + public void testExpand8() throws ParseException { + Tensor input = evaluate("tensor(d0[3],d1[1]):[1,2,3]"); + Tensor shape = evaluate("tensor(d0[2]):[3,4]"); + assertEval("expand", input, shape, evaluate("tensor(d0[3],d1[4]):[1,1,1,1,2,2,2,2,3,3,3,3]")); + shape = evaluate("tensor(d0[3]):[2,1,4]"); + assertEval("expand", input, shape, evaluate("tensor(d0[2],d1[3],d2[4]):[1,1,1,1,2,2,2,2,3,3,3,3,1,1,1,1,2,2,2,2,3,3,3,3]")); + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -558,6 +604,10 @@ public class OnnxOperationsTestCase { return renameToStandardType(op, tensor); } + private void assertEval(String opName, Tensor expected, AttributeConverter attr) { + assertEval(opName, null, null, null, null, null, expected, attr, 0); + } + private void assertEval(String opName, Tensor x, Tensor expected) { assertEval(opName, x, null, null, null, null, expected, null, 0); } @@ -667,6 +717,10 @@ public class OnnxOperationsTestCase { return new Attributes().attr(name, vals).build(); } + static AttributeConverter createAttribute(String name, Tensor val) { + return new Attributes().attr(name, val).build(); + } + static Attributes createAttributes() { return new Attributes(); } @@ -700,9 +754,14 @@ public class OnnxOperationsTestCase { Attributes attr(String name, Tensor tensor) { Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder(); - builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);; tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get())); - tensor.valueIterator().forEachRemaining(builder::addDoubleData); + if (tensor.type().valueType() == TensorType.Value.FLOAT) { + builder.setDataType(Onnx.TensorProto.DataType.FLOAT); + tensor.valueIterator().forEachRemaining(d -> builder.addFloatData(d.floatValue())); + } else { + builder.setDataType(Onnx.TensorProto.DataType.DOUBLE); + tensor.valueIterator().forEachRemaining(builder::addDoubleData); + } Onnx.TensorProto val = builder.build(); nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(TENSOR).setT(val).build()); return this; diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 6bce791914c..c22d906e2b2 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1024,6 +1024,7 @@ "public static final int SQRT", "public static final int TAN", "public static final int TANH", + "public static final int ERF", "public static final int ATAN2", "public static final int FMOD", "public static final int LDEXP", @@ -1373,6 +1374,7 @@ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function sqrt", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function tan", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function tanh", + "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function erf", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function atan2", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function fmod", "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function ldexp", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index c3c1c371a68..99afb3b38d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.tensor.functions.ScalarFunctions; + import java.io.Serializable; import static java.lang.Math.*; @@ -36,6 +38,7 @@ public enum Function implements Serializable { sqrt { public double evaluate(double x, double y) { return sqrt(x); } }, tan { public double evaluate(double x, double y) { return tan(x); } }, tanh { public double evaluate(double x, double y) { return tanh(x); } }, + erf { public double evaluate(double x, double y) { return ScalarFunctions.Erf.erf(x); } }, atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } }, fmod(2) { public double evaluate(double x, double y) { return x % y; } }, diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 8aa10bf7b34..5f27bbcbeee 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -115,6 +115,7 @@ TOKEN : <SQRT: "sqrt"> | <TAN: "tan"> | <TANH: "tanh"> | + <ERF: "erf"> | <ATAN2: "atan2"> | <FMOD: "fmod"> | @@ -727,7 +728,8 @@ Function unaryFunctionName() : { } <SQUARE> { return Function.square; } | <SQRT> { return Function.sqrt; } | <TAN> { return Function.tan; } | - <TANH> { return Function.tanh; } + <TANH> { return Function.tanh; } | + <ERF> { return Function.erf; } } Function binaryFunctionName() : { } diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index 7a4c6c9e56a..0a8b59c7d7e 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -33,8 +33,8 @@ using search::AttributeGuard; using search::AttributeVector; using search::attribute::DistanceMetric; using search::attribute::HnswIndexParams; -using search::queryeval::NearestNeighborBlueprint; using search::queryeval::GlobalFilter; +using search::queryeval::NearestNeighborBlueprint; using search::tensor::DefaultNearestNeighborIndexFactory; using search::tensor::DenseTensorAttribute; using search::tensor::DocVectorAccess; @@ -44,6 +44,7 @@ using search::tensor::HnswNode; using search::tensor::NearestNeighborIndex; using search::tensor::NearestNeighborIndexFactory; using search::tensor::NearestNeighborIndexSaver; +using search::tensor::PrepareResult; using search::tensor::TensorAttribute; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; @@ -97,6 +98,12 @@ public: } }; +class MockPrepareResult : public PrepareResult { +public: + uint32_t docid; + MockPrepareResult(uint32_t docid_in) : docid(docid_in) {} +}; + class MockNearestNeighborIndex : public NearestNeighborIndex { private: using Entry = std::pair<uint32_t, DoubleVector>; @@ -105,6 +112,8 @@ private: const DocVectorAccess& _vectors; EntryVector _adds; EntryVector _removes; + mutable EntryVector _prepare_adds; + EntryVector _complete_adds; generation_t _transfer_gen; generation_t _trim_gen; mutable size_t _memory_usage_cnt; @@ -115,6 +124,8 @@ public: : _vectors(vectors), _adds(), _removes(), + _prepare_adds(), + _complete_adds(), _transfer_gen(std::numeric_limits<generation_t>::max()), _trim_gen(std::numeric_limits<generation_t>::max()), _memory_usage_cnt(0), @@ -124,6 +135,8 @@ public: void clear() { _adds.clear(); _removes.clear(); + _prepare_adds.clear(); + _complete_adds.clear(); } int get_index_value() const { return _index_value; @@ -134,10 +147,13 @@ public: void expect_empty_add() const { EXPECT_TRUE(_adds.empty()); } + void expect_entry(uint32_t exp_docid, const DoubleVector& exp_vector, const EntryVector& entries) const { + EXPECT_EQUAL(1u, entries.size()); + EXPECT_EQUAL(exp_docid, entries.back().first); + EXPECT_EQUAL(exp_vector, entries.back().second); + } void expect_add(uint32_t exp_docid, const DoubleVector& exp_vector) const { - EXPECT_EQUAL(1u, _adds.size()); - EXPECT_EQUAL(exp_docid, _adds.back().first); - EXPECT_EQUAL(exp_vector, _adds.back().second); + expect_entry(exp_docid, exp_vector, _adds); } void expect_adds(const EntryVector &exp_adds) const { EXPECT_EQUAL(exp_adds, _adds); @@ -146,9 +162,13 @@ public: EXPECT_TRUE(_removes.empty()); } void expect_remove(uint32_t exp_docid, const DoubleVector& exp_vector) const { - EXPECT_EQUAL(1u, _removes.size()); - EXPECT_EQUAL(exp_docid, _removes.back().first); - EXPECT_EQUAL(exp_vector, _removes.back().second); + expect_entry(exp_docid, exp_vector, _removes); + } + void expect_prepare_add(uint32_t exp_docid, const DoubleVector& exp_vector) const { + expect_entry(exp_docid, exp_vector, _prepare_adds); + } + void expect_complete_add(uint32_t exp_docid, const DoubleVector& exp_vector) const { + expect_entry(exp_docid, exp_vector, _complete_adds); } generation_t get_transfer_gen() const { return _transfer_gen; } generation_t get_trim_gen() const { return _trim_gen; } @@ -158,14 +178,21 @@ public: auto vector = _vectors.get_vector(docid).typify<double>(); _adds.emplace_back(docid, DoubleVector(vector.begin(), vector.end())); } - std::unique_ptr<search::tensor::PrepareResult> prepare_add_document(uint32_t, - vespalib::tensor::TypedCells, - vespalib::GenerationHandler::Guard) const override { - return std::unique_ptr<search::tensor::PrepareResult>(); + std::unique_ptr<PrepareResult> prepare_add_document(uint32_t docid, + vespalib::tensor::TypedCells vector, + vespalib::GenerationHandler::Guard guard) const override { + (void) guard; + auto d_vector = vector.typify<double>(); + _prepare_adds.emplace_back(docid, DoubleVector(d_vector.begin(), d_vector.end())); + return std::make_unique<MockPrepareResult>(docid); } void complete_add_document(uint32_t docid, - std::unique_ptr<search::tensor::PrepareResult>) override { - add_document(docid); + std::unique_ptr<PrepareResult> prepare_result) override { + auto* mock_result = dynamic_cast<MockPrepareResult*>(prepare_result.get()); + assert(mock_result); + EXPECT_EQUAL(docid, mock_result->docid); + auto vector = _vectors.get_vector(docid).typify<double>(); + _complete_adds.emplace_back(docid, DoubleVector(vector.begin(), vector.end())); } void remove_document(uint32_t docid) override { auto vector = _vectors.get_vector(docid).typify<double>(); @@ -342,6 +369,16 @@ struct Fixture { set_tensor_internal(docid, *createTensor(spec)); } + std::unique_ptr<PrepareResult> prepare_set_tensor(uint32_t docid, const TensorSpec& spec) const { + return _tensorAttr->prepare_set_tensor(docid, *createTensor(spec)); + } + + void complete_set_tensor(uint32_t docid, const TensorSpec& spec, std::unique_ptr<PrepareResult> prepare_result) { + ensureSpace(docid); + _tensorAttr->complete_set_tensor(docid, *createTensor(spec), std::move(prepare_result)); + _attr->commit(); + } + void set_empty_tensor(uint32_t docid) { set_tensor_internal(docid, *_tensorAttr->getEmptyTensor()); } @@ -687,6 +724,30 @@ TEST_F("setTensor() updates nearest neighbor index", DenseTensorAttributeMockInd index.expect_add(1, {7, 9}); } +TEST_F("nearest neighbor index can be updated in two phases", DenseTensorAttributeMockIndex) +{ + auto& index = f.mock_index(); + { + auto vec_a = vec_2d(3, 5); + auto prepare_result = f.prepare_set_tensor(1, vec_a); + index.expect_prepare_add(1, {3, 5}); + f.complete_set_tensor(1, vec_a, std::move(prepare_result)); + f.assertGetTensor(vec_a, 1); + index.expect_complete_add(1, {3, 5}); + } + index.clear(); + { + // Replaces previous value. + auto vec_b = vec_2d(7, 9); + auto prepare_result = f.prepare_set_tensor(1, vec_b); + index.expect_prepare_add(1, {7, 9}); + f.complete_set_tensor(1, vec_b, std::move(prepare_result)); + index.expect_remove(1, {3, 5}); + f.assertGetTensor(vec_b, 1); + index.expect_complete_add(1, {7, 9}); + } +} + TEST_F("clearDoc() updates nearest neighbor index", DenseTensorAttributeMockIndex) { auto& index = f.mock_index(); diff --git a/searchlib/src/tests/hitcollector/hitcollector_test.cpp b/searchlib/src/tests/hitcollector/hitcollector_test.cpp index 31a24d2a8f1..2274314c7da 100644 --- a/searchlib/src/tests/hitcollector/hitcollector_test.cpp +++ b/searchlib/src/tests/hitcollector/hitcollector_test.cpp @@ -55,7 +55,7 @@ void checkResult(const ResultSet & rs, const std::vector<RankedHit> & exp) for (uint32_t i = 0; i < exp.size(); ++i) { EXPECT_EQUAL(rh[i]._docId, exp[i]._docId); - EXPECT_EQUAL(rh[i]._rankValue, exp[i]._rankValue); + EXPECT_EQUAL(rh[i]._rankValue + 1.0, exp[i]._rankValue + 1.0); } } else { ASSERT_TRUE(rs.getArray() == nullptr); @@ -328,7 +328,7 @@ TEST("testScaling") { finalScores[3] = 300; finalScores[4] = 400; - testScaling(initScores, std::move(finalScores), exp); + TEST_DO(testScaling(initScores, std::move(finalScores), exp)); } { // scale down and adjust up exp[0]._rankValue = 200; // scaled @@ -342,7 +342,7 @@ TEST("testScaling") { finalScores[3] = 500; finalScores[4] = 600; - testScaling(initScores, std::move(finalScores), exp); + TEST_DO(testScaling(initScores, std::move(finalScores), exp)); } { // scale up and adjust down @@ -357,7 +357,7 @@ TEST("testScaling") { finalScores[3] = 3250; finalScores[4] = 4500; - testScaling(initScores, std::move(finalScores), exp); + TEST_DO(testScaling(initScores, std::move(finalScores), exp)); } { // minimal scale (second phase range = 0 (4 - 4) -> 1) exp[0]._rankValue = 1; // scaled @@ -371,7 +371,7 @@ TEST("testScaling") { finalScores[3] = 4; finalScores[4] = 4; - testScaling(initScores, std::move(finalScores), exp); + TEST_DO(testScaling(initScores, std::move(finalScores), exp)); } { // minimal scale (first phase range = 0 (4000 - 4000) -> 1) std::vector<feature_t> is(initScores); @@ -387,7 +387,7 @@ TEST("testScaling") { finalScores[3] = 400; finalScores[4] = 500; - testScaling(is, std::move(finalScores), exp); + TEST_DO(testScaling(is, std::move(finalScores), exp)); } } diff --git a/searchlib/src/tests/tensor/hnsw_index/.gitignore b/searchlib/src/tests/tensor/hnsw_index/.gitignore new file mode 100644 index 00000000000..bc9bc27160b --- /dev/null +++ b/searchlib/src/tests/tensor/hnsw_index/.gitignore @@ -0,0 +1 @@ +/mt_stress_hnsw_app diff --git a/searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt b/searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt index 04c7312a63f..b6a87502fdf 100644 --- a/searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt +++ b/searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt @@ -7,3 +7,11 @@ vespa_add_executable(searchlib_hnsw_index_test_app TEST gtest ) vespa_add_test(NAME searchlib_hnsw_index_test_app COMMAND searchlib_hnsw_index_test_app) + +vespa_add_executable(mt_stress_hnsw_app TEST + SOURCES + stress_hnsw_mt.cpp + DEPENDS + searchlib + gtest +) diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index 33ea0f2df5b..7dc0efc106d 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -616,8 +616,8 @@ TEST_F(TwoPhaseTest, two_phase_add) complete_add(7, std::move(up)); // 1 filtered out because it was removed - // TODO: 5 filtered out because it was updated - expect_levels(7, {{2}, {4,5}}); + // 5 filtered out because it was updated + expect_levels(7, {{2}, {4}}); } diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp new file mode 100644 index 00000000000..4dec9550f6f --- /dev/null +++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp @@ -0,0 +1,348 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> +#include <stdio.h> +#include <unistd.h> +#include <chrono> +#include <cstdlib> +#include <future> +#include <vector> + +#include <vespa/eval/tensor/dense/typed_cells.h> +#include <vespa/searchlib/common/bitvector.h> +#include <vespa/searchlib/tensor/distance_functions.h> +#include <vespa/searchlib/tensor/doc_vector_access.h> +#include <vespa/searchlib/tensor/hnsw_index.h> +#include <vespa/searchlib/tensor/inv_log_level_generator.h> +#include <vespa/searchlib/tensor/random_level_generator.h> +#include <vespa/vespalib/data/input.h> +#include <vespa/vespalib/data/memory_input.h> +#include <vespa/vespalib/data/slime/slime.h> +#include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/util/blockingthreadstackexecutor.h> +#include <vespa/vespalib/util/generationhandler.h> +#include <vespa/vespalib/util/lambdatask.h> + +#include <vespa/log/log.h> +LOG_SETUP("stress_hnsw_mt"); + +using namespace search::tensor; +using namespace vespalib::slime; +using search::BitVector; +using vespalib::GenerationHandler; +using vespalib::MemoryUsage; +using vespalib::Slime; + +#define NUM_DIMS 128 +#define NUM_POSSIBLE_V 1000000 +#define NUM_POSSIBLE_DOCS 30000 +#define NUM_OPS 1000000 + +class RndGen { +private: + std::mt19937_64 urng; + std::uniform_real_distribution<double> uf; +public: + RndGen() : urng(0x1234deadbeef5678uLL), uf(0.0, 1.0) {} + + double nextUniform() { + return uf(urng); + } +}; + +using ConstVectorRef = vespalib::ConstArrayRef<float>; + +struct MallocPointVector { + float v[NUM_DIMS]; + operator ConstVectorRef() const { return ConstVectorRef(v, NUM_DIMS); } +}; +static MallocPointVector *aligned_alloc_pv(size_t num) { + size_t num_bytes = num * sizeof(MallocPointVector); + double mega_bytes = num_bytes / (1024.0*1024.0); + fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes); + char *mem = (char *)malloc(num_bytes + 512); + mem += 512; + size_t val = (size_t)mem; + size_t unalign = val % 512; + mem -= unalign; + return reinterpret_cast<MallocPointVector *>(mem); +} + +void read_vector_file(MallocPointVector *p) { + std::string data_set = "sift"; + std::string data_dir = "."; + char *home = getenv("HOME"); + if (home) { + data_dir = home; + data_dir += "/" + data_set; + } + std::string fn = data_dir + "/" + data_set + "_base.fvecs"; + int fd = open(fn.c_str(), O_RDONLY); + if (fd < 0) { + perror(fn.c_str()); + exit(1); + } + int d; + size_t rv; + fprintf(stderr, "reading %u vectors from %s\n", NUM_POSSIBLE_V, fn.c_str()); + for (uint32_t i = 0; i < NUM_POSSIBLE_V; ++i) { + rv = read(fd, &d, 4); + ASSERT_EQ(rv, 4u); + ASSERT_EQ(d, NUM_DIMS); + rv = read(fd, &p[i].v, NUM_DIMS*sizeof(float)); + ASSERT_EQ(rv, sizeof(MallocPointVector)); + } + close(fd); + fprintf(stderr, "reading %u vectors OK\n", NUM_POSSIBLE_V); +} + +class MyDocVectorStore : public DocVectorAccess { +private: + MallocPointVector *_vectors; +public: + MyDocVectorStore() { + _vectors = aligned_alloc_pv(NUM_POSSIBLE_DOCS); + } + MyDocVectorStore& set(uint32_t docid, ConstVectorRef vec) { + assert(docid < NUM_POSSIBLE_DOCS); + memcpy(&_vectors[docid], vec.cbegin(), sizeof(MallocPointVector)); + return *this; + } + vespalib::tensor::TypedCells get_vector(uint32_t docid) const override { + assert(docid < NUM_POSSIBLE_DOCS); + ConstVectorRef ref(_vectors[docid]); + return vespalib::tensor::TypedCells(ref); + } +}; + +using FloatSqEuclideanDistance = SquaredEuclideanDistance<float>; +using HnswIndexUP = std::unique_ptr<HnswIndex>; + +class Stressor : public ::testing::Test { +private: + struct LoadedVectors { + MallocPointVector *pv_storage; + void load() { + pv_storage = aligned_alloc_pv(size()); + read_vector_file(pv_storage); + } + size_t size() const { return NUM_POSSIBLE_V; } + vespalib::ConstArrayRef<float> operator[] (size_t i) { + return pv_storage[i]; + } + } loaded_vectors; +public: + BitVector::UP in_progress; + std::mutex in_progress_lock; + BitVector::UP existing_ids; + RndGen rng; + MyDocVectorStore vectors; + GenerationHandler gen_handler; + HnswIndexUP index; + vespalib::BlockingThreadStackExecutor multi_prepare_workers; + vespalib::BlockingThreadStackExecutor write_thread; + + using PrepUP = std::unique_ptr<PrepareResult>; + using ReadGuard = GenerationHandler::Guard; + using PrepareFuture = std::future<PrepUP>; + + // union of data required by tasks + struct TaskBase : vespalib::Executor::Task { + Stressor &parent; + uint32_t docid; + ConstVectorRef vec; + PrepareFuture prepare_future; + ReadGuard read_guard; + + TaskBase(Stressor &p, uint32_t d, ConstVectorRef v, PrepareFuture f, ReadGuard g) + : parent(p), docid(d), vec(v), prepare_future(std::move(f)), read_guard(g) + {} + TaskBase(Stressor &p, uint32_t d, ConstVectorRef v, ReadGuard g) // prepare add + : TaskBase(p, d, v, PrepareFuture(), g) {} + TaskBase(Stressor &p, uint32_t d, ConstVectorRef v, PrepareFuture r) // complete add+update + : TaskBase(p, d, v, std::move(r), ReadGuard()) {} + TaskBase(Stressor &p, uint32_t d) // complete remove + : TaskBase(p, d, ConstVectorRef(), PrepareFuture(), ReadGuard()) {} + + ~TaskBase() {} + }; + + struct PrepareAddTask : TaskBase { + using TaskBase::TaskBase; + std::promise<PrepUP> result_promise; + auto get_result_future() { + return result_promise.get_future(); + } + void run() override { + auto v = vespalib::tensor::TypedCells(vec); + auto up = parent.index->prepare_add_document(docid, v, read_guard); + result_promise.set_value(std::move(up)); + } + }; + + struct CompleteAddTask : TaskBase { + using TaskBase::TaskBase; + void run() override { + auto prepare_result = prepare_future.get(); + parent.vectors.set(docid, vec); + parent.index->complete_add_document(docid, std::move(prepare_result)); + parent.existing_ids->setBit(docid); + parent.commit(docid); + } + }; + + struct CompleteRemoveTask : TaskBase { + using TaskBase::TaskBase; + void run() override { + parent.index->remove_document(docid); + parent.existing_ids->clearBit(docid); + parent.commit(docid); + } + }; + + struct CompleteUpdateTask : TaskBase { + using TaskBase::TaskBase; + void run() override { + auto prepare_result = prepare_future.get(); + parent.index->remove_document(docid); + parent.vectors.set(docid, vec); + parent.index->complete_add_document(docid, std::move(prepare_result)); + EXPECT_EQ(parent.existing_ids->testBit(docid), true); + parent.commit(docid); + } + }; + + Stressor() + : loaded_vectors(), + in_progress(BitVector::create(NUM_POSSIBLE_DOCS)), + existing_ids(BitVector::create(NUM_POSSIBLE_DOCS)), + rng(), + vectors(), + gen_handler(), + index(), + multi_prepare_workers(10, 128*1024, 50), + write_thread(1, 128*1024, 500) + { + loaded_vectors.load(); + } + + ~Stressor() {} + + void init() { + uint32_t m = 16; + index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(), + std::make_unique<InvLogLevelGenerator>(m), + HnswIndex::Config(2*m, m, 200, true)); + } + size_t get_rnd(size_t size) { + return rng.nextUniform() * size; + } + void add_document(uint32_t docid) { + size_t vec_num = get_rnd(loaded_vectors.size()); + ConstVectorRef vec = loaded_vectors[vec_num]; + auto guard = take_read_guard(); + auto prepare_task = std::make_unique<PrepareAddTask>(*this, docid, vec, guard); + auto complete_task = std::make_unique<CompleteAddTask>(*this, docid, vec, prepare_task->get_result_future()); + auto r = multi_prepare_workers.execute(std::move(prepare_task)); + ASSERT_EQ(r.get(), nullptr); + r = write_thread.execute(std::move(complete_task)); + ASSERT_EQ(r.get(), nullptr); + } + void remove_document(uint32_t docid) { + auto task = std::make_unique<CompleteRemoveTask>(*this, docid); + auto r = write_thread.execute(std::move(task)); + ASSERT_EQ(r.get(), nullptr); + } + void update_document(uint32_t docid) { + size_t vec_num = get_rnd(loaded_vectors.size()); + ConstVectorRef vec = loaded_vectors[vec_num]; + auto guard = take_read_guard(); + auto prepare_task = std::make_unique<PrepareAddTask>(*this, docid, vec, guard); + auto complete_task = std::make_unique<CompleteUpdateTask>(*this, docid, vec, prepare_task->get_result_future()); + auto r = multi_prepare_workers.execute(std::move(prepare_task)); + ASSERT_EQ(r.get(), nullptr); + r = write_thread.execute(std::move(complete_task)); + ASSERT_EQ(r.get(), nullptr); + } + void commit(uint32_t docid) { + index->transfer_hold_lists(gen_handler.getCurrentGeneration()); + gen_handler.incGeneration(); + gen_handler.updateFirstUsedGeneration(); + index->trim_hold_lists(gen_handler.getFirstUsedGeneration()); + std::lock_guard<std::mutex> guard(in_progress_lock); + in_progress->clearBit(docid); + // printf("commit: %u\n", docid); + } + void gen_operation() { + uint32_t docid = get_rnd(NUM_POSSIBLE_DOCS); + { + std::lock_guard<std::mutex> guard(in_progress_lock); + while (in_progress->testBit(docid)) { + docid = get_rnd(NUM_POSSIBLE_DOCS); + } + in_progress->setBit(docid); + } + if (existing_ids->testBit(docid)) { + if (get_rnd(100) < 70) { + // printf("start remove op: %u\n", docid); + remove_document(docid); + } else { + // printf("start update op: %u\n", docid); + update_document(docid); + } + } else { + // printf("start add op: %u\n", docid); + add_document(docid); + } + } + GenerationHandler::Guard take_read_guard() { + return gen_handler.takeGuard(); + } + MemoryUsage memory_usage() const { + return index->memory_usage(); + } + uint32_t count_in_progress() { + std::lock_guard<std::mutex> guard(in_progress_lock); + in_progress->invalidateCachedCount(); + return in_progress->countTrueBits(); + } + std::string json_state() { + Slime actualSlime; + SlimeInserter inserter(actualSlime); + index->get_state(inserter); + vespalib::SimpleBuffer buf; + vespalib::slime::JsonFormat::encode(actualSlime, buf, false); + return buf.get().make_string(); + } +}; + + +TEST_F(Stressor, stress) +{ + init(); + for (int i = 0; i < NUM_OPS; ++i) { + gen_operation(); + if (i % 1000 == 0) { + uint32_t cnt = count_in_progress(); + fprintf(stderr, "generating operations %d / %d; in progress: %u ops\n", + i, NUM_OPS, cnt); + auto r = write_thread.execute(vespalib::makeLambdaTask([&]() { + EXPECT_TRUE(index->check_link_symmetry()); + })); + EXPECT_EQ(r.get(), nullptr); + } + } + fprintf(stderr, "waiting for queued operations...\n"); + multi_prepare_workers.sync(); + write_thread.sync(); + EXPECT_EQ(count_in_progress(), 0); + EXPECT_TRUE(index->check_link_symmetry()); + fprintf(stderr, "HNSW index state after test:\n%s\n", json_state().c_str()); + existing_ids->invalidateCachedCount(); + fprintf(stderr, "Expected valid nodes: %u\n", existing_ids->countTrueBits()); + fprintf(stderr, "all done.\n"); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp index bcc886fccad..2db6437664e 100644 --- a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp @@ -37,7 +37,7 @@ using V = std::vector<uint32_t>; void populate(HnswGraph &graph) { // no 0 graph.make_node_for_document(1, 1); - graph.make_node_for_document(2, 2); + auto er = graph.make_node_for_document(2, 2); // no 3 graph.make_node_for_document(4, 2); graph.make_node_for_document(5, 0); @@ -49,7 +49,7 @@ void populate(HnswGraph &graph) { graph.set_link_array(6, 0, V{1, 2, 4}); graph.set_link_array(2, 1, V{4}); graph.set_link_array(4, 1, V{2}); - graph.set_entry_node({2, 1}); + graph.set_entry_node({2, er, 1}); } void modify(HnswGraph &graph) { @@ -63,7 +63,7 @@ void modify(HnswGraph &graph) { graph.set_link_array(4, 1, V{7}); graph.set_link_array(7, 1, V{4}); - graph.set_entry_node({4, 1}); + graph.set_entry_node({4, graph.get_node_ref(4), 1}); } diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp index c9ed4039655..76533839de7 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp @@ -5,6 +5,7 @@ #include "nearest_neighbor_index.h" #include "nearest_neighbor_index_saver.h" #include "tensor_attribute.hpp" +#include <vespa/eval/tensor/dense/dense_tensor_view.h> #include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/fastlib/io/bufferedfile.h> @@ -18,6 +19,7 @@ LOG_SETUP(".searchlib.tensor.dense_tensor_attribute"); using search::attribute::LoadUtils; using vespalib::eval::ValueType; using vespalib::slime::ObjectInserter; +using vespalib::tensor::DenseTensorView; using vespalib::tensor::MutableDenseTensorView; using vespalib::tensor::Tensor; @@ -77,6 +79,15 @@ can_use_index_save_file(const search::attribute::Config &config, const search::a } void +DenseTensorAttribute::internal_set_tensor(DocId docid, const Tensor& tensor) +{ + checkTensorType(tensor); + consider_remove_from_index(docid); + EntryRef ref = _denseTensorStore.setTensor(tensor); + setTensorRef(docid, ref); +} + +void DenseTensorAttribute::consider_remove_from_index(DocId docid) { if (_index && _refVector[docid].valid()) { @@ -126,15 +137,32 @@ DenseTensorAttribute::clearDoc(DocId docId) void DenseTensorAttribute::setTensor(DocId docId, const Tensor &tensor) { - checkTensorType(tensor); - consider_remove_from_index(docId); - EntryRef ref = _denseTensorStore.setTensor(tensor); - setTensorRef(docId, ref); + internal_set_tensor(docId, tensor); if (_index) { _index->add_document(docId); } } +std::unique_ptr<PrepareResult> +DenseTensorAttribute::prepare_set_tensor(DocId docid, const Tensor& tensor) const +{ + if (_index) { + const auto* view = dynamic_cast<const DenseTensorView*>(&tensor); + assert(view); + return _index->prepare_add_document(docid, view->cellsRef(), getGenerationHandler().takeGuard()); + } + return std::unique_ptr<PrepareResult>(); +} + +void +DenseTensorAttribute::complete_set_tensor(DocId docid, const Tensor& tensor, + std::unique_ptr<PrepareResult> prepare_result) +{ + internal_set_tensor(docid, tensor); + if (_index) { + _index->complete_add_document(docid, std::move(prepare_result)); + } +} std::unique_ptr<Tensor> DenseTensorAttribute::getTensor(DocId docId) const diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h index f0383627ea2..7fd06357114 100644 --- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h @@ -23,6 +23,7 @@ private: DenseTensorStore _denseTensorStore; std::unique_ptr<NearestNeighborIndex> _index; + void internal_set_tensor(DocId docid, const Tensor& tensor); void consider_remove_from_index(DocId docid); vespalib::MemoryUsage memory_usage() const override; @@ -33,6 +34,8 @@ public: // Implements AttributeVector and ITensorAttribute uint32_t clearDoc(DocId docId) override; void setTensor(DocId docId, const Tensor &tensor) override; + std::unique_ptr<PrepareResult> prepare_set_tensor(DocId docid, const Tensor& tensor) const override; + void complete_set_tensor(DocId docid, const Tensor& tensor, std::unique_ptr<PrepareResult> prepare_result) override; std::unique_ptr<Tensor> getTensor(DocId docId) const override; void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override; bool onLoad() override; diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp index 37e3ea1adbd..564676a2d44 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp @@ -13,13 +13,14 @@ HnswGraph::HnswGraph() links(HnswIndex::make_default_link_store_config()), entry_docid_and_level() { + node_refs.ensure_size(1, AtomicEntryRef()); EntryNode entry; set_entry_node(entry); } HnswGraph::~HnswGraph() {} -void +HnswGraph::NodeRef HnswGraph::make_node_for_document(uint32_t docid, uint32_t num_levels) { node_refs.ensure_size(docid + 1, AtomicEntryRef()); @@ -29,15 +30,23 @@ HnswGraph::make_node_for_document(uint32_t docid, uint32_t num_levels) vespalib::Array<AtomicEntryRef> levels(num_levels, AtomicEntryRef()); auto node_ref = nodes.add(levels); node_refs[docid].store_release(node_ref); + return node_ref; } void HnswGraph::remove_node_for_document(uint32_t docid) { auto node_ref = node_refs[docid].load_acquire(); - nodes.remove(node_ref); + assert(node_ref.valid()); + auto levels = nodes.get(node_ref); vespalib::datastore::EntryRef invalid; node_refs[docid].store_release(invalid); + // Ensure data referenced through the old ref can be recycled: + nodes.remove(node_ref); + for (size_t i = 0; i < levels.size(); ++i) { + auto old_links_ref = levels[i].load_acquire(); + links.remove(old_links_ref); + } } void @@ -47,6 +56,7 @@ HnswGraph::set_link_array(uint32_t docid, uint32_t level, const LinkArrayRef& ne auto node_ref = node_refs[docid].load_acquire(); assert(node_ref.valid()); auto levels = nodes.get_writable(node_ref); + assert(level < levels.size()); auto old_links_ref = levels[level].load_acquire(); levels[level].store_release(new_links_ref); links.remove(old_links_ref); diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_graph.h b/searchlib/src/vespa/searchlib/tensor/hnsw_graph.h index 125692af627..8b40eb87bae 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_graph.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_graph.h @@ -23,6 +23,7 @@ struct HnswGraph { // Provides mapping from document id -> node reference. // The reference is used to lookup the node data in NodeStore. using NodeRefVector = vespalib::RcuVector<AtomicEntryRef>; + using NodeRef = vespalib::datastore::EntryRef; // This stores the level arrays for all nodes. // Each node consists of an array of levels (from level 0 to n) where each entry is a reference to the link array at that level. @@ -45,33 +46,64 @@ struct HnswGraph { ~HnswGraph(); - void make_node_for_document(uint32_t docid, uint32_t num_levels); + NodeRef make_node_for_document(uint32_t docid, uint32_t num_levels); void remove_node_for_document(uint32_t docid); + NodeRef get_node_ref(uint32_t docid) const { + return node_refs[docid].load_acquire(); + } + + bool still_valid(uint32_t docid, NodeRef node_ref) const { + return node_ref.valid() && (get_node_ref(docid) == node_ref); + } + + LevelArrayRef get_level_array(NodeRef node_ref) const { + if (node_ref.valid()) { + return nodes.get(node_ref); + } + return LevelArrayRef(); + } + LevelArrayRef get_level_array(uint32_t docid) const { - auto node_ref = node_refs[docid].load_acquire(); - assert(node_ref.valid()); - return nodes.get(node_ref); + auto node_ref = get_node_ref(docid); + return get_level_array(node_ref); + } + + LinkArrayRef get_link_array(LevelArrayRef levels, uint32_t level) const { + if (level < levels.size()) { + auto links_ref = levels[level].load_acquire(); + if (links_ref.valid()) { + return links.get(links_ref); + } + } + return LinkArrayRef(); } LinkArrayRef get_link_array(uint32_t docid, uint32_t level) const { auto levels = get_level_array(docid); - assert(level < levels.size()); - return links.get(levels[level].load_acquire()); + return get_link_array(levels, level); + } + + LinkArrayRef get_link_array(NodeRef node_ref, uint32_t level) const { + auto levels = get_level_array(node_ref); + return get_link_array(levels, level); } - + void set_link_array(uint32_t docid, uint32_t level, const LinkArrayRef& new_links); struct EntryNode { uint32_t docid; + NodeRef node_ref; int32_t level; EntryNode() : docid(0), // Note that docid 0 is reserved and never used + node_ref(), level(-1) {} - EntryNode(uint32_t docid_in, int32_t level_in) + EntryNode(uint32_t docid_in, NodeRef node_ref_in, int32_t level_in) : docid(docid_in), + node_ref(node_ref_in), level(level_in) {} }; @@ -80,15 +112,43 @@ struct HnswGraph { uint64_t value = node.level; value <<= 32; value |= node.docid; + if (node.node_ref.valid()) { + assert(node.level >= 0); + assert(node.docid > 0); + } else { + assert(node.level == -1); + assert(node.docid == 0); + } entry_docid_and_level.store(value, std::memory_order_release); } + uint64_t get_entry_atomic() const { + return entry_docid_and_level.load(std::memory_order_acquire); + } + EntryNode get_entry_node() const { EntryNode entry; - uint64_t value = entry_docid_and_level.load(std::memory_order_acquire); - entry.docid = (uint32_t)value; - entry.level = (int32_t)(value >> 32); - return entry; + while (true) { + uint64_t value = get_entry_atomic(); + entry.docid = (uint32_t)value; + entry.node_ref = get_node_ref(entry.docid); + entry.level = (int32_t)(value >> 32); + if ((entry.docid == 0) + && (entry.level == -1) + && ! entry.node_ref.valid()) + { + // invalid in every way + return entry; + } + if ((entry.docid > 0) + && (entry.level > -1) + && entry.node_ref.valid() + && (get_entry_atomic() == value)) + { + // valid in every way + return entry; + } + } } size_t size() const { return node_refs.size(); } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index a03614a785e..36d970dfd01 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -72,10 +72,10 @@ HnswIndex::max_links_for_level(uint32_t level) const } bool -HnswIndex::have_closer_distance(HnswCandidate candidate, const LinkArrayRef& result) const +HnswIndex::have_closer_distance(HnswCandidate candidate, const HnswCandidateVector& result) const { - for (uint32_t result_docid : result) { - double dist = calc_distance(candidate.docid, result_docid); + for (const auto & neighbor : result) { + double dist = calc_distance(candidate.docid, neighbor.docid); if (dist < candidate.distance) { return true; } @@ -91,7 +91,7 @@ HnswIndex::select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_ SelectResult result; for (const auto & candidate : sorted) { if (result.used.size() < max_links) { - result.used.push_back(candidate.docid); + result.used.push_back(candidate); } else { result.unused.push_back(candidate.docid); } @@ -114,7 +114,7 @@ HnswIndex::select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint result.unused.push_back(candidate.docid); continue; } - result.used.push_back(candidate.docid); + result.used.push_back(candidate); if (result.used.size() == max_links) { while (!nearest.empty()) { candidate = nearest.top(); @@ -148,7 +148,12 @@ HnswIndex::shrink_if_needed(uint32_t docid, uint32_t level) neighbors.emplace_back(neighbor_docid, dist); } auto split = select_neighbors(neighbors, max_links); - _graph.set_link_array(docid, level, split.used); + LinkArray new_links; + new_links.reserve(split.used.size()); + for (const auto & neighbor : split.used) { + new_links.push_back(neighbor.docid); + } + _graph.set_link_array(docid, level, new_links); for (uint32_t removed_docid : split.unused) { remove_link_to(removed_docid, docid, level); } @@ -201,10 +206,13 @@ HnswIndex::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& e bool keep_searching = true; while (keep_searching) { keep_searching = false; - for (uint32_t neighbor_docid : _graph.get_link_array(nearest.docid, level)) { + for (uint32_t neighbor_docid : _graph.get_link_array(nearest.node_ref, level)) { + auto neighbor_ref = _graph.get_node_ref(neighbor_docid); double dist = calc_distance(input, neighbor_docid); - if (dist < nearest.distance) { - nearest = HnswCandidate(neighbor_docid, dist); + if (_graph.still_valid(neighbor_docid, neighbor_ref) + && dist < nearest.distance) + { + nearest = HnswCandidate(neighbor_docid, neighbor_ref, dist); keep_searching = true; } } @@ -239,16 +247,20 @@ HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find, break; } candidates.pop(); - for (uint32_t neighbor_docid : _graph.get_link_array(cand.docid, level)) { - if ((neighbor_docid >= doc_id_limit) || visited.is_marked(neighbor_docid)) { + for (uint32_t neighbor_docid : _graph.get_link_array(cand.node_ref, level)) { + auto neighbor_ref = _graph.get_node_ref(neighbor_docid); + if ((! neighbor_ref.valid()) + || (neighbor_docid >= doc_id_limit) + || visited.is_marked(neighbor_docid)) + { continue; } visited.mark(neighbor_docid); double dist_to_input = calc_distance(input, neighbor_docid); if (dist_to_input < limit_dist) { - candidates.emplace(neighbor_docid, dist_to_input); + candidates.emplace(neighbor_docid, neighbor_ref, dist_to_input); if ((!filter) || filter->testBit(neighbor_docid)) { - best_neighbors.emplace(neighbor_docid, dist_to_input); + best_neighbors.emplace(neighbor_docid, neighbor_ref, dist_to_input); if (best_neighbors.size() > neighbors_to_find) { best_neighbors.pop(); limit_dist = best_neighbors.top().distance; @@ -287,11 +299,13 @@ HnswIndex::internal_prepare_add(uint32_t docid, TypedCells input_vector) const PreparedAddDoc op(docid, level); auto entry = _graph.get_entry_node(); if (entry.docid == 0) { + // graph has no entry point return op; } int search_level = entry.level; double entry_dist = calc_distance(input_vector, entry.docid); - HnswCandidate entry_point(entry.docid, entry_dist); + // TODO: check if entry docid/node_ref is still valid here + HnswCandidate entry_point(entry.docid, entry.node_ref, entry_dist); while (search_level > op.max_level) { entry_point = find_nearest_in_layer(input_vector, entry_point, search_level); --search_level; @@ -305,22 +319,35 @@ HnswIndex::internal_prepare_add(uint32_t docid, TypedCells input_vector) const while (search_level >= 0) { search_layer(input_vector, _cfg.neighbors_to_explore_at_construction(), best_neighbors, search_level); auto neighbors = select_neighbors(best_neighbors.peek(), _cfg.max_links_on_inserts()); - auto use = neighbors.used; - op.connections[search_level].assign(use.begin(), use.end()); + op.connections[search_level].reserve(neighbors.used.size()); + for (const auto & neighbor : neighbors.used) { + auto neighbor_levels = _graph.get_level_array(neighbor.node_ref); + if (size_t(search_level) < neighbor_levels.size()) { + op.connections[search_level].emplace_back(neighbor.docid, neighbor.node_ref); + } else { + LOG(warning, "in prepare_add(%u), selected neighbor %u is missing level %d (has %zu levels)", + docid, neighbor.docid, search_level, neighbor_levels.size()); + } + } --search_level; } return op; } HnswIndex::LinkArray -HnswIndex::filter_valid_docids(const LinkArrayRef &docids) +HnswIndex::filter_valid_docids(uint32_t level, const PreparedAddDoc::Links &neighbors, uint32_t self_docid) { LinkArray valid; - valid.reserve(docids.size()); - for (uint32_t docid : docids) { - auto node_ref = _graph.node_refs[docid].load_acquire(); - if (node_ref.valid()) { - valid.push_back(docid); + valid.reserve(neighbors.size()); + for (const auto & neighbor : neighbors) { + uint32_t docid = neighbor.first; + HnswGraph::NodeRef node_ref = neighbor.second; + if (_graph.still_valid(docid, node_ref)) { + assert(docid != self_docid); + auto levels = _graph.get_level_array(node_ref); + if (level < levels.size()) { + valid.push_back(docid); + } } } return valid; @@ -329,13 +356,13 @@ HnswIndex::filter_valid_docids(const LinkArrayRef &docids) void HnswIndex::internal_complete_add(uint32_t docid, PreparedAddDoc &op) { - _graph.make_node_for_document(docid, op.max_level + 1); + auto node_ref = _graph.make_node_for_document(docid, op.max_level + 1); for (int level = 0; level <= op.max_level; ++level) { - auto neighbors = filter_valid_docids(op.connections[level]); + auto neighbors = filter_valid_docids(level, op.connections[level], docid); connect_new_node(docid, neighbors, level); } if (op.max_level > get_entry_level()) { - _graph.set_entry_node({docid, op.max_level}); + _graph.set_entry_node({docid, node_ref, op.max_level}); } } @@ -392,19 +419,18 @@ void HnswIndex::remove_document(uint32_t docid) { bool need_new_entrypoint = (docid == get_entry_docid()); - LinkArray empty; LevelArrayRef node_levels = _graph.get_level_array(docid); for (int level = node_levels.size(); level-- > 0; ) { LinkArrayRef my_links = _graph.get_link_array(docid, level); for (uint32_t neighbor_id : my_links) { if (need_new_entrypoint) { - _graph.set_entry_node({neighbor_id, level}); + auto entry_node_ref = _graph.get_node_ref(neighbor_id); + _graph.set_entry_node({neighbor_id, entry_node_ref, level}); need_new_entrypoint = false; } remove_link_to(neighbor_id, docid, level); } mutual_reconnect(my_links, level); - _graph.set_link_array(docid, level, empty); } if (need_new_entrypoint) { HnswGraph::EntryNode entry; @@ -530,12 +556,14 @@ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVecto { FurthestPriQ best_neighbors; auto entry = _graph.get_entry_node(); - if (entry.level < 0) { + if (entry.docid == 0) { + // graph has no entry point return best_neighbors; } int search_level = entry.level; double entry_dist = calc_distance(vector, entry.docid); - HnswCandidate entry_point(entry.docid, entry_dist); + // TODO: check if entry docid/node_ref is still valid here + HnswCandidate entry_point(entry.docid, entry.node_ref, entry_dist); while (search_level > 0) { entry_point = find_nearest_in_layer(vector, entry_point, search_level); --search_level; @@ -568,13 +596,13 @@ HnswIndex::set_node(uint32_t docid, const HnswNode &node) { size_t num_levels = node.size(); assert(num_levels > 0); - _graph.make_node_for_document(docid, num_levels); + auto node_ref = _graph.make_node_for_document(docid, num_levels); for (size_t level = 0; level < num_levels; ++level) { connect_new_node(docid, node.level(level), level); } int max_level = num_levels - 1; if (get_entry_level() < max_level) { - _graph.set_entry_node({docid, max_level}); + _graph.set_entry_node({docid, node_ref, max_level}); } } @@ -593,6 +621,8 @@ HnswIndex::check_link_symmetry() const auto neighbor_links = _graph.get_link_array(neighbor_docid, level); if (! has_link_to(neighbor_links, docid)) { all_sym = false; + LOG(warning, "check_link_symmetry: docid %zu links to %u on level %u, but no backlink", + docid, neighbor_docid, level); } } ++level; diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index df7453023cf..ab3eced8fdc 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -91,10 +91,11 @@ protected: * where the candidate is located. * Used by select_neighbors_heuristic(). */ - bool have_closer_distance(HnswCandidate candidate, const LinkArrayRef& curr_result) const; + bool have_closer_distance(HnswCandidate candidate, const HnswCandidateVector& curr_result) const; struct SelectResult { - LinkArray used; + HnswCandidateVector used; LinkArray unused; + ~SelectResult() {} }; SelectResult select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const; SelectResult select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const; @@ -123,7 +124,8 @@ protected: struct PreparedAddDoc : public PrepareResult { uint32_t docid; int32_t max_level; - std::vector<LinkArray> connections; + using Links = std::vector<std::pair<uint32_t, HnswGraph::NodeRef>>; + std::vector<Links> connections; PreparedAddDoc(uint32_t docid_in, int32_t max_level_in) : docid(docid_in), max_level(max_level_in), connections(max_level+1) {} @@ -131,7 +133,7 @@ protected: PreparedAddDoc(PreparedAddDoc&& other) = default; }; PreparedAddDoc internal_prepare_add(uint32_t docid, TypedCells input_vector) const; - LinkArray filter_valid_docids(const LinkArrayRef &docids); + LinkArray filter_valid_docids(uint32_t level, const PreparedAddDoc::Links &neighbors, uint32_t me); void internal_complete_add(uint32_t docid, PreparedAddDoc &op); public: HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func, diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp index 9f49c0647c6..ac98b28d105 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp @@ -39,7 +39,8 @@ HnswIndexLoader::load(const fileutil::LoadedBuffer& buf) } if (_failed) return false; _graph.node_refs.ensure_size(num_nodes); - _graph.set_entry_node({entry_docid, entry_level}); + auto entry_node_ref = _graph.get_node_ref(entry_docid); + _graph.set_entry_node({entry_docid, entry_node_ref, entry_level}); return true; } diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h index b11d3f36a7a..99266505780 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h @@ -5,6 +5,7 @@ #include <cstdint> #include <queue> #include <vector> +#include "hnsw_graph.h" namespace search::tensor { @@ -13,8 +14,12 @@ namespace search::tensor { */ struct HnswCandidate { uint32_t docid; + HnswGraph::NodeRef node_ref; double distance; - HnswCandidate(uint32_t docid_in, double distance_in) : docid(docid_in), distance(distance_in) {} + HnswCandidate(uint32_t docid_in, double distance_in) + : docid(docid_in), node_ref(), distance(distance_in) {} + HnswCandidate(uint32_t docid_in, HnswGraph::NodeRef node_ref_in, double distance_in) + : docid(docid_in), node_ref(node_ref_in), distance(distance_in) {} }; struct GreaterDistance { diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java index 03916949cae..a932ca935e0 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java @@ -84,7 +84,7 @@ public class FeedHandlerV3 extends LoggingRequestHandler { SourceSessionParams sourceSessionParams = sourceSessionParams(request); clientFeederByClientId.put(clientId, new ClientFeederV3(retainSource(sessionCache, sourceSessionParams), - new FeedReaderFactory(), + new FeedReaderFactory(true), //TODO make error debugging configurable docTypeManager, clientId, metric, diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java index 6a3229e86b7..81b08d5fb25 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.http.server; import com.yahoo.document.DocumentTypeManager; import com.yahoo.document.json.JsonFeedReader; +import com.yahoo.text.Utf8; import com.yahoo.vespa.http.client.config.FeedParams; import com.yahoo.vespaxmlparser.FeedReader; import com.yahoo.vespaxmlparser.VespaXMLFeedReader; @@ -14,6 +15,12 @@ import java.io.InputStream; * @author dybis */ public class FeedReaderFactory { + private static final int MARK_READLIMIT = 200; + + private final boolean debug; + public FeedReaderFactory(boolean debug) { + this.debug = debug; + } /** * Creates FeedReader @@ -28,10 +35,22 @@ public class FeedReaderFactory { FeedParams.DataFormat dataFormat) { switch (dataFormat) { case XML_UTF8: + byte [] peek = null; + int bytesPeeked = 0; try { + if (debug && inputStream.markSupported()) { + peek = new byte[MARK_READLIMIT]; + inputStream.mark(MARK_READLIMIT); + bytesPeeked = inputStream.read(peek); + inputStream.reset(); + } return new VespaXMLFeedReader(inputStream, docTypeManager); } catch (Exception e) { - throw new RuntimeException("Could not create VespaXMLFeedReader", e); + if (bytesPeeked > 0) { + throw new RuntimeException("Could not create VespaXMLFeedReader. First characters are: '" + Utf8.toString(peek, 0, bytesPeeked) + "'", e); + } else { + throw new RuntimeException("Could not create VespaXMLFeedReader.", e); + } } case JSON_UTF8: return new JsonFeedReader(inputStream, docTypeManager); diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java index c9f255f026e..c8ae79deebd 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java @@ -13,6 +13,7 @@ public class ByteLimitedInputStream extends InputStream { private final InputStream wrappedStream; private int remaining; + private int remainingWhenMarked; public ByteLimitedInputStream(InputStream wrappedStream, int limit) { this.wrappedStream = wrappedStream; @@ -78,4 +79,21 @@ public class ByteLimitedInputStream extends InputStream { } } + @Override + public synchronized void mark(int readlimit) { + wrappedStream.mark(readlimit); + remainingWhenMarked = remaining; + } + + @Override + public synchronized void reset() throws IOException { + wrappedStream.reset(); + remaining = remainingWhenMarked; + } + + @Override + public boolean markSupported() { + return wrappedStream.markSupported(); + } + } diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/FeedReaderFactoryTestCase.java b/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/FeedReaderFactoryTestCase.java new file mode 100644 index 00000000000..47f057013b7 --- /dev/null +++ b/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/FeedReaderFactoryTestCase.java @@ -0,0 +1,40 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.vespa.http.server; + +import com.yahoo.document.DocumentTypeManager; +import com.yahoo.text.Utf8; +import com.yahoo.vespa.http.client.config.FeedParams; +import org.junit.Test; + +import java.io.ByteArrayInputStream; +import java.io.InputStream; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class FeedReaderFactoryTestCase { + DocumentTypeManager manager = new DocumentTypeManager(); + + private InputStream createStream(String s) { + return new ByteArrayInputStream(Utf8.toBytes(s)); + } + + @Test + public void testXmlExceptionWithDebug() { + try { + new FeedReaderFactory(true).createReader(createStream("Some malformed xml"), manager, FeedParams.DataFormat.XML_UTF8); + fail(); + } catch (RuntimeException e) { + assertEquals("Could not create VespaXMLFeedReader. First characters are: 'Some malformed xml'", e.getMessage()); + } + } + @Test + public void testXmlException() { + try { + new FeedReaderFactory(false).createReader(createStream("Some malformed xml"), manager, FeedParams.DataFormat.XML_UTF8); + fail(); + } catch (RuntimeException e) { + assertEquals("Could not create VespaXMLFeedReader.", e.getMessage()); + } + } +} diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java b/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java index 3aa3cdcb3a8..3dd8145ec73 100644 --- a/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java +++ b/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java @@ -8,8 +8,8 @@ import java.io.IOException; import java.io.InputStream; import java.nio.charset.StandardCharsets; -import static org.hamcrest.core.Is.is; -import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> @@ -29,63 +29,78 @@ public class ByteLimitedInputStreamTestCase { public void requireThatBasicsWork() throws IOException { ByteLimitedInputStream stream = create("abcdefghijklmnopqr".getBytes(StandardCharsets.US_ASCII), 9); - assertThat(stream.available(), is(9)); - assertThat(stream.read(), is(97)); - assertThat(stream.available(), is(8)); - assertThat(stream.read(), is(98)); - assertThat(stream.available(), is(7)); - assertThat(stream.read(), is(99)); - assertThat(stream.available(), is(6)); - assertThat(stream.read(), is(100)); - assertThat(stream.available(), is(5)); - assertThat(stream.read(), is(101)); - assertThat(stream.available(), is(4)); - assertThat(stream.read(), is(102)); - assertThat(stream.available(), is(3)); - assertThat(stream.read(), is(103)); - assertThat(stream.available(), is(2)); - assertThat(stream.read(), is(104)); - assertThat(stream.available(), is(1)); - assertThat(stream.read(), is(105)); - assertThat(stream.available(), is(0)); - assertThat(stream.read(), is(-1)); - assertThat(stream.available(), is(0)); - assertThat(stream.read(), is(-1)); - assertThat(stream.available(), is(0)); - assertThat(stream.read(), is(-1)); - assertThat(stream.available(), is(0)); - assertThat(stream.read(), is(-1)); - assertThat(stream.available(), is(0)); - assertThat(stream.read(), is(-1)); - assertThat(stream.available(), is(0)); + assertEquals(9, stream.available()); + assertEquals(97, stream.read()); + assertEquals(8, stream.available()); + assertEquals(98, stream.read()); + assertEquals(7, stream.available()); + assertEquals(99, stream.read()); + assertEquals(6, stream.available()); + assertEquals(100, stream.read()); + assertEquals(5, stream.available()); + assertEquals(101, stream.read()); + assertEquals(4, stream.available()); + assertEquals(102, stream.read()); + assertEquals(3, stream.available()); + assertEquals(103, stream.read()); + assertEquals(2, stream.available()); + assertEquals(104, stream.read()); + assertEquals(1, stream.available()); + assertEquals(105, stream.read()); + assertEquals(0, stream.available()); + assertEquals(-1, stream.read()); + assertEquals(0, stream.available()); + assertEquals(-1, stream.read()); + assertEquals(0, stream.available()); + assertEquals(-1, stream.read()); + assertEquals(0, stream.available()); + assertEquals(-1, stream.read()); + assertEquals(0, stream.available()); + assertEquals(-1, stream.read()); + assertEquals(0, stream.available()); } @Test public void requireThatChunkedReadWorks() throws IOException { ByteLimitedInputStream stream = create("abcdefghijklmnopqr".getBytes(StandardCharsets.US_ASCII), 9); - assertThat(stream.available(), is(9)); + assertEquals(9, stream.available()); byte[] toBuf = new byte[4]; - assertThat(stream.read(toBuf), is(4)); - assertThat(toBuf[0], is((byte) 97)); - assertThat(toBuf[1], is((byte) 98)); - assertThat(toBuf[2], is((byte) 99)); - assertThat(toBuf[3], is((byte) 100)); - assertThat(stream.available(), is(5)); + assertEquals(4, stream.read(toBuf)); + assertEquals(97, toBuf[0]); + assertEquals(98, toBuf[1]); + assertEquals(99, toBuf[2]); + assertEquals(100, toBuf[3]); + assertEquals(5, stream.available()); - assertThat(stream.read(toBuf), is(4)); - assertThat(toBuf[0], is((byte) 101)); - assertThat(toBuf[1], is((byte) 102)); - assertThat(toBuf[2], is((byte) 103)); - assertThat(toBuf[3], is((byte) 104)); - assertThat(stream.available(), is(1)); + assertEquals(4, stream.read(toBuf)); + assertEquals(101, toBuf[0]); + assertEquals(102, toBuf[1]); + assertEquals(103, toBuf[2]); + assertEquals(104, toBuf[3]); + assertEquals(1, stream.available()); - assertThat(stream.read(toBuf), is(1)); - assertThat(toBuf[0], is((byte) 105)); - assertThat(stream.available(), is(0)); + assertEquals(1, stream.read(toBuf)); + assertEquals(105, toBuf[0]); + assertEquals(0, stream.available()); - assertThat(stream.read(toBuf), is(-1)); - assertThat(stream.available(), is(0)); + assertEquals(-1, stream.read(toBuf)); + assertEquals(0, stream.available()); + } + + @Test + public void requireMarkWorks() throws IOException { + InputStream stream = create("abcdefghijklmnopqr".getBytes(StandardCharsets.US_ASCII), 9); + assertEquals(97, stream.read()); + assertTrue(stream.markSupported()); + stream.mark(5); + assertEquals(98, stream.read()); + assertEquals(99, stream.read()); + stream.reset(); + assertEquals(98, stream.read()); + assertEquals(99, stream.read()); + assertEquals(100, stream.read()); + assertEquals(101, stream.read()); } } diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java b/vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java index 9a61af7266f..df1d5505632 100644 --- a/vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java +++ b/vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java @@ -13,6 +13,10 @@ import java.io.InputStream; */ public class MockFeedReaderFactory extends FeedReaderFactory { + public MockFeedReaderFactory() { + super(true); + } + @Override public FeedReader createReader( InputStream inputStream, diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index d9467a41f78..154b6871392 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2075,6 +2075,22 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.ScalarFunctions$Erf": { + "superClass": "java.lang.Object", + "interfaces": [ + "java.util.function.DoubleUnaryOperator" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>()", + "public double applyAsDouble(double)", + "public java.lang.String toString()", + "public static double erf(double)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.ScalarFunctions$Exp": { "superClass": "java.lang.Object", "interfaces": [ @@ -2506,6 +2522,7 @@ "public static java.util.function.DoubleUnaryOperator square()", "public static java.util.function.DoubleUnaryOperator tan()", "public static java.util.function.DoubleUnaryOperator tanh()", + "public static java.util.function.DoubleUnaryOperator erf()", "public static java.util.function.DoubleUnaryOperator elu()", "public static java.util.function.DoubleUnaryOperator elu(double)", "public static java.util.function.DoubleUnaryOperator leakyrelu()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index d9204e24d68..c19b07cf96f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -50,6 +50,7 @@ public class ScalarFunctions { public static DoubleUnaryOperator square() { return new Square(); } public static DoubleUnaryOperator tan() { return new Tan(); } public static DoubleUnaryOperator tanh() { return new Tanh(); } + public static DoubleUnaryOperator erf() { return new Erf(); } public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); } @@ -330,6 +331,30 @@ public class ScalarFunctions { public String toString() { return "f(a)(tanh(a))"; } } + public static class Erf implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return erf(operand); } + @Override + public String toString() { return "f(a)(erf(a))"; } + + // Use Horner's method + // From https://introcs.cs.princeton.edu/java/21function/ErrorFunction.java.html + public static double erf(double v) { + double t = 1.0 / (1.0 + 0.5 * Math.abs(v)); + double ans = 1 - t * Math.exp(-v*v - 1.26551223 + + t * ( 1.00002368 + + t * ( 0.37409196 + + t * ( 0.09678418 + + t * (-0.18628806 + + t * ( 0.27886807 + + t * (-1.13520398 + + t * ( 1.48851587 + + t * (-0.82215223 + + t * ( 0.17087277)))))))))); + if (v >= 0) return ans; + else return -ans; + } + } // Variable-length operators ----------------------------------------------------------------------------- |