diff options
Diffstat (limited to 'config-model')
6 files changed, 66 insertions, 17 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java index 272b668b5fb..90a27d1f036 100644 --- a/config-model/src/main/java/com/yahoo/schema/OnnxModel.java +++ b/config-model/src/main/java/com/yahoo/schema/OnnxModel.java @@ -1,15 +1,17 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.schema; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.model.ml.OnnxModelInfo; -import com.yahoo.searchlib.rankingexpression.Reference; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.Map; import java.util.Objects; import java.util.Optional; +import java.util.Set; /** * A global ONNX model distributed using file distribution, similar to ranking constants. @@ -21,6 +23,7 @@ public class OnnxModel extends DistributableResource { private OnnxModelInfo modelInfo = null; private final Map<String, String> inputMap = new HashMap<>(); private final Map<String, String> outputMap = new HashMap<>(); + private final Set<String> initializers = new HashSet<>(); private String statelessExecutionMode = null; private Integer statelessInterOpThreads = null; @@ -101,11 +104,13 @@ public class OnnxModel extends DistributableResource { for (String onnxName : modelInfo.getOutputs()) { addOutputNameMapping(onnxName, OnnxModelInfo.asValidIdentifier(onnxName), false); } + initializers.addAll(modelInfo.getInitializers()); this.modelInfo = modelInfo; } public Map<String, String> getInputMap() { return Collections.unmodifiableMap(inputMap); } public Map<String, String> getOutputMap() { return Collections.unmodifiableMap(outputMap); } + public Set<String> getInitializers() { return Set.copyOf(initializers); } public String getDefaultOutput() { return modelInfo != null ? modelInfo.getDefaultOutput() : ""; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java index af072c5b59a..7f578f07fe3 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/InputRecorder.java @@ -2,7 +2,6 @@ package com.yahoo.schema.expressiontransforms; import com.yahoo.schema.FeatureNames; -import com.yahoo.schema.RankProfile; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.parser.ParseException; @@ -12,13 +11,12 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.functions.DynamicTensor; import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Slice; import java.io.StringReader; import java.util.HashSet; import java.util.Set; +import java.util.logging.Logger; /** * Analyzes expression to figure out what inputs it needs @@ -27,6 +25,8 @@ import java.util.Set; */ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { + private static final Logger log = Logger.getLogger(InputRecorder.class.getName()); + private final Set<String> neededInputs; private final Set<String> handled = new HashSet<>(); @@ -120,7 +120,11 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { if (model == null) { throw new IllegalArgumentException("missing onnx model: " + arg); } - for (String onnxInput : model.getInputMap().values()) { + model.getInputMap().forEach((onnxName, onnxInput) -> { + if (model.getInitializers().contains(onnxName)) { + log.fine(() -> "For input '%s': skipping name '%s' as it's an initializer".formatted(onnxInput, onnxName)); + return; + } var reader = new StringReader(onnxInput); try { var asExpression = new RankingExpression(reader); @@ -128,7 +132,7 @@ public class InputRecorder extends ExpressionTransformer<InputRecorderContext> { } catch (ParseException e) { throw new IllegalArgumentException("illegal onnx input '" + onnxInput + "': " + e.getMessage()); } - } + }); return; } neededInputs.add(feature.toString()); diff --git a/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java b/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java index 5509d11885c..b414d3757e2 100644 --- a/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java +++ b/config-model/src/main/java/com/yahoo/schema/parser/ConvertSchemaCollection.java @@ -20,7 +20,7 @@ import java.util.List; * Class converting a collection of schemas from the intermediate format. * * @author arnej27959 - **/ + */ public class ConvertSchemaCollection { private final IntermediateCollection input; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java index 0cc52edf3cc..b1eace947cc 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidator.java @@ -2,11 +2,15 @@ package com.yahoo.vespa.model.application.validation.change; import com.yahoo.config.model.api.ConfigChangeAction; +import com.yahoo.config.model.api.ServiceInfo; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.vespa.model.VespaModel; import com.yahoo.config.application.api.ValidationId; +import com.yahoo.vespa.model.container.ApplicationContainer; +import com.yahoo.vespa.model.container.ApplicationContainerCluster; import com.yahoo.vespa.model.content.cluster.ContentCluster; +import java.util.ArrayList; import java.util.List; /** @@ -19,16 +23,23 @@ public class ContentClusterRemovalValidator implements ChangeValidator { @Override public List<ConfigChangeAction> validate(VespaModel current, VespaModel next, DeployState deployState) { + List<ConfigChangeAction> actions = new ArrayList<>(); for (String currentClusterId : current.getContentClusters().keySet()) { ContentCluster nextCluster = next.getContentClusters().get(currentClusterId); - if (nextCluster == null) + if (nextCluster == null) { deployState.validationOverrides().invalid(ValidationId.contentClusterRemoval, - "Content cluster '" + currentClusterId + "' is removed. " + - "This will cause loss of all data in this cluster", - deployState.now()); - } + "Content cluster '" + currentClusterId + "' is removed. " + + "This will cause loss of all data in this cluster", + deployState.now()); - return List.of(); + // If we allow the removal, we must restart all containers to ensure mbus is OK. + for (ApplicationContainerCluster cluster : next.getContainerClusters().values()) { + actions.add(new VespaRestartAction(cluster.id(), + "Content cluster '" + currentClusterId + "' has been removed", + cluster.getContainers().stream().map(ApplicationContainer::getServiceInfo).toList())); + } + } + } + return actions; } - } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java index 2742dc59fcd..7c89a349d7d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelInfo.java @@ -42,13 +42,16 @@ public class OnnxModelInfo { private final Map<String, OnnxTypeInfo> inputs; private final Map<String, OnnxTypeInfo> outputs; private final Map<String, TensorType> vespaTypes = new HashMap<>(); + private final Set<String> initializers; - private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, Map<String, OnnxTypeInfo> outputs, String defaultOutput) { + private OnnxModelInfo(ApplicationPackage app, String path, Map<String, OnnxTypeInfo> inputs, + Map<String, OnnxTypeInfo> outputs, Set<String> initializers, String defaultOutput) { this.app = app; this.modelPath = path; this.inputs = Collections.unmodifiableMap(inputs); this.outputs = Collections.unmodifiableMap(outputs); this.defaultOutput = defaultOutput; + this.initializers = Set.copyOf(initializers); } public String getModelPath() { @@ -63,6 +66,8 @@ public class OnnxModelInfo { return outputs.keySet(); } + public Set<String> getInitializers() { return initializers; } + public String getDefaultOutput() { return defaultOutput; } @@ -208,6 +213,14 @@ public class OnnxModelInfo { } g.writeEndArray(); + g.writeArrayFieldStart("initializers"); + for (Onnx.TensorProto initializers : model.getGraph().getInitializerList()) { + g.writeStartObject(); + g.writeStringField("name", initializers.getName()); + g.writeEndObject(); + } + g.writeEndArray(); + g.writeEndObject(); g.close(); return out.toString(); @@ -218,6 +231,7 @@ public class OnnxModelInfo { JsonNode root = m.readTree(json); Map<String, OnnxTypeInfo> inputs = new HashMap<>(); Map<String, OnnxTypeInfo> outputs = new HashMap<>(); + Set<String> initializers = new HashSet<>(); String defaultOutput = ""; String path = null; @@ -233,7 +247,13 @@ public class OnnxModelInfo { if (root.get("outputs").has(0)) { defaultOutput = root.get("outputs").get(0).get("name").textValue(); } - return new OnnxModelInfo(app, path, inputs, outputs, defaultOutput); + var initializerRoot = root.get("initializers"); + if (initializerRoot != null) { + for (JsonNode initializer : initializerRoot) { + initializers.add(initializer.get("name").textValue()); + } + } + return new OnnxModelInfo(app, path, inputs, outputs, initializers, defaultOutput); } static private void onnxTypeToJson(JsonGenerator g, Onnx.ValueInfoProto valueInfo) throws IOException { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java index 5c360a9343f..65dfce8ff6c 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/ContentClusterRemovalValidatorTest.java @@ -1,14 +1,18 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.application.validation.change; +import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ValidationId; import com.yahoo.config.application.api.ValidationOverrides; +import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.provision.Environment; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.application.validation.ValidationTester; import com.yahoo.yolean.Exceptions; import org.junit.jupiter.api.Test; +import java.util.List; + import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; @@ -36,7 +40,12 @@ public class ContentClusterRemovalValidatorTest { @Test void testOverridingContentRemovalValidation() { VespaModel previous = tester.deploy(null, getServices("contentClusterId"), Environment.prod, null).getFirst(); - tester.deploy(previous, getServices("newContentClusterId"), Environment.prod, removalOverride); // Allowed due to override + var result = tester.deploy(previous, getServices("newContentClusterId"), Environment.prod, removalOverride); // Allowed due to override + assertEquals(result.getFirst().getContainerClusters().values().stream() + .flatMap(cluster -> cluster.getContainers().stream()) + .map(container -> container.getServiceInfo()) + .toList(), + result.getSecond().stream().flatMap(action -> action.getServices().stream()).toList()); } private static String getServices(String contentClusterId) { |