diff options
29 files changed, 381 insertions, 126 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 272b668b5fb..90a27d1f036 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,15 +1,17 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.ml.OnnxModelInfo; -import com.yahoo.searchlib.rankingexpression.Reference; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,6 +23,7 @@ public class OnnxModel extends DistributableResource { private OnnxModelInfo modelInfo = null; private final Map<String, String> inputMap = new HashMap<>(); private final Map<String, String> outputMap = new HashMap<>(); + private final Set<String> initializers = new HashSet<>(); private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; @@ -101,11 +104,13 @@ public class OnnxModel extends DistributableResource { for (String onnxName : modelInfo.getOutputs()) { addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); } + initializers.addAll(modelInfo.getInitializers()); this.modelInfo = modelInfo; } public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + public Set<String> getInitializers() { return Set.copyOf(initializers); } public String getDefaultOutput() { return modelInfo != null ? modelInfo.getDefaultOutput() : ""; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index af072c5b59a..7f578f07fe3 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -2,7 +2,6 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; -import com.yahoo.schema.RankProfile; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -12,13 +11,12 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.functions.DynamicTensor; import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Slice; import java.io.StringReader; import java.util.HashSet; import java.util.Set; +import java.util.logging.Logger; /** * Analyzes expression to figure out what inputs it needs @@ -27,6 +25,8 @@ import java.util.Set; */ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { + private static final Logger log = Logger.getLogger(InputRecorder.class.getName()); + private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); @@ -120,7 +120,11 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } - for (String onnxInput : model.getInputMap().values()) { + model.getInputMap().forEach((onnxName, onnxInput) -> { + if (model.getInitializers().contains(onnxName)) { + log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName)); + return; + } var reader = new StringReader(onnxInput); try { var asExpression = new RankingExpression(reader); @@ -128,7 +132,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } catch (ParseException e) { throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage()); } - } + }); return; } neededInputs.add(feature.toString()); diff --git a/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java b/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java index 5509d11885c..b414d3757e2 100644 --- a/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java +++ b/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java @@ -20,7 +20,7 @@ import java.util.List; * Class converting a collection of schemas from the intermediate format. * * @author arnej27959 - **/ + */ public class ConvertSchemaCollection { private final IntermediateCollection input; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java index 0cc52edf3cc..b1eace947cc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java @@ -2,11 +2,15 @@ package com.yahoo.vespa.model.application.validation.change; import com.yahoo.config.model.api.ConfigChangeAction; +import com.yahoo.config.model.api.ServiceInfo; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.vespa.model.VespaModel; import com.yahoo.config.application.api.ValidationId; +import com.yahoo.vespa.model.container.ApplicationContainer; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import com.yahoo.vespa.model.content.cluster.ContentCluster; +import java.util.ArrayList; import java.util.List; /** @@ -19,16 +23,23 @@ public class ContentClusterRemovalValidator implements ChangeValidator { @Override public List<ConfigChangeAction> validate(VespaModel current, VespaModel next, DeployState deployState) { + List<ConfigChangeAction> actions = new ArrayList<>(); for (String currentClusterId : current.getContentClusters().keySet()) { ContentCluster nextCluster = next.getContentClusters().get(currentClusterId); - if (nextCluster == null) + if (nextCluster == null) { deployState.validationOverrides().invalid(ValidationId.contentClusterRemoval, - "Content cluster '" + currentClusterId + "' is removed. " + - "This will cause loss of all data in this cluster", - deployState.now()); - } + "Content cluster '" + currentClusterId + "' is removed. " + + "This will cause loss of all data in this cluster", + deployState.now()); - return List.of(); + // If we allow the removal, we must restart all containers to ensure mbus is OK. + for (ApplicationContainerCluster cluster : next.getContainerClusters().values()) { + actions.add(new VespaRestartAction(cluster.id(), + "Content cluster '" + currentClusterId + "' has been removed", + cluster.getContainers().stream().map(ApplicationContainer::getServiceInfo).toList())); + } + } + } + return actions; } - } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 2742dc59fcd..7c89a349d7d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -42,13 +42,16 @@ public class OnnxModelInfo { private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); + private final Set<String> initializers; - private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, + Map<String, OnnxTypeInfo> outputs, Set<String> initializers, String defaultOutput) { this.app = app; this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); this.defaultOutput = defaultOutput; + this.initializers = Set.copyOf(initializers); } public String getModelPath() { @@ -63,6 +66,8 @@ public class OnnxModelInfo { return outputs.keySet(); } + public Set<String> getInitializers() { return initializers; } + public String getDefaultOutput() { return defaultOutput; } @@ -208,6 +213,14 @@ public class OnnxModelInfo { } g.writeEndArray(); + g.writeArrayFieldStart("initializers"); + for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) { + g.writeStartObject(); + g.writeStringField("name", initializers.getName()); + g.writeEndObject(); + } + g.writeEndArray(); + g.writeEndObject(); g.close(); return out.toString(); @@ -218,6 +231,7 @@ public class OnnxModelInfo { JsonNode root = m.readTree(json); Map<String, OnnxTypeInfo> inputs = new HashMap<>(); Map<String, OnnxTypeInfo> outputs = new HashMap<>(); + Set<String> initializers = new HashSet<>(); String defaultOutput = ""; String path = null; @@ -233,7 +247,13 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); + var initializerRoot = root.get("initializers"); + if (initializerRoot != null) { + for (JsonNode initializer : initializerRoot) { + initializers.add(initializer.get("name").textValue()); + } + } + return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java index 5c360a9343f..65dfce8ff6c 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java @@ -1,14 +1,18 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation.change; +import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ValidationId; import com.yahoo.config.application.api.ValidationOverrides; +import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.provision.Environment; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.application.validation.ValidationTester; import com.yahoo.yolean.Exceptions; import org.junit.jupiter.api.Test; +import java.util.List; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; @@ -36,7 +40,12 @@ public class ContentClusterRemovalValidatorTest { @Test void testOverridingContentRemovalValidation() { VespaModel previous = tester.deploy(null, getServices("contentClusterId"), Environment.prod, null).getFirst(); - tester.deploy(previous, getServices("newContentClusterId"), Environment.prod, removalOverride); // Allowed due to override + var result = tester.deploy(previous, getServices("newContentClusterId"), Environment.prod, removalOverride); // Allowed due to override + assertEquals(result.getFirst().getContainerClusters().values().stream() + .flatMap(cluster -> cluster.getContainers().stream()) + .map(container -> container.getServiceInfo()) + .toList(), + result.getSecond().stream().flatMap(action -> action.getServices().stream()).toList()); } private static String getServices(String contentClusterId) { diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java index 945e0730fe6..34a7c1a6f7c 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/noderepository/NodeRepositoryNode.java @@ -113,6 +113,8 @@ public class NodeRepositoryNode { private String cloudAccount; @JsonProperty("wireguardPubKey") private String wireguardPubKey; + @JsonProperty("archiveUri") + private String archiveUri; public String getUrl() { return url; @@ -460,6 +462,10 @@ public class NodeRepositoryNode { public void setWireguardPubKey(String wireguardPubKey) { this.wireguardPubKey = wireguardPubKey; } + public String getArchiveUri() { return archiveUri; } + + public void setArchiveUri(String archiveUri) { this.archiveUri = archiveUri; } + // --- Helper methods for code that (wrongly) consume this directly public boolean hasType(NodeType type) { @@ -521,6 +527,7 @@ public class NodeRepositoryNode { ", switchHostname='" + switchHostname + '\'' + ", cloudAccount='" + cloudAccount + '\'' + ", wireguardPubKey='" + wireguardPubKey + '\'' + + ", archiveUri='" + archiveUri + '\'' + '}'; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index a1534ebc533..44ea693d8cd 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -47,6 +47,7 @@ import com.yahoo.vespa.hosted.controller.api.integration.deployment.RevisionId; import com.yahoo.vespa.hosted.controller.api.integration.deployment.TesterId; import com.yahoo.vespa.hosted.controller.api.integration.noderepository.RestartFilter; import com.yahoo.vespa.hosted.controller.api.integration.secrets.TenantSecretStore; +import com.yahoo.vespa.hosted.controller.application.Change; import com.yahoo.vespa.hosted.controller.application.Deployment; import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics; import com.yahoo.vespa.hosted.controller.application.DeploymentMetrics.Warning; @@ -60,6 +61,7 @@ import com.yahoo.vespa.hosted.controller.application.pkg.ApplicationPackageValid import com.yahoo.vespa.hosted.controller.athenz.impl.AthenzFacade; import com.yahoo.vespa.hosted.controller.certificate.EndpointCertificates; import com.yahoo.vespa.hosted.controller.concurrent.Once; +import com.yahoo.vespa.hosted.controller.deployment.DeploymentStatus; import com.yahoo.vespa.hosted.controller.deployment.DeploymentTrigger; import com.yahoo.vespa.hosted.controller.deployment.JobStatus; import com.yahoo.vespa.hosted.controller.deployment.Run; @@ -586,7 +588,16 @@ public class ApplicationController { } // Validate new deployment spec thoroughly before storing it. - controller.jobController().deploymentStatus(application.get()); + DeploymentStatus status = controller.jobController().deploymentStatus(application.get()); + Change dummyChange = Change.of(RevisionId.forProduction(Long.MAX_VALUE)); // Should always run everywhere. + for (var jobs : status.jobsToRun(applicationPackage.deploymentSpec().instanceNames().stream() + .collect(toMap(name -> name, __ -> dummyChange))) + .entrySet()) { + for (var job : jobs.getValue()) { + decideCloudAccountOf(new DeploymentId(jobs.getKey().application(), job.type().zone()), + applicationPackage.deploymentSpec()); + } + } for (Notification notification : controller.notificationsDb().listNotifications(NotificationSource.from(application.get().id()), true)) { if ( notification.source().instance().isPresent() @@ -696,7 +707,7 @@ public class ApplicationController { throw new IllegalArgumentException("Requested cloud account '" + requestedAccount.get().value() + "' is not valid for tenant '" + tenant + "'"); } - if (!controller.zoneRegistry().hasZone(zoneId, requestedAccount.get())) { + if ( ! controller.zoneRegistry().hasZone(zoneId, requestedAccount.get())) { throw new IllegalArgumentException("Zone " + zoneId + " is not configured in requested cloud account '" + requestedAccount.get().value() + "'"); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java index f86073cfb25..3880b028eb0 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/ApplicationPackageStream.java @@ -5,6 +5,7 @@ import java.io.FilterInputStream; import java.io.IOException; import java.io.InputStream; import java.io.UncheckedIOException; +import java.nio.file.attribute.FileTime; import java.util.HashMap; import java.util.Map; import java.util.concurrent.atomic.AtomicReference; @@ -32,6 +33,7 @@ public class ApplicationPackageStream { private final Supplier<Predicate<String>> filter; private final Supplier<InputStream> in; private final AtomicReference<ApplicationPackage> truncatedPackage = new AtomicReference<>(); + private final FileTime createdAt = FileTime.fromMillis(System.currentTimeMillis()); /** Stream that effectively copies the input stream to its {@link #truncatedPackage()} when exhausted. */ public ApplicationPackageStream(Supplier<InputStream> in) { @@ -60,7 +62,7 @@ public class ApplicationPackageStream { * and the first to be exhausted will populate the truncated application package. */ public InputStream zipStream() { - return new Stream(in.get(), replacer.get(), filter.get(), truncatedPackage); + return new Stream(in.get(), replacer.get(), filter.get(), createdAt, truncatedPackage); } /** @@ -85,6 +87,7 @@ public class ApplicationPackageStream { private final ZipInputStream inZip; private final Replacer replacer; private final Predicate<String> filter; + private final FileTime createdAt; private byte[] currentOut = new byte[0]; private InputStream currentIn = InputStream.nullInputStream(); private boolean includeCurrent = false; @@ -92,11 +95,12 @@ public class ApplicationPackageStream { private boolean closed = false; private boolean done = false; - private Stream(InputStream in, Replacer replacer, Predicate<String> filter, AtomicReference<ApplicationPackage> truncatedPackage) { + private Stream(InputStream in, Replacer replacer, Predicate<String> filter, FileTime createdAt, AtomicReference<ApplicationPackage> truncatedPackage) { this.in = in; this.inZip = new ZipInputStream(in); this.replacer = replacer; this.filter = filter; + this.createdAt = createdAt; this.truncatedPackage = truncatedPackage; } @@ -129,10 +133,12 @@ public class ApplicationPackageStream { ZipEntry next = inZip.getNextEntry(); String name; + FileTime modifiedAt; InputStream content = null; if (next == null) { // We may still have replacements to fill in, but if we don't, we're done filling, forever! name = replacer.next(); + modifiedAt = createdAt; if (name == null) { outZip.close(); // This typically makes new output available, so must check for that after this. teeZip.close(); @@ -144,6 +150,7 @@ public class ApplicationPackageStream { } else { name = next.getName(); + modifiedAt = next.getLastModifiedTime(); content = new FilterInputStream(inZip) { @Override public void close() { } }; // Protect inZip from replacements closing it. } @@ -153,8 +160,8 @@ public class ApplicationPackageStream { currentIn = InputStream.nullInputStream(); } else { - if (includeCurrent) teeZip.putNextEntry(new ZipEntry(name)); - outZip.putNextEntry(new ZipEntry(name)); + if (includeCurrent) teeZip.putNextEntry(new ZipEntry(name) {{ setLastModifiedTime(modifiedAt); }}); + outZip.putNextEntry(new ZipEntry(name) {{ setLastModifiedTime(modifiedAt); }}); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java index a76e76611c2..050b77a391e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentStatus.java @@ -390,7 +390,7 @@ public class DeploymentStatus { } /** The set of jobs that need to run for the given changes to be considered complete. */ - private Map<JobId, List<Job>> jobsToRun(Map<InstanceName, Change> changes) { + public Map<JobId, List<Job>> jobsToRun(Map<InstanceName, Change> changes) { return jobsToRun(changes, false); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java index 73c64be3e47..10e4052f067 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java @@ -713,6 +713,7 @@ public class JobController { // TODO(mpolden): Enable for public CD once all tests have been updated if (controller.system() != SystemName.PublicCd) { controller.applications().validatePackage(applicationPackage, application.get()); + controller.applications().decideCloudAccountOf(new DeploymentId(id, type.zone()), applicationPackage.deploymentSpec()); } controller.applications().store(application); }); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java index 8f0536480f5..d9ee82f5e90 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java @@ -1429,9 +1429,14 @@ public class ControllerTest { // Deployment fails because zone is not configured in requested cloud account tester.controllerTester().flagSource().withListFlag(PermanentFlags.CLOUD_ACCOUNTS.id(), List.of(cloudAccount), String.class); - context.submit(applicationPackage) - .runJobExpectingFailure(systemTest, "Zone test.us-east-1 is not configured in requested cloud account '012345678912'") - .abortJob(stagingTest); + assertEquals("Zone test.us-east-1 is not configured in requested cloud account '012345678912'", + assertThrows(IllegalArgumentException.class, + () -> context.submit(applicationPackage)) + .getMessage()); + assertEquals("Zone dev.us-east-1 is not configured in requested cloud account '012345678912'", + assertThrows(IllegalArgumentException.class, + () -> context.runJob(devUsEast1, applicationPackage)) + .getMessage()); // Deployment to prod succeeds once all zones are configured in requested account tester.controllerTester().zoneRegistry().configureCloudAccount(CloudAccount.from(cloudAccount), diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java index 45d8637aa3e..ccad9d6d08b 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/ExpressionConverter.java @@ -55,6 +55,11 @@ public abstract class ExpressionConverter implements Cloneable { return new CatExpression(lst); } + public Expression innerConvert(ChoiceExpression exp) { + var convertedInnerExpressions = exp.asList().stream().map(inner -> convert(inner)).toList(); + return new ChoiceExpression(convertedInnerExpressions); + } + public Expression innerConvert(ForEachExpression exp) { return new ForEachExpression(convert(exp.getInnerExpression())); } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java index 30c824d410d..bba1b09cda2 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/InputExpression.java @@ -21,7 +21,9 @@ public final class InputExpression extends Expression { public InputExpression(String fieldName) { super(null); - this.fieldName = Objects.requireNonNull(fieldName); + if (fieldName == null) + throw new IllegalArgumentException("'input' must be given a field name as argument"); + this.fieldName = fieldName; } public String getFieldName() { diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java index 351c925ed56..7ece841e9b7 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ChoiceTestCase.java @@ -4,7 +4,13 @@ package com.yahoo.vespa.indexinglanguage.expressions; import com.yahoo.document.DataType; import com.yahoo.document.Field; import com.yahoo.document.datatypes.StringFieldValue; +import com.yahoo.language.Linguistics; +import com.yahoo.language.process.Embedder; +import com.yahoo.language.simple.SimpleLinguistics; +import com.yahoo.vespa.indexinglanguage.ExpressionSearcher; import com.yahoo.vespa.indexinglanguage.SimpleTestAdapter; +import com.yahoo.vespa.indexinglanguage.parser.ParseException; +import com.yahoo.yolean.Exceptions; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -44,4 +50,24 @@ public class ChoiceTestCase { } } + @Test + public void testIllegalChoiceExpression() throws ParseException { + try { + parse("input (foo || 99999999) | attribute"); + } + catch (IllegalArgumentException e) { + assertEquals("'input' must be given a field name as argument", Exceptions.toMessageString(e)); + } + } + + @Test + public void testInnerConvert() throws ParseException { + var expression = parse("(input foo || 99999999) | attribute"); + new ExpressionSearcher<>(AttributeExpression.class).searchIn(expression); // trigger innerConvert + } + + private static Expression parse(String s) throws ParseException { + return Expression.fromString(s, new SimpleLinguistics(), Embedder.throwsOnUse.asMap()); + } + } diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 9bb60827a68..c27ed9d2c31 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -106,6 +106,11 @@ </dependency> <dependency> + <groupId>org.lz4</groupId> + <artifactId>lz4-java</artifactId> + </dependency> + + <dependency> <groupId>junit</groupId> <artifactId>junit</artifactId> <scope>test</scope> diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java index 7cdc27b6d63..02fa7b68dc4 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxEvaluator.java @@ -7,6 +7,7 @@ import ai.onnxruntime.OnnxTensor; import ai.onnxruntime.OnnxValue; import ai.onnxruntime.OrtException; import ai.onnxruntime.OrtSession; +import ai.vespa.modelintegration.evaluator.OnnxRuntime.ModelPathOrData; import ai.vespa.modelintegration.evaluator.OnnxRuntime.ReferencedOrtSession; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -28,7 +29,11 @@ public class OnnxEvaluator implements AutoCloseable { private final ReferencedOrtSession session; OnnxEvaluator(String modelPath, OnnxEvaluatorOptions options, OnnxRuntime runtime) { - session = createSession(modelPath, runtime, options, true); + session = createSession(ModelPathOrData.of(modelPath), runtime, options, true); + } + + OnnxEvaluator(byte[] data, OnnxEvaluatorOptions options, OnnxRuntime runtime) { + session = createSession(ModelPathOrData.of(data), runtime, options, true); } public Tensor evaluate(Map<String, Tensor> inputs, String output) { @@ -125,19 +130,20 @@ public class OnnxEvaluator implements AutoCloseable { } } - private static ReferencedOrtSession createSession(String modelPath, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) { + private static ReferencedOrtSession createSession( + ModelPathOrData model, OnnxRuntime runtime, OnnxEvaluatorOptions options, boolean tryCuda) { if (options == null) { options = new OnnxEvaluatorOptions(); } try { - return runtime.acquireSession(modelPath, options, tryCuda && options.requestingGpu()); + return runtime.acquireSession(model, options, tryCuda && options.requestingGpu()); } catch (OrtException e) { if (e.getCode() == OrtException.OrtErrorCode.ORT_NO_SUCHFILE) { - throw new IllegalArgumentException("No such file: " + modelPath); + throw new IllegalArgumentException("No such file: " + model.path().get()); } if (tryCuda && isCudaError(e) && !options.gpuDeviceRequired()) { // Failed in CUDA native code, but GPU device is optional, so we can proceed without it - return createSession(modelPath, runtime, options, false); + return createSession(model, runtime, options, false); } if (isCudaError(e)) { throw new IllegalArgumentException("GPU device is required, but CUDA initialization failed", e); @@ -146,6 +152,9 @@ public class OnnxEvaluator implements AutoCloseable { } } + // For unit testing + OrtSession ortSession() { return session.instance(); } + private String mapToInternalName(String outputName) throws OrtException { var info = session.instance().getOutputInfo(); var internalNames = info.keySet(); diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java index 42830041c02..ece1db55c1e 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/OnnxRuntime.java @@ -10,9 +10,15 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.jdisc.ResourceReference; import com.yahoo.jdisc.refcount.DebugReferencesWithStack; import com.yahoo.jdisc.refcount.References; +import net.jpountz.xxhash.XXHashFactory; +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; +import java.util.Optional; import java.util.logging.Level; import java.util.logging.Logger; @@ -26,14 +32,22 @@ import static com.yahoo.yolean.Exceptions.throwUnchecked; public class OnnxRuntime extends AbstractComponent { // For unit testing - @FunctionalInterface interface OrtSessionFactory { + interface OrtSessionFactory { OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException; + OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException; } private static final Logger log = Logger.getLogger(OnnxRuntime.class.getName()); private static final OrtEnvironmentResult ortEnvironment = getOrtEnvironment(); - private static final OrtSessionFactory defaultFactory = (path, opts) -> ortEnvironment().createSession(path, opts); + private static final OrtSessionFactory defaultFactory = new OrtSessionFactory() { + @Override public OrtSession create(String path, OrtSession.SessionOptions opts) throws OrtException { + return ortEnvironment().createSession(path, opts); + } + @Override public OrtSession create(byte[] data, OrtSession.SessionOptions opts) throws OrtException { + return ortEnvironment().createSession(data, opts); + } + }; private final Object monitor = new Object(); private final Map<OrtSessionId, SharedOrtSession> sessions = new HashMap<>(); @@ -43,6 +57,14 @@ public class OnnxRuntime extends AbstractComponent { OnnxRuntime(OrtSessionFactory factory) { this.factory = factory; } + public OnnxEvaluator evaluatorOf(byte[] model) { + return new OnnxEvaluator(model, null, this); + } + + public OnnxEvaluator evaluatorOf(byte[] model, OnnxEvaluatorOptions options) { + return new OnnxEvaluator(model, options, this); + } + public OnnxEvaluator evaluatorOf(String modelPath) { return new OnnxEvaluator(modelPath, null, this); } @@ -105,8 +127,8 @@ public class OnnxRuntime extends AbstractComponent { }; } - ReferencedOrtSession acquireSession(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException { - var sessionId = new OrtSessionId(modelPath, options, loadCuda); + ReferencedOrtSession acquireSession(ModelPathOrData model, OnnxEvaluatorOptions options, boolean loadCuda) throws OrtException { + var sessionId = new OrtSessionId(calculateModelHash(model), options, loadCuda); synchronized (monitor) { var sharedSession = sessions.get(sessionId); if (sharedSession != null) { @@ -114,8 +136,9 @@ public class OnnxRuntime extends AbstractComponent { } } + var opts = options.getOptions(loadCuda); // Note: identical models loaded simultaneously will result in duplicate session instances - var session = factory.create(modelPath, options.getOptions(loadCuda)); + var session = model.path().isPresent() ? factory.create(model.path().get(), opts) : factory.create(model.data().get(), opts); log.fine(() -> "Created new session (%s)".formatted(System.identityHashCode(session))); var sharedSession = new SharedOrtSession(sessionId, session); @@ -125,25 +148,52 @@ public class OnnxRuntime extends AbstractComponent { return referencedSession; } + private static long calculateModelHash(ModelPathOrData model) { + if (model.path().isPresent()) { + try (var hasher = XXHashFactory.fastestInstance().newStreamingHash64(0); + var in = Files.newInputStream(Paths.get(model.path().get()))) { + byte[] buffer = new byte[8192]; + int bytesRead; + while ((bytesRead = in.read(buffer)) != -1) { + hasher.update(buffer, 0, bytesRead); + } + return hasher.getValue(); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } else { + var data = model.data().get(); + return XXHashFactory.fastestInstance().hash64().hash(data, 0, data.length, 0); + } + } + int sessionsCached() { synchronized(monitor) { return sessions.size(); } } - public static class ReferencedOrtSession implements AutoCloseable { + static class ReferencedOrtSession implements AutoCloseable { private final OrtSession instance; private final ResourceReference ref; - public ReferencedOrtSession(OrtSession instance, ResourceReference ref) { + ReferencedOrtSession(OrtSession instance, ResourceReference ref) { this.instance = instance; this.ref = ref; } - public OrtSession instance() { return instance; } + OrtSession instance() { return instance; } @Override public void close() { ref.close(); } } + record ModelPathOrData(Optional<String> path, Optional<byte[]> data) { + static ModelPathOrData of(String path) { return new ModelPathOrData(Optional.of(path), Optional.empty()); } + static ModelPathOrData of(byte[] data) { return new ModelPathOrData(Optional.empty(), Optional.of(data)); } + ModelPathOrData { + if (path.isEmpty() == data.isEmpty()) throw new IllegalArgumentException("Either path or data must be non-empty"); + } + } + // Assumes options are never modified after being stored in `onnxSessions` - record OrtSessionId(String modelPath, OnnxEvaluatorOptions options, boolean loadCuda) {} + private record OrtSessionId(long modelHash, OnnxEvaluatorOptions options, boolean loadCuda) {} - record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {} + private record OrtEnvironmentResult(OrtEnvironment env, Throwable failure) {} private class SharedOrtSession { private final OrtSessionId id; diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java index 5aba54de11b..5a367ef83e4 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxEvaluatorTest.java @@ -5,30 +5,26 @@ package ai.vespa.modelintegration.evaluator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; -import org.junit.jupiter.api.BeforeAll; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; import java.util.HashMap; import java.util.Map; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeNotNull; +import static org.junit.Assume.assumeTrue; /** * @author lesters */ public class OnnxEvaluatorTest { - private static OnnxRuntime runtime; - - @BeforeAll - public static void beforeAll() { - if (OnnxRuntime.isRuntimeAvailable()) runtime = new OnnxRuntime(); - } - @Test public void testSimpleModel() { - assumeNotNull(runtime); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/simple/simple.onnx"); // Input types @@ -53,7 +49,8 @@ public class OnnxEvaluatorTest { @Test public void testBatchDimension() { - assumeNotNull(runtime); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/pytorch/one_layer.onnx"); // Input types @@ -72,21 +69,23 @@ public class OnnxEvaluatorTest { @Test public void testMatMul() { - assumeNotNull(runtime); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); String expected = "tensor<float>(d0[2],d1[4]):[38,44,50,56,83,98,113,128]"; String input1 = "tensor<float>(d0[2],d1[3]):[1,2,3,4,5,6]"; String input2 = "tensor<float>(d0[3],d1[4]):[1,2,3,4,5,6,7,8,9,10,11,12]"; - assertEvaluate("simple/matmul.onnx", expected, input1, input2); + assertEvaluate(runtime, "simple/matmul.onnx", expected, input1, input2); } @Test public void testTypes() { - assumeNotNull(runtime); - assertEvaluate("add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); - assertEvaluate("add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); - assertEvaluate("add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); - assertEvaluate("cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]"); - assertEvaluate("cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]"); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + assertEvaluate(runtime, "add_double.onnx", "tensor(d0[1]):[3]", "tensor(d0[1]):[1]", "tensor(d0[1]):[2]"); + assertEvaluate(runtime, "add_float.onnx", "tensor<float>(d0[1]):[3]", "tensor<float>(d0[1]):[1]", "tensor<float>(d0[1]):[2]"); + assertEvaluate(runtime, "add_int64.onnx", "tensor<double>(d0[1]):[3]", "tensor<double>(d0[1]):[1]", "tensor<double>(d0[1]):[2]"); + assertEvaluate(runtime, "cast_int8_float.onnx", "tensor<float>(d0[1]):[-128]", "tensor<int8>(d0[1]):[128]"); + assertEvaluate(runtime, "cast_float_int8.onnx", "tensor<int8>(d0[1]):[-1]", "tensor<float>(d0[1]):[255]"); // ONNX Runtime 1.8.0 does not support much of bfloat16 yet // assertEvaluate("cast_bfloat16_float.onnx", "tensor<float>(d0[1]):[1]", "tensor<bfloat16>(d0[1]):[1]"); @@ -94,7 +93,8 @@ public class OnnxEvaluatorTest { @Test public void testNotIdentifiers() { - assumeNotNull(runtime); + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/badnames.onnx"); var inputInfo = evaluator.getInputInfo(); var outputInfo = evaluator.getOutputInfo(); @@ -159,7 +159,18 @@ public class OnnxEvaluatorTest { assertEquals(3, allResults.size()); } - private void assertEvaluate(String model, String output, String... input) { + @Test + public void testLoadModelFromBytes() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + var model = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx")); + var evaluator = runtime.evaluatorOf(model); + assertEquals(3, evaluator.getInputs().size()); + assertEquals(1, evaluator.getOutputs().size()); + evaluator.close(); + } + + private void assertEvaluate(OnnxRuntime runtime, String model, String output, String... input) { OnnxEvaluator evaluator = runtime.evaluatorOf("src/test/models/onnx/" + model); Map<String, Tensor> inputs = new HashMap<>(); for (int i = 0; i < input.length; ++i) { diff --git a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java index 81b1237e770..fdbd4fa4e5c 100644 --- a/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java +++ b/model-integration/src/test/java/ai/vespa/modelintegration/evaluator/OnnxRuntimeTest.java @@ -2,16 +2,18 @@ package ai.vespa.modelintegration.evaluator; -import ai.onnxruntime.OrtException; -import ai.onnxruntime.OrtSession; import org.junit.jupiter.api.Test; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; + import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertNotSame; import static org.junit.jupiter.api.Assertions.assertSame; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; /** * @author bjorncs @@ -19,30 +21,81 @@ import static org.mockito.Mockito.verify; class OnnxRuntimeTest { @Test - void reuses_sessions_while_active() throws OrtException { - var runtime = new OnnxRuntime((__, ___) -> mock(OrtSession.class)); - var session1 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - var session2 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - var session3 = runtime.acquireSession("model2", new OnnxEvaluatorOptions(), false); - assertSame(session1.instance(), session2.instance()); - assertNotSame(session1.instance(), session3.instance()); + void reuses_sessions_while_active() { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + OnnxRuntime runtime = new OnnxRuntime(); + String model1 = "src/test/models/onnx/simple/simple.onnx"; + var evaluator1 = runtime.evaluatorOf(model1); + var evaluator2 = runtime.evaluatorOf(model1); + String model2 = "src/test/models/onnx/simple/matmul.onnx"; + var evaluator3 = runtime.evaluatorOf(model2); + assertSameSession(evaluator1, evaluator2); + assertNotSameSession(evaluator1, evaluator3); assertEquals(2, runtime.sessionsCached()); - session1.close(); - session2.close(); + evaluator1.close(); + evaluator2.close(); assertEquals(1, runtime.sessionsCached()); - verify(session1.instance()).close(); - verify(session3.instance(), never()).close(); + assertClosed(evaluator1); + assertNotClosed(evaluator3); - session3.close(); + evaluator3.close(); assertEquals(0, runtime.sessionsCached()); - verify(session3.instance()).close(); + assertClosed(evaluator3); - var session4 = runtime.acquireSession("model1", new OnnxEvaluatorOptions(), false); - assertNotSame(session1.instance(), session4.instance()); + var session4 = runtime.evaluatorOf(model1); + assertNotSameSession(evaluator1, session4); assertEquals(1, runtime.sessionsCached()); session4.close(); assertEquals(0, runtime.sessionsCached()); - verify(session4.instance()).close(); + assertClosed(session4); + } + + @Test + void loads_model_from_byte_array() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + byte[] bytes = Files.readAllBytes(Paths.get("src/test/models/onnx/simple/simple.onnx")); + var evaluator1 = runtime.evaluatorOf(bytes); + var evaluator2 = runtime.evaluatorOf(bytes); + assertEquals(3, evaluator1.getInputs().size()); + assertEquals(1, runtime.sessionsCached()); + assertSameSession(evaluator1, evaluator2); + evaluator2.close(); + evaluator1.close(); + assertEquals(0, runtime.sessionsCached()); + assertClosed(evaluator1); + } + + @Test + void loading_same_model_from_bytes_and_file_resolve_to_same_instance() throws IOException { + assumeTrue(OnnxRuntime.isRuntimeAvailable()); + var runtime = new OnnxRuntime(); + String modelPath = "src/test/models/onnx/simple/simple.onnx"; + byte[] bytes = Files.readAllBytes(Paths.get(modelPath)); + try (var evaluator1 = runtime.evaluatorOf(bytes); + var evaluator2 = runtime.evaluatorOf(modelPath)) { + assertSameSession(evaluator1, evaluator2); + assertEquals(1, runtime.sessionsCached()); + } + } + + private static void assertClosed(OnnxEvaluator evaluator) { assertTrue(isClosed(evaluator), "Session is not closed"); } + private static void assertNotClosed(OnnxEvaluator evaluator) { assertFalse(isClosed(evaluator), "Session is closed"); } + private static void assertSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) { + assertSame(evaluator1.ortSession(), evaluator2.ortSession()); + } + private static void assertNotSameSession(OnnxEvaluator evaluator1, OnnxEvaluator evaluator2) { + assertNotSame(evaluator1.ortSession(), evaluator2.ortSession()); + } + + private static boolean isClosed(OnnxEvaluator evaluator) { + try { + evaluator.getInputs(); + return false; + } catch (IllegalStateException e) { + assertEquals("Asking for inputs from a closed OrtSession.", e.getMessage()); + return true; + } } }
\ No newline at end of file diff --git a/searchcore/src/tests/proton/docsummary/docsummary_test.cpp b/searchcore/src/tests/proton/docsummary/docsummary_test.cpp index 020f4ff42c1..70f81d629d5 100644 --- a/searchcore/src/tests/proton/docsummary/docsummary_test.cpp +++ b/searchcore/src/tests/proton/docsummary/docsummary_test.cpp @@ -107,7 +107,10 @@ namespace proton { class MockDocsumFieldWriterFactory : public search::docsummary::IDocsumFieldWriterFactory { public: - std::unique_ptr<DocsumFieldWriter> create_docsum_field_writer(const vespalib::string&, const vespalib::string&, const vespalib::string&) override { + std::unique_ptr<DocsumFieldWriter> create_docsum_field_writer(const vespalib::string&, + const vespalib::string&, + const vespalib::string&, + std::shared_ptr<search::MatchingElementsFields>) override { return {}; } diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp index d6f06c9161e..6ac4fea2921 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.cpp @@ -23,8 +23,7 @@ namespace search::docsummary { DocsumFieldWriterFactory::DocsumFieldWriterFactory(bool use_v8_geo_positions, const IDocsumEnvironment& env, const IQueryTermFilterFactory& query_term_filter_factory) : _use_v8_geo_positions(use_v8_geo_positions), _env(env), - _query_term_filter_factory(query_term_filter_factory), - _matching_elems_fields(std::make_shared<MatchingElementsFields>()) + _query_term_filter_factory(query_term_filter_factory) { } @@ -58,7 +57,8 @@ throw_missing_source(const vespalib::string& command) std::unique_ptr<DocsumFieldWriter> DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& field_name, const vespalib::string& command, - const vespalib::string& source) + const vespalib::string& source, + std::shared_ptr<MatchingElementsFields> matching_elems_fields) { std::unique_ptr<DocsumFieldWriter> fieldWriter; if (command == command::dynamic_teaser) { @@ -116,9 +116,9 @@ DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& fie if (has_attribute_manager()) { auto attr_ctx = getEnvironment().getAttributeManager()->createContext(); if (attr_ctx->getAttribute(source_field) != nullptr) { - fieldWriter = AttributeDFWFactory::create(*getEnvironment().getAttributeManager(), source_field, true, _matching_elems_fields); + fieldWriter = AttributeDFWFactory::create(*getEnvironment().getAttributeManager(), source_field, true, matching_elems_fields); } else { - fieldWriter = AttributeCombinerDFW::create(source_field, *attr_ctx, true, _matching_elems_fields); + fieldWriter = AttributeCombinerDFW::create(source_field, *attr_ctx, true, matching_elems_fields); } throw_if_nullptr(fieldWriter, command); } @@ -126,7 +126,7 @@ DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& fie const vespalib::string& source_field = source.empty() ? field_name : source; if (has_attribute_manager()) { auto attr_ctx = getEnvironment().getAttributeManager()->createContext(); - fieldWriter = MatchedElementsFilterDFW::create(source_field,*attr_ctx, _matching_elems_fields); + fieldWriter = MatchedElementsFilterDFW::create(source_field,*attr_ctx, matching_elems_fields); throw_if_nullptr(fieldWriter, command); } } else if (command == command::documentid) { diff --git a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h index e50fb85cca6..7175f043701 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h +++ b/searchsummary/src/vespa/searchsummary/docsummary/docsum_field_writer_factory.h @@ -20,7 +20,6 @@ class DocsumFieldWriterFactory : public IDocsumFieldWriterFactory const IDocsumEnvironment& _env; const IQueryTermFilterFactory& _query_term_filter_factory; protected: - std::shared_ptr<MatchingElementsFields> _matching_elems_fields; const IDocsumEnvironment& getEnvironment() const noexcept { return _env; } bool has_attribute_manager() const noexcept; public: @@ -28,7 +27,8 @@ public: ~DocsumFieldWriterFactory() override; std::unique_ptr<DocsumFieldWriter> create_docsum_field_writer(const vespalib::string& field_name, const vespalib::string& command, - const vespalib::string& source) override; + const vespalib::string& source, + std::shared_ptr<MatchingElementsFields> matching_elems_fields) override; }; } diff --git a/searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h b/searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h index 6a5cd691857..bc2ebe3c40c 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h +++ b/searchsummary/src/vespa/searchsummary/docsummary/i_docsum_field_writer_factory.h @@ -5,6 +5,8 @@ #include <memory> #include <vespa/vespalib/stllike/string.h> +namespace search { class MatchingElementsFields; } + namespace search::docsummary { class DocsumFieldWriter; @@ -21,7 +23,8 @@ public: */ virtual std::unique_ptr<DocsumFieldWriter> create_docsum_field_writer(const vespalib::string& field_name, const vespalib::string& command, - const vespalib::string& source) = 0; + const vespalib::string& source, + std::shared_ptr<MatchingElementsFields> matching_elems_fields) = 0; }; } diff --git a/searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp b/searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp index f620dcb1df5..eddb67f5822 100644 --- a/searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp +++ b/searchsummary/src/vespa/searchsummary/docsummary/resultconfig.cpp @@ -5,6 +5,7 @@ #include "docsum_field_writer_factory.h" #include "resultclass.h" #include <vespa/config-summary.h> +#include <vespa/searchlib/common/matching_elements_fields.h> #include <vespa/vespalib/stllike/hash_map.hpp> #include <vespa/vespalib/util/exceptions.h> #include <atomic> @@ -124,6 +125,7 @@ ResultConfig::readConfig(const SummaryConfig &cfg, const char *configId, IDocsum break; } resClass->set_omit_summary_features(cfg_class.omitsummaryfeatures); + auto matching_elems_fields = std::make_shared<MatchingElementsFields>(); for (const auto & field : cfg_class.fields) { const char *fieldname = field.name.c_str(); vespalib::string command = field.command; @@ -134,7 +136,8 @@ ResultConfig::readConfig(const SummaryConfig &cfg, const char *configId, IDocsum try { docsum_field_writer = docsum_field_writer_factory.create_docsum_field_writer(fieldname, command, - source_name); + source_name, + matching_elems_fields); } catch (const vespalib::IllegalArgumentException& ex) { LOG(error, "Exception during setup of summary result class '%s': field='%s', command='%s', source='%s': %s", cfg_class.name.c_str(), fieldname, command.c_str(), source_name.c_str(), ex.getMessage().c_str()); diff --git a/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp b/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp index 36873b713aa..b103d7d85b2 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp +++ b/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.cpp @@ -48,7 +48,8 @@ DocsumFieldWriterFactory::~DocsumFieldWriterFactory() = default; std::unique_ptr<DocsumFieldWriter> DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& field_name, const vespalib::string& command, - const vespalib::string& source) + const vespalib::string& source, + std::shared_ptr<MatchingElementsFields> matching_elems_fields) { std::unique_ptr<DocsumFieldWriter> fieldWriter; using namespace search::docsummary; @@ -65,10 +66,10 @@ DocsumFieldWriterFactory::create_docsum_field_writer(const vespalib::string& fie } else if ((command == command::matched_attribute_elements_filter) || (command == command::matched_elements_filter)) { vespalib::string source_field = source.empty() ? field_name : source; - populate_fields(*_matching_elems_fields, _vsm_fields_config, source_field); - fieldWriter = MatchedElementsFilterDFW::create(source_field, _matching_elems_fields); + populate_fields(*matching_elems_fields, _vsm_fields_config, source_field); + fieldWriter = MatchedElementsFilterDFW::create(source_field, matching_elems_fields); } else { - return search::docsummary::DocsumFieldWriterFactory::create_docsum_field_writer(field_name, command, source); + return search::docsummary::DocsumFieldWriterFactory::create_docsum_field_writer(field_name, command, source, matching_elems_fields); } return fieldWriter; } diff --git a/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h b/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h index 078c466d3d2..ac5cae8c49d 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h +++ b/streamingvisitors/src/vespa/vsm/vsm/docsum_field_writer_factory.h @@ -21,7 +21,8 @@ public: std::unique_ptr<search::docsummary::DocsumFieldWriter> create_docsum_field_writer(const vespalib::string& field_name, const vespalib::string& command, - const vespalib::string& source) override; + const vespalib::string& source, + std::shared_ptr<search::MatchingElementsFields> matching_elems_fields) override; }; } diff --git a/vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java b/vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java index fda456ace05..14b44ede12a 100644 --- a/vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java +++ b/vdslib/src/main/java/com/yahoo/vdslib/VisitorStatistics.java @@ -2,6 +2,7 @@ package com.yahoo.vdslib; public class VisitorStatistics { + int bucketsVisited = 0; long documentsVisited = 0; long bytesVisited = 0; @@ -20,8 +21,8 @@ public class VisitorStatistics { public void setBucketsVisited(int bucketsVisited) { this.bucketsVisited = bucketsVisited; } /** - * @return the number of documents matching the document selection in the backend and that - * has been passed to the client-specified visitor instance (dumpvisitor, searchvisitor etc). + * Returns the number of documents matching the document selection in the backend that + * has been passed to the client-specified visitor instance (dumpvisitor, searchvisitor etc). */ public long getDocumentsVisited() { return documentsVisited; } public void setDocumentsVisited(long documentsVisited) { this.documentsVisited = documentsVisited; } @@ -30,9 +31,9 @@ public class VisitorStatistics { public void setBytesVisited(long bytesVisited) { this.bytesVisited = bytesVisited; } /** - * @return Number of documents returned to the visitor client by the backend. This number may - * be lower than that returned by getDocumentsVisited() since the client-specified visitor - * instance may further have filtered the set of documents returned by the backend. + * Returns the number of documents returned to the visitor client by the backend. This number may + * be lower than that returned by getDocumentsVisited() since the client-specified visitor + * instance may further have filtered the set of documents returned by the backend. */ public long getDocumentsReturned() { return documentsReturned; } public void setDocumentsReturned(long documentsReturned) { this.documentsReturned = documentsReturned; } @@ -40,15 +41,14 @@ public class VisitorStatistics { public long getBytesReturned() { return bytesReturned; } public void setBytesReturned(long bytesReturned) { this.bytesReturned = bytesReturned; } + @Override public String toString() { - String out = + return "Buckets visited: " + bucketsVisited + "\n" + "Documents visited: " + documentsVisited + "\n" + "Bytes visited: " + bytesVisited + "\n" + "Documents returned: " + documentsReturned + "\n" + "Bytes returned: " + bytesReturned + "\n"; - - return out; } } diff --git a/vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java b/vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java index 9195b5ab858..cf9a36f2aa8 100644 --- a/vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java +++ b/vespa-feed-client/src/test/java/ai/vespa/feed/client/impl/ApacheClusterTest.java @@ -35,7 +35,7 @@ class ApacheClusterTest { final WireMockExtension server = new WireMockExtension(); @Test - void testClient() throws IOException, ExecutionException, InterruptedException, TimeoutException { + void testClient() throws Exception { for (Compression compression : Compression.values()) { try (ApacheCluster cluster = new ApacheCluster(new FeedClientBuilderImpl(List.of(URI.create("http://localhost:" + server.port()))) .setCompression(compression))) { @@ -48,25 +48,28 @@ class ApacheClusterTest { Map.of("name1", () -> "value1", "name2", () -> "value2"), "content".getBytes(UTF_8), - Duration.ofSeconds(20)), + Duration.ofSeconds(10)), vessel); - HttpResponse response = vessel.get(15, TimeUnit.SECONDS); - assertEquals("{}", new String(response.body(), UTF_8)); - assertEquals(200, response.code()); - ByteArrayOutputStream buffer = new ByteArrayOutputStream(); - try (OutputStream zip = new GZIPOutputStream(buffer)) { zip.write("content".getBytes(UTF_8)); } - server.verify(1, anyRequestedFor(anyUrl())); - RequestPatternBuilder expected = postRequestedFor(urlEqualTo("/path")).withHeader("name1", equalTo("value1")) - .withHeader("name2", equalTo("value2")) - .withHeader("Content-Type", equalTo("application/json; charset=UTF-8")) - .withRequestBody(equalTo("content")); - expected = switch (compression) { - case auto, none -> expected.withoutHeader("Content-Encoding"); - case gzip -> expected.withHeader("Content-Encoding", equalTo("gzip")); + AutoCloseable verifyResponse = () -> { + HttpResponse response = vessel.get(15, TimeUnit.SECONDS); + assertEquals("{}", new String(response.body(), UTF_8)); + assertEquals(200, response.code()); }; - server.verify(1, expected); - server.resetRequests(); + AutoCloseable verifyServer = () -> { + server.verify(1, anyRequestedFor(anyUrl())); + RequestPatternBuilder expected = postRequestedFor(urlEqualTo("/path")).withHeader("name1", equalTo("value1")) + .withHeader("name2", equalTo("value2")) + .withHeader("Content-Type", equalTo("application/json; charset=UTF-8")) + .withRequestBody(equalTo("content")); + expected = switch (compression) { + case auto, none -> expected.withoutHeader("Content-Encoding"); + case gzip -> expected.withHeader("Content-Encoding", equalTo("gzip")); + }; + server.verify(1, expected); + server.resetRequests(); + }; + try (verifyServer; verifyResponse) { } } } } |