From 8d1451224710fa7a65e9a7410113c64c104a8dd0 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Wed, 23 Jun 2021 10:00:18 +0200 Subject: Add feature falg for controlling onnx dryrun verification. --- .../com/yahoo/config/model/api/ModelContext.java | 1 + .../yahoo/config/model/deploy/TestProperties.java | 6 ++ .../searchdefinition/derived/RankProfileList.java | 4 + .../RankingExpressionWithOnnxModelTestCase.java | 117 +++++++++++---------- .../config/server/deploy/ModelContextImpl.java | 3 + .../src/main/java/com/yahoo/vespa/flags/Flags.java | 7 ++ 6 files changed, 85 insertions(+), 53 deletions(-) diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index 81ee0a4c4c3..badb6e01885 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -86,6 +86,7 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"baldersheim"}) default boolean distributeExternalRankExpressions() { return false; } @ModelFeatureFlag(owners = {"baldersheim"}) default int maxConcurrentMergesPerNode() { throw new UnsupportedOperationException("TODO specify default value"); } @ModelFeatureFlag(owners = {"baldersheim"}) default int maxMergeQueueSize() { throw new UnsupportedOperationException("TODO specify default value"); } + @ModelFeatureFlag(owners = {"baldersheim"}) default boolean dryRunOnnxOnSetup() { throw new UnsupportedOperationException("TODO specify default value"); } @ModelFeatureFlag(owners = {"geirst"}) default boolean enableFeedBlockInDistributor() { return true; } @ModelFeatureFlag(owners = {"hmusum"}, removeAfter = "7.406") default int clusterControllerMaxHeapSizeInMb() { return 128; } @ModelFeatureFlag(owners = {"hmusum"}, removeAfter = "7.422") default int metricsProxyMaxHeapSizeInMb(ClusterSpec.Type type) { return 256; } diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java index fe1bf93f32b..c62dc6e4631 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java @@ -63,6 +63,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea private int maxMergeQueueSize = 1024; private int largeRankExpressionLimit = 0x10000; private boolean allowDisableMtls = true; + private boolean dryRunOnnxOnSetup = false; private List operatorCertificates = Collections.emptyList(); @Override public ModelContext.FeatureFlags featureFlags() { return this; } @@ -107,7 +108,12 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea @Override public int largeRankExpressionLimit() { return largeRankExpressionLimit; } @Override public int maxConcurrentMergesPerNode() { return maxConcurrentMergesPerNode; } @Override public int maxMergeQueueSize() { return maxMergeQueueSize; } + @Override public boolean dryRunOnnxOnSetup() { return dryRunOnnxOnSetup; } + public TestProperties setDryRunOnnxOnSetup(boolean value) { + dryRunOnnxOnSetup = value; + return this; + } public TestProperties useExternalRankExpression(boolean value) { useExternalRankExpression = value; return this; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index 7c533cce006..1796a4cba17 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -35,6 +35,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ private final RankingConstants rankingConstants; private final LargeRankExpressions largeRankExpressions; private final OnnxModels onnxModels; + private final boolean dryRunOnnxOnSetup; public static RankProfileList empty = new RankProfileList(); @@ -42,6 +43,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ rankingConstants = new RankingConstants(); largeRankExpressions = new LargeRankExpressions(); onnxModels = new OnnxModels(); + dryRunOnnxOnSetup = true; } /** @@ -62,6 +64,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ this.rankingConstants = rankingConstants; this.largeRankExpressions = largeRankExpressions; onnxModels = search == null ? new OnnxModels() : search.onnxModels(); // as ONNX models come from parsing rank expressions + dryRunOnnxOnSetup = deployProperties.featureFlags().dryRunOnnxOnSetup(); deriveRankProfiles(rankProfileRegistry, queryProfiles, importedModels, search, attributeFields, deployProperties); } @@ -137,6 +140,7 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ log.warning("Illegal file reference " + model); // Let tests pass ... we should find a better way else { OnnxModelsConfig.Model.Builder modelBuilder = new OnnxModelsConfig.Model.Builder(); + modelBuilder.dry_run_on_setup(dryRunOnnxOnSetup); modelBuilder.name(model.getName()); modelBuilder.fileref(model.getFileReference()); model.getInputMap().forEach((name, source) -> modelBuilder.input(new OnnxModelsConfig.Model.Input.Builder().name(name).source(source))); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java index f460383f42b..ab148130a7d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxModelTestCase.java @@ -2,8 +2,10 @@ package com.yahoo.searchdefinition.processing; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -27,9 +29,9 @@ public class RankingExpressionWithOnnxModelTestCase { @Test public void testOnnxModelFeature() throws Exception { - VespaModel model = loadModel(applicationDir); + VespaModel model = loadModel(applicationDir, false); assertTransformedFeature(model); - assertGeneratedConfig(model); + assertGeneratedConfig(model, false); Path storedApplicationDir = applicationDir.append("copy"); try { @@ -39,73 +41,82 @@ public class RankingExpressionWithOnnxModelTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - VespaModel storedModel = loadModel(storedApplicationDir); + VespaModel storedModel = loadModel(storedApplicationDir, true); assertTransformedFeature(storedModel); - assertGeneratedConfig(storedModel); + assertGeneratedConfig(storedModel, true); } finally { IOUtils.recursiveDeleteDir(storedApplicationDir.toFile()); } } - private VespaModel loadModel(Path path) throws Exception { + private VespaModel loadModel(Path path, boolean dryRunOnnx) throws Exception { FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile()); - DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).build(); + ModelContext.Properties properties = new TestProperties().setDryRunOnnxOnSetup(dryRunOnnx); + DeployState state = new DeployState.Builder().applicationPackage(applicationPackage).properties(properties).build(); return new VespaModel(state); } - private void assertGeneratedConfig(VespaModel model) { - DocumentDatabase db = ((IndexedSearchCluster)model.getSearchClusters().get(0)).getDocumentDbs().get(0); + private void assertGeneratedConfig(VespaModel vespaModel, boolean expectDryRunOnnx) { + DocumentDatabase db = ((IndexedSearchCluster)vespaModel.getSearchClusters().get(0)).getDocumentDbs().get(0); OnnxModelsConfig.Builder builder = new OnnxModelsConfig.Builder(); ((OnnxModelsConfig.Producer) db).getConfig(builder); OnnxModelsConfig config = new OnnxModelsConfig(builder); assertEquals(6, config.model().size()); + for (OnnxModelsConfig.Model model : config.model()) { + assertEquals(expectDryRunOnnx, model.dry_run_on_setup()); + } - assertEquals("my_model", config.model(0).name()); - assertEquals(3, config.model(0).input().size()); - assertEquals("second/input:0", config.model(0).input(0).name()); - assertEquals("constant(my_constant)", config.model(0).input(0).source()); - assertEquals("first_input", config.model(0).input(1).name()); - assertEquals("attribute(document_field)", config.model(0).input(1).source()); - assertEquals("third_input", config.model(0).input(2).name()); - assertEquals("rankingExpression(my_function)", config.model(0).input(2).source()); - assertEquals(3, config.model(0).output().size()); - assertEquals("path/to/output:0", config.model(0).output(0).name()); - assertEquals("out", config.model(0).output(0).as()); - assertEquals("path/to/output:1", config.model(0).output(1).name()); - assertEquals("path_to_output_1", config.model(0).output(1).as()); - assertEquals("path/to/output:2", config.model(0).output(2).name()); - assertEquals("path_to_output_2", config.model(0).output(2).as()); - - assertEquals("files_model_onnx", config.model(1).name()); - assertEquals(3, config.model(1).input().size()); - assertEquals(3, config.model(1).output().size()); - assertEquals("path/to/output:0", config.model(1).output(0).name()); - assertEquals("path_to_output_0", config.model(1).output(0).as()); - assertEquals("path/to/output:1", config.model(1).output(1).name()); - assertEquals("path_to_output_1", config.model(1).output(1).as()); - assertEquals("path/to/output:2", config.model(1).output(2).name()); - assertEquals("path_to_output_2", config.model(1).output(2).as()); - assertEquals("files_model_onnx", config.model(1).name()); - - assertEquals("another_model", config.model(2).name()); - assertEquals("third_input", config.model(2).input(2).name()); - assertEquals("rankingExpression(another_function)", config.model(2).input(2).source()); - - assertEquals("files_summary_model_onnx", config.model(3).name()); - assertEquals(3, config.model(3).input().size()); - assertEquals(3, config.model(3).output().size()); - - assertEquals("dynamic_model", config.model(5).name()); - assertEquals(1, config.model(5).input().size()); - assertEquals(1, config.model(5).output().size()); - assertEquals("rankingExpression(my_function)", config.model(5).input(0).source()); - - assertEquals("unbound_model", config.model(4).name()); - assertEquals(1, config.model(4).input().size()); - assertEquals(1, config.model(4).output().size()); - assertEquals("rankingExpression(my_function)", config.model(4).input(0).source()); - + OnnxModelsConfig.Model model = config.model(0); + assertEquals("my_model", model.name()); + assertEquals(3, model.input().size()); + assertEquals("second/input:0", model.input(0).name()); + assertEquals("constant(my_constant)", model.input(0).source()); + assertEquals("first_input", model.input(1).name()); + assertEquals("attribute(document_field)", model.input(1).source()); + assertEquals("third_input", model.input(2).name()); + assertEquals("rankingExpression(my_function)", model.input(2).source()); + assertEquals(3, model.output().size()); + assertEquals("path/to/output:0", model.output(0).name()); + assertEquals("out", model.output(0).as()); + assertEquals("path/to/output:1", model.output(1).name()); + assertEquals("path_to_output_1", model.output(1).as()); + assertEquals("path/to/output:2", model.output(2).name()); + assertEquals("path_to_output_2", model.output(2).as()); + + model = config.model(1); + assertEquals("files_model_onnx", model.name()); + assertEquals(3, model.input().size()); + assertEquals(3, model.output().size()); + assertEquals("path/to/output:0", model.output(0).name()); + assertEquals("path_to_output_0", model.output(0).as()); + assertEquals("path/to/output:1", model.output(1).name()); + assertEquals("path_to_output_1", model.output(1).as()); + assertEquals("path/to/output:2", model.output(2).name()); + assertEquals("path_to_output_2", model.output(2).as()); + assertEquals("files_model_onnx", model.name()); + + model = config.model(2); + assertEquals("another_model", model.name()); + assertEquals("third_input", model.input(2).name()); + assertEquals("rankingExpression(another_function)", model.input(2).source()); + + model = config.model(3); + assertEquals("files_summary_model_onnx", model.name()); + assertEquals(3, model.input().size()); + assertEquals(3, model.output().size()); + + model = config.model(4); + assertEquals("unbound_model", model.name()); + assertEquals(1, model.input().size()); + assertEquals(1, model.output().size()); + assertEquals("rankingExpression(my_function)", model.input(0).source()); + + model = config.model(5); + assertEquals("dynamic_model", model.name()); + assertEquals(1, model.input().size()); + assertEquals(1, model.output().size()); + assertEquals("rankingExpression(my_function)", model.input(0).source()); } private void assertTransformedFeature(VespaModel model) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 94cfba12453..2b8eaf86cda 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -182,6 +182,7 @@ public class ModelContextImpl implements ModelContext { private final int maxMergeQueueSize; private final int largeRankExpressionLimit; private final boolean throwIfResourceLimitsSpecified; + private final boolean dryRunOnnxOnSetup; public FeatureFlags(FlagSource source, ApplicationId appId) { this.dedicatedClusterControllerFlavor = parseDedicatedClusterControllerFlavor(flagValue(source, appId, Flags.DEDICATED_CLUSTER_CONTROLLER_FLAVOR)); @@ -209,6 +210,7 @@ public class ModelContextImpl implements ModelContext { this.maxConcurrentMergesPerContentNode = flagValue(source, appId, Flags.MAX_CONCURRENT_MERGES_PER_NODE); this.maxMergeQueueSize = flagValue(source, appId, Flags.MAX_MERGE_QUEUE_SIZE); this.throwIfResourceLimitsSpecified = flagValue(source, appId, Flags.THROW_EXCEPTION_IF_RESOURCE_LIMITS_SPECIFIED); + this.dryRunOnnxOnSetup = flagValue(source, appId, Flags.DRY_RUN_ONNX_ON_SETUP); } @Override public Optional dedicatedClusterControllerFlavor() { return Optional.ofNullable(dedicatedClusterControllerFlavor); } @@ -238,6 +240,7 @@ public class ModelContextImpl implements ModelContext { @Override public int maxConcurrentMergesPerNode() { return maxConcurrentMergesPerContentNode; } @Override public int maxMergeQueueSize() { return maxMergeQueueSize; } @Override public boolean throwIfResourceLimitsSpecified() { return throwIfResourceLimitsSpecified; } + @Override public boolean dryRunOnnxOnSetup() { return dryRunOnnxOnSetup; } private static V flagValue(FlagSource source, ApplicationId appId, UnboundFlag flag) { return flag.bindTo(source) diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 5f784bcefa1..59a52bf49b8 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -286,6 +286,13 @@ public class Flags { "Whether to load local sessions when bootstrapping config server", "Takes effect on restart of config server"); + public static final UnboundBooleanFlag DRY_RUN_ONNX_ON_SETUP = defineFeatureFlag( + "dryrun-onnx-on-setup", false, + List.of("baldersheim"), "2021-06-23", "2021-08-01", + "Whether to dry run onnx models on setup for better error checking", + "Takes effect on next internal redeployment", + APPLICATION_ID); + /** WARNING: public for testing: All flags should be defined in {@link Flags}. */ public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List owners, String createdAt, String expiresAt, String description, -- cgit v1.2.3 From aecf4b2d38f185400ae3a44a2b716f04c1a2c6d3 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Wed, 23 Jun 2021 11:31:18 +0200 Subject: 'dryrun' -> 'dry-run' --- flags/src/main/java/com/yahoo/vespa/flags/Flags.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index 59a52bf49b8..b3ea6e1d27c 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -287,7 +287,7 @@ public class Flags { "Takes effect on restart of config server"); public static final UnboundBooleanFlag DRY_RUN_ONNX_ON_SETUP = defineFeatureFlag( - "dryrun-onnx-on-setup", false, + "dry-run-onnx-on-setup", false, List.of("baldersheim"), "2021-06-23", "2021-08-01", "Whether to dry run onnx models on setup for better error checking", "Takes effect on next internal redeployment", -- cgit v1.2.3