summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2021-06-23 10:00:18 +0200
committerHenning Baldersheim <balder@yahoo-inc.com>2021-06-23 10:00:18 +0200
commit8d1451224710fa7a65e9a7410113c64c104a8dd0 (patch)
treea6f12ad954bfa900f64f5f3c0f4b43a0aa2bcb70
parented7bc94c5c184b1dc735d8db4f95894102cd46f2 (diff)
Add feature falg for controlling onnx dryrun verification.
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java1
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java117
-rw-r--r--configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java3
-rw-r--r--flags/src/main/java/com/yahoo/vespa/flags/Flags.java7
6 files changed, 85 insertions, 53 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
index 81ee0a4c4c3..badb6e01885 100644
--- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
+++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java
@@ -86,6 +86,7 @@ public interface ModelContext {
@ModelFeatureFlag(owners = {"baldersheim"}) default boolean distributeExternalRankExpressions() { return false; }
@ModelFeatureFlag(owners = {"baldersheim"}) default int maxConcurrentMergesPerNode() { throw new UnsupportedOperationException("TODO specify default value"); }
@ModelFeatureFlag(owners = {"baldersheim"}) default int maxMergeQueueSize() { throw new UnsupportedOperationException("TODO specify default value"); }
+ @ModelFeatureFlag(owners = {"baldersheim"}) default boolean dryRunOnnxOnSetup() { throw new UnsupportedOperationException("TODO specify default value"); }
@ModelFeatureFlag(owners = {"geirst"}) default boolean enableFeedBlockInDistributor() { return true; }
@ModelFeatureFlag(owners = {"hmusum"}, removeAfter = "7.406") default int clusterControllerMaxHeapSizeInMb() { return 128; }
@ModelFeatureFlag(owners = {"hmusum"}, removeAfter = "7.422") default int metricsProxyMaxHeapSizeInMb(ClusterSpec.Type type) { return 256; }
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java
index fe1bf93f32b..c62dc6e4631 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java
@@ -63,6 +63,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea
private int maxMergeQueueSize = 1024;
private int largeRankExpressionLimit = 0x10000;
private boolean allowDisableMtls = true;
+ private boolean dryRunOnnxOnSetup = false;
private List<X509Certificate> operatorCertificates = Collections.emptyList();
@Override public ModelContext.FeatureFlags featureFlags() { return this; }
@@ -107,7 +108,12 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea
@Override public int largeRankExpressionLimit() { return largeRankExpressionLimit; }
@Override public int maxConcurrentMergesPerNode() { return maxConcurrentMergesPerNode; }
@Override public int maxMergeQueueSize() { return maxMergeQueueSize; }
+ @Override public boolean dryRunOnnxOnSetup() { return dryRunOnnxOnSetup; }
+ public TestProperties setDryRunOnnxOnSetup(boolean value) {
+ dryRunOnnxOnSetup = value;
+ return this;
+ }
public TestProperties useExternalRankExpression(boolean value) {
useExternalRankExpression = value;
return this;
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
index 7c533cce006..1796a4cba17 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java
@@ -35,6 +35,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
private final RankingConstants rankingConstants;
private final LargeRankExpressions largeRankExpressions;
private final OnnxModels onnxModels;
+ private final boolean dryRunOnnxOnSetup;
public static RankProfileList empty = new RankProfileList();
@@ -42,6 +43,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
rankingConstants = new RankingConstants();
largeRankExpressions = new LargeRankExpressions();
onnxModels = new OnnxModels();
+ dryRunOnnxOnSetup = true;
}
/**
@@ -62,6 +64,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
this.rankingConstants = rankingConstants;
this.largeRankExpressions = largeRankExpressions;
onnxModels = search == null ? new OnnxModels() : search.onnxModels(); // as ONNX models come from parsing rank expressions
+ dryRunOnnxOnSetup = deployProperties.featureFlags().dryRunOnnxOnSetup();
deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties);
}
@@ -137,6 +140,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ
log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way
else {
OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder();
+ modelBuilder.dry_run_on_setup(dryRunOnnxOnSetup);
modelBuilder.name(model.getName());
modelBuilder.fileref(model.getFileReference());
model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source)));
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
index f460383f42b..ab148130a7d 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -2,8 +2,10 @@
package com.yahoo.searchdefinition.processing;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.api.ModelContext;
import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.config.model.deploy.DeployState;
+import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -27,9 +29,9 @@ public class RankingExpressionWithOnnxModelTestCase {
@Test
public void testOnnxModelFeature() throws Exception {
- VespaModel model = loadModel(applicationDir);
+ VespaModel model = loadModel(applicationDir, false);
assertTransformedFeature(model);
- assertGeneratedConfig(model);
+ assertGeneratedConfig(model, false);
Path storedApplicationDir = applicationDir.append("copy");
try {
@@ -39,73 +41,82 @@ public class RankingExpressionWithOnnxModelTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- VespaModel storedModel = loadModel(storedApplicationDir);
+ VespaModel storedModel = loadModel(storedApplicationDir, true);
assertTransformedFeature(storedModel);
- assertGeneratedConfig(storedModel);
+ assertGeneratedConfig(storedModel, true);
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDir.toFile());
}
}
- private VespaModel loadModel(Path path) throws Exception {
+ private VespaModel loadModel(Path path, boolean dryRunOnnx) throws Exception {
FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile());
- DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).build();
+ ModelContext.Properties properties = new TestProperties().setDryRunOnnxOnSetup(dryRunOnnx);
+ DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).properties(properties).build();
return new VespaModel(state);
}
- private void assertGeneratedConfig(VespaModel model) {
- DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0);
+ private void assertGeneratedConfig(VespaModel vespaModel, boolean expectDryRunOnnx) {
+ DocumentDatabase db = ((IndexedSearchCluster)vespaModel.getSearchClusters().get(0)).getDocumentDbs().get(0);
OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder();
((OnnxModelsConfig.Producer) db).getConfig(builder);
OnnxModelsConfig config = new OnnxModelsConfig(builder);
assertEquals(6, config.model().size());
+ for (OnnxModelsConfig.Model model : config.model()) {
+ assertEquals(expectDryRunOnnx, model.dry_run_on_setup());
+ }
- assertEquals("my_model", config.model(0).name());
- assertEquals(3, config.model(0).input().size());
- assertEquals("second/input:0", config.model(0).input(0).name());
- assertEquals("constant(my_constant)", config.model(0).input(0).source());
- assertEquals("first_input", config.model(0).input(1).name());
- assertEquals("attribute(document_field)", config.model(0).input(1).source());
- assertEquals("third_input", config.model(0).input(2).name());
- assertEquals("rankingExpression(my_function)", config.model(0).input(2).source());
- assertEquals(3, config.model(0).output().size());
- assertEquals("path/to/output:0", config.model(0).output(0).name());
- assertEquals("out", config.model(0).output(0).as());
- assertEquals("path/to/output:1", config.model(0).output(1).name());
- assertEquals("path_to_output_1", config.model(0).output(1).as());
- assertEquals("path/to/output:2", config.model(0).output(2).name());
- assertEquals("path_to_output_2", config.model(0).output(2).as());
-
- assertEquals("files_model_onnx", config.model(1).name());
- assertEquals(3, config.model(1).input().size());
- assertEquals(3, config.model(1).output().size());
- assertEquals("path/to/output:0", config.model(1).output(0).name());
- assertEquals("path_to_output_0", config.model(1).output(0).as());
- assertEquals("path/to/output:1", config.model(1).output(1).name());
- assertEquals("path_to_output_1", config.model(1).output(1).as());
- assertEquals("path/to/output:2", config.model(1).output(2).name());
- assertEquals("path_to_output_2", config.model(1).output(2).as());
- assertEquals("files_model_onnx", config.model(1).name());
-
- assertEquals("another_model", config.model(2).name());
- assertEquals("third_input", config.model(2).input(2).name());
- assertEquals("rankingExpression(another_function)", config.model(2).input(2).source());
-
- assertEquals("files_summary_model_onnx", config.model(3).name());
- assertEquals(3, config.model(3).input().size());
- assertEquals(3, config.model(3).output().size());
-
- assertEquals("dynamic_model", config.model(5).name());
- assertEquals(1, config.model(5).input().size());
- assertEquals(1, config.model(5).output().size());
- assertEquals("rankingExpression(my_function)", config.model(5).input(0).source());
-
- assertEquals("unbound_model", config.model(4).name());
- assertEquals(1, config.model(4).input().size());
- assertEquals(1, config.model(4).output().size());
- assertEquals("rankingExpression(my_function)", config.model(4).input(0).source());
-
+ OnnxModelsConfig.Model model = config.model(0);
+ assertEquals("my_model", model.name());
+ assertEquals(3, model.input().size());
+ assertEquals("second/input:0", model.input(0).name());
+ assertEquals("constant(my_constant)", model.input(0).source());
+ assertEquals("first_input", model.input(1).name());
+ assertEquals("attribute(document_field)", model.input(1).source());
+ assertEquals("third_input", model.input(2).name());
+ assertEquals("rankingExpression(my_function)", model.input(2).source());
+ assertEquals(3, model.output().size());
+ assertEquals("path/to/output:0", model.output(0).name());
+ assertEquals("out", model.output(0).as());
+ assertEquals("path/to/output:1", model.output(1).name());
+ assertEquals("path_to_output_1", model.output(1).as());
+ assertEquals("path/to/output:2", model.output(2).name());
+ assertEquals("path_to_output_2", model.output(2).as());
+
+ model = config.model(1);
+ assertEquals("files_model_onnx", model.name());
+ assertEquals(3, model.input().size());
+ assertEquals(3, model.output().size());
+ assertEquals("path/to/output:0", model.output(0).name());
+ assertEquals("path_to_output_0", model.output(0).as());
+ assertEquals("path/to/output:1", model.output(1).name());
+ assertEquals("path_to_output_1", model.output(1).as());
+ assertEquals("path/to/output:2", model.output(2).name());
+ assertEquals("path_to_output_2", model.output(2).as());
+ assertEquals("files_model_onnx", model.name());
+
+ model = config.model(2);
+ assertEquals("another_model", model.name());
+ assertEquals("third_input", model.input(2).name());
+ assertEquals("rankingExpression(another_function)", model.input(2).source());
+
+ model = config.model(3);
+ assertEquals("files_summary_model_onnx", model.name());
+ assertEquals(3, model.input().size());
+ assertEquals(3, model.output().size());
+
+ model = config.model(4);
+ assertEquals("unbound_model", model.name());
+ assertEquals(1, model.input().size());
+ assertEquals(1, model.output().size());
+ assertEquals("rankingExpression(my_function)", model.input(0).source());
+
+ model = config.model(5);
+ assertEquals("dynamic_model", model.name());
+ assertEquals(1, model.input().size());
+ assertEquals(1, model.output().size());
+ assertEquals("rankingExpression(my_function)", model.input(0).source());
}
private void assertTransformedFeature(VespaModel model) {
diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
index 94cfba12453..2b8eaf86cda 100644
--- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
+++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java
@@ -182,6 +182,7 @@ public class ModelContextImpl implements ModelContext {
private final int maxMergeQueueSize;
private final int largeRankExpressionLimit;
private final boolean throwIfResourceLimitsSpecified;
+ private final boolean dryRunOnnxOnSetup;
public FeatureFlags(FlagSource source, ApplicationId appId) {
this.dedicatedClusterControllerFlavor = parseDedicatedClusterControllerFlavor(flagValue(source, appId, Flags.DEDICATED_CLUSTER_CONTROLLER_FLAVOR));
@@ -209,6 +210,7 @@ public class ModelContextImpl implements ModelContext {
this.maxConcurrentMergesPerContentNode = flagValue(source, appId, Flags.MAX_CONCURRENT_MERGES_PER_NODE);
this.maxMergeQueueSize = flagValue(source, appId, Flags.MAX_MERGE_QUEUE_SIZE);
this.throwIfResourceLimitsSpecified = flagValue(source, appId, Flags.THROW_EXCEPTION_IF_RESOURCE_LIMITS_SPECIFIED);
+ this.dryRunOnnxOnSetup = flagValue(source, appId, Flags.DRY_RUN_ONNX_ON_SETUP);
}
@Override public Optional<NodeResources> dedicatedClusterControllerFlavor() { return Optional.ofNullable(dedicatedClusterControllerFlavor); }
@@ -238,6 +240,7 @@ public class ModelContextImpl implements ModelContext {
@Override public int maxConcurrentMergesPerNode() { return maxConcurrentMergesPerContentNode; }
@Override public int maxMergeQueueSize() { return maxMergeQueueSize; }
@Override public boolean throwIfResourceLimitsSpecified() { return throwIfResourceLimitsSpecified; }
+ @Override public boolean dryRunOnnxOnSetup() { return dryRunOnnxOnSetup; }
private static <V> V flagValue(FlagSource source, ApplicationId appId, UnboundFlag<? extends V, ?, ?> flag) {
return flag.bindTo(source)
diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
index 5f784bcefa1..59a52bf49b8 100644
--- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
+++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java
@@ -286,6 +286,13 @@ public class Flags {
"Whether to load local sessions when bootstrapping config server",
"Takes effect on restart of config server");
+ public static final UnboundBooleanFlag DRY_RUN_ONNX_ON_SETUP = defineFeatureFlag(
+ "dryrun-onnx-on-setup", false,
+ List.of("baldersheim"), "2021-06-23", "2021-08-01",
+ "Whether to dry run onnx models on setup for better error checking",
+ "Takes effect on next internal redeployment",
+ APPLICATION_ID);
+
/** WARNING: public for testing: All flags should be defined in {@link Flags}. */
public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners,
String createdAt, String expiresAt, String description,