summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarald Musum <musum@verizonmedia.com>2020-06-19 11:16:02 +0200
committerHarald Musum <musum@verizonmedia.com>2020-06-19 11:16:02 +0200
commit5b6d5916c999069ac5097e18ef755f713e98f411 (patch)
tree79165b521e45d934b2541c2da44b5ed574900ef0
parent5f5df5235de6f7c2532c19833e64adc113ffe1ed (diff)
parenta74a8703830bb4d8656ea626a085c6966de5c06d (diff)
Merge branch 'master' into hmusum/configserver-refactoring-13
-rw-r--r--bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java5
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java2
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java3
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java4
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java2
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java4
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java3
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java3
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/CostCalculator.java19
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/Plan.java23
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/PlanController.java10
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/ResourceUsage.java54
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/package-info.java5
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java4
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java7
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java1
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java4
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java23
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java13
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java10
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java7
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/apps/tensor_conformance/generate.cpp2
-rw-r--r--eval/src/tests/eval/inline_operation/inline_operation_test.cpp1
-rw-r--r--eval/src/tests/eval/multiply_add/CMakeLists.txt9
-rw-r--r--eval/src/tests/eval/multiply_add/multiply_add_test.cpp44
-rw-r--r--eval/src/tests/eval/node_tools/node_tools_test.cpp1
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp1
-rw-r--r--eval/src/vespa/eval/eval/call_nodes.cpp1
-rw-r--r--eval/src/vespa/eval/eval/call_nodes.h1
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp1
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp3
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp3
-rw-r--r--eval/src/vespa/eval/eval/node_tools.cpp1
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp1
-rw-r--r--eval/src/vespa/eval/eval/node_visitor.h2
-rw-r--r--eval/src/vespa/eval/eval/operation.cpp2
-rw-r--r--eval/src/vespa/eval/eval/operation.h1
-rw-r--r--eval/src/vespa/eval/eval/test/eval_spec.cpp7
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp1
-rw-r--r--eval/src/vespa/eval/eval/visit_stuff.cpp1
-rw-r--r--jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java37
-rw-r--r--jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java36
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java83
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java122
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java3
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java91
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java86
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java12
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java67
-rw-r--r--searchlib/abi-spec.json2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java3
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj4
-rw-r--r--searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp87
-rw-r--r--searchlib/src/tests/hitcollector/hitcollector_test.cpp12
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/.gitignore1
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt8
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp4
-rw-r--r--searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp348
-rw-r--r--searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp6
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp36
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp14
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_graph.h84
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp94
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index.h10
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp3
-rw-r--r--searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h7
-rw-r--r--vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java2
-rw-r--r--vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java21
-rw-r--r--vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java18
-rw-r--r--vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/FeedReaderFactoryTestCase.java40
-rw-r--r--vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java113
-rw-r--r--vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java4
-rw-r--r--vespajlib/abi-spec.json17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java25
78 files changed, 1632 insertions, 172 deletions
diff --git a/bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java b/bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java
index 0626c786822..35db7b5fef3 100644
--- a/bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java
+++ b/bundle-plugin/src/main/java/com/yahoo/container/plugin/bundle/AnalyzeBundle.java
@@ -49,11 +49,14 @@ public class AnalyzeBundle {
PublicPackages pp = publicPackages(jarFile);
exports.addAll(pp.exports);
globals.addAll(pp.globals);
+
+ // TODO: remove and clean up everything related to global packages.
+ if (! pp.globals.isEmpty()) throw new RuntimeException("Found global packages in bundle " + jarFile.getAbsolutePath());
}
return new PublicPackages(exports, globals);
}
- public static PublicPackages publicPackages(File jarFile) {
+ static PublicPackages publicPackages(File jarFile) {
try {
Optional<Manifest> jarManifest = JarFiles.getManifest(jarFile);
if (jarManifest.isPresent()) {
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java
index 651572b3e36..0fb1407830a 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/Deployment.java
@@ -155,7 +155,7 @@ public class Deployment implements com.yahoo.config.provision.Deployment {
}
session.waitUntilActivated(timeoutBudget);
- log.log(Level.INFO, session.logPre() + "activated successfully using " +
+ log.log(Level.INFO, session.logPre() + "Session " + session.getSessionId() + " activated successfully using " +
hostProvisioner.map(provisioner -> provisioner.getClass().getSimpleName()).orElse("no host provisioner") +
". Config generation " + session.getMetaData().getGeneration() +
(previousActiveSession != null ? ". Based on previous active session " + previousActiveSession.getSessionId() : "") +
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java
index b7b5c6380ef..543f9c2e303 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/maintenance/ApplicationPackageMaintainer.java
@@ -49,12 +49,15 @@ public class ApplicationPackageMaintainer extends ConfigServerMaintainer {
@Override
protected void maintain() {
+ log.fine(() -> "Running");
if (! distributeApplicationPackage.value()) return;
try (var fileDownloader = new FileDownloader(createConnectionPool(configserverConfig), downloadDirectory)){
for (var applicationId : applicationRepository.listApplications()) {
+ log.fine(() -> "Verifying application package for " + applicationId);
RemoteSession session = applicationRepository.getActiveSession(applicationId);
FileReference applicationPackage = session.getApplicationPackageReference();
+ log.fine(() -> "Verifying application package file reference " + applicationPackage + " for session " + session.getSessionId());
if (applicationPackage != null && missingOnDisk(applicationPackage)) {
log.fine(() -> "Downloading missing application package for application " + applicationId + " - session " + session.getSessionId());
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java
index 2397dba6b5e..aa851a95335 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java
@@ -103,7 +103,9 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P
modelVersion,
wantedNodeVespaVersion);
- log.log(Level.FINE, "Create and validate model " + modelVersion + " for " + applicationId);
+
+ log.log(properties.zone().system().isCd() ? Level.INFO : Level.FINE,
+ "Create and validate model " + modelVersion + " for " + applicationId + ", previous model is " + modelOf(modelVersion));
ValidationParameters validationParameters =
new ValidationParameters(params.ignoreValidationErrors() ? IgnoreValidationErrors.TRUE : IgnoreValidationErrors.FALSE);
ModelCreateResult result = modelFactory.createAndValidateModel(modelContext, validationParameters);
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java
index a29b105f43f..2e101762fc4 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/PrepareParams.java
@@ -84,7 +84,7 @@ public final class PrepareParams {
private boolean verbose = false;
private boolean isBootstrap = false;
private ApplicationId applicationId = ApplicationId.defaultId();
- private TimeoutBudget timeoutBudget = new TimeoutBudget(Clock.systemUTC(), Duration.ofSeconds(30));
+ private TimeoutBudget timeoutBudget = new TimeoutBudget(Clock.systemUTC(), Duration.ofSeconds(60));
private Optional<Version> vespaVersion = Optional.empty();
private List<ContainerEndpoint> containerEndpoints = null;
private Optional<String> tlsSecretsKeyName = Optional.empty();
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java
index 3c83263f781..4542b8267e8 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java
@@ -126,10 +126,10 @@ public class SessionPreparer {
Preparation preparation = new Preparation(hostValidator, logger, params, currentActiveApplicationSet,
tenantPath, serverDbSessionDir, applicationPackage, sessionZooKeeperClient);
- // Note: Done before pre-processing, requires that to be done by users of the distributed package
+ preparation.preprocess();
+
var distributedApplicationPackage = preparation.distributeApplicationPackage();
- preparation.preprocess();
try {
AllocatedHosts allocatedHosts = preparation.buildModels(now);
preparation.makeResult(allocatedHosts);
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java
index dee11bcd332..0a8d6cdde69 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionStateWatcher.java
@@ -51,7 +51,6 @@ public class SessionStateWatcher {
private void sessionChanged(Session.Status status) {
long sessionId = remoteSession.getSessionId();
- log.log(Level.FINE, remoteSession.logPre() + "Session change: Session " + remoteSession.getSessionId() + " changed status to " + status);
// valid for NEW -> PREPARE transitions, not ACTIVATE -> PREPARE.
if (status.equals(Session.Status.PREPARE)) {
@@ -91,6 +90,8 @@ public class SessionStateWatcher {
ChildData node = fileCache.getCurrentData();
if (node != null) {
newStatus = Session.Status.parse(Utf8.toString(node.getData()));
+ log.log(Level.FINE, remoteSession.logPre() + "Session change: Remote session " + remoteSession.getSessionId() +
+ " changed status from " + currentStatus.name() + " to " + newStatus.name());
sessionChanged(newStatus);
}
} catch (Exception e) {
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java
index 222910d5bc6..0b5f2538892 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/ServiceRegistry.java
@@ -4,6 +4,7 @@ package com.yahoo.vespa.hosted.controller.api.integration;
import com.yahoo.vespa.hosted.controller.api.integration.aws.ApplicationRoleService;
import com.yahoo.vespa.hosted.controller.api.integration.aws.AwsEventFetcher;
import com.yahoo.vespa.hosted.controller.api.integration.aws.ResourceTagger;
+import com.yahoo.vespa.hosted.controller.api.integration.billing.PlanController;
import com.yahoo.vespa.hosted.controller.api.integration.certificates.EndpointCertificateProvider;
import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServer;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.ApplicationStore;
@@ -76,4 +77,6 @@ public interface ServiceRegistry {
SystemMonitor systemMonitor();
+ PlanController planController();
+
}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/CostCalculator.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/CostCalculator.java
new file mode 100644
index 00000000000..628beec8450
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/CostCalculator.java
@@ -0,0 +1,19 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.integration.billing;
+
+import com.yahoo.config.provision.NodeResources;
+import com.yahoo.vespa.hosted.controller.api.integration.resource.CostInfo;
+
+
+/**
+ * @author ogronnesby
+ */
+public interface CostCalculator {
+
+ /** Calculate the cost for the given usage */
+ CostInfo calculate(ResourceUsage usage);
+
+ /** Estimate the cost for the given resources */
+ double calculate(NodeResources resources);
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/Plan.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/Plan.java
new file mode 100644
index 00000000000..75a88136c45
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/Plan.java
@@ -0,0 +1,23 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.integration.billing;
+
+/**
+ * A Plan decides two different things:
+ *
+ * - How to map from usage to a sum of money that is owed.
+ * - Limits on how much resources can be used.
+ *
+ * @author ogronnesby
+ */
+public interface Plan {
+
+ /** The ID of the plan as used in APIs and storage systems */
+ String id();
+
+ /** The calculator used to calculate a bill for usage */
+ CostCalculator calculator();
+
+ /** The quota limits associated with the plan */
+ Object quota();
+
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/PlanController.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/PlanController.java
new file mode 100644
index 00000000000..f13c251d212
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/PlanController.java
@@ -0,0 +1,10 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.integration.billing;
+
+import com.yahoo.config.provision.TenantName;
+
+public interface PlanController {
+
+ Plan getPlan(TenantName tenant);
+
+} \ No newline at end of file
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/ResourceUsage.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/ResourceUsage.java
new file mode 100644
index 00000000000..cbfd2b6ff50
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/ResourceUsage.java
@@ -0,0 +1,54 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.hosted.controller.api.integration.billing;
+
+import com.yahoo.config.provision.ApplicationId;
+import com.yahoo.config.provision.zone.ZoneId;
+
+import java.math.BigDecimal;
+
+/**
+ * @author olaa
+ */
+public class ResourceUsage {
+
+ private final ApplicationId applicationId;
+ private final ZoneId zoneId;
+ private final Plan plan;
+ private final BigDecimal cpuMillis;
+ private final BigDecimal memoryMillis;
+ private final BigDecimal diskMillis;
+
+ public ResourceUsage(ApplicationId applicationId, ZoneId zoneId, Plan plan,
+ BigDecimal cpuMillis, BigDecimal memoryMillis, BigDecimal diskMillis) {
+ this.applicationId = applicationId;
+ this.zoneId = zoneId;
+ this.cpuMillis = cpuMillis;
+ this.memoryMillis = memoryMillis;
+ this.diskMillis = diskMillis;
+ this.plan = plan;
+ }
+
+ public ApplicationId getApplicationId() {
+ return applicationId;
+ }
+
+ public ZoneId getZoneId() {
+ return zoneId;
+ }
+
+ public BigDecimal getCpuMillis() {
+ return cpuMillis;
+ }
+
+ public BigDecimal getMemoryMillis() {
+ return memoryMillis;
+ }
+
+ public BigDecimal getDiskMillis() {
+ return diskMillis;
+ }
+
+ public Plan getPlan() {
+ return plan;
+ }
+}
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/package-info.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/package-info.java
new file mode 100644
index 00000000000..ae31f4a782d
--- /dev/null
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/billing/package-info.java
@@ -0,0 +1,5 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+@ExportPackage
+package com.yahoo.vespa.hosted.controller.api.integration.billing;
+
+import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java
index 68dff26529f..2fdf442dbe0 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java
@@ -219,8 +219,10 @@ enum PathGroup {
/** Paths used for invoice management */
hostedAccountant(PathPrefix.api,
"/billing/v1/invoice/{*}",
- "/billing/v1/billing");
+ "/billing/v1/billing"),
+ /** Path used for listing endpoint certificate request info */
+ endpointCertificateRequestInfo(PathPrefix.none, "/certificateRequests/");
final List<String> pathSpecs;
final PathPrefix prefix;
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java
index 9a5a0ad0e77..83adba6f59b 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/Policy.java
@@ -166,7 +166,12 @@ enum Policy {
/** Invoice management */
hostedAccountant(Privilege.grant(Action.all())
.on(PathGroup.hostedAccountant)
- .in(SystemName.PublicCd));
+ .in(SystemName.PublicCd)),
+
+ /** Listing endpoint certificate request info */
+ endpointCertificateRequestInfo(Privilege.grant(Action.read)
+ .on(PathGroup.endpointCertificateRequestInfo)
+ .in(SystemName.all()));
private final Set<Privilege> privileges;
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java
index bf5ba4001fa..b9d534019db 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/RoleDefinition.java
@@ -61,6 +61,7 @@ public enum RoleDefinition {
/** Admin — the administrative function for user management etc. */
administrator(Policy.tenantUpdate,
Policy.tenantManager,
+ Policy.tenantDelete,
Policy.applicationManager,
Policy.paymentInstrumentRead,
Policy.paymentInstrumentUpdate,
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java
index 64549825b04..425364f6741 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificateManager.java
@@ -6,7 +6,6 @@ import com.google.common.io.BaseEncoding;
import com.yahoo.config.application.api.DeploymentInstanceSpec;
import com.yahoo.config.provision.ApplicationId;
import com.yahoo.config.provision.ClusterSpec;
-import com.yahoo.config.provision.Environment;
import com.yahoo.config.provision.SystemName;
import com.yahoo.config.provision.zone.RoutingMethod;
import com.yahoo.config.provision.zone.ZoneApi;
@@ -51,7 +50,6 @@ import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import java.util.stream.Collectors;
-import java.util.stream.Stream;
/**
* Looks up stored endpoint certificate metadata, provisions new certificates if none is found,
@@ -323,10 +321,10 @@ public class EndpointCertificateManager {
}
/** Create a common name based on a hash of the ApplicationId. This should always be less than 64 characters long. */
+ @SuppressWarnings("UnstableApiUsage")
private static String commonNameHashOf(ApplicationId application, SystemName system) {
var hashCode = Hashing.sha1().hashString(application.serializedForm(), Charset.defaultCharset());
var base32encoded = BaseEncoding.base32().omitPadding().lowerCase().encode(hashCode.asBytes());
return 'v' + base32encoded + Endpoint.dnsSuffix(system);
}
-
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java
index ee5d50414c1..786547d4a67 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueue.java
@@ -1,7 +1,6 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.hosted.controller.dns;
-import java.util.logging.Level;
import com.yahoo.vespa.hosted.controller.api.integration.dns.NameService;
import java.util.ArrayList;
@@ -10,6 +9,8 @@ import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.LinkedBlockingDeque;
+import java.util.function.UnaryOperator;
+import java.util.logging.Level;
import java.util.logging.Logger;
/**
@@ -39,12 +40,14 @@ public class NameServiceQueue {
return Collections.unmodifiableCollection(requests);
}
- /** Returns a copy of this containing only the n most recent requests */
+ /** Returns a copy of this containing the last n requests */
public NameServiceQueue last(int n) {
- requireNonNegative(n);
- if (requests.size() <= n) return this;
- List<NameServiceRequest> requests = new ArrayList<>(this.requests);
- return new NameServiceQueue(requests.subList(requests.size() - n, requests.size()));
+ return resize(n, (requests) -> requests.subList(requests.size() - n, requests.size()));
+ }
+
+ /** Returns a copy of this containing the first n requests */
+ public NameServiceQueue first(int n) {
+ return resize(n, (requests) -> requests.subList(0, n));
}
/** Returns a copy of this with given request queued according to priority */
@@ -58,6 +61,7 @@ public class NameServiceQueue {
return queue;
}
+ /** Returns a copy of this with given request added */
public NameServiceQueue with(NameServiceRequest request) {
return with(request, Priority.normal);
}
@@ -91,6 +95,13 @@ public class NameServiceQueue {
return requests.toString();
}
+ private NameServiceQueue resize(int n, UnaryOperator<List<NameServiceRequest>> resizer) {
+ requireNonNegative(n);
+ if (requests.size() <= n) return this;
+ List<NameServiceRequest> requests = new ArrayList<>(this.requests);
+ return new NameServiceQueue(resizer.apply(requests));
+ }
+
private static void requireNonNegative(int n) {
if (n < 0) throw new IllegalArgumentException("n must be >= 0, got " + n);
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java
index 73daef7c2b0..e7eaf083a57 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/NameServiceDispatcher.java
@@ -45,7 +45,7 @@ public class NameServiceDispatcher extends ControllerMaintainer {
var remaining = queue.dispatchTo(nameService, requestCount);
if (queue == remaining) return; // Queue unchanged
- var dispatched = queue.last(requestCount);
+ var dispatched = queue.first(requestCount);
if (!dispatched.requests().isEmpty()) {
log.log(Level.INFO, "Dispatched name service request(s) in " +
Duration.between(instant, clock.instant()) +
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java
index dc3dbabcc07..bd0143ef879 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/security/CloudAccessControl.java
@@ -8,6 +8,8 @@ import com.yahoo.vespa.flags.FetchVector;
import com.yahoo.vespa.flags.FlagSource;
import com.yahoo.vespa.flags.Flags;
import com.yahoo.vespa.hosted.controller.Application;
+import com.yahoo.vespa.hosted.controller.api.integration.ServiceRegistry;
+import com.yahoo.vespa.hosted.controller.api.integration.billing.PlanController;
import com.yahoo.vespa.hosted.controller.api.integration.organization.BillingInfo;
import com.yahoo.vespa.hosted.controller.api.integration.user.Roles;
import com.yahoo.vespa.hosted.controller.api.integration.user.UserId;
@@ -34,11 +36,13 @@ public class CloudAccessControl implements AccessControl {
private final UserManagement userManagement;
private final BooleanFlag enablePublicSignup;
+ private final PlanController planController;
@Inject
- public CloudAccessControl(UserManagement userManagement, FlagSource flagSource) {
+ public CloudAccessControl(UserManagement userManagement, FlagSource flagSource, ServiceRegistry serviceRegistry) {
this.userManagement = userManagement;
this.enablePublicSignup = Flags.ENABLE_PUBLIC_SIGNUP_FLOW.bindTo(flagSource);
+ planController = serviceRegistry.planController();
}
@Override
@@ -97,12 +101,17 @@ public class CloudAccessControl implements AccessControl {
@Override
public void deleteTenant(TenantName tenant, Credentials credentials) {
- // TODO: allow only if 0 resources, 0 balance
+ if(!(allowedByPrivilegedRole((Auth0Credentials) credentials) || isTrial(tenant)))
+ throw new ForbiddenException("Please contact the Vespa team for assistance in deleting non-trial tenants");
for (TenantRole role : Roles.tenantRoles(tenant))
userManagement.deleteRole(role);
}
+ private boolean isTrial(TenantName tenant) {
+ return planController.getPlan(tenant).id().equals("trial");
+ }
+
@Override
public void createApplication(TenantAndApplicationId id, Credentials credentials) {
for (Role role : Roles.applicationRoles(id.tenant(), id.application()))
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java
index d0362ae98b8..30ed9b5432c 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/dns/NameServiceQueueTest.java
@@ -74,13 +74,21 @@ public class NameServiceQueueTest {
assertTrue(queue.requests().isEmpty());
assertTrue("Removed " + r1, nameService.findRecords(Record.Type.CNAME, r1.name()).isEmpty());
- // Keep n most recent requests
+ // Keep n last requests
queue = queue.with(req1).with(req2).with(req3).with(req4).with(req6)
.last(2);
assertEquals(List.of(req4, req6), List.copyOf(queue.requests()));
assertSame(queue, queue.last(2));
assertSame(queue, queue.last(10));
assertTrue(queue.last(0).requests().isEmpty());
+
+ // Keep n first requests
+ queue = NameServiceQueue.EMPTY.with(req1).with(req2).with(req3).with(req4).with(req6)
+ .first(3);
+ assertEquals(List.of(req1, req2, req3), List.copyOf(queue.requests()));
+ assertSame(queue, queue.first(3));
+ assertSame(queue, queue.first(10));
+ assertTrue(queue.first(0).requests().isEmpty());
}
}
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java
index 4f81e443d9c..b7e7c9814e3 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ServiceRegistryMock.java
@@ -12,6 +12,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.aws.MockAwsEventFetcher
import com.yahoo.vespa.hosted.controller.api.integration.aws.MockResourceTagger;
import com.yahoo.vespa.hosted.controller.api.integration.aws.NoopApplicationRoleService;
import com.yahoo.vespa.hosted.controller.api.integration.aws.ResourceTagger;
+import com.yahoo.vespa.hosted.controller.api.integration.billing.PlanController;
import com.yahoo.vespa.hosted.controller.api.integration.certificates.EndpointCertificateMock;
import com.yahoo.vespa.hosted.controller.api.integration.configserver.ConfigServer;
import com.yahoo.vespa.hosted.controller.api.integration.dns.MemoryNameService;
@@ -59,6 +60,7 @@ public class ServiceRegistryMock extends AbstractComponent implements ServiceReg
private final MockRunDataStore mockRunDataStore = new MockRunDataStore();
private final MockResourceTagger mockResourceTagger = new MockResourceTagger();
private final ApplicationRoleService applicationRoleService = new NoopApplicationRoleService();
+ private final PlanController planController = (tenantName) -> null;
public ServiceRegistryMock(SystemName system) {
this.zoneRegistryMock = new ZoneRegistryMock(system);
@@ -201,4 +203,9 @@ public class ServiceRegistryMock extends AbstractComponent implements ServiceReg
return endpointCertificateMock;
}
+ @Override
+ public PlanController planController() {
+ return planController;
+ }
+
}
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index 3a9aabc83ba..3e81521550a 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -19,6 +19,7 @@ vespa_define_module(
src/tests/eval/gbdt
src/tests/eval/inline_operation
src/tests/eval/interpreted_function
+ src/tests/eval/multiply_add
src/tests/eval/node_tools
src/tests/eval/node_types
src/tests/eval/param_usage
diff --git a/eval/src/apps/tensor_conformance/generate.cpp b/eval/src/apps/tensor_conformance/generate.cpp
index 7d48307b786..df1c06593cb 100644
--- a/eval/src/apps/tensor_conformance/generate.cpp
+++ b/eval/src/apps/tensor_conformance/generate.cpp
@@ -100,6 +100,8 @@ void generate_tensor_map(TestBuilder &dst) {
generate_op1_map("relu(a)", operation::Relu::f, Sub2(Div16(N())), dst);
generate_op1_map("sigmoid(a)", operation::Sigmoid::f, Sub2(Div16(N())), dst);
generate_op1_map("elu(a)", operation::Elu::f, Sub2(Div16(N())), dst);
+ // TODO(havardpe): add erf when supported by Java
+ // generate_op1_map("erf(a)", operation::Erf::f, Sub2(Div16(N())), dst);
generate_op1_map("a in [1,5,7,13,42]", MyIn::f, N(), dst);
generate_map_expr("map(a,f(a)((a+1)*2))", MyOp::f, Div16(N()), dst);
}
diff --git a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
index fe9396398da..de5a3fbf395 100644
--- a/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
+++ b/eval/src/tests/eval/inline_operation/inline_operation_test.cpp
@@ -65,6 +65,7 @@ TEST(InlineOperationTest, op1_lambdas_are_recognized) {
EXPECT_EQ(as_op1("relu(a)"), &Relu::f);
EXPECT_EQ(as_op1("sigmoid(a)"), &Sigmoid::f);
EXPECT_EQ(as_op1("elu(a)"), &Elu::f);
+ EXPECT_EQ(as_op1("erf(a)"), &Erf::f);
//-------------------------------------------
EXPECT_EQ(as_op1("1/a"), &Inv::f);
EXPECT_EQ(as_op1("1.0/a"), &Inv::f);
diff --git a/eval/src/tests/eval/multiply_add/CMakeLists.txt b/eval/src/tests/eval/multiply_add/CMakeLists.txt
new file mode 100644
index 00000000000..c50aa4f50a2
--- /dev/null
+++ b/eval/src/tests/eval/multiply_add/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_multiply_add_test_app TEST
+ SOURCES
+ multiply_add_test.cpp
+ DEPENDS
+ vespaeval
+ gtest
+)
+vespa_add_test(NAME eval_multiply_add_test_app COMMAND eval_multiply_add_test_app)
diff --git a/eval/src/tests/eval/multiply_add/multiply_add_test.cpp b/eval/src/tests/eval/multiply_add/multiply_add_test.cpp
new file mode 100644
index 00000000000..35cab0a6030
--- /dev/null
+++ b/eval/src/tests/eval/multiply_add/multiply_add_test.cpp
@@ -0,0 +1,44 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/function.h>
+#include <vespa/eval/eval/llvm/compiled_function.h>
+#include <vespa/eval/eval/interpreted_function.h>
+#include <vespa/eval/tensor/default_tensor_engine.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+
+using Engine = vespalib::tensor::DefaultTensorEngine;
+
+double gcc_fun(double a, double b) {
+ return (a * 3) + b;
+}
+
+TEST(MultiplyAddTest, multiply_add_gives_same_result) {
+ auto fun = Function::parse("a*3+b");
+ CompiledFunction cfun(*fun, PassParams::ARRAY);
+ NodeTypes node_types = NodeTypes(*fun, {ValueType::double_type(), ValueType::double_type()});
+ InterpretedFunction ifun(Engine::ref(), *fun, node_types);
+ auto llvm_fun = cfun.get_function();
+ //-------------------------------------------------------------------------
+ double a = -1.0/3.0;
+ double b = 1.0;
+ std::vector<double> ab({a, b});
+ SimpleParams params(ab);
+ InterpretedFunction::Context ictx(ifun);
+ //-------------------------------------------------------------------------
+ const Value &result_value = ifun.eval(ictx, params);
+ double ifun_res = result_value.as_double();
+ double llvm_res = llvm_fun(&ab[0]);
+ double gcc_res = gcc_fun(a, b);
+ fprintf(stderr, "ifun_res: %a\n", ifun_res);
+ fprintf(stderr, "llvm_res: %a\n", llvm_res);
+ fprintf(stderr, "gcc_res: %a\n", gcc_res);
+ EXPECT_EQ(ifun_res, llvm_res);
+ EXPECT_DOUBLE_EQ(llvm_res + 1.0, gcc_res + 1.0);
+ if (llvm_res != gcc_res) {
+ fprintf(stderr, "WARNING: diverging results caused by fused multiply add\n");
+ }
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/tests/eval/node_tools/node_tools_test.cpp b/eval/src/tests/eval/node_tools/node_tools_test.cpp
index ca89650127e..13185065f57 100644
--- a/eval/src/tests/eval/node_tools/node_tools_test.cpp
+++ b/eval/src/tests/eval/node_tools/node_tools_test.cpp
@@ -99,6 +99,7 @@ TEST("require that call node types can be copied") {
TEST_DO(verify_copy("relu(a)"));
TEST_DO(verify_copy("sigmoid(a)"));
TEST_DO(verify_copy("elu(a)"));
+ TEST_DO(verify_copy("erf(a)"));
}
TEST("require that tensor node types can NOT be copied (yet)") {
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index 7912ec213bc..f595c58ef29 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -191,6 +191,7 @@ TEST("require that various operations resolve appropriate type") {
TEST_DO(verify_op1("relu(%s)")); // Relu
TEST_DO(verify_op1("sigmoid(%s)")); // Sigmoid
TEST_DO(verify_op1("elu(%s)")); // Elu
+ TEST_DO(verify_op1("erf(%s)")); // Erf
}
TEST("require that map resolves correct type") {
diff --git a/eval/src/vespa/eval/eval/call_nodes.cpp b/eval/src/vespa/eval/eval/call_nodes.cpp
index 69a9151a2bb..2fc25bdbc77 100644
--- a/eval/src/vespa/eval/eval/call_nodes.cpp
+++ b/eval/src/vespa/eval/eval/call_nodes.cpp
@@ -42,6 +42,7 @@ CallRepo::CallRepo() : _map() {
add(nodes::Relu());
add(nodes::Sigmoid());
add(nodes::Elu());
+ add(nodes::Erf());
}
} // namespace vespalib::eval::nodes
diff --git a/eval/src/vespa/eval/eval/call_nodes.h b/eval/src/vespa/eval/eval/call_nodes.h
index 8210616750e..c5a41756005 100644
--- a/eval/src/vespa/eval/eval/call_nodes.h
+++ b/eval/src/vespa/eval/eval/call_nodes.h
@@ -138,6 +138,7 @@ struct IsNan : CallHelper<IsNan> { IsNan() : Helper("isNan", 1) {} };
struct Relu : CallHelper<Relu> { Relu() : Helper("relu", 1) {} };
struct Sigmoid : CallHelper<Sigmoid> { Sigmoid() : Helper("sigmoid", 1) {} };
struct Elu : CallHelper<Elu> { Elu() : Helper("elu", 1) {} };
+struct Erf : CallHelper<Erf> { Erf() : Helper("erf", 1) {} };
//-----------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index cd31f92f96a..31167be5fe1 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -85,6 +85,7 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
void visit(const Relu &) override { add_byte(59); }
void visit(const Sigmoid &) override { add_byte(60); }
void visit(const Elu &) override { add_byte(61); }
+ void visit(const Erf &) override { add_byte(62); }
// traverse
bool open(const Node &node) override { node.accept(*this); return true; }
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 89f1789e97b..6f9bee025c9 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -644,6 +644,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const Elu &) override {
make_call_1("vespalib_eval_elu");
}
+ void visit(const Erf &) override {
+ make_call_1("erf");
+ }
};
FunctionBuilder::~FunctionBuilder() { }
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index e80633b5c41..c5b5ca59401 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -344,6 +344,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const Elu &node) override {
make_map(node, operation::Elu::f);
}
+ void visit(const Erf &node) override {
+ make_map(node, operation::Erf::f);
+ }
//-------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp
index 7bbe095c060..1c194736138 100644
--- a/eval/src/vespa/eval/eval/node_tools.cpp
+++ b/eval/src/vespa/eval/eval/node_tools.cpp
@@ -180,6 +180,7 @@ struct CopyNode : NodeTraverser, NodeVisitor {
void visit(const Relu &node) override { copy_call(node); }
void visit(const Sigmoid &node) override { copy_call(node); }
void visit(const Elu &node) override { copy_call(node); }
+ void visit(const Erf &node) override { copy_call(node); }
// traverse nodes
bool open(const Node &) override { return !error; }
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 468b9a58655..cbc96e719e0 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -274,6 +274,7 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
void visit(const Relu &node) override { resolve_op1(node); }
void visit(const Sigmoid &node) override { resolve_op1(node); }
void visit(const Elu &node) override { resolve_op1(node); }
+ void visit(const Erf &node) override { resolve_op1(node); }
//-------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h
index d3e066c8f53..95a5bec8be7 100644
--- a/eval/src/vespa/eval/eval/node_visitor.h
+++ b/eval/src/vespa/eval/eval/node_visitor.h
@@ -83,6 +83,7 @@ struct NodeVisitor {
virtual void visit(const nodes::Relu &) = 0;
virtual void visit(const nodes::Sigmoid &) = 0;
virtual void visit(const nodes::Elu &) = 0;
+ virtual void visit(const nodes::Erf &) = 0;
virtual ~NodeVisitor() {}
};
@@ -150,6 +151,7 @@ struct EmptyNodeVisitor : NodeVisitor {
void visit(const nodes::Relu &) override {}
void visit(const nodes::Sigmoid &) override {}
void visit(const nodes::Elu &) override {}
+ void visit(const nodes::Erf &) override {}
};
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/operation.cpp b/eval/src/vespa/eval/eval/operation.cpp
index fa8be4d20bc..b97ac3f2261 100644
--- a/eval/src/vespa/eval/eval/operation.cpp
+++ b/eval/src/vespa/eval/eval/operation.cpp
@@ -49,6 +49,7 @@ double IsNan::f(double a) { return std::isnan(a) ? 1.0 : 0.0; }
double Relu::f(double a) { return std::max(a, 0.0); }
double Sigmoid::f(double a) { return 1.0 / (1.0 + std::exp(-1.0 * a)); }
double Elu::f(double a) { return (a < 0) ? std::exp(a) - 1 : a; }
+double Erf::f(double a) { return std::erf(a); }
//-----------------------------------------------------------------------------
double Inv::f(double a) { return (1.0 / a); }
double Square::f(double a) { return (a * a); }
@@ -106,6 +107,7 @@ std::map<vespalib::string,op1_t> make_op1_map() {
add_op1(map, "relu(a)", Relu::f);
add_op1(map, "sigmoid(a)", Sigmoid::f);
add_op1(map, "elu(a)", Elu::f);
+ add_op1(map, "erf(a)", Erf::f);
//-------------------------------------
add_op1(map, "1/a", Inv::f);
add_op1(map, "a*a", Square::f);
diff --git a/eval/src/vespa/eval/eval/operation.h b/eval/src/vespa/eval/eval/operation.h
index 02d3322f867..3170c868214 100644
--- a/eval/src/vespa/eval/eval/operation.h
+++ b/eval/src/vespa/eval/eval/operation.h
@@ -48,6 +48,7 @@ struct IsNan { static double f(double a); };
struct Relu { static double f(double a); };
struct Sigmoid { static double f(double a); };
struct Elu { static double f(double a); };
+struct Erf { static double f(double a); };
//-----------------------------------------------------------------------------
struct Inv { static double f(double a); };
struct Square { static double f(double a); };
diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp
index 709234a1a2c..b1dfa6d3c9c 100644
--- a/eval/src/vespa/eval/eval/test/eval_spec.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp
@@ -151,6 +151,7 @@ EvalSpec::add_function_call_cases() {
add_rule({"a", -1.0, 1.0}, "relu(a)", [](double a){ return std::max(a, 0.0); });
add_rule({"a", -1.0, 1.0}, "sigmoid(a)", [](double a){ return 1.0 / (1.0 + std::exp(-1.0 * a)); });
add_rule({"a", -1.0, 1.0}, "elu(a)", [](double a){ return (a < 0) ? std::exp(a)-1 : a; });
+ add_rule({"a", -1.0, 1.0}, "erf(a)", [](double a){ return std::erf(a); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "atan2(a,b)", [](double a, double b){ return std::atan2(a, b); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "ldexp(a,b)", [](double a, double b){ return std::ldexp(a, b); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "pow(a,b)", [](double a, double b){ return std::pow(a, b); });
@@ -162,11 +163,11 @@ EvalSpec::add_function_call_cases() {
void
EvalSpec::add_tensor_operation_cases() {
add_rule({"a", -1.0, 1.0}, "map(a,f(x)(sin(x)))", [](double x){ return std::sin(x); });
- add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x+x*3))", [](double x){ return (x + (x * 3)); });
+ add_rule({"a", -1.0, 1.0}, "map(a,f(x)(x*x*3))", [](double x){ return ((x * x) * 3); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); });
- add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); });
+ add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "join(a,b,f(x,y)(x*y*3))", [](double x, double y){ return ((x * y) * 3); });
add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y))", [](double x, double y){ return (x + y); });
- add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x+y*3))", [](double x, double y){ return (x + (y * 3)); });
+ add_rule({"a", -1.0, 1.0}, {"b", -1.0, 1.0}, "merge(a,b,f(x,y)(x*y*3))", [](double x, double y){ return ((x * y) * 3); });
add_rule({"a", -1.0, 1.0}, "reduce(a,avg)", [](double a){ return a; });
add_rule({"a", -1.0, 1.0}, "reduce(a,count)", [](double){ return 1.0; });
add_rule({"a", -1.0, 1.0}, "reduce(a,prod)", [](double a){ return a; });
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index 41c6dd21e24..95e720cd1a2 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -412,6 +412,7 @@ struct TestContext {
TEST_DO(test_map_op("relu(a)", operation::Relu::f, Sub2(Div16(N()))));
TEST_DO(test_map_op("sigmoid(a)", operation::Sigmoid::f, Sub2(Div16(N()))));
TEST_DO(test_map_op("elu(a)", operation::Elu::f, Sub2(Div16(N()))));
+ TEST_DO(test_map_op("erf(a)", operation::Erf::f, Sub2(Div16(N()))));
TEST_DO(test_map_op("a in [1,5,7,13,42]", MyIn::f, N()));
TEST_DO(test_map_op("(a+1)*2", MyOp::f, Div16(N())));
}
diff --git a/eval/src/vespa/eval/eval/visit_stuff.cpp b/eval/src/vespa/eval/eval/visit_stuff.cpp
index 821e609ebd0..9306a720837 100644
--- a/eval/src/vespa/eval/eval/visit_stuff.cpp
+++ b/eval/src/vespa/eval/eval/visit_stuff.cpp
@@ -35,6 +35,7 @@ vespalib::string name_of(map_fun_t fun) {
if (fun == operation::Relu::f) return "relu";
if (fun == operation::Sigmoid::f) return "sigmoid";
if (fun == operation::Elu::f) return "elu";
+ if (fun == operation::Erf::f) return "erf";
return "[other map function]";
}
diff --git a/jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java b/jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java
index 748c2951a6a..1662ed5b46a 100644
--- a/jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java
+++ b/jdisc_core/src/main/java/com/yahoo/jdisc/handler/UnsafeContentInputStream.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jdisc.handler;
+import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.util.Objects;
@@ -19,6 +20,8 @@ public class UnsafeContentInputStream extends InputStream {
private final ReadableContentChannel content;
private ByteBuffer buf = ByteBuffer.allocate(0);
+ private byte [] marked;
+ private int readSinceMarked;
/**
* <p>Constructs a new ContentInputStream that reads from the given {@link ReadableContentChannel}.</p>
@@ -37,7 +40,15 @@ public class UnsafeContentInputStream extends InputStream {
if (buf == null) {
return -1;
}
- return ((int)buf.get()) & 0xFF;
+ byte b = buf.get();
+ if (marked != null) {
+ if (readSinceMarked < marked.length) {
+ marked[readSinceMarked++] = b;
+ } else {
+ marked = null;
+ }
+ }
+ return ((int)b) & 0xFF;
}
@Override
@@ -79,4 +90,28 @@ public class UnsafeContentInputStream extends InputStream {
}
}
+
+ @Override
+ public synchronized void mark(int readlimit) {
+ marked = new byte[readlimit];
+ readSinceMarked = 0;
+ }
+
+ @Override
+ public synchronized void reset() throws IOException {
+ if (marked == null) {
+ throw new IOException("mark has not been called, or too much has been read since marked.");
+ }
+ ByteBuffer newBuf = ByteBuffer.allocate(readSinceMarked + buf.remaining());
+ newBuf.put(marked, 0, readSinceMarked);
+ newBuf.put(buf);
+ newBuf.flip();
+ buf = newBuf;
+ marked = null;
+ }
+
+ @Override
+ public boolean markSupported() {
+ return true;
+ }
}
diff --git a/jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java b/jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java
index c00fab6cb56..c96450c1bd2 100644
--- a/jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java
+++ b/jdisc_core/src/test/java/com/yahoo/jdisc/handler/UnsafeContentInputStreamTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jdisc.handler;
+import com.yahoo.text.Utf8;
import org.junit.Test;
import java.io.BufferedReader;
@@ -8,11 +9,11 @@ import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.nio.ByteBuffer;
-import java.util.concurrent.Future;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
/**
* @author Simon Thoresen Hult
@@ -32,7 +33,37 @@ public class UnsafeContentInputStreamTestCase {
assertNull(reader.readLine());
}
- @SuppressWarnings("deprecation")
+ @Test
+ public void testMark() throws IOException {
+ BufferedContentChannel channel = new BufferedContentChannel();
+ FastContentWriter writer = new FastContentWriter(channel);
+ writer.write("Hello ");
+ writer.write("World!");
+ writer.close();
+
+ InputStream stream = asInputStream(channel);
+ assertTrue(stream.markSupported());
+ int first = stream.read();
+ assertEquals('H', first);
+ stream.mark(10);
+ byte [] buf = new byte[8];
+ stream.read(buf);
+ assertEquals("ello Wor", Utf8.toString(buf));
+ stream.reset();
+ stream.mark(5);
+ buf = new byte [9];
+ stream.read(buf);
+ assertEquals("ello Worl", Utf8.toString(buf));
+ try {
+ stream.reset();
+ fail("UnsafeContentInputStream.reset expected to fail when your read past readLimit.");
+ } catch (IOException e) {
+ assertEquals("mark has not been called, or too much has been read since marked.", e.getMessage());
+ } catch (Throwable t) {
+ fail("Did not expect " + t);
+ }
+ }
+
@Test
public void requireThatCompletionsAreCalledWithDeprecatedContentWriter() throws IOException {
BufferedContentChannel channel = new BufferedContentChannel();
@@ -63,7 +94,6 @@ public class UnsafeContentInputStreamTestCase {
assertTrue(writer.isDone());
}
- @SuppressWarnings("deprecation")
@Test
public void requireThatCloseDrainsStreamWithDeprecatedContentWriter() {
BufferedContentChannel channel = new BufferedContentChannel();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index d14ad033a69..3d36b1bfffc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,11 +2,15 @@
package ai.vespa.rankingexpression.importer.onnx;
+import ai.vespa.rankingexpression.importer.operations.ConstantOfShape;
+import ai.vespa.rankingexpression.importer.operations.Expand;
import ai.vespa.rankingexpression.importer.operations.Gather;
+import ai.vespa.rankingexpression.importer.operations.OnnxConstant;
import ai.vespa.rankingexpression.importer.operations.OnnxCast;
import ai.vespa.rankingexpression.importer.operations.Gemm;
import ai.vespa.rankingexpression.importer.operations.ConcatReduce;
import ai.vespa.rankingexpression.importer.operations.OnnxConcat;
+import ai.vespa.rankingexpression.importer.operations.Range;
import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Slice;
@@ -81,11 +85,15 @@ class GraphImporter {
case "cast": return new OnnxCast(modelName, nodeName, inputs, attributes);
case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
case "concat": return new OnnxConcat(modelName, nodeName, inputs, attributes);
+ case "constant": return new OnnxConstant(modelName, nodeName, inputs, attributes);
+ case "constantofshape": return new ConstantOfShape(modelName, nodeName, inputs, attributes);
case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu(attributes.get("alpha").orElse(eluAlpha).asDouble()));
+ case "erf": return new Map(modelName, nodeName, inputs, ScalarFunctions.erf());
case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
+ case "expand": return new Expand(modelName, nodeName, inputs);
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
case "gather": return new Gather(modelName, nodeName, inputs, attributes);
case "gemm": return new Gemm(modelName, nodeName, inputs, attributes);
@@ -100,6 +108,7 @@ class GraphImporter {
case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
+ case "range": return new Range(modelName, nodeName, inputs);
case "reshape": return new Reshape(modelName, nodeName, inputs, attributes);
case "reducel1": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.abs(), null);
case "reducel2": return new Reduce(modelName, nodeName, inputs, attributes, com.yahoo.tensor.functions.Reduce.Aggregator.sum, ScalarFunctions.square(), ScalarFunctions.sqrt());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java
new file mode 100644
index 00000000000..887e350b430
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConstantOfShape.java
@@ -0,0 +1,83 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class ConstantOfShape extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+
+ private TensorType.Value valueTypeOfTensor = TensorType.Value.DOUBLE;
+ private double valueToFillWith = 0.0;
+
+
+ public ConstantOfShape(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+
+ Optional<Value> value = attributeMap.get("value");
+ if (value.isPresent()) {
+ Tensor t = value.get().asTensor();
+ valueTypeOfTensor = t.type().valueType();
+ valueToFillWith = t.valueIterator().next();
+ }
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(1)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ if (input.getConstantValue().isEmpty()) {
+ throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a constant.");
+ }
+ Tensor shape = input.getConstantValue().get().asTensor();
+ if (shape.type().dimensions().size() > 1) {
+ throw new IllegalArgumentException("ConstantOfShape: 'shape' input must be a tensor with 0 or 1 dimensions.");
+ }
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueTypeOfTensor);
+ Iterator<Double> iter = shape.valueIterator();
+ for (int i = 0; iter.hasNext(); i++) {
+ builder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, iter.next().longValue()));
+ }
+ return builder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputTypesPresent(1)) return null;
+ ExpressionNode valueExpr = new ConstantNode(new DoubleValue(valueToFillWith));
+ TensorFunction function = Generate.bound(type.type(), wrapScalar(valueExpr));
+ return function;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public ConstantOfShape withInputs(List<IntermediateOperation> inputs) {
+ return new ConstantOfShape(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "ConstantOfShape"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java
new file mode 100644
index 00000000000..30a7bc3bbad
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Expand.java
@@ -0,0 +1,122 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class Expand extends IntermediateOperation {
+
+ public Expand(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) return null;
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ Optional<Value> shapeValue = inputs.get(1).getConstantValue();
+ if (shapeValue.isEmpty())
+ throw new IllegalArgumentException("Expand " + name + ": shape must be a constant.");
+
+ Tensor shape = shapeValue.get().asTensor();
+ if (shape.type().rank() != 1)
+ throw new IllegalArgumentException("Expand " + name + ": shape must be a 1-d tensor.");
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ int inputRank = inputType.rank();
+ int shapeSize = shape.type().dimensions().get(0).size().get().intValue();
+ int sizeDiff = shapeSize - inputRank;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(inputType.type().valueType());
+ Iterator<Double> iter = shape.valueIterator();
+
+ // Add any extra dimensions
+ for (int i = 0; i < sizeDiff; ++i) {
+ typeBuilder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, iter.next().intValue()));
+ }
+
+ // Dimensions are matched innermost
+ for (int i = sizeDiff; i < shapeSize; i++) {
+ int shapeDimSize = iter.next().intValue();
+ int inputDimSize = inputType.dimensions().get(i - sizeDiff).size().get().intValue();
+ if (shapeDimSize != inputDimSize && shapeDimSize != 1 && inputDimSize != 1) {
+ throw new IllegalArgumentException("Expand " + name + ": dimension sizes of input and shape " +
+ "are not compatible. Either they must be equal or one must be of size 1.");
+ }
+ int dimSize = Math.max(shapeDimSize, inputDimSize);
+ typeBuilder.add(TensorType.Dimension.indexed(vespaName() + "_" + i, dimSize));
+ }
+
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(2)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ OrderedTensorType inputType = input.type().get();
+ OrderedTensorType type = type().get();
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ int sizeDiff = type().get().rank() - inputType.rank();
+ for (int i = sizeDiff; i < type().get().rank(); ++i) {
+ String inputDimensionName = inputType.dimensions().get(i - sizeDiff).name();
+ String typeDimensionName = type.dimensionNames().get(i);
+ long inputDimensionSize = inputType.dimensions().get(i - sizeDiff).size().get();
+
+ ExpressionNode index;
+ if (inputDimensionSize == 1) {
+ index = new ConstantNode(new DoubleValue(0.0));
+ } else {
+ index = new EmbracedNode(new ReferenceNode(typeDimensionName));
+ }
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(index)));
+ }
+
+ TensorFunction<Reference> externalRef = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(externalRef, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+ return Generate.bound(type.type(), wrapScalar(sliceExpression));
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public Expand withInputs(List<IntermediateOperation> inputs) {
+ return new Expand(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Expand"; }
+
+}
+
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
index 2a34ae53d5e..91ff5d9cdd8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gather.java
@@ -105,6 +105,9 @@ public class Gather extends IntermediateOperation {
private ExpressionNode createSliceExpression(List<Slice.DimensionValue<Reference>> dimensionValues, String referenceName) {
TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(referenceName));
+ if (dimensionValues.isEmpty()) {
+ return new TensorFunctionNode(inputIndices);
+ }
Slice<Reference> sliceIndices = new Slice<>(inputIndices, dimensionValues);
return new TensorFunctionNode(sliceIndices);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java
new file mode 100644
index 00000000000..3c5ddf48cfc
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/OnnxConstant.java
@@ -0,0 +1,91 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+
+public class OnnxConstant extends IntermediateOperation {
+
+ private final AttributeMap attributeMap;
+ private final Value value;
+
+ public OnnxConstant(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
+ this.attributeMap = attributeMap;
+ this.value = value();
+ setConstantValueFunction(type -> new TensorValue(this.value.asTensor()));
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ OrderedTensorType type;
+ if (value instanceof TensorValue) {
+ type = OrderedTensorType.fromSpec(value.type().toString()).rename(vespaName() + "_");
+ } else {
+ type = OrderedTensorType.fromDimensionList(TensorType.Value.DOUBLE, Collections.emptyList());
+ }
+ return type;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public Optional<Value> getConstantValue() {
+ return Optional.of(new TensorValue(value.asTensor().withType(type().get().type())));
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+ @Override
+ public OnnxConstant withInputs(List<IntermediateOperation> inputs) {
+ return new OnnxConstant(modelName(), name(), inputs, attributeMap);
+ }
+
+ @Override
+ public String operationName() { return "Constant"; }
+
+ @Override
+ public String toString() {
+ return "Constant(" + type + ")";
+ }
+
+ @Override
+ public String toFullString() {
+ return "\t" + type + ":\tConstant(" + type + ")";
+ }
+
+ private Value value() {
+ Optional<Value> value = attributeMap.get("value");
+ if (value.isEmpty()) {
+ value = attributeMap.get("value_float");
+ if (value.isEmpty()) {
+ value = attributeMap.get("value_int");
+ }
+ }
+ if (value.isEmpty()) {
+ throw new IllegalArgumentException("Node '" + name + "' of type " +
+ "constant has missing or non-supported 'value' attribute");
+ }
+ return value.get();
+ }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
new file mode 100644
index 00000000000..6df686cf910
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Range.java
@@ -0,0 +1,86 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class Range extends IntermediateOperation {
+
+ private double start;
+ private double limit;
+ private double delta;
+ private long elements;
+
+ public Range(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ private double getConstantInput(int index, String name) {
+ IntermediateOperation input = inputs.get(index);
+ if (input.getConstantValue().isEmpty()) {
+ throw new IllegalArgumentException("Range: " + name + " input must be a constant.");
+ }
+ Tensor value = input.getConstantValue().get().asTensor();
+ if ( ! input.getConstantValue().get().hasDouble()) {
+ throw new IllegalArgumentException("Range: " + name + " input must be a scalar.");
+ }
+ return value.asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if ( ! allInputTypesPresent(3)) return null;
+
+ start = getConstantInput(0, "start"); // must be constant because we need to know type
+ limit = getConstantInput(1, "limit");
+ delta = getConstantInput(2, "delta");
+ elements = (long) Math.ceil((limit - start) / delta);
+
+ OrderedTensorType type = new OrderedTensorType.Builder()
+ .add(TensorType.Dimension.indexed(vespaName(), elements))
+ .build();
+ return type;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if ( ! allInputTypesPresent(3)) return null;
+ String dimensionName = type().get().dimensionNames().get(0);
+ ExpressionNode startExpr = new ConstantNode(new DoubleValue(start));
+ ExpressionNode deltaExpr = new ConstantNode(new DoubleValue(delta));
+ ExpressionNode dimExpr = new EmbracedNode(new ReferenceNode(dimensionName));
+ ExpressionNode stepExpr = new ArithmeticNode(deltaExpr, ArithmeticOperator.MULTIPLY, dimExpr);
+ ExpressionNode addExpr = new ArithmeticNode(startExpr, ArithmeticOperator.PLUS, stepExpr);
+ TensorFunction function = Generate.bound(type.type(), wrapScalar(addExpr));
+ return function;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ addConstraintsFrom(type, renamer);
+ }
+
+ @Override
+ public Range withInputs(List<IntermediateOperation> inputs) {
+ return new Range(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Range"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
index 8696d0f1858..69283f10711 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java
@@ -78,14 +78,12 @@ public class Select extends IntermediateOperation {
List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions();
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
-
// These tensors should have the same dimension names
- renamer.addConstraint(aDim0, bDim0, DimensionRenamer.Constraint.equal(false), this);
- renamer.addConstraint(aDim1, bDim1, DimensionRenamer.Constraint.equal(false), this);
+ for (int i = 0; i < aDimensions.size(); ++i) {
+ String aDim = aDimensions.get(i).name();
+ String bDim = bDimensions.get(i).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer.Constraint.equal(false), this);
+ }
}
@Override
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index d5dff7fb1b7..7b9868d71f5 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
@@ -66,6 +67,9 @@ public class OnnxOperationsTestCase {
assertEval("leakyrelu", x, evaluate("max(0.01 * x, x)", x));
assertEval("leakyrelu", x, evaluate("max(0.001 * x, x)", x), createAttribute("alpha", 0.001f));
+ x = evaluate("tensor(d0[7]):[-40.0, -0.5, -0.1, 0.0, 0.1, 0.5, 40.0]");
+ assertEval("erf", x, evaluate("erf(x)", x));
+
x = evaluate("tensor(d0[3]):[0.01, 1.0, 10.0]");
assertEval("log", x, evaluate("log(x)", x));
assertEval("sqrt", x, evaluate("sqrt(x)", x));
@@ -405,9 +409,14 @@ public class OnnxOperationsTestCase {
@Test
public void testGather1() throws ParseException {
- // 1 dim input, 1 dim indices
+ // 1 dim input, 0 dim indices
Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
- Tensor y = evaluate("tensor(d0[3]):[0,2,4]");
+ Tensor y = evaluate("tensor():[0]");
+ assertEval("gather", x, y, evaluate("tensor():[1]"));
+
+ // 1 dim input, 1 dim indices
+ x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ y = evaluate("tensor(d0[3]):[0,2,4]");
assertEval("gather", x, y, evaluate("tensor(d0[3]):[1,3,5]"));
// 2 dim input, 1 dim indices - axis 0
@@ -533,6 +542,43 @@ public class OnnxOperationsTestCase {
assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2);
}
+ @Test
+ public void testRange11() throws ParseException {
+ Tensor start = evaluate("tensor():[3]");
+ Tensor limit = evaluate("tensor():[9]");
+ Tensor delta = evaluate("tensor():[3]");
+ assertEval("range", start, limit, delta, evaluate("tensor(d0[2]):[3,6]"));
+
+ start = evaluate("tensor():[10]");
+ limit = evaluate("tensor():[4]");
+ delta = evaluate("tensor():[-2]");
+ assertEval("range", start, limit, delta, evaluate("tensor(d0[3]):[10,8,6]"));
+ assertEval("range", start, limit, delta, evaluate("tensor(d0[3]):[10,8,6]"));
+ }
+
+ @Test
+ public void testConstant12() throws ParseException {
+ assertEval("constant", evaluate("tensor(d0[3]):[1,2,3]"), createAttribute("value", evaluate("tensor(d0[3]):[1,2,3]")));
+ assertEval("constant", evaluate("tensor<float>():[313.0]"), createAttribute("value_float", 313.0f));
+ assertEval("constant", evaluate("tensor():[42]"), createAttribute("value_int", 42));
+ }
+
+ @Test
+ public void testConstantOfShape9() throws ParseException {
+ Tensor shape = evaluate("tensor(d0[3]):[1,2,3]");
+ assertEval("constantofshape", shape, evaluate("tensor(d0[1],d1[2],d2[3]):[0,0,0,0,0,0]"));
+ assertEval("constantofshape", shape, evaluate("tensor<float>(d0[1],d1[2],d2[3]):[1,1,1,1,1,1]"), createAttribute("value", evaluate("tensor<float>(d0[1]):[1]")));
+ }
+
+ @Test
+ public void testExpand8() throws ParseException {
+ Tensor input = evaluate("tensor(d0[3],d1[1]):[1,2,3]");
+ Tensor shape = evaluate("tensor(d0[2]):[3,4]");
+ assertEval("expand", input, shape, evaluate("tensor(d0[3],d1[4]):[1,1,1,1,2,2,2,2,3,3,3,3]"));
+ shape = evaluate("tensor(d0[3]):[2,1,4]");
+ assertEval("expand", input, shape, evaluate("tensor(d0[2],d1[3],d2[4]):[1,1,1,1,2,2,2,2,3,3,3,3,1,1,1,1,2,2,2,2,3,3,3,3]"));
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -558,6 +604,10 @@ public class OnnxOperationsTestCase {
return renameToStandardType(op, tensor);
}
+ private void assertEval(String opName, Tensor expected, AttributeConverter attr) {
+ assertEval(opName, null, null, null, null, null, expected, attr, 0);
+ }
+
private void assertEval(String opName, Tensor x, Tensor expected) {
assertEval(opName, x, null, null, null, null, expected, null, 0);
}
@@ -667,6 +717,10 @@ public class OnnxOperationsTestCase {
return new Attributes().attr(name, vals).build();
}
+ static AttributeConverter createAttribute(String name, Tensor val) {
+ return new Attributes().attr(name, val).build();
+ }
+
static Attributes createAttributes() {
return new Attributes();
}
@@ -700,9 +754,14 @@ public class OnnxOperationsTestCase {
Attributes attr(String name, Tensor tensor) {
Onnx.TensorProto.Builder builder = Onnx.TensorProto.newBuilder();
- builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);;
tensor.type().dimensions().forEach(d -> builder.addDims(d.size().get()));
- tensor.valueIterator().forEachRemaining(builder::addDoubleData);
+ if (tensor.type().valueType() == TensorType.Value.FLOAT) {
+ builder.setDataType(Onnx.TensorProto.DataType.FLOAT);
+ tensor.valueIterator().forEachRemaining(d -> builder.addFloatData(d.floatValue()));
+ } else {
+ builder.setDataType(Onnx.TensorProto.DataType.DOUBLE);
+ tensor.valueIterator().forEachRemaining(builder::addDoubleData);
+ }
Onnx.TensorProto val = builder.build();
nodeBuilder.addAttribute(Onnx.AttributeProto.newBuilder().setName(name).setType(TENSOR).setT(val).build());
return this;
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 6bce791914c..c22d906e2b2 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1024,6 +1024,7 @@
"public static final int SQRT",
"public static final int TAN",
"public static final int TANH",
+ "public static final int ERF",
"public static final int ATAN2",
"public static final int FMOD",
"public static final int LDEXP",
@@ -1373,6 +1374,7 @@
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function sqrt",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function tan",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function tanh",
+ "public static final enum com.yahoo.searchlib.rankingexpression.rule.Function erf",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function atan2",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function fmod",
"public static final enum com.yahoo.searchlib.rankingexpression.rule.Function ldexp",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
index c3c1c371a68..99afb3b38d0 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;
+import com.yahoo.tensor.functions.ScalarFunctions;
+
import java.io.Serializable;
import static java.lang.Math.*;
@@ -36,6 +38,7 @@ public enum Function implements Serializable {
sqrt { public double evaluate(double x, double y) { return sqrt(x); } },
tan { public double evaluate(double x, double y) { return tan(x); } },
tanh { public double evaluate(double x, double y) { return tanh(x); } },
+ erf { public double evaluate(double x, double y) { return ScalarFunctions.Erf.erf(x); } },
atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } },
fmod(2) { public double evaluate(double x, double y) { return x % y; } },
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 8aa10bf7b34..5f27bbcbeee 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -115,6 +115,7 @@ TOKEN :
<SQRT: "sqrt"> |
<TAN: "tan"> |
<TANH: "tanh"> |
+ <ERF: "erf"> |
<ATAN2: "atan2"> |
<FMOD: "fmod"> |
@@ -727,7 +728,8 @@ Function unaryFunctionName() : { }
<SQUARE> { return Function.square; } |
<SQRT> { return Function.sqrt; } |
<TAN> { return Function.tan; } |
- <TANH> { return Function.tanh; }
+ <TANH> { return Function.tanh; } |
+ <ERF> { return Function.erf; }
}
Function binaryFunctionName() : { }
diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
index 7a4c6c9e56a..0a8b59c7d7e 100644
--- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
+++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp
@@ -33,8 +33,8 @@ using search::AttributeGuard;
using search::AttributeVector;
using search::attribute::DistanceMetric;
using search::attribute::HnswIndexParams;
-using search::queryeval::NearestNeighborBlueprint;
using search::queryeval::GlobalFilter;
+using search::queryeval::NearestNeighborBlueprint;
using search::tensor::DefaultNearestNeighborIndexFactory;
using search::tensor::DenseTensorAttribute;
using search::tensor::DocVectorAccess;
@@ -44,6 +44,7 @@ using search::tensor::HnswNode;
using search::tensor::NearestNeighborIndex;
using search::tensor::NearestNeighborIndexFactory;
using search::tensor::NearestNeighborIndexSaver;
+using search::tensor::PrepareResult;
using search::tensor::TensorAttribute;
using vespalib::eval::TensorSpec;
using vespalib::eval::ValueType;
@@ -97,6 +98,12 @@ public:
}
};
+class MockPrepareResult : public PrepareResult {
+public:
+ uint32_t docid;
+ MockPrepareResult(uint32_t docid_in) : docid(docid_in) {}
+};
+
class MockNearestNeighborIndex : public NearestNeighborIndex {
private:
using Entry = std::pair<uint32_t, DoubleVector>;
@@ -105,6 +112,8 @@ private:
const DocVectorAccess& _vectors;
EntryVector _adds;
EntryVector _removes;
+ mutable EntryVector _prepare_adds;
+ EntryVector _complete_adds;
generation_t _transfer_gen;
generation_t _trim_gen;
mutable size_t _memory_usage_cnt;
@@ -115,6 +124,8 @@ public:
: _vectors(vectors),
_adds(),
_removes(),
+ _prepare_adds(),
+ _complete_adds(),
_transfer_gen(std::numeric_limits<generation_t>::max()),
_trim_gen(std::numeric_limits<generation_t>::max()),
_memory_usage_cnt(0),
@@ -124,6 +135,8 @@ public:
void clear() {
_adds.clear();
_removes.clear();
+ _prepare_adds.clear();
+ _complete_adds.clear();
}
int get_index_value() const {
return _index_value;
@@ -134,10 +147,13 @@ public:
void expect_empty_add() const {
EXPECT_TRUE(_adds.empty());
}
+ void expect_entry(uint32_t exp_docid, const DoubleVector& exp_vector, const EntryVector& entries) const {
+ EXPECT_EQUAL(1u, entries.size());
+ EXPECT_EQUAL(exp_docid, entries.back().first);
+ EXPECT_EQUAL(exp_vector, entries.back().second);
+ }
void expect_add(uint32_t exp_docid, const DoubleVector& exp_vector) const {
- EXPECT_EQUAL(1u, _adds.size());
- EXPECT_EQUAL(exp_docid, _adds.back().first);
- EXPECT_EQUAL(exp_vector, _adds.back().second);
+ expect_entry(exp_docid, exp_vector, _adds);
}
void expect_adds(const EntryVector &exp_adds) const {
EXPECT_EQUAL(exp_adds, _adds);
@@ -146,9 +162,13 @@ public:
EXPECT_TRUE(_removes.empty());
}
void expect_remove(uint32_t exp_docid, const DoubleVector& exp_vector) const {
- EXPECT_EQUAL(1u, _removes.size());
- EXPECT_EQUAL(exp_docid, _removes.back().first);
- EXPECT_EQUAL(exp_vector, _removes.back().second);
+ expect_entry(exp_docid, exp_vector, _removes);
+ }
+ void expect_prepare_add(uint32_t exp_docid, const DoubleVector& exp_vector) const {
+ expect_entry(exp_docid, exp_vector, _prepare_adds);
+ }
+ void expect_complete_add(uint32_t exp_docid, const DoubleVector& exp_vector) const {
+ expect_entry(exp_docid, exp_vector, _complete_adds);
}
generation_t get_transfer_gen() const { return _transfer_gen; }
generation_t get_trim_gen() const { return _trim_gen; }
@@ -158,14 +178,21 @@ public:
auto vector = _vectors.get_vector(docid).typify<double>();
_adds.emplace_back(docid, DoubleVector(vector.begin(), vector.end()));
}
- std::unique_ptr<search::tensor::PrepareResult> prepare_add_document(uint32_t,
- vespalib::tensor::TypedCells,
- vespalib::GenerationHandler::Guard) const override {
- return std::unique_ptr<search::tensor::PrepareResult>();
+ std::unique_ptr<PrepareResult> prepare_add_document(uint32_t docid,
+ vespalib::tensor::TypedCells vector,
+ vespalib::GenerationHandler::Guard guard) const override {
+ (void) guard;
+ auto d_vector = vector.typify<double>();
+ _prepare_adds.emplace_back(docid, DoubleVector(d_vector.begin(), d_vector.end()));
+ return std::make_unique<MockPrepareResult>(docid);
}
void complete_add_document(uint32_t docid,
- std::unique_ptr<search::tensor::PrepareResult>) override {
- add_document(docid);
+ std::unique_ptr<PrepareResult> prepare_result) override {
+ auto* mock_result = dynamic_cast<MockPrepareResult*>(prepare_result.get());
+ assert(mock_result);
+ EXPECT_EQUAL(docid, mock_result->docid);
+ auto vector = _vectors.get_vector(docid).typify<double>();
+ _complete_adds.emplace_back(docid, DoubleVector(vector.begin(), vector.end()));
}
void remove_document(uint32_t docid) override {
auto vector = _vectors.get_vector(docid).typify<double>();
@@ -342,6 +369,16 @@ struct Fixture {
set_tensor_internal(docid, *createTensor(spec));
}
+ std::unique_ptr<PrepareResult> prepare_set_tensor(uint32_t docid, const TensorSpec& spec) const {
+ return _tensorAttr->prepare_set_tensor(docid, *createTensor(spec));
+ }
+
+ void complete_set_tensor(uint32_t docid, const TensorSpec& spec, std::unique_ptr<PrepareResult> prepare_result) {
+ ensureSpace(docid);
+ _tensorAttr->complete_set_tensor(docid, *createTensor(spec), std::move(prepare_result));
+ _attr->commit();
+ }
+
void set_empty_tensor(uint32_t docid) {
set_tensor_internal(docid, *_tensorAttr->getEmptyTensor());
}
@@ -687,6 +724,30 @@ TEST_F("setTensor() updates nearest neighbor index", DenseTensorAttributeMockInd
index.expect_add(1, {7, 9});
}
+TEST_F("nearest neighbor index can be updated in two phases", DenseTensorAttributeMockIndex)
+{
+ auto& index = f.mock_index();
+ {
+ auto vec_a = vec_2d(3, 5);
+ auto prepare_result = f.prepare_set_tensor(1, vec_a);
+ index.expect_prepare_add(1, {3, 5});
+ f.complete_set_tensor(1, vec_a, std::move(prepare_result));
+ f.assertGetTensor(vec_a, 1);
+ index.expect_complete_add(1, {3, 5});
+ }
+ index.clear();
+ {
+ // Replaces previous value.
+ auto vec_b = vec_2d(7, 9);
+ auto prepare_result = f.prepare_set_tensor(1, vec_b);
+ index.expect_prepare_add(1, {7, 9});
+ f.complete_set_tensor(1, vec_b, std::move(prepare_result));
+ index.expect_remove(1, {3, 5});
+ f.assertGetTensor(vec_b, 1);
+ index.expect_complete_add(1, {7, 9});
+ }
+}
+
TEST_F("clearDoc() updates nearest neighbor index", DenseTensorAttributeMockIndex)
{
auto& index = f.mock_index();
diff --git a/searchlib/src/tests/hitcollector/hitcollector_test.cpp b/searchlib/src/tests/hitcollector/hitcollector_test.cpp
index 31a24d2a8f1..2274314c7da 100644
--- a/searchlib/src/tests/hitcollector/hitcollector_test.cpp
+++ b/searchlib/src/tests/hitcollector/hitcollector_test.cpp
@@ -55,7 +55,7 @@ void checkResult(const ResultSet & rs, const std::vector<RankedHit> & exp)
for (uint32_t i = 0; i < exp.size(); ++i) {
EXPECT_EQUAL(rh[i]._docId, exp[i]._docId);
- EXPECT_EQUAL(rh[i]._rankValue, exp[i]._rankValue);
+ EXPECT_EQUAL(rh[i]._rankValue + 1.0, exp[i]._rankValue + 1.0);
}
} else {
ASSERT_TRUE(rs.getArray() == nullptr);
@@ -328,7 +328,7 @@ TEST("testScaling") {
finalScores[3] = 300;
finalScores[4] = 400;
- testScaling(initScores, std::move(finalScores), exp);
+ TEST_DO(testScaling(initScores, std::move(finalScores), exp));
}
{ // scale down and adjust up
exp[0]._rankValue = 200; // scaled
@@ -342,7 +342,7 @@ TEST("testScaling") {
finalScores[3] = 500;
finalScores[4] = 600;
- testScaling(initScores, std::move(finalScores), exp);
+ TEST_DO(testScaling(initScores, std::move(finalScores), exp));
}
{ // scale up and adjust down
@@ -357,7 +357,7 @@ TEST("testScaling") {
finalScores[3] = 3250;
finalScores[4] = 4500;
- testScaling(initScores, std::move(finalScores), exp);
+ TEST_DO(testScaling(initScores, std::move(finalScores), exp));
}
{ // minimal scale (second phase range = 0 (4 - 4) -> 1)
exp[0]._rankValue = 1; // scaled
@@ -371,7 +371,7 @@ TEST("testScaling") {
finalScores[3] = 4;
finalScores[4] = 4;
- testScaling(initScores, std::move(finalScores), exp);
+ TEST_DO(testScaling(initScores, std::move(finalScores), exp));
}
{ // minimal scale (first phase range = 0 (4000 - 4000) -> 1)
std::vector<feature_t> is(initScores);
@@ -387,7 +387,7 @@ TEST("testScaling") {
finalScores[3] = 400;
finalScores[4] = 500;
- testScaling(is, std::move(finalScores), exp);
+ TEST_DO(testScaling(is, std::move(finalScores), exp));
}
}
diff --git a/searchlib/src/tests/tensor/hnsw_index/.gitignore b/searchlib/src/tests/tensor/hnsw_index/.gitignore
new file mode 100644
index 00000000000..bc9bc27160b
--- /dev/null
+++ b/searchlib/src/tests/tensor/hnsw_index/.gitignore
@@ -0,0 +1 @@
+/mt_stress_hnsw_app
diff --git a/searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt b/searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt
index 04c7312a63f..b6a87502fdf 100644
--- a/searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt
+++ b/searchlib/src/tests/tensor/hnsw_index/CMakeLists.txt
@@ -7,3 +7,11 @@ vespa_add_executable(searchlib_hnsw_index_test_app TEST
gtest
)
vespa_add_test(NAME searchlib_hnsw_index_test_app COMMAND searchlib_hnsw_index_test_app)
+
+vespa_add_executable(mt_stress_hnsw_app TEST
+ SOURCES
+ stress_hnsw_mt.cpp
+ DEPENDS
+ searchlib
+ gtest
+)
diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
index 33ea0f2df5b..7dc0efc106d 100644
--- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp
@@ -616,8 +616,8 @@ TEST_F(TwoPhaseTest, two_phase_add)
complete_add(7, std::move(up));
// 1 filtered out because it was removed
- // TODO: 5 filtered out because it was updated
- expect_levels(7, {{2}, {4,5}});
+ // 5 filtered out because it was updated
+ expect_levels(7, {{2}, {4}});
}
diff --git a/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
new file mode 100644
index 00000000000..4dec9550f6f
--- /dev/null
+++ b/searchlib/src/tests/tensor/hnsw_index/stress_hnsw_mt.cpp
@@ -0,0 +1,348 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <sys/types.h>
+#include <sys/stat.h>
+#include <fcntl.h>
+#include <stdio.h>
+#include <unistd.h>
+#include <chrono>
+#include <cstdlib>
+#include <future>
+#include <vector>
+
+#include <vespa/eval/tensor/dense/typed_cells.h>
+#include <vespa/searchlib/common/bitvector.h>
+#include <vespa/searchlib/tensor/distance_functions.h>
+#include <vespa/searchlib/tensor/doc_vector_access.h>
+#include <vespa/searchlib/tensor/hnsw_index.h>
+#include <vespa/searchlib/tensor/inv_log_level_generator.h>
+#include <vespa/searchlib/tensor/random_level_generator.h>
+#include <vespa/vespalib/data/input.h>
+#include <vespa/vespalib/data/memory_input.h>
+#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/gtest/gtest.h>
+#include <vespa/vespalib/util/blockingthreadstackexecutor.h>
+#include <vespa/vespalib/util/generationhandler.h>
+#include <vespa/vespalib/util/lambdatask.h>
+
+#include <vespa/log/log.h>
+LOG_SETUP("stress_hnsw_mt");
+
+using namespace search::tensor;
+using namespace vespalib::slime;
+using search::BitVector;
+using vespalib::GenerationHandler;
+using vespalib::MemoryUsage;
+using vespalib::Slime;
+
+#define NUM_DIMS 128
+#define NUM_POSSIBLE_V 1000000
+#define NUM_POSSIBLE_DOCS 30000
+#define NUM_OPS 1000000
+
+class RndGen {
+private:
+ std::mt19937_64 urng;
+ std::uniform_real_distribution<double> uf;
+public:
+ RndGen() : urng(0x1234deadbeef5678uLL), uf(0.0, 1.0) {}
+
+ double nextUniform() {
+ return uf(urng);
+ }
+};
+
+using ConstVectorRef = vespalib::ConstArrayRef<float>;
+
+struct MallocPointVector {
+ float v[NUM_DIMS];
+ operator ConstVectorRef() const { return ConstVectorRef(v, NUM_DIMS); }
+};
+static MallocPointVector *aligned_alloc_pv(size_t num) {
+ size_t num_bytes = num * sizeof(MallocPointVector);
+ double mega_bytes = num_bytes / (1024.0*1024.0);
+ fprintf(stderr, "allocate %.2f MB of vectors\n", mega_bytes);
+ char *mem = (char *)malloc(num_bytes + 512);
+ mem += 512;
+ size_t val = (size_t)mem;
+ size_t unalign = val % 512;
+ mem -= unalign;
+ return reinterpret_cast<MallocPointVector *>(mem);
+}
+
+void read_vector_file(MallocPointVector *p) {
+ std::string data_set = "sift";
+ std::string data_dir = ".";
+ char *home = getenv("HOME");
+ if (home) {
+ data_dir = home;
+ data_dir += "/" + data_set;
+ }
+ std::string fn = data_dir + "/" + data_set + "_base.fvecs";
+ int fd = open(fn.c_str(), O_RDONLY);
+ if (fd < 0) {
+ perror(fn.c_str());
+ exit(1);
+ }
+ int d;
+ size_t rv;
+ fprintf(stderr, "reading %u vectors from %s\n", NUM_POSSIBLE_V, fn.c_str());
+ for (uint32_t i = 0; i < NUM_POSSIBLE_V; ++i) {
+ rv = read(fd, &d, 4);
+ ASSERT_EQ(rv, 4u);
+ ASSERT_EQ(d, NUM_DIMS);
+ rv = read(fd, &p[i].v, NUM_DIMS*sizeof(float));
+ ASSERT_EQ(rv, sizeof(MallocPointVector));
+ }
+ close(fd);
+ fprintf(stderr, "reading %u vectors OK\n", NUM_POSSIBLE_V);
+}
+
+class MyDocVectorStore : public DocVectorAccess {
+private:
+ MallocPointVector *_vectors;
+public:
+ MyDocVectorStore() {
+ _vectors = aligned_alloc_pv(NUM_POSSIBLE_DOCS);
+ }
+ MyDocVectorStore& set(uint32_t docid, ConstVectorRef vec) {
+ assert(docid < NUM_POSSIBLE_DOCS);
+ memcpy(&_vectors[docid], vec.cbegin(), sizeof(MallocPointVector));
+ return *this;
+ }
+ vespalib::tensor::TypedCells get_vector(uint32_t docid) const override {
+ assert(docid < NUM_POSSIBLE_DOCS);
+ ConstVectorRef ref(_vectors[docid]);
+ return vespalib::tensor::TypedCells(ref);
+ }
+};
+
+using FloatSqEuclideanDistance = SquaredEuclideanDistance<float>;
+using HnswIndexUP = std::unique_ptr<HnswIndex>;
+
+class Stressor : public ::testing::Test {
+private:
+ struct LoadedVectors {
+ MallocPointVector *pv_storage;
+ void load() {
+ pv_storage = aligned_alloc_pv(size());
+ read_vector_file(pv_storage);
+ }
+ size_t size() const { return NUM_POSSIBLE_V; }
+ vespalib::ConstArrayRef<float> operator[] (size_t i) {
+ return pv_storage[i];
+ }
+ } loaded_vectors;
+public:
+ BitVector::UP in_progress;
+ std::mutex in_progress_lock;
+ BitVector::UP existing_ids;
+ RndGen rng;
+ MyDocVectorStore vectors;
+ GenerationHandler gen_handler;
+ HnswIndexUP index;
+ vespalib::BlockingThreadStackExecutor multi_prepare_workers;
+ vespalib::BlockingThreadStackExecutor write_thread;
+
+ using PrepUP = std::unique_ptr<PrepareResult>;
+ using ReadGuard = GenerationHandler::Guard;
+ using PrepareFuture = std::future<PrepUP>;
+
+ // union of data required by tasks
+ struct TaskBase : vespalib::Executor::Task {
+ Stressor &parent;
+ uint32_t docid;
+ ConstVectorRef vec;
+ PrepareFuture prepare_future;
+ ReadGuard read_guard;
+
+ TaskBase(Stressor &p, uint32_t d, ConstVectorRef v, PrepareFuture f, ReadGuard g)
+ : parent(p), docid(d), vec(v), prepare_future(std::move(f)), read_guard(g)
+ {}
+ TaskBase(Stressor &p, uint32_t d, ConstVectorRef v, ReadGuard g) // prepare add
+ : TaskBase(p, d, v, PrepareFuture(), g) {}
+ TaskBase(Stressor &p, uint32_t d, ConstVectorRef v, PrepareFuture r) // complete add+update
+ : TaskBase(p, d, v, std::move(r), ReadGuard()) {}
+ TaskBase(Stressor &p, uint32_t d) // complete remove
+ : TaskBase(p, d, ConstVectorRef(), PrepareFuture(), ReadGuard()) {}
+
+ ~TaskBase() {}
+ };
+
+ struct PrepareAddTask : TaskBase {
+ using TaskBase::TaskBase;
+ std::promise<PrepUP> result_promise;
+ auto get_result_future() {
+ return result_promise.get_future();
+ }
+ void run() override {
+ auto v = vespalib::tensor::TypedCells(vec);
+ auto up = parent.index->prepare_add_document(docid, v, read_guard);
+ result_promise.set_value(std::move(up));
+ }
+ };
+
+ struct CompleteAddTask : TaskBase {
+ using TaskBase::TaskBase;
+ void run() override {
+ auto prepare_result = prepare_future.get();
+ parent.vectors.set(docid, vec);
+ parent.index->complete_add_document(docid, std::move(prepare_result));
+ parent.existing_ids->setBit(docid);
+ parent.commit(docid);
+ }
+ };
+
+ struct CompleteRemoveTask : TaskBase {
+ using TaskBase::TaskBase;
+ void run() override {
+ parent.index->remove_document(docid);
+ parent.existing_ids->clearBit(docid);
+ parent.commit(docid);
+ }
+ };
+
+ struct CompleteUpdateTask : TaskBase {
+ using TaskBase::TaskBase;
+ void run() override {
+ auto prepare_result = prepare_future.get();
+ parent.index->remove_document(docid);
+ parent.vectors.set(docid, vec);
+ parent.index->complete_add_document(docid, std::move(prepare_result));
+ EXPECT_EQ(parent.existing_ids->testBit(docid), true);
+ parent.commit(docid);
+ }
+ };
+
+ Stressor()
+ : loaded_vectors(),
+ in_progress(BitVector::create(NUM_POSSIBLE_DOCS)),
+ existing_ids(BitVector::create(NUM_POSSIBLE_DOCS)),
+ rng(),
+ vectors(),
+ gen_handler(),
+ index(),
+ multi_prepare_workers(10, 128*1024, 50),
+ write_thread(1, 128*1024, 500)
+ {
+ loaded_vectors.load();
+ }
+
+ ~Stressor() {}
+
+ void init() {
+ uint32_t m = 16;
+ index = std::make_unique<HnswIndex>(vectors, std::make_unique<FloatSqEuclideanDistance>(),
+ std::make_unique<InvLogLevelGenerator>(m),
+ HnswIndex::Config(2*m, m, 200, true));
+ }
+ size_t get_rnd(size_t size) {
+ return rng.nextUniform() * size;
+ }
+ void add_document(uint32_t docid) {
+ size_t vec_num = get_rnd(loaded_vectors.size());
+ ConstVectorRef vec = loaded_vectors[vec_num];
+ auto guard = take_read_guard();
+ auto prepare_task = std::make_unique<PrepareAddTask>(*this, docid, vec, guard);
+ auto complete_task = std::make_unique<CompleteAddTask>(*this, docid, vec, prepare_task->get_result_future());
+ auto r = multi_prepare_workers.execute(std::move(prepare_task));
+ ASSERT_EQ(r.get(), nullptr);
+ r = write_thread.execute(std::move(complete_task));
+ ASSERT_EQ(r.get(), nullptr);
+ }
+ void remove_document(uint32_t docid) {
+ auto task = std::make_unique<CompleteRemoveTask>(*this, docid);
+ auto r = write_thread.execute(std::move(task));
+ ASSERT_EQ(r.get(), nullptr);
+ }
+ void update_document(uint32_t docid) {
+ size_t vec_num = get_rnd(loaded_vectors.size());
+ ConstVectorRef vec = loaded_vectors[vec_num];
+ auto guard = take_read_guard();
+ auto prepare_task = std::make_unique<PrepareAddTask>(*this, docid, vec, guard);
+ auto complete_task = std::make_unique<CompleteUpdateTask>(*this, docid, vec, prepare_task->get_result_future());
+ auto r = multi_prepare_workers.execute(std::move(prepare_task));
+ ASSERT_EQ(r.get(), nullptr);
+ r = write_thread.execute(std::move(complete_task));
+ ASSERT_EQ(r.get(), nullptr);
+ }
+ void commit(uint32_t docid) {
+ index->transfer_hold_lists(gen_handler.getCurrentGeneration());
+ gen_handler.incGeneration();
+ gen_handler.updateFirstUsedGeneration();
+ index->trim_hold_lists(gen_handler.getFirstUsedGeneration());
+ std::lock_guard<std::mutex> guard(in_progress_lock);
+ in_progress->clearBit(docid);
+ // printf("commit: %u\n", docid);
+ }
+ void gen_operation() {
+ uint32_t docid = get_rnd(NUM_POSSIBLE_DOCS);
+ {
+ std::lock_guard<std::mutex> guard(in_progress_lock);
+ while (in_progress->testBit(docid)) {
+ docid = get_rnd(NUM_POSSIBLE_DOCS);
+ }
+ in_progress->setBit(docid);
+ }
+ if (existing_ids->testBit(docid)) {
+ if (get_rnd(100) < 70) {
+ // printf("start remove op: %u\n", docid);
+ remove_document(docid);
+ } else {
+ // printf("start update op: %u\n", docid);
+ update_document(docid);
+ }
+ } else {
+ // printf("start add op: %u\n", docid);
+ add_document(docid);
+ }
+ }
+ GenerationHandler::Guard take_read_guard() {
+ return gen_handler.takeGuard();
+ }
+ MemoryUsage memory_usage() const {
+ return index->memory_usage();
+ }
+ uint32_t count_in_progress() {
+ std::lock_guard<std::mutex> guard(in_progress_lock);
+ in_progress->invalidateCachedCount();
+ return in_progress->countTrueBits();
+ }
+ std::string json_state() {
+ Slime actualSlime;
+ SlimeInserter inserter(actualSlime);
+ index->get_state(inserter);
+ vespalib::SimpleBuffer buf;
+ vespalib::slime::JsonFormat::encode(actualSlime, buf, false);
+ return buf.get().make_string();
+ }
+};
+
+
+TEST_F(Stressor, stress)
+{
+ init();
+ for (int i = 0; i < NUM_OPS; ++i) {
+ gen_operation();
+ if (i % 1000 == 0) {
+ uint32_t cnt = count_in_progress();
+ fprintf(stderr, "generating operations %d / %d; in progress: %u ops\n",
+ i, NUM_OPS, cnt);
+ auto r = write_thread.execute(vespalib::makeLambdaTask([&]() {
+ EXPECT_TRUE(index->check_link_symmetry());
+ }));
+ EXPECT_EQ(r.get(), nullptr);
+ }
+ }
+ fprintf(stderr, "waiting for queued operations...\n");
+ multi_prepare_workers.sync();
+ write_thread.sync();
+ EXPECT_EQ(count_in_progress(), 0);
+ EXPECT_TRUE(index->check_link_symmetry());
+ fprintf(stderr, "HNSW index state after test:\n%s\n", json_state().c_str());
+ existing_ids->invalidateCachedCount();
+ fprintf(stderr, "Expected valid nodes: %u\n", existing_ids->countTrueBits());
+ fprintf(stderr, "all done.\n");
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp
index bcc886fccad..2db6437664e 100644
--- a/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp
+++ b/searchlib/src/tests/tensor/hnsw_saver/hnsw_save_load_test.cpp
@@ -37,7 +37,7 @@ using V = std::vector<uint32_t>;
void populate(HnswGraph &graph) {
// no 0
graph.make_node_for_document(1, 1);
- graph.make_node_for_document(2, 2);
+ auto er = graph.make_node_for_document(2, 2);
// no 3
graph.make_node_for_document(4, 2);
graph.make_node_for_document(5, 0);
@@ -49,7 +49,7 @@ void populate(HnswGraph &graph) {
graph.set_link_array(6, 0, V{1, 2, 4});
graph.set_link_array(2, 1, V{4});
graph.set_link_array(4, 1, V{2});
- graph.set_entry_node({2, 1});
+ graph.set_entry_node({2, er, 1});
}
void modify(HnswGraph &graph) {
@@ -63,7 +63,7 @@ void modify(HnswGraph &graph) {
graph.set_link_array(4, 1, V{7});
graph.set_link_array(7, 1, V{4});
- graph.set_entry_node({4, 1});
+ graph.set_entry_node({4, graph.get_node_ref(4), 1});
}
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
index c9ed4039655..76533839de7 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.cpp
@@ -5,6 +5,7 @@
#include "nearest_neighbor_index.h"
#include "nearest_neighbor_index_saver.h"
#include "tensor_attribute.hpp"
+#include <vespa/eval/tensor/dense/dense_tensor_view.h>
#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/fastlib/io/bufferedfile.h>
@@ -18,6 +19,7 @@ LOG_SETUP(".searchlib.tensor.dense_tensor_attribute");
using search::attribute::LoadUtils;
using vespalib::eval::ValueType;
using vespalib::slime::ObjectInserter;
+using vespalib::tensor::DenseTensorView;
using vespalib::tensor::MutableDenseTensorView;
using vespalib::tensor::Tensor;
@@ -77,6 +79,15 @@ can_use_index_save_file(const search::attribute::Config &config, const search::a
}
void
+DenseTensorAttribute::internal_set_tensor(DocId docid, const Tensor& tensor)
+{
+ checkTensorType(tensor);
+ consider_remove_from_index(docid);
+ EntryRef ref = _denseTensorStore.setTensor(tensor);
+ setTensorRef(docid, ref);
+}
+
+void
DenseTensorAttribute::consider_remove_from_index(DocId docid)
{
if (_index && _refVector[docid].valid()) {
@@ -126,15 +137,32 @@ DenseTensorAttribute::clearDoc(DocId docId)
void
DenseTensorAttribute::setTensor(DocId docId, const Tensor &tensor)
{
- checkTensorType(tensor);
- consider_remove_from_index(docId);
- EntryRef ref = _denseTensorStore.setTensor(tensor);
- setTensorRef(docId, ref);
+ internal_set_tensor(docId, tensor);
if (_index) {
_index->add_document(docId);
}
}
+std::unique_ptr<PrepareResult>
+DenseTensorAttribute::prepare_set_tensor(DocId docid, const Tensor& tensor) const
+{
+ if (_index) {
+ const auto* view = dynamic_cast<const DenseTensorView*>(&tensor);
+ assert(view);
+ return _index->prepare_add_document(docid, view->cellsRef(), getGenerationHandler().takeGuard());
+ }
+ return std::unique_ptr<PrepareResult>();
+}
+
+void
+DenseTensorAttribute::complete_set_tensor(DocId docid, const Tensor& tensor,
+ std::unique_ptr<PrepareResult> prepare_result)
+{
+ internal_set_tensor(docid, tensor);
+ if (_index) {
+ _index->complete_add_document(docid, std::move(prepare_result));
+ }
+}
std::unique_ptr<Tensor>
DenseTensorAttribute::getTensor(DocId docId) const
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h
index f0383627ea2..7fd06357114 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_attribute.h
@@ -23,6 +23,7 @@ private:
DenseTensorStore _denseTensorStore;
std::unique_ptr<NearestNeighborIndex> _index;
+ void internal_set_tensor(DocId docid, const Tensor& tensor);
void consider_remove_from_index(DocId docid);
vespalib::MemoryUsage memory_usage() const override;
@@ -33,6 +34,8 @@ public:
// Implements AttributeVector and ITensorAttribute
uint32_t clearDoc(DocId docId) override;
void setTensor(DocId docId, const Tensor &tensor) override;
+ std::unique_ptr<PrepareResult> prepare_set_tensor(DocId docid, const Tensor& tensor) const override;
+ void complete_set_tensor(DocId docid, const Tensor& tensor, std::unique_ptr<PrepareResult> prepare_result) override;
std::unique_ptr<Tensor> getTensor(DocId docId) const override;
void getTensor(DocId docId, vespalib::tensor::MutableDenseTensorView &tensor) const override;
bool onLoad() override;
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp
index 37e3ea1adbd..564676a2d44 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_graph.cpp
@@ -13,13 +13,14 @@ HnswGraph::HnswGraph()
links(HnswIndex::make_default_link_store_config()),
entry_docid_and_level()
{
+ node_refs.ensure_size(1, AtomicEntryRef());
EntryNode entry;
set_entry_node(entry);
}
HnswGraph::~HnswGraph() {}
-void
+HnswGraph::NodeRef
HnswGraph::make_node_for_document(uint32_t docid, uint32_t num_levels)
{
node_refs.ensure_size(docid + 1, AtomicEntryRef());
@@ -29,15 +30,23 @@ HnswGraph::make_node_for_document(uint32_t docid, uint32_t num_levels)
vespalib::Array<AtomicEntryRef> levels(num_levels, AtomicEntryRef());
auto node_ref = nodes.add(levels);
node_refs[docid].store_release(node_ref);
+ return node_ref;
}
void
HnswGraph::remove_node_for_document(uint32_t docid)
{
auto node_ref = node_refs[docid].load_acquire();
- nodes.remove(node_ref);
+ assert(node_ref.valid());
+ auto levels = nodes.get(node_ref);
vespalib::datastore::EntryRef invalid;
node_refs[docid].store_release(invalid);
+ // Ensure data referenced through the old ref can be recycled:
+ nodes.remove(node_ref);
+ for (size_t i = 0; i < levels.size(); ++i) {
+ auto old_links_ref = levels[i].load_acquire();
+ links.remove(old_links_ref);
+ }
}
void
@@ -47,6 +56,7 @@ HnswGraph::set_link_array(uint32_t docid, uint32_t level, const LinkArrayRef& ne
auto node_ref = node_refs[docid].load_acquire();
assert(node_ref.valid());
auto levels = nodes.get_writable(node_ref);
+ assert(level < levels.size());
auto old_links_ref = levels[level].load_acquire();
levels[level].store_release(new_links_ref);
links.remove(old_links_ref);
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_graph.h b/searchlib/src/vespa/searchlib/tensor/hnsw_graph.h
index 125692af627..8b40eb87bae 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_graph.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_graph.h
@@ -23,6 +23,7 @@ struct HnswGraph {
// Provides mapping from document id -> node reference.
// The reference is used to lookup the node data in NodeStore.
using NodeRefVector = vespalib::RcuVector<AtomicEntryRef>;
+ using NodeRef = vespalib::datastore::EntryRef;
// This stores the level arrays for all nodes.
// Each node consists of an array of levels (from level 0 to n) where each entry is a reference to the link array at that level.
@@ -45,33 +46,64 @@ struct HnswGraph {
~HnswGraph();
- void make_node_for_document(uint32_t docid, uint32_t num_levels);
+ NodeRef make_node_for_document(uint32_t docid, uint32_t num_levels);
void remove_node_for_document(uint32_t docid);
+ NodeRef get_node_ref(uint32_t docid) const {
+ return node_refs[docid].load_acquire();
+ }
+
+ bool still_valid(uint32_t docid, NodeRef node_ref) const {
+ return node_ref.valid() && (get_node_ref(docid) == node_ref);
+ }
+
+ LevelArrayRef get_level_array(NodeRef node_ref) const {
+ if (node_ref.valid()) {
+ return nodes.get(node_ref);
+ }
+ return LevelArrayRef();
+ }
+
LevelArrayRef get_level_array(uint32_t docid) const {
- auto node_ref = node_refs[docid].load_acquire();
- assert(node_ref.valid());
- return nodes.get(node_ref);
+ auto node_ref = get_node_ref(docid);
+ return get_level_array(node_ref);
+ }
+
+ LinkArrayRef get_link_array(LevelArrayRef levels, uint32_t level) const {
+ if (level < levels.size()) {
+ auto links_ref = levels[level].load_acquire();
+ if (links_ref.valid()) {
+ return links.get(links_ref);
+ }
+ }
+ return LinkArrayRef();
}
LinkArrayRef get_link_array(uint32_t docid, uint32_t level) const {
auto levels = get_level_array(docid);
- assert(level < levels.size());
- return links.get(levels[level].load_acquire());
+ return get_link_array(levels, level);
+ }
+
+ LinkArrayRef get_link_array(NodeRef node_ref, uint32_t level) const {
+ auto levels = get_level_array(node_ref);
+ return get_link_array(levels, level);
}
-
+
void set_link_array(uint32_t docid, uint32_t level, const LinkArrayRef& new_links);
struct EntryNode {
uint32_t docid;
+ NodeRef node_ref;
int32_t level;
EntryNode()
: docid(0), // Note that docid 0 is reserved and never used
+ node_ref(),
level(-1)
{}
- EntryNode(uint32_t docid_in, int32_t level_in)
+ EntryNode(uint32_t docid_in, NodeRef node_ref_in, int32_t level_in)
: docid(docid_in),
+ node_ref(node_ref_in),
level(level_in)
{}
};
@@ -80,15 +112,43 @@ struct HnswGraph {
uint64_t value = node.level;
value <<= 32;
value |= node.docid;
+ if (node.node_ref.valid()) {
+ assert(node.level >= 0);
+ assert(node.docid > 0);
+ } else {
+ assert(node.level == -1);
+ assert(node.docid == 0);
+ }
entry_docid_and_level.store(value, std::memory_order_release);
}
+ uint64_t get_entry_atomic() const {
+ return entry_docid_and_level.load(std::memory_order_acquire);
+ }
+
EntryNode get_entry_node() const {
EntryNode entry;
- uint64_t value = entry_docid_and_level.load(std::memory_order_acquire);
- entry.docid = (uint32_t)value;
- entry.level = (int32_t)(value >> 32);
- return entry;
+ while (true) {
+ uint64_t value = get_entry_atomic();
+ entry.docid = (uint32_t)value;
+ entry.node_ref = get_node_ref(entry.docid);
+ entry.level = (int32_t)(value >> 32);
+ if ((entry.docid == 0)
+ && (entry.level == -1)
+ && ! entry.node_ref.valid())
+ {
+ // invalid in every way
+ return entry;
+ }
+ if ((entry.docid > 0)
+ && (entry.level > -1)
+ && entry.node_ref.valid()
+ && (get_entry_atomic() == value))
+ {
+ // valid in every way
+ return entry;
+ }
+ }
}
size_t size() const { return node_refs.size(); }
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
index a03614a785e..36d970dfd01 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp
@@ -72,10 +72,10 @@ HnswIndex::max_links_for_level(uint32_t level) const
}
bool
-HnswIndex::have_closer_distance(HnswCandidate candidate, const LinkArrayRef& result) const
+HnswIndex::have_closer_distance(HnswCandidate candidate, const HnswCandidateVector& result) const
{
- for (uint32_t result_docid : result) {
- double dist = calc_distance(candidate.docid, result_docid);
+ for (const auto & neighbor : result) {
+ double dist = calc_distance(candidate.docid, neighbor.docid);
if (dist < candidate.distance) {
return true;
}
@@ -91,7 +91,7 @@ HnswIndex::select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_
SelectResult result;
for (const auto & candidate : sorted) {
if (result.used.size() < max_links) {
- result.used.push_back(candidate.docid);
+ result.used.push_back(candidate);
} else {
result.unused.push_back(candidate.docid);
}
@@ -114,7 +114,7 @@ HnswIndex::select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint
result.unused.push_back(candidate.docid);
continue;
}
- result.used.push_back(candidate.docid);
+ result.used.push_back(candidate);
if (result.used.size() == max_links) {
while (!nearest.empty()) {
candidate = nearest.top();
@@ -148,7 +148,12 @@ HnswIndex::shrink_if_needed(uint32_t docid, uint32_t level)
neighbors.emplace_back(neighbor_docid, dist);
}
auto split = select_neighbors(neighbors, max_links);
- _graph.set_link_array(docid, level, split.used);
+ LinkArray new_links;
+ new_links.reserve(split.used.size());
+ for (const auto & neighbor : split.used) {
+ new_links.push_back(neighbor.docid);
+ }
+ _graph.set_link_array(docid, level, new_links);
for (uint32_t removed_docid : split.unused) {
remove_link_to(removed_docid, docid, level);
}
@@ -201,10 +206,13 @@ HnswIndex::find_nearest_in_layer(const TypedCells& input, const HnswCandidate& e
bool keep_searching = true;
while (keep_searching) {
keep_searching = false;
- for (uint32_t neighbor_docid : _graph.get_link_array(nearest.docid, level)) {
+ for (uint32_t neighbor_docid : _graph.get_link_array(nearest.node_ref, level)) {
+ auto neighbor_ref = _graph.get_node_ref(neighbor_docid);
double dist = calc_distance(input, neighbor_docid);
- if (dist < nearest.distance) {
- nearest = HnswCandidate(neighbor_docid, dist);
+ if (_graph.still_valid(neighbor_docid, neighbor_ref)
+ && dist < nearest.distance)
+ {
+ nearest = HnswCandidate(neighbor_docid, neighbor_ref, dist);
keep_searching = true;
}
}
@@ -239,16 +247,20 @@ HnswIndex::search_layer(const TypedCells& input, uint32_t neighbors_to_find,
break;
}
candidates.pop();
- for (uint32_t neighbor_docid : _graph.get_link_array(cand.docid, level)) {
- if ((neighbor_docid >= doc_id_limit) || visited.is_marked(neighbor_docid)) {
+ for (uint32_t neighbor_docid : _graph.get_link_array(cand.node_ref, level)) {
+ auto neighbor_ref = _graph.get_node_ref(neighbor_docid);
+ if ((! neighbor_ref.valid())
+ || (neighbor_docid >= doc_id_limit)
+ || visited.is_marked(neighbor_docid))
+ {
continue;
}
visited.mark(neighbor_docid);
double dist_to_input = calc_distance(input, neighbor_docid);
if (dist_to_input < limit_dist) {
- candidates.emplace(neighbor_docid, dist_to_input);
+ candidates.emplace(neighbor_docid, neighbor_ref, dist_to_input);
if ((!filter) || filter->testBit(neighbor_docid)) {
- best_neighbors.emplace(neighbor_docid, dist_to_input);
+ best_neighbors.emplace(neighbor_docid, neighbor_ref, dist_to_input);
if (best_neighbors.size() > neighbors_to_find) {
best_neighbors.pop();
limit_dist = best_neighbors.top().distance;
@@ -287,11 +299,13 @@ HnswIndex::internal_prepare_add(uint32_t docid, TypedCells input_vector) const
PreparedAddDoc op(docid, level);
auto entry = _graph.get_entry_node();
if (entry.docid == 0) {
+ // graph has no entry point
return op;
}
int search_level = entry.level;
double entry_dist = calc_distance(input_vector, entry.docid);
- HnswCandidate entry_point(entry.docid, entry_dist);
+ // TODO: check if entry docid/node_ref is still valid here
+ HnswCandidate entry_point(entry.docid, entry.node_ref, entry_dist);
while (search_level > op.max_level) {
entry_point = find_nearest_in_layer(input_vector, entry_point, search_level);
--search_level;
@@ -305,22 +319,35 @@ HnswIndex::internal_prepare_add(uint32_t docid, TypedCells input_vector) const
while (search_level >= 0) {
search_layer(input_vector, _cfg.neighbors_to_explore_at_construction(), best_neighbors, search_level);
auto neighbors = select_neighbors(best_neighbors.peek(), _cfg.max_links_on_inserts());
- auto use = neighbors.used;
- op.connections[search_level].assign(use.begin(), use.end());
+ op.connections[search_level].reserve(neighbors.used.size());
+ for (const auto & neighbor : neighbors.used) {
+ auto neighbor_levels = _graph.get_level_array(neighbor.node_ref);
+ if (size_t(search_level) < neighbor_levels.size()) {
+ op.connections[search_level].emplace_back(neighbor.docid, neighbor.node_ref);
+ } else {
+ LOG(warning, "in prepare_add(%u), selected neighbor %u is missing level %d (has %zu levels)",
+ docid, neighbor.docid, search_level, neighbor_levels.size());
+ }
+ }
--search_level;
}
return op;
}
HnswIndex::LinkArray
-HnswIndex::filter_valid_docids(const LinkArrayRef &docids)
+HnswIndex::filter_valid_docids(uint32_t level, const PreparedAddDoc::Links &neighbors, uint32_t self_docid)
{
LinkArray valid;
- valid.reserve(docids.size());
- for (uint32_t docid : docids) {
- auto node_ref = _graph.node_refs[docid].load_acquire();
- if (node_ref.valid()) {
- valid.push_back(docid);
+ valid.reserve(neighbors.size());
+ for (const auto & neighbor : neighbors) {
+ uint32_t docid = neighbor.first;
+ HnswGraph::NodeRef node_ref = neighbor.second;
+ if (_graph.still_valid(docid, node_ref)) {
+ assert(docid != self_docid);
+ auto levels = _graph.get_level_array(node_ref);
+ if (level < levels.size()) {
+ valid.push_back(docid);
+ }
}
}
return valid;
@@ -329,13 +356,13 @@ HnswIndex::filter_valid_docids(const LinkArrayRef &docids)
void
HnswIndex::internal_complete_add(uint32_t docid, PreparedAddDoc &op)
{
- _graph.make_node_for_document(docid, op.max_level + 1);
+ auto node_ref = _graph.make_node_for_document(docid, op.max_level + 1);
for (int level = 0; level <= op.max_level; ++level) {
- auto neighbors = filter_valid_docids(op.connections[level]);
+ auto neighbors = filter_valid_docids(level, op.connections[level], docid);
connect_new_node(docid, neighbors, level);
}
if (op.max_level > get_entry_level()) {
- _graph.set_entry_node({docid, op.max_level});
+ _graph.set_entry_node({docid, node_ref, op.max_level});
}
}
@@ -392,19 +419,18 @@ void
HnswIndex::remove_document(uint32_t docid)
{
bool need_new_entrypoint = (docid == get_entry_docid());
- LinkArray empty;
LevelArrayRef node_levels = _graph.get_level_array(docid);
for (int level = node_levels.size(); level-- > 0; ) {
LinkArrayRef my_links = _graph.get_link_array(docid, level);
for (uint32_t neighbor_id : my_links) {
if (need_new_entrypoint) {
- _graph.set_entry_node({neighbor_id, level});
+ auto entry_node_ref = _graph.get_node_ref(neighbor_id);
+ _graph.set_entry_node({neighbor_id, entry_node_ref, level});
need_new_entrypoint = false;
}
remove_link_to(neighbor_id, docid, level);
}
mutual_reconnect(my_links, level);
- _graph.set_link_array(docid, level, empty);
}
if (need_new_entrypoint) {
HnswGraph::EntryNode entry;
@@ -530,12 +556,14 @@ HnswIndex::top_k_candidates(const TypedCells &vector, uint32_t k, const BitVecto
{
FurthestPriQ best_neighbors;
auto entry = _graph.get_entry_node();
- if (entry.level < 0) {
+ if (entry.docid == 0) {
+ // graph has no entry point
return best_neighbors;
}
int search_level = entry.level;
double entry_dist = calc_distance(vector, entry.docid);
- HnswCandidate entry_point(entry.docid, entry_dist);
+ // TODO: check if entry docid/node_ref is still valid here
+ HnswCandidate entry_point(entry.docid, entry.node_ref, entry_dist);
while (search_level > 0) {
entry_point = find_nearest_in_layer(vector, entry_point, search_level);
--search_level;
@@ -568,13 +596,13 @@ HnswIndex::set_node(uint32_t docid, const HnswNode &node)
{
size_t num_levels = node.size();
assert(num_levels > 0);
- _graph.make_node_for_document(docid, num_levels);
+ auto node_ref = _graph.make_node_for_document(docid, num_levels);
for (size_t level = 0; level < num_levels; ++level) {
connect_new_node(docid, node.level(level), level);
}
int max_level = num_levels - 1;
if (get_entry_level() < max_level) {
- _graph.set_entry_node({docid, max_level});
+ _graph.set_entry_node({docid, node_ref, max_level});
}
}
@@ -593,6 +621,8 @@ HnswIndex::check_link_symmetry() const
auto neighbor_links = _graph.get_link_array(neighbor_docid, level);
if (! has_link_to(neighbor_links, docid)) {
all_sym = false;
+ LOG(warning, "check_link_symmetry: docid %zu links to %u on level %u, but no backlink",
+ docid, neighbor_docid, level);
}
}
++level;
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
index df7453023cf..ab3eced8fdc 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h
@@ -91,10 +91,11 @@ protected:
* where the candidate is located.
* Used by select_neighbors_heuristic().
*/
- bool have_closer_distance(HnswCandidate candidate, const LinkArrayRef& curr_result) const;
+ bool have_closer_distance(HnswCandidate candidate, const HnswCandidateVector& curr_result) const;
struct SelectResult {
- LinkArray used;
+ HnswCandidateVector used;
LinkArray unused;
+ ~SelectResult() {}
};
SelectResult select_neighbors_heuristic(const HnswCandidateVector& neighbors, uint32_t max_links) const;
SelectResult select_neighbors_simple(const HnswCandidateVector& neighbors, uint32_t max_links) const;
@@ -123,7 +124,8 @@ protected:
struct PreparedAddDoc : public PrepareResult {
uint32_t docid;
int32_t max_level;
- std::vector<LinkArray> connections;
+ using Links = std::vector<std::pair<uint32_t, HnswGraph::NodeRef>>;
+ std::vector<Links> connections;
PreparedAddDoc(uint32_t docid_in, int32_t max_level_in)
: docid(docid_in), max_level(max_level_in), connections(max_level+1)
{}
@@ -131,7 +133,7 @@ protected:
PreparedAddDoc(PreparedAddDoc&& other) = default;
};
PreparedAddDoc internal_prepare_add(uint32_t docid, TypedCells input_vector) const;
- LinkArray filter_valid_docids(const LinkArrayRef &docids);
+ LinkArray filter_valid_docids(uint32_t level, const PreparedAddDoc::Links &neighbors, uint32_t me);
void internal_complete_add(uint32_t docid, PreparedAddDoc &op);
public:
HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distance_func,
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp
index 9f49c0647c6..ac98b28d105 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index_loader.cpp
@@ -39,7 +39,8 @@ HnswIndexLoader::load(const fileutil::LoadedBuffer& buf)
}
if (_failed) return false;
_graph.node_refs.ensure_size(num_nodes);
- _graph.set_entry_node({entry_docid, entry_level});
+ auto entry_node_ref = _graph.get_node_ref(entry_docid);
+ _graph.set_entry_node({entry_docid, entry_node_ref, entry_level});
return true;
}
diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h
index b11d3f36a7a..99266505780 100644
--- a/searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h
+++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index_utils.h
@@ -5,6 +5,7 @@
#include <cstdint>
#include <queue>
#include <vector>
+#include "hnsw_graph.h"
namespace search::tensor {
@@ -13,8 +14,12 @@ namespace search::tensor {
*/
struct HnswCandidate {
uint32_t docid;
+ HnswGraph::NodeRef node_ref;
double distance;
- HnswCandidate(uint32_t docid_in, double distance_in) : docid(docid_in), distance(distance_in) {}
+ HnswCandidate(uint32_t docid_in, double distance_in)
+ : docid(docid_in), node_ref(), distance(distance_in) {}
+ HnswCandidate(uint32_t docid_in, HnswGraph::NodeRef node_ref_in, double distance_in)
+ : docid(docid_in), node_ref(node_ref_in), distance(distance_in) {}
};
struct GreaterDistance {
diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java
index 03916949cae..a932ca935e0 100644
--- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java
+++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedHandlerV3.java
@@ -84,7 +84,7 @@ public class FeedHandlerV3 extends LoggingRequestHandler {
SourceSessionParams sourceSessionParams = sourceSessionParams(request);
clientFeederByClientId.put(clientId,
new ClientFeederV3(retainSource(sessionCache, sourceSessionParams),
- new FeedReaderFactory(),
+ new FeedReaderFactory(true), //TODO make error debugging configurable
docTypeManager,
clientId,
metric,
diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java
index 6a3229e86b7..81b08d5fb25 100644
--- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java
+++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/FeedReaderFactory.java
@@ -3,6 +3,7 @@ package com.yahoo.vespa.http.server;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.document.json.JsonFeedReader;
+import com.yahoo.text.Utf8;
import com.yahoo.vespa.http.client.config.FeedParams;
import com.yahoo.vespaxmlparser.FeedReader;
import com.yahoo.vespaxmlparser.VespaXMLFeedReader;
@@ -14,6 +15,12 @@ import java.io.InputStream;
* @author dybis
*/
public class FeedReaderFactory {
+ private static final int MARK_READLIMIT = 200;
+
+ private final boolean debug;
+ public FeedReaderFactory(boolean debug) {
+ this.debug = debug;
+ }
/**
* Creates FeedReader
@@ -28,10 +35,22 @@ public class FeedReaderFactory {
FeedParams.DataFormat dataFormat) {
switch (dataFormat) {
case XML_UTF8:
+ byte [] peek = null;
+ int bytesPeeked = 0;
try {
+ if (debug && inputStream.markSupported()) {
+ peek = new byte[MARK_READLIMIT];
+ inputStream.mark(MARK_READLIMIT);
+ bytesPeeked = inputStream.read(peek);
+ inputStream.reset();
+ }
return new VespaXMLFeedReader(inputStream, docTypeManager);
} catch (Exception e) {
- throw new RuntimeException("Could not create VespaXMLFeedReader", e);
+ if (bytesPeeked > 0) {
+ throw new RuntimeException("Could not create VespaXMLFeedReader. First characters are: '" + Utf8.toString(peek, 0, bytesPeeked) + "'", e);
+ } else {
+ throw new RuntimeException("Could not create VespaXMLFeedReader.", e);
+ }
}
case JSON_UTF8:
return new JsonFeedReader(inputStream, docTypeManager);
diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java
index c9f255f026e..c8ae79deebd 100644
--- a/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java
+++ b/vespaclient-container-plugin/src/main/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStream.java
@@ -13,6 +13,7 @@ public class ByteLimitedInputStream extends InputStream {
private final InputStream wrappedStream;
private int remaining;
+ private int remainingWhenMarked;
public ByteLimitedInputStream(InputStream wrappedStream, int limit) {
this.wrappedStream = wrappedStream;
@@ -78,4 +79,21 @@ public class ByteLimitedInputStream extends InputStream {
}
}
+ @Override
+ public synchronized void mark(int readlimit) {
+ wrappedStream.mark(readlimit);
+ remainingWhenMarked = remaining;
+ }
+
+ @Override
+ public synchronized void reset() throws IOException {
+ wrappedStream.reset();
+ remaining = remainingWhenMarked;
+ }
+
+ @Override
+ public boolean markSupported() {
+ return wrappedStream.markSupported();
+ }
+
}
diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/FeedReaderFactoryTestCase.java b/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/FeedReaderFactoryTestCase.java
new file mode 100644
index 00000000000..47f057013b7
--- /dev/null
+++ b/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/FeedReaderFactoryTestCase.java
@@ -0,0 +1,40 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.vespa.http.server;
+
+import com.yahoo.document.DocumentTypeManager;
+import com.yahoo.text.Utf8;
+import com.yahoo.vespa.http.client.config.FeedParams;
+import org.junit.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.InputStream;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+public class FeedReaderFactoryTestCase {
+ DocumentTypeManager manager = new DocumentTypeManager();
+
+ private InputStream createStream(String s) {
+ return new ByteArrayInputStream(Utf8.toBytes(s));
+ }
+
+ @Test
+ public void testXmlExceptionWithDebug() {
+ try {
+ new FeedReaderFactory(true).createReader(createStream("Some malformed xml"), manager, FeedParams.DataFormat.XML_UTF8);
+ fail();
+ } catch (RuntimeException e) {
+ assertEquals("Could not create VespaXMLFeedReader. First characters are: 'Some malformed xml'", e.getMessage());
+ }
+ }
+ @Test
+ public void testXmlException() {
+ try {
+ new FeedReaderFactory(false).createReader(createStream("Some malformed xml"), manager, FeedParams.DataFormat.XML_UTF8);
+ fail();
+ } catch (RuntimeException e) {
+ assertEquals("Could not create VespaXMLFeedReader.", e.getMessage());
+ }
+ }
+}
diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java b/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java
index 3aa3cdcb3a8..3dd8145ec73 100644
--- a/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java
+++ b/vespaclient-container-plugin/src/test/java/com/yahoo/vespa/http/server/util/ByteLimitedInputStreamTestCase.java
@@ -8,8 +8,8 @@ import java.io.IOException;
import java.io.InputStream;
import java.nio.charset.StandardCharsets;
-import static org.hamcrest.core.Is.is;
-import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
/**
* @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
@@ -29,63 +29,78 @@ public class ByteLimitedInputStreamTestCase {
public void requireThatBasicsWork() throws IOException {
ByteLimitedInputStream stream = create("abcdefghijklmnopqr".getBytes(StandardCharsets.US_ASCII), 9);
- assertThat(stream.available(), is(9));
- assertThat(stream.read(), is(97));
- assertThat(stream.available(), is(8));
- assertThat(stream.read(), is(98));
- assertThat(stream.available(), is(7));
- assertThat(stream.read(), is(99));
- assertThat(stream.available(), is(6));
- assertThat(stream.read(), is(100));
- assertThat(stream.available(), is(5));
- assertThat(stream.read(), is(101));
- assertThat(stream.available(), is(4));
- assertThat(stream.read(), is(102));
- assertThat(stream.available(), is(3));
- assertThat(stream.read(), is(103));
- assertThat(stream.available(), is(2));
- assertThat(stream.read(), is(104));
- assertThat(stream.available(), is(1));
- assertThat(stream.read(), is(105));
- assertThat(stream.available(), is(0));
- assertThat(stream.read(), is(-1));
- assertThat(stream.available(), is(0));
- assertThat(stream.read(), is(-1));
- assertThat(stream.available(), is(0));
- assertThat(stream.read(), is(-1));
- assertThat(stream.available(), is(0));
- assertThat(stream.read(), is(-1));
- assertThat(stream.available(), is(0));
- assertThat(stream.read(), is(-1));
- assertThat(stream.available(), is(0));
+ assertEquals(9, stream.available());
+ assertEquals(97, stream.read());
+ assertEquals(8, stream.available());
+ assertEquals(98, stream.read());
+ assertEquals(7, stream.available());
+ assertEquals(99, stream.read());
+ assertEquals(6, stream.available());
+ assertEquals(100, stream.read());
+ assertEquals(5, stream.available());
+ assertEquals(101, stream.read());
+ assertEquals(4, stream.available());
+ assertEquals(102, stream.read());
+ assertEquals(3, stream.available());
+ assertEquals(103, stream.read());
+ assertEquals(2, stream.available());
+ assertEquals(104, stream.read());
+ assertEquals(1, stream.available());
+ assertEquals(105, stream.read());
+ assertEquals(0, stream.available());
+ assertEquals(-1, stream.read());
+ assertEquals(0, stream.available());
+ assertEquals(-1, stream.read());
+ assertEquals(0, stream.available());
+ assertEquals(-1, stream.read());
+ assertEquals(0, stream.available());
+ assertEquals(-1, stream.read());
+ assertEquals(0, stream.available());
+ assertEquals(-1, stream.read());
+ assertEquals(0, stream.available());
}
@Test
public void requireThatChunkedReadWorks() throws IOException {
ByteLimitedInputStream stream = create("abcdefghijklmnopqr".getBytes(StandardCharsets.US_ASCII), 9);
- assertThat(stream.available(), is(9));
+ assertEquals(9, stream.available());
byte[] toBuf = new byte[4];
- assertThat(stream.read(toBuf), is(4));
- assertThat(toBuf[0], is((byte) 97));
- assertThat(toBuf[1], is((byte) 98));
- assertThat(toBuf[2], is((byte) 99));
- assertThat(toBuf[3], is((byte) 100));
- assertThat(stream.available(), is(5));
+ assertEquals(4, stream.read(toBuf));
+ assertEquals(97, toBuf[0]);
+ assertEquals(98, toBuf[1]);
+ assertEquals(99, toBuf[2]);
+ assertEquals(100, toBuf[3]);
+ assertEquals(5, stream.available());
- assertThat(stream.read(toBuf), is(4));
- assertThat(toBuf[0], is((byte) 101));
- assertThat(toBuf[1], is((byte) 102));
- assertThat(toBuf[2], is((byte) 103));
- assertThat(toBuf[3], is((byte) 104));
- assertThat(stream.available(), is(1));
+ assertEquals(4, stream.read(toBuf));
+ assertEquals(101, toBuf[0]);
+ assertEquals(102, toBuf[1]);
+ assertEquals(103, toBuf[2]);
+ assertEquals(104, toBuf[3]);
+ assertEquals(1, stream.available());
- assertThat(stream.read(toBuf), is(1));
- assertThat(toBuf[0], is((byte) 105));
- assertThat(stream.available(), is(0));
+ assertEquals(1, stream.read(toBuf));
+ assertEquals(105, toBuf[0]);
+ assertEquals(0, stream.available());
- assertThat(stream.read(toBuf), is(-1));
- assertThat(stream.available(), is(0));
+ assertEquals(-1, stream.read(toBuf));
+ assertEquals(0, stream.available());
+ }
+
+ @Test
+ public void requireMarkWorks() throws IOException {
+ InputStream stream = create("abcdefghijklmnopqr".getBytes(StandardCharsets.US_ASCII), 9);
+ assertEquals(97, stream.read());
+ assertTrue(stream.markSupported());
+ stream.mark(5);
+ assertEquals(98, stream.read());
+ assertEquals(99, stream.read());
+ stream.reset();
+ assertEquals(98, stream.read());
+ assertEquals(99, stream.read());
+ assertEquals(100, stream.read());
+ assertEquals(101, stream.read());
}
}
diff --git a/vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java b/vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java
index 9a61af7266f..df1d5505632 100644
--- a/vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java
+++ b/vespaclient-container-plugin/src/test/java/com/yahoo/vespaxmlparser/MockFeedReaderFactory.java
@@ -13,6 +13,10 @@ import java.io.InputStream;
*/
public class MockFeedReaderFactory extends FeedReaderFactory {
+ public MockFeedReaderFactory() {
+ super(true);
+ }
+
@Override
public FeedReader createReader(
InputStream inputStream,
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index d9467a41f78..154b6871392 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2075,6 +2075,22 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.ScalarFunctions$Erf": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "java.util.function.DoubleUnaryOperator"
+ ],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public double applyAsDouble(double)",
+ "public java.lang.String toString()",
+ "public static double erf(double)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.ScalarFunctions$Exp": {
"superClass": "java.lang.Object",
"interfaces": [
@@ -2506,6 +2522,7 @@
"public static java.util.function.DoubleUnaryOperator square()",
"public static java.util.function.DoubleUnaryOperator tan()",
"public static java.util.function.DoubleUnaryOperator tanh()",
+ "public static java.util.function.DoubleUnaryOperator erf()",
"public static java.util.function.DoubleUnaryOperator elu()",
"public static java.util.function.DoubleUnaryOperator elu(double)",
"public static java.util.function.DoubleUnaryOperator leakyrelu()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index d9204e24d68..c19b07cf96f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -50,6 +50,7 @@ public class ScalarFunctions {
public static DoubleUnaryOperator square() { return new Square(); }
public static DoubleUnaryOperator tan() { return new Tan(); }
public static DoubleUnaryOperator tanh() { return new Tanh(); }
+ public static DoubleUnaryOperator erf() { return new Erf(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator elu(double alpha) { return new Elu(alpha); }
@@ -330,6 +331,30 @@ public class ScalarFunctions {
public String toString() { return "f(a)(tanh(a))"; }
}
+ public static class Erf implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return erf(operand); }
+ @Override
+ public String toString() { return "f(a)(erf(a))"; }
+
+ // Use Horner's method
+ // From https://introcs.cs.princeton.edu/java/21function/ErrorFunction.java.html
+ public static double erf(double v) {
+ double t = 1.0 / (1.0 + 0.5 * Math.abs(v));
+ double ans = 1 - t * Math.exp(-v*v - 1.26551223 +
+ t * ( 1.00002368 +
+ t * ( 0.37409196 +
+ t * ( 0.09678418 +
+ t * (-0.18628806 +
+ t * ( 0.27886807 +
+ t * (-1.13520398 +
+ t * ( 1.48851587 +
+ t * (-0.82215223 +
+ t * ( 0.17087277))))))))));
+ if (v >= 0) return ans;
+ else return -ans;
+ }
+ }
// Variable-length operators -----------------------------------------------------------------------------