summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/schema/OnnxModel.java7
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java14
-rw-r--r--config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java25
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java24
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java11
-rw-r--r--controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java7
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java15
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java15
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java2
-rw-r--r--controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java1
-rw-r--r--controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java11
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java5
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java4
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java26
-rw-r--r--model-integration/pom.xml5
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java19
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java70
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java53
-rw-r--r--model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java95
-rw-r--r--searchcore/src/tests/proton/docsummary/docsummary_test.cpp5
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp12
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h4
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h5
-rw-r--r--searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp5
-rw-r--r--streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp9
-rw-r--r--streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h3
-rw-r--r--vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java16
-rw-r--r--vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java37
29 files changed, 381 insertions, 126 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
index 272b668b5fb..90a27d1f036 100644
--- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
+++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java
@@ -1,15 +1,17 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.schema;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.model.ml.OnnxModelInfo;
-import com.yahoo.searchlib.rankingexpression.Reference;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
+import java.util.Set;
/**
* A global ONNX model distributed using file distribution, similar to ranking constants.
@@ -21,6 +23,7 @@ public class OnnxModel extends DistributableResource {
private OnnxModelInfo modelInfo = null;
private final Map<String, String> inputMap = new HashMap<>();
private final Map<String, String> outputMap = new HashMap<>();
+ private final Set<String> initializers = new HashSet<>();
private String statelessExecutionMode = null;
private Integer statelessInterOpThreads = null;
@@ -101,11 +104,13 @@ public class OnnxModel extends DistributableResource {
for (String onnxName : modelInfo.getOutputs()) {
addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false);
}
+ initializers.addAll(modelInfo.getInitializers());
this.modelInfo = modelInfo;
}
public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); }
public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); }
+ public Set<String> getInitializers() { return Set.copyOf(initializers); }
public String getDefaultOutput() {
return modelInfo != null ? modelInfo.getDefaultOutput() : "";
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
index af072c5b59a..7f578f07fe3 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java
@@ -2,7 +2,6 @@
package com.yahoo.schema.expressiontransforms;
import com.yahoo.schema.FeatureNames;
-import com.yahoo.schema.RankProfile;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
@@ -12,13 +11,12 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.functions.DynamicTensor;
import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Slice;
import java.io.StringReader;
import java.util.HashSet;
import java.util.Set;
+import java.util.logging.Logger;
/**
* Analyzes expression to figure out what inputs it needs
@@ -27,6 +25,8 @@ import java.util.Set;
*/
public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
+ private static final Logger log = Logger.getLogger(InputRecorder.class.getName());
+
private final Set<String> neededInputs;
private final Set<String> handled = new HashSet<>();
@@ -120,7 +120,11 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
if (model == null) {
throw new IllegalArgumentException("missing onnx model: " + arg);
}
- for (String onnxInput : model.getInputMap().values()) {
+ model.getInputMap().forEach((onnxName, onnxInput) -> {
+ if (model.getInitializers().contains(onnxName)) {
+ log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName));
+ return;
+ }
var reader = new StringReader(onnxInput);
try {
var asExpression = new RankingExpression(reader);
@@ -128,7 +132,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> {
} catch (ParseException e) {
throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage());
}
- }
+ });
return;
}
neededInputs.add(feature.toString());
diff --git a/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java b/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java
index 5509d11885c..b414d3757e2 100644
--- a/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java
+++ b/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java
@@ -20,7 +20,7 @@ import java.util.List;
* Class converting a collection of schemas from the intermediate format.
*
* @author arnej27959
- **/
+ */
public class ConvertSchemaCollection {
private final IntermediateCollection input;
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java
index 0cc52edf3cc..b1eace947cc 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java
@@ -2,11 +2,15 @@
package com.yahoo.vespa.model.application.validation.change;
import com.yahoo.config.model.api.ConfigChangeAction;
+import com.yahoo.config.model.api.ServiceInfo;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.config.application.api.ValidationId;
+import com.yahoo.vespa.model.container.ApplicationContainer;
+import com.yahoo.vespa.model.container.ApplicationContainerCluster;
import com.yahoo.vespa.model.content.cluster.ContentCluster;
+import java.util.ArrayList;
import java.util.List;
/**
@@ -19,16 +23,23 @@ public class ContentClusterRemovalValidator implements ChangeValidator {
@Override
public List<ConfigChangeAction> validate(VespaModel current, VespaModel next, DeployState deployState) {
+ List<ConfigChangeAction> actions = new ArrayList<>();
for (String currentClusterId : current.getContentClusters().keySet()) {
ContentCluster nextCluster = next.getContentClusters().get(currentClusterId);
- if (nextCluster == null)
+ if (nextCluster == null) {
deployState.validationOverrides().invalid(ValidationId.contentClusterRemoval,
- "Content cluster '" + currentClusterId + "' is removed. " +
- "This will cause loss of all data in this cluster",
- deployState.now());
- }
+ "Content cluster '" + currentClusterId + "' is removed. " +
+ "This will cause loss of all data in this cluster",
+ deployState.now());
- return List.of();
+ // If we allow the removal, we must restart all containers to ensure mbus is OK.
+ for (ApplicationContainerCluster cluster : next.getContainerClusters().values()) {
+ actions.add(new VespaRestartAction(cluster.id(),
+ "Content cluster '" + currentClusterId + "' has been removed",
+ cluster.getContainers().stream().map(ApplicationContainer::getServiceInfo).toList()));
+ }
+ }
+ }
+ return actions;
}
-
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
index 2742dc59fcd..7c89a349d7d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java
@@ -42,13 +42,16 @@ public class OnnxModelInfo {
private final Map<String, OnnxTypeInfo> inputs;
private final Map<String, OnnxTypeInfo> outputs;
private final Map<String, TensorType> vespaTypes = new HashMap<>();
+ private final Set<String> initializers;
- private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) {
+ private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs,
+ Map<String, OnnxTypeInfo> outputs, Set<String> initializers, String defaultOutput) {
this.app = app;
this.modelPath = path;
this.inputs = Collections.unmodifiableMap(inputs);
this.outputs = Collections.unmodifiableMap(outputs);
this.defaultOutput = defaultOutput;
+ this.initializers = Set.copyOf(initializers);
}
public String getModelPath() {
@@ -63,6 +66,8 @@ public class OnnxModelInfo {
return outputs.keySet();
}
+ public Set<String> getInitializers() { return initializers; }
+
public String getDefaultOutput() {
return defaultOutput;
}
@@ -208,6 +213,14 @@ public class OnnxModelInfo {
}
g.writeEndArray();
+ g.writeArrayFieldStart("initializers");
+ for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) {
+ g.writeStartObject();
+ g.writeStringField("name", initializers.getName());
+ g.writeEndObject();
+ }
+ g.writeEndArray();
+
g.writeEndObject();
g.close();
return out.toString();
@@ -218,6 +231,7 @@ public class OnnxModelInfo {
JsonNode root = m.readTree(json);
Map<String, OnnxTypeInfo> inputs = new HashMap<>();
Map<String, OnnxTypeInfo> outputs = new HashMap<>();
+ Set<String> initializers = new HashSet<>();
String defaultOutput = "";
String path = null;
@@ -233,7 +247,13 @@ public class OnnxModelInfo {
if (root.get("outputs").has(0)) {
defaultOutput = root.get("outputs").get(0).get("name").textValue();
}
- return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput);
+ var initializerRoot = root.get("initializers");
+ if (initializerRoot != null) {
+ for (JsonNode initializer : initializerRoot) {
+ initializers.add(initializer.get("name").textValue());
+ }
+ }
+ return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput);
}
static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException {
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java
index 5c360a9343f..65dfce8ff6c 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java
@@ -1,14 +1,18 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.application.validation.change;
+import com.yahoo.collections.Pair;
import com.yahoo.config.application.api.ValidationId;
import com.yahoo.config.application.api.ValidationOverrides;
+import com.yahoo.config.model.api.ConfigChangeAction;
import com.yahoo.config.provision.Environment;
import com.yahoo.vespa.model.VespaModel;
import com.yahoo.vespa.model.application.validation.ValidationTester;
import com.yahoo.yolean.Exceptions;
import org.junit.jupiter.api.Test;
+import java.util.List;
+
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.fail;
@@ -36,7 +40,12 @@ public class ContentClusterRemovalValidatorTest {
@Test
void testOverridingContentRemovalValidation() {
VespaModel previous = tester.deploy(null, getServices("contentClusterId"), Environment.prod, null).getFirst();
- tester.deploy(previous, getServices("newContentClusterId"), Environment.prod, removalOverride); // Allowed due to override
+ var result = tester.deploy(previous, getServices("newContentClusterId"), Environment.prod, removalOverride); // Allowed due to override
+ assertEquals(result.getFirst().getContainerClusters().values().stream()
+ .flatMap(cluster -> cluster.getContainers().stream())
+ .map(container -> container.getServiceInfo())
+ .toList(),
+ result.getSecond().stream().flatMap(action -> action.getServices().stream()).toList());
}
private static String getServices(String contentClusterId) {
diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java
index 945e0730fe6..34a7c1a6f7c 100644
--- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java
+++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java
@@ -113,6 +113,8 @@ public class NodeRepositoryNode {
private String cloudAccount;
@JsonProperty("wireguardPubKey")
private String wireguardPubKey;
+ @JsonProperty("archiveUri")
+ private String archiveUri;
public String getUrl() {
return url;
@@ -460,6 +462,10 @@ public class NodeRepositoryNode {
public void setWireguardPubKey(String wireguardPubKey) { this.wireguardPubKey = wireguardPubKey; }
+ public String getArchiveUri() { return archiveUri; }
+
+ public void setArchiveUri(String archiveUri) { this.archiveUri = archiveUri; }
+
// --- Helper methods for code that (wrongly) consume this directly
public boolean hasType(NodeType type) {
@@ -521,6 +527,7 @@ public class NodeRepositoryNode {
", switchHostname='" + switchHostname + '\'' +
", cloudAccount='" + cloudAccount + '\'' +
", wireguardPubKey='" + wireguardPubKey + '\'' +
+ ", archiveUri='" + archiveUri + '\'' +
'}';
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
index a1534ebc533..44ea693d8cd 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java
@@ -47,6 +47,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.deployment.RevisionId;
import com.yahoo.vespa.hosted.controller.api.integration.deployment.TesterId;
import com.yahoo.vespa.hosted.controller.api.integration.noderepository.RestartFilter;
import com.yahoo.vespa.hosted.controller.api.integration.secrets.TenantSecretStore;
+import com.yahoo.vespa.hosted.controller.application.Change;
import com.yahoo.vespa.hosted.controller.application.Deployment;
import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics;
import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics.Warning;
@@ -60,6 +61,7 @@ import com.yahoo.vespa.hosted.controller.application.pkg.ApplicationPackageValid
import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade;
import com.yahoo.vespa.hosted.controller.certificate.EndpointCertificates;
import com.yahoo.vespa.hosted.controller.concurrent.Once;
+import com.yahoo.vespa.hosted.controller.deployment.DeploymentStatus;
import com.yahoo.vespa.hosted.controller.deployment.DeploymentTrigger;
import com.yahoo.vespa.hosted.controller.deployment.JobStatus;
import com.yahoo.vespa.hosted.controller.deployment.Run;
@@ -586,7 +588,16 @@ public class ApplicationController {
}
// Validate new deployment spec thoroughly before storing it.
- controller.jobController().deploymentStatus(application.get());
+ DeploymentStatus status = controller.jobController().deploymentStatus(application.get());
+ Change dummyChange = Change.of(RevisionId.forProduction(Long.MAX_VALUE)); // Should always run everywhere.
+ for (var jobs : status.jobsToRun(applicationPackage.deploymentSpec().instanceNames().stream()
+ .collect(toMap(name -> name, __ -> dummyChange)))
+ .entrySet()) {
+ for (var job : jobs.getValue()) {
+ decideCloudAccountOf(new DeploymentId(jobs.getKey().application(), job.type().zone()),
+ applicationPackage.deploymentSpec());
+ }
+ }
for (Notification notification : controller.notificationsDb().listNotifications(NotificationSource.from(application.get().id()), true)) {
if ( notification.source().instance().isPresent()
@@ -696,7 +707,7 @@ public class ApplicationController {
throw new IllegalArgumentException("Requested cloud account '" + requestedAccount.get().value() +
"' is not valid for tenant '" + tenant + "'");
}
- if (!controller.zoneRegistry().hasZone(zoneId, requestedAccount.get())) {
+ if ( ! controller.zoneRegistry().hasZone(zoneId, requestedAccount.get())) {
throw new IllegalArgumentException("Zone " + zoneId + " is not configured in requested cloud account '" +
requestedAccount.get().value() + "'");
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java
index f86073cfb25..3880b028eb0 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java
@@ -5,6 +5,7 @@ import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.UncheckedIOException;
+import java.nio.file.attribute.FileTime;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicReference;
@@ -32,6 +33,7 @@ public class ApplicationPackageStream {
private final Supplier<Predicate<String>> filter;
private final Supplier<InputStream> in;
private final AtomicReference<ApplicationPackage> truncatedPackage = new AtomicReference<>();
+ private final FileTime createdAt = FileTime.fromMillis(System.currentTimeMillis());
/** Stream that effectively copies the input stream to its {@link #truncatedPackage()} when exhausted. */
public ApplicationPackageStream(Supplier<InputStream> in) {
@@ -60,7 +62,7 @@ public class ApplicationPackageStream {
* and the first to be exhausted will populate the truncated application package.
*/
public InputStream zipStream() {
- return new Stream(in.get(), replacer.get(), filter.get(), truncatedPackage);
+ return new Stream(in.get(), replacer.get(), filter.get(), createdAt, truncatedPackage);
}
/**
@@ -85,6 +87,7 @@ public class ApplicationPackageStream {
private final ZipInputStream inZip;
private final Replacer replacer;
private final Predicate<String> filter;
+ private final FileTime createdAt;
private byte[] currentOut = new byte[0];
private InputStream currentIn = InputStream.nullInputStream();
private boolean includeCurrent = false;
@@ -92,11 +95,12 @@ public class ApplicationPackageStream {
private boolean closed = false;
private boolean done = false;
- private Stream(InputStream in, Replacer replacer, Predicate<String> filter, AtomicReference<ApplicationPackage> truncatedPackage) {
+ private Stream(InputStream in, Replacer replacer, Predicate<String> filter, FileTime createdAt, AtomicReference<ApplicationPackage> truncatedPackage) {
this.in = in;
this.inZip = new ZipInputStream(in);
this.replacer = replacer;
this.filter = filter;
+ this.createdAt = createdAt;
this.truncatedPackage = truncatedPackage;
}
@@ -129,10 +133,12 @@ public class ApplicationPackageStream {
ZipEntry next = inZip.getNextEntry();
String name;
+ FileTime modifiedAt;
InputStream content = null;
if (next == null) {
// We may still have replacements to fill in, but if we don't, we're done filling, forever!
name = replacer.next();
+ modifiedAt = createdAt;
if (name == null) {
outZip.close(); // This typically makes new output available, so must check for that after this.
teeZip.close();
@@ -144,6 +150,7 @@ public class ApplicationPackageStream {
}
else {
name = next.getName();
+ modifiedAt = next.getLastModifiedTime();
content = new FilterInputStream(inZip) { @Override public void close() { } }; // Protect inZip from replacements closing it.
}
@@ -153,8 +160,8 @@ public class ApplicationPackageStream {
currentIn = InputStream.nullInputStream();
}
else {
- if (includeCurrent) teeZip.putNextEntry(new ZipEntry(name));
- outZip.putNextEntry(new ZipEntry(name));
+ if (includeCurrent) teeZip.putNextEntry(new ZipEntry(name) {{ setLastModifiedTime(modifiedAt); }});
+ outZip.putNextEntry(new ZipEntry(name) {{ setLastModifiedTime(modifiedAt); }});
}
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
index a76e76611c2..050b77a391e 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java
@@ -390,7 +390,7 @@ public class DeploymentStatus {
}
/** The set of jobs that need to run for the given changes to be considered complete. */
- private Map<JobId, List<Job>> jobsToRun(Map<InstanceName, Change> changes) {
+ public Map<JobId, List<Job>> jobsToRun(Map<InstanceName, Change> changes) {
return jobsToRun(changes, false);
}
diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java
index 73c64be3e47..10e4052f067 100644
--- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java
+++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java
@@ -713,6 +713,7 @@ public class JobController {
// TODO(mpolden): Enable for public CD once all tests have been updated
if (controller.system() != SystemName.PublicCd) {
controller.applications().validatePackage(applicationPackage, application.get());
+ controller.applications().decideCloudAccountOf(new DeploymentId(id, type.zone()), applicationPackage.deploymentSpec());
}
controller.applications().store(application);
});
diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
index 8f0536480f5..d9ee82f5e90 100644
--- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
+++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java
@@ -1429,9 +1429,14 @@ public class ControllerTest {
// Deployment fails because zone is not configured in requested cloud account
tester.controllerTester().flagSource().withListFlag(PermanentFlags.CLOUD_ACCOUNTS.id(), List.of(cloudAccount), String.class);
- context.submit(applicationPackage)
- .runJobExpectingFailure(systemTest, "Zone test.us-east-1 is not configured in requested cloud account '012345678912'")
- .abortJob(stagingTest);
+ assertEquals("Zone test.us-east-1 is not configured in requested cloud account '012345678912'",
+ assertThrows(IllegalArgumentException.class,
+ () -> context.submit(applicationPackage))
+ .getMessage());
+ assertEquals("Zone dev.us-east-1 is not configured in requested cloud account '012345678912'",
+ assertThrows(IllegalArgumentException.class,
+ () -> context.runJob(devUsEast1, applicationPackage))
+ .getMessage());
// Deployment to prod succeeds once all zones are configured in requested account
tester.controllerTester().zoneRegistry().configureCloudAccount(CloudAccount.from(cloudAccount),
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java
index 45d8637aa3e..ccad9d6d08b 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java
@@ -55,6 +55,11 @@ public abstract class ExpressionConverter implements Cloneable {
return new CatExpression(lst);
}
+ public Expression innerConvert(ChoiceExpression exp) {
+ var convertedInnerExpressions = exp.asList().stream().map(inner -> convert(inner)).toList();
+ return new ChoiceExpression(convertedInnerExpressions);
+ }
+
public Expression innerConvert(ForEachExpression exp) {
return new ForEachExpression(convert(exp.getInnerExpression()));
}
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java
index 30c824d410d..bba1b09cda2 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java
@@ -21,7 +21,9 @@ public final class InputExpression extends Expression {
public InputExpression(String fieldName) {
super(null);
- this.fieldName = Objects.requireNonNull(fieldName);
+ if (fieldName == null)
+ throw new IllegalArgumentException("'input' must be given a field name as argument");
+ this.fieldName = fieldName;
}
public String getFieldName() {
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java
index 351c925ed56..7ece841e9b7 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java
@@ -4,7 +4,13 @@ package com.yahoo.vespa.indexinglanguage.expressions;
import com.yahoo.document.DataType;
import com.yahoo.document.Field;
import com.yahoo.document.datatypes.StringFieldValue;
+import com.yahoo.language.Linguistics;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.language.simple.SimpleLinguistics;
+import com.yahoo.vespa.indexinglanguage.ExpressionSearcher;
import com.yahoo.vespa.indexinglanguage.SimpleTestAdapter;
+import com.yahoo.vespa.indexinglanguage.parser.ParseException;
+import com.yahoo.yolean.Exceptions;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -44,4 +50,24 @@ public class ChoiceTestCase {
}
}
+ @Test
+ public void testIllegalChoiceExpression() throws ParseException {
+ try {
+ parse("input (foo || 99999999) | attribute");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("'input' must be given a field name as argument", Exceptions.toMessageString(e));
+ }
+ }
+
+ @Test
+ public void testInnerConvert() throws ParseException {
+ var expression = parse("(input foo || 99999999) | attribute");
+ new ExpressionSearcher<>(AttributeExpression.class).searchIn(expression); // trigger innerConvert
+ }
+
+ private static Expression parse(String s) throws ParseException {
+ return Expression.fromString(s, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
+ }
+
}
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 9bb60827a68..c27ed9d2c31 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -106,6 +106,11 @@
</dependency>
<dependency>
+ <groupId>org.lz4</groupId>
+ <artifactId>lz4-java</artifactId>
+ </dependency>
+
+ <dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<scope>test</scope>
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
index 7cdc27b6d63..02fa7b68dc4 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java
@@ -7,6 +7,7 @@ import ai.onnxruntime.OnnxTensor;
import ai.onnxruntime.OnnxValue;
import ai.onnxruntime.OrtException;
import ai.onnxruntime.OrtSession;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime.ModelPathOrData;
import ai.vespa.modelintegration.evaluator.OnnxRuntime.ReferencedOrtSession;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -28,7 +29,11 @@ public class OnnxEvaluator implements AutoCloseable {
private final ReferencedOrtSession session;
OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) {
- session = createSession(modelPath, runtime, options, true);
+ session = createSession(ModelPathOrData.of(modelPath), runtime, options, true);
+ }
+
+ OnnxEvaluator(byte[] data, OnnxEvaluatorOptions options, OnnxRuntime runtime) {
+ session = createSession(ModelPathOrData.of(data), runtime, options, true);
}
public Tensor evaluate(Map<String, Tensor> inputs, String output) {
@@ -125,19 +130,20 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
- private static ReferencedOrtSession createSession(String modelPath, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) {
+ private static ReferencedOrtSession createSession(
+ ModelPathOrData model, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) {
if (options == null) {
options = new OnnxEvaluatorOptions();
}
try {
- return runtime.acquireSession(modelPath, options, tryCuda && options.requestingGpu());
+ return runtime.acquireSession(model, options, tryCuda && options.requestingGpu());
} catch (OrtException e) {
if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) {
- throw new IllegalArgumentException("No such file: " + modelPath);
+ throw new IllegalArgumentException("No such file: " + model.path().get());
}
if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) {
// Failed in CUDA native code, but GPU device is optional, so we can proceed without it
- return createSession(modelPath, runtime, options, false);
+ return createSession(model, runtime, options, false);
}
if (isCudaError(e)) {
throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e);
@@ -146,6 +152,9 @@ public class OnnxEvaluator implements AutoCloseable {
}
}
+ // For unit testing
+ OrtSession ortSession() { return session.instance(); }
+
private String mapToInternalName(String outputName) throws OrtException {
var info = session.instance().getOutputInfo();
var internalNames = info.keySet();
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
index 42830041c02..ece1db55c1e 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java
@@ -10,9 +10,15 @@ import com.yahoo.component.annotation.Inject;
import com.yahoo.jdisc.ResourceReference;
import com.yahoo.jdisc.refcount.DebugReferencesWithStack;
import com.yahoo.jdisc.refcount.References;
+import net.jpountz.xxhash.XXHashFactory;
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -26,14 +32,22 @@ import static com.yahoo.yolean.Exceptions.throwUnchecked;
public class OnnxRuntime extends AbstractComponent {
// For unit testing
- @FunctionalInterface interface OrtSessionFactory {
+ interface OrtSessionFactory {
OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException;
+ OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException;
}
private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName());
private static final OrtEnvironmentResult ortEnvironment = getOrtEnvironment();
- private static final OrtSessionFactory defaultFactory = (path, opts) -> ortEnvironment().createSession(path, opts);
+ private static final OrtSessionFactory defaultFactory = new OrtSessionFactory() {
+ @Override public OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException {
+ return ortEnvironment().createSession(path, opts);
+ }
+ @Override public OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException {
+ return ortEnvironment().createSession(data, opts);
+ }
+ };
private final Object monitor = new Object();
private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>();
@@ -43,6 +57,14 @@ public class OnnxRuntime extends AbstractComponent {
OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; }
+ public OnnxEvaluator evaluatorOf(byte[] model) {
+ return new OnnxEvaluator(model, null, this);
+ }
+
+ public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) {
+ return new OnnxEvaluator(model, options, this);
+ }
+
public OnnxEvaluator evaluatorOf(String modelPath) {
return new OnnxEvaluator(modelPath, null, this);
}
@@ -105,8 +127,8 @@ public class OnnxRuntime extends AbstractComponent {
};
}
- ReferencedOrtSession acquireSession(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException {
- var sessionId = new OrtSessionId(modelPath, options, loadCuda);
+ ReferencedOrtSession acquireSession(ModelPathOrData model, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException {
+ var sessionId = new OrtSessionId(calculateModelHash(model), options, loadCuda);
synchronized (monitor) {
var sharedSession = sessions.get(sessionId);
if (sharedSession != null) {
@@ -114,8 +136,9 @@ public class OnnxRuntime extends AbstractComponent {
}
}
+ var opts = options.getOptions(loadCuda);
// Note: identical models loaded simultaneously will result in duplicate session instances
- var session = factory.create(modelPath, options.getOptions(loadCuda));
+ var session = model.path().isPresent() ? factory.create(model.path().get(), opts) : factory.create(model.data().get(), opts);
log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session)));
var sharedSession = new SharedOrtSession(sessionId, session);
@@ -125,25 +148,52 @@ public class OnnxRuntime extends AbstractComponent {
return referencedSession;
}
+ private static long calculateModelHash(ModelPathOrData model) {
+ if (model.path().isPresent()) {
+ try (var hasher = XXHashFactory.fastestInstance().newStreamingHash64(0);
+ var in = Files.newInputStream(Paths.get(model.path().get()))) {
+ byte[] buffer = new byte[8192];
+ int bytesRead;
+ while ((bytesRead = in.read(buffer)) != -1) {
+ hasher.update(buffer, 0, bytesRead);
+ }
+ return hasher.getValue();
+ } catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ } else {
+ var data = model.data().get();
+ return XXHashFactory.fastestInstance().hash64().hash(data, 0, data.length, 0);
+ }
+ }
+
int sessionsCached() { synchronized(monitor) { return sessions.size(); } }
- public static class ReferencedOrtSession implements AutoCloseable {
+ static class ReferencedOrtSession implements AutoCloseable {
private final OrtSession instance;
private final ResourceReference ref;
- public ReferencedOrtSession(OrtSession instance, ResourceReference ref) {
+ ReferencedOrtSession(OrtSession instance, ResourceReference ref) {
this.instance = instance;
this.ref = ref;
}
- public OrtSession instance() { return instance; }
+ OrtSession instance() { return instance; }
@Override public void close() { ref.close(); }
}
+ record ModelPathOrData(Optional<String> path, Optional<byte[]> data) {
+ static ModelPathOrData of(String path) { return new ModelPathOrData(Optional.of(path), Optional.empty()); }
+ static ModelPathOrData of(byte[] data) { return new ModelPathOrData(Optional.empty(), Optional.of(data)); }
+ ModelPathOrData {
+ if (path.isEmpty() == data.isEmpty()) throw new IllegalArgumentException("Either path or data must be non-empty");
+ }
+ }
+
// Assumes options are never modified after being stored in `onnxSessions`
- record OrtSessionId(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) {}
+ private record OrtSessionId(long modelHash, OnnxEvaluatorOptions options, boolean loadCuda) {}
- record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {}
+ private record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {}
private class SharedOrtSession {
private final OrtSessionId id;
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
index 5aba54de11b..5a367ef83e4 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java
@@ -5,30 +5,26 @@ package ai.vespa.modelintegration.evaluator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
-import org.junit.jupiter.api.BeforeAll;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
-import static org.junit.Assume.assumeNotNull;
+import static org.junit.Assume.assumeTrue;
/**
* @author lesters
*/
public class OnnxEvaluatorTest {
- private static OnnxRuntime runtime;
-
- @BeforeAll
- public static void beforeAll() {
- if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime();
- }
-
@Test
public void testSimpleModel() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx");
// Input types
@@ -53,7 +49,8 @@ public class OnnxEvaluatorTest {
@Test
public void testBatchDimension() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx");
// Input types
@@ -72,21 +69,23 @@ public class OnnxEvaluatorTest {
@Test
public void testMatMul() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]";
String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]";
String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]";
- assertEvaluate("simple/matmul.onnx", expected, input1, input2);
+ assertEvaluate(runtime, "simple/matmul.onnx", expected, input1, input2);
}
@Test
public void testTypes() {
- assumeNotNull(runtime);
- assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
- assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
- assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
- assertEvaluate("cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]");
- assertEvaluate("cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]");
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ assertEvaluate(runtime, "add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]");
+ assertEvaluate(runtime, "add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]");
+ assertEvaluate(runtime, "add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]");
+ assertEvaluate(runtime, "cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]");
+ assertEvaluate(runtime, "cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]");
// ONNX Runtime 1.8.0 does not support much of bfloat16 yet
// assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]");
@@ -94,7 +93,8 @@ public class OnnxEvaluatorTest {
@Test
public void testNotIdentifiers() {
- assumeNotNull(runtime);
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx");
var inputInfo = evaluator.getInputInfo();
var outputInfo = evaluator.getOutputInfo();
@@ -159,7 +159,18 @@ public class OnnxEvaluatorTest {
assertEquals(3, allResults.size());
}
- private void assertEvaluate(String model, String output, String... input) {
+ @Test
+ public void testLoadModelFromBytes() throws IOException {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ var model = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx"));
+ var evaluator = runtime.evaluatorOf(model);
+ assertEquals(3, evaluator.getInputs().size());
+ assertEquals(1, evaluator.getOutputs().size());
+ evaluator.close();
+ }
+
+ private void assertEvaluate(OnnxRuntime runtime, String model, String output, String... input) {
OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model);
Map<String, Tensor> inputs = new HashMap<>();
for (int i = 0; i < input.length; ++i) {
diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
index 81b1237e770..fdbd4fa4e5c 100644
--- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
+++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java
@@ -2,16 +2,18 @@
package ai.vespa.modelintegration.evaluator;
-import ai.onnxruntime.OrtException;
-import ai.onnxruntime.OrtSession;
import org.junit.jupiter.api.Test;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertNotSame;
import static org.junit.jupiter.api.Assertions.assertSame;
-import static org.mockito.Mockito.mock;
-import static org.mockito.Mockito.never;
-import static org.mockito.Mockito.verify;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+import static org.junit.jupiter.api.Assumptions.assumeTrue;
/**
* @author bjorncs
@@ -19,30 +21,81 @@ import static org.mockito.Mockito.verify;
class OnnxRuntimeTest {
@Test
- void reuses_sessions_while_active() throws OrtException {
- var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class));
- var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
- var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
- var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false);
- assertSame(session1.instance(), session2.instance());
- assertNotSame(session1.instance(), session3.instance());
+ void reuses_sessions_while_active() {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ OnnxRuntime runtime = new OnnxRuntime();
+ String model1 = "src/test/models/onnx/simple/simple.onnx";
+ var evaluator1 = runtime.evaluatorOf(model1);
+ var evaluator2 = runtime.evaluatorOf(model1);
+ String model2 = "src/test/models/onnx/simple/matmul.onnx";
+ var evaluator3 = runtime.evaluatorOf(model2);
+ assertSameSession(evaluator1, evaluator2);
+ assertNotSameSession(evaluator1, evaluator3);
assertEquals(2, runtime.sessionsCached());
- session1.close();
- session2.close();
+ evaluator1.close();
+ evaluator2.close();
assertEquals(1, runtime.sessionsCached());
- verify(session1.instance()).close();
- verify(session3.instance(), never()).close();
+ assertClosed(evaluator1);
+ assertNotClosed(evaluator3);
- session3.close();
+ evaluator3.close();
assertEquals(0, runtime.sessionsCached());
- verify(session3.instance()).close();
+ assertClosed(evaluator3);
- var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false);
- assertNotSame(session1.instance(), session4.instance());
+ var session4 = runtime.evaluatorOf(model1);
+ assertNotSameSession(evaluator1, session4);
assertEquals(1, runtime.sessionsCached());
session4.close();
assertEquals(0, runtime.sessionsCached());
- verify(session4.instance()).close();
+ assertClosed(session4);
+ }
+
+ @Test
+ void loads_model_from_byte_array() throws IOException {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ byte[] bytes = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx"));
+ var evaluator1 = runtime.evaluatorOf(bytes);
+ var evaluator2 = runtime.evaluatorOf(bytes);
+ assertEquals(3, evaluator1.getInputs().size());
+ assertEquals(1, runtime.sessionsCached());
+ assertSameSession(evaluator1, evaluator2);
+ evaluator2.close();
+ evaluator1.close();
+ assertEquals(0, runtime.sessionsCached());
+ assertClosed(evaluator1);
+ }
+
+ @Test
+ void loading_same_model_from_bytes_and_file_resolve_to_same_instance() throws IOException {
+ assumeTrue(OnnxRuntime.isRuntimeAvailable());
+ var runtime = new OnnxRuntime();
+ String modelPath = "src/test/models/onnx/simple/simple.onnx";
+ byte[] bytes = Files.readAllBytes(Paths.get(modelPath));
+ try (var evaluator1 = runtime.evaluatorOf(bytes);
+ var evaluator2 = runtime.evaluatorOf(modelPath)) {
+ assertSameSession(evaluator1, evaluator2);
+ assertEquals(1, runtime.sessionsCached());
+ }
+ }
+
+ private static void assertClosed(OnnxEvaluator evaluator) { assertTrue(isClosed(evaluator), "Session is not closed"); }
+ private static void assertNotClosed(OnnxEvaluator evaluator) { assertFalse(isClosed(evaluator), "Session is closed"); }
+ private static void assertSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) {
+ assertSame(evaluator1.ortSession(), evaluator2.ortSession());
+ }
+ private static void assertNotSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) {
+ assertNotSame(evaluator1.ortSession(), evaluator2.ortSession());
+ }
+
+ private static boolean isClosed(OnnxEvaluator evaluator) {
+ try {
+ evaluator.getInputs();
+ return false;
+ } catch (IllegalStateException e) {
+ assertEquals("Asking for inputs from a closed OrtSession.", e.getMessage());
+ return true;
+ }
}
} \ No newline at end of file
diff --git a/searchcore/src/tests/proton/docsummary/docsummary_test.cpp b/searchcore/src/tests/proton/docsummary/docsummary_test.cpp
index 020f4ff42c1..70f81d629d5 100644
--- a/searchcore/src/tests/proton/docsummary/docsummary_test.cpp
+++ b/searchcore/src/tests/proton/docsummary/docsummary_test.cpp
@@ -107,7 +107,10 @@ namespace proton {
class MockDocsumFieldWriterFactory : public search::docsummary::IDocsumFieldWriterFactory
{
public:
- std::unique_ptr<DocsumFieldWriter> create_docsum_field_writer(const vespalib::string&, const vespalib::string&, const vespalib::string&) override {
+ std::unique_ptr<DocsumFieldWriter> create_docsum_field_writer(const vespalib::string&,
+ const vespalib::string&,
+ const vespalib::string&,
+ std::shared_ptr<search::MatchingElementsFields>) override {
return {};
}
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp
index d6f06c9161e..6ac4fea2921 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp
@@ -23,8 +23,7 @@ namespace search::docsummary {
DocsumFieldWriterFactory::DocsumFieldWriterFactory(bool use_v8_geo_positions, const IDocsumEnvironment& env, const IQueryTermFilterFactory& query_term_filter_factory)
: _use_v8_geo_positions(use_v8_geo_positions),
_env(env),
- _query_term_filter_factory(query_term_filter_factory),
- _matching_elems_fields(std::make_shared<MatchingElementsFields>())
+ _query_term_filter_factory(query_term_filter_factory)
{
}
@@ -58,7 +57,8 @@ throw_missing_source(const vespalib::string& command)
std::unique_ptr<DocsumFieldWriter>
DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& field_name,
const vespalib::string& command,
- const vespalib::string& source)
+ const vespalib::string& source,
+ std::shared_ptr<MatchingElementsFields> matching_elems_fields)
{
std::unique_ptr<DocsumFieldWriter> fieldWriter;
if (command == command::dynamic_teaser) {
@@ -116,9 +116,9 @@ DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& fie
if (has_attribute_manager()) {
auto attr_ctx = getEnvironment().getAttributeManager()->createContext();
if (attr_ctx->getAttribute(source_field) != nullptr) {
- fieldWriter = AttributeDFWFactory::create(*getEnvironment().getAttributeManager(), source_field, true, _matching_elems_fields);
+ fieldWriter = AttributeDFWFactory::create(*getEnvironment().getAttributeManager(), source_field, true, matching_elems_fields);
} else {
- fieldWriter = AttributeCombinerDFW::create(source_field, *attr_ctx, true, _matching_elems_fields);
+ fieldWriter = AttributeCombinerDFW::create(source_field, *attr_ctx, true, matching_elems_fields);
}
throw_if_nullptr(fieldWriter, command);
}
@@ -126,7 +126,7 @@ DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& fie
const vespalib::string& source_field = source.empty() ? field_name : source;
if (has_attribute_manager()) {
auto attr_ctx = getEnvironment().getAttributeManager()->createContext();
- fieldWriter = MatchedElementsFilterDFW::create(source_field,*attr_ctx, _matching_elems_fields);
+ fieldWriter = MatchedElementsFilterDFW::create(source_field,*attr_ctx, matching_elems_fields);
throw_if_nullptr(fieldWriter, command);
}
} else if (command == command::documentid) {
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h
index e50fb85cca6..7175f043701 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h
+++ b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h
@@ -20,7 +20,6 @@ class DocsumFieldWriterFactory : public IDocsumFieldWriterFactory
const IDocsumEnvironment& _env;
const IQueryTermFilterFactory& _query_term_filter_factory;
protected:
- std::shared_ptr<MatchingElementsFields> _matching_elems_fields;
const IDocsumEnvironment& getEnvironment() const noexcept { return _env; }
bool has_attribute_manager() const noexcept;
public:
@@ -28,7 +27,8 @@ public:
~DocsumFieldWriterFactory() override;
std::unique_ptr<DocsumFieldWriter> create_docsum_field_writer(const vespalib::string& field_name,
const vespalib::string& command,
- const vespalib::string& source) override;
+ const vespalib::string& source,
+ std::shared_ptr<MatchingElementsFields> matching_elems_fields) override;
};
}
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h b/searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h
index 6a5cd691857..bc2ebe3c40c 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h
+++ b/searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h
@@ -5,6 +5,8 @@
#include <memory>
#include <vespa/vespalib/stllike/string.h>
+namespace search { class MatchingElementsFields; }
+
namespace search::docsummary {
class DocsumFieldWriter;
@@ -21,7 +23,8 @@ public:
*/
virtual std::unique_ptr<DocsumFieldWriter> create_docsum_field_writer(const vespalib::string& field_name,
const vespalib::string& command,
- const vespalib::string& source) = 0;
+ const vespalib::string& source,
+ std::shared_ptr<MatchingElementsFields> matching_elems_fields) = 0;
};
}
diff --git a/searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp b/searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp
index f620dcb1df5..eddb67f5822 100644
--- a/searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp
+++ b/searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp
@@ -5,6 +5,7 @@
#include "docsum_field_writer_factory.h"
#include "resultclass.h"
#include <vespa/config-summary.h>
+#include <vespa/searchlib/common/matching_elements_fields.h>
#include <vespa/vespalib/stllike/hash_map.hpp>
#include <vespa/vespalib/util/exceptions.h>
#include <atomic>
@@ -124,6 +125,7 @@ ResultConfig::readConfig(const SummaryConfig &cfg, const char *configId, IDocsum
break;
}
resClass->set_omit_summary_features(cfg_class.omitsummaryfeatures);
+ auto matching_elems_fields = std::make_shared<MatchingElementsFields>();
for (const auto & field : cfg_class.fields) {
const char *fieldname = field.name.c_str();
vespalib::string command = field.command;
@@ -134,7 +136,8 @@ ResultConfig::readConfig(const SummaryConfig &cfg, const char *configId, IDocsum
try {
docsum_field_writer = docsum_field_writer_factory.create_docsum_field_writer(fieldname,
command,
- source_name);
+ source_name,
+ matching_elems_fields);
} catch (const vespalib::IllegalArgumentException& ex) {
LOG(error, "Exception during setup of summary result class '%s': field='%s', command='%s', source='%s': %s",
cfg_class.name.c_str(), fieldname, command.c_str(), source_name.c_str(), ex.getMessage().c_str());
diff --git a/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp b/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp
index 36873b713aa..b103d7d85b2 100644
--- a/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp
+++ b/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp
@@ -48,7 +48,8 @@ DocsumFieldWriterFactory::~DocsumFieldWriterFactory() = default;
std::unique_ptr<DocsumFieldWriter>
DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& field_name,
const vespalib::string& command,
- const vespalib::string& source)
+ const vespalib::string& source,
+ std::shared_ptr<MatchingElementsFields> matching_elems_fields)
{
std::unique_ptr<DocsumFieldWriter> fieldWriter;
using namespace search::docsummary;
@@ -65,10 +66,10 @@ DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& fie
} else if ((command == command::matched_attribute_elements_filter) ||
(command == command::matched_elements_filter)) {
vespalib::string source_field = source.empty() ? field_name : source;
- populate_fields(*_matching_elems_fields, _vsm_fields_config, source_field);
- fieldWriter = MatchedElementsFilterDFW::create(source_field, _matching_elems_fields);
+ populate_fields(*matching_elems_fields, _vsm_fields_config, source_field);
+ fieldWriter = MatchedElementsFilterDFW::create(source_field, matching_elems_fields);
} else {
- return search::docsummary::DocsumFieldWriterFactory::create_docsum_field_writer(field_name, command, source);
+ return search::docsummary::DocsumFieldWriterFactory::create_docsum_field_writer(field_name, command, source, matching_elems_fields);
}
return fieldWriter;
}
diff --git a/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h b/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h
index 078c466d3d2..ac5cae8c49d 100644
--- a/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h
+++ b/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h
@@ -21,7 +21,8 @@ public:
std::unique_ptr<search::docsummary::DocsumFieldWriter>
create_docsum_field_writer(const vespalib::string& field_name,
const vespalib::string& command,
- const vespalib::string& source) override;
+ const vespalib::string& source,
+ std::shared_ptr<search::MatchingElementsFields> matching_elems_fields) override;
};
}
diff --git a/vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java b/vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java
index fda456ace05..14b44ede12a 100644
--- a/vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java
+++ b/vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java
@@ -2,6 +2,7 @@
package com.yahoo.vdslib;
public class VisitorStatistics {
+
int bucketsVisited = 0;
long documentsVisited = 0;
long bytesVisited = 0;
@@ -20,8 +21,8 @@ public class VisitorStatistics {
public void setBucketsVisited(int bucketsVisited) { this.bucketsVisited = bucketsVisited; }
/**
- * @return the number of documents matching the document selection in the backend and that
- * has been passed to the client-specified visitor instance (dumpvisitor, searchvisitor etc).
+ * Returns the number of documents matching the document selection in the backend that
+ * has been passed to the client-specified visitor instance (dumpvisitor, searchvisitor etc).
*/
public long getDocumentsVisited() { return documentsVisited; }
public void setDocumentsVisited(long documentsVisited) { this.documentsVisited = documentsVisited; }
@@ -30,9 +31,9 @@ public class VisitorStatistics {
public void setBytesVisited(long bytesVisited) { this.bytesVisited = bytesVisited; }
/**
- * @return Number of documents returned to the visitor client by the backend. This number may
- * be lower than that returned by getDocumentsVisited() since the client-specified visitor
- * instance may further have filtered the set of documents returned by the backend.
+ * Returns the number of documents returned to the visitor client by the backend. This number may
+ * be lower than that returned by getDocumentsVisited() since the client-specified visitor
+ * instance may further have filtered the set of documents returned by the backend.
*/
public long getDocumentsReturned() { return documentsReturned; }
public void setDocumentsReturned(long documentsReturned) { this.documentsReturned = documentsReturned; }
@@ -40,15 +41,14 @@ public class VisitorStatistics {
public long getBytesReturned() { return bytesReturned; }
public void setBytesReturned(long bytesReturned) { this.bytesReturned = bytesReturned; }
+ @Override
public String toString() {
- String out =
+ return
"Buckets visited: " + bucketsVisited + "\n" +
"Documents visited: " + documentsVisited + "\n" +
"Bytes visited: " + bytesVisited + "\n" +
"Documents returned: " + documentsReturned + "\n" +
"Bytes returned: " + bytesReturned + "\n";
-
- return out;
}
}
diff --git a/vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java b/vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java
index 9195b5ab858..cf9a36f2aa8 100644
--- a/vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java
+++ b/vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java
@@ -35,7 +35,7 @@ class ApacheClusterTest {
final WireMockExtension server = new WireMockExtension();
@Test
- void testClient() throws IOException, ExecutionException, InterruptedException, TimeoutException {
+ void testClient() throws Exception {
for (Compression compression : Compression.values()) {
try (ApacheCluster cluster = new ApacheCluster(new FeedClientBuilderImpl(List.of(URI.create("http://localhost:" + server.port())))
.setCompression(compression))) {
@@ -48,25 +48,28 @@ class ApacheClusterTest {
Map.of("name1", () -> "value1",
"name2", () -> "value2"),
"content".getBytes(UTF_8),
- Duration.ofSeconds(20)),
+ Duration.ofSeconds(10)),
vessel);
- HttpResponse response = vessel.get(15, TimeUnit.SECONDS);
- assertEquals("{}", new String(response.body(), UTF_8));
- assertEquals(200, response.code());
- ByteArrayOutputStream buffer = new ByteArrayOutputStream();
- try (OutputStream zip = new GZIPOutputStream(buffer)) { zip.write("content".getBytes(UTF_8)); }
- server.verify(1, anyRequestedFor(anyUrl()));
- RequestPatternBuilder expected = postRequestedFor(urlEqualTo("/path")).withHeader("name1", equalTo("value1"))
- .withHeader("name2", equalTo("value2"))
- .withHeader("Content-Type", equalTo("application/json; charset=UTF-8"))
- .withRequestBody(equalTo("content"));
- expected = switch (compression) {
- case auto, none -> expected.withoutHeader("Content-Encoding");
- case gzip -> expected.withHeader("Content-Encoding", equalTo("gzip"));
+ AutoCloseable verifyResponse = () -> {
+ HttpResponse response = vessel.get(15, TimeUnit.SECONDS);
+ assertEquals("{}", new String(response.body(), UTF_8));
+ assertEquals(200, response.code());
};
- server.verify(1, expected);
- server.resetRequests();
+ AutoCloseable verifyServer = () -> {
+ server.verify(1, anyRequestedFor(anyUrl()));
+ RequestPatternBuilder expected = postRequestedFor(urlEqualTo("/path")).withHeader("name1", equalTo("value1"))
+ .withHeader("name2", equalTo("value2"))
+ .withHeader("Content-Type", equalTo("application/json; charset=UTF-8"))
+ .withRequestBody(equalTo("content"));
+ expected = switch (compression) {
+ case auto, none -> expected.withoutHeader("Content-Encoding");
+ case gzip -> expected.withHeader("Content-Encoding", equalTo("gzip"));
+ };
+ server.verify(1, expected);
+ server.resetRequests();
+ };
+ try (verifyServer; verifyResponse) { }
}
}
}