diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-21 14:08:35 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-21 14:08:35 +0200 |
commit | 3753d2d8b27f3941974a51dc0de5d07d879bacd2 (patch) | |
tree | 6d78dd7f51d8c7da325b3f683fd60f0781b65c38 | |
parent | 61c3bef3dc00d485ca87cb2e2b145e1b20626bf7 (diff) |
Reduce scope of converted models
5 files changed, 101 insertions, 42 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 2bc0ccf6686..8766ff441bc 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -54,7 +54,7 @@ import java.util.Set; import java.util.stream.Collectors; /** - * A machine learned model imported from the models/ directory in the application package. + * A machine learned model imported from the models/ directory in the application package, for a single rank profile. * This encapsulates the difference between reading a model * - from a file application package, where it is represented by an ImportedModel, and * - from a ZK application package, where the models/ directory is unavailable and models are read from @@ -74,25 +74,29 @@ public class ConvertedModel { */ private final Map<String, RankingExpression> expressions; + /** + * Create a converted model for a rank profile given from either an imported model, + * or (if unavailable) from stored application package data. + */ public ConvertedModel(Path modelPath, RankProfileTransformContext context, - ModelImporter modelImporter, + ImportedModels importedModels, FeatureArguments arguments) { // TODO: Remove this.modelPath = modelPath; this.modelName = toModelName(modelPath); ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), modelPath); - if ( store.hasSourceModel()) // not converted yet - access from models/ directory - expressions = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, arguments); + if ( store.hasSourceModel()) + expressions = convertModel(store, context.rankProfile(), context.queryProfiles(), importedModels, arguments); else expressions = transformFromStoredModel(store, context.rankProfile()); } - private Map<String, RankingExpression> importModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelImporter modelImporter, - FeatureArguments arguments) { - ImportedModel model = modelImporter.importModel(store.modelFiles.modelName(), store.sourceModelDir()); + private Map<String, RankingExpression> convertModel(ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ImportedModels importedModels, + FeatureArguments arguments) { + ImportedModel model = importedModels.imported(store.modelFiles.modelName(), store.sourceModelDir()); return transformFromImportedModel(model, store, profile, queryProfiles, arguments); } @@ -262,9 +266,12 @@ public class ConvertedModel { } private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) - throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists."); - + if (profile.getMacros().containsKey(macroName)) { + if ( ! profile.getMacros().get(macroName).getRankingExpression().equals(expression)) + throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists in " + profile + + " - with a different definition"); + return; + } profile.addMacro(macroName, false); // todo: inline if only used once RankProfile.Macro macro = profile.getMacros().get(macroName); macro.setRankingExpression(expression); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java new file mode 100644 index 00000000000..72eac2282f4 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java @@ -0,0 +1,36 @@ +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; + +/** + * Lazily loaded models imported from the models/ directory in the application package + * + * @author bratseth + */ +class ImportedModels { + + private final ModelImporter modelImporter; + + /** The cache of already imported models */ + private final Map<String, ImportedModel> importedModels = new HashMap<>(); + + ImportedModels(ModelImporter modelImporter) { + this.modelImporter = modelImporter; + } + + /** + * Returns the model at the given location in the application package (lazily loaded), + * + * @throws IllegalArgumentException if the model cannot be loaded + */ + public ImportedModel imported(String modelName, File modelDir) { + return modelImporter.importModel(modelName, modelDir); + // return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelDir)); // TODO + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 97395c1aad3..cb3873e9d71 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -28,11 +28,11 @@ import java.util.Optional; */ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final OnnxImporter onnxImporter = new OnnxImporter(); - /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ private final Map<Path, ConvertedModel> convertedModels = new HashMap<>(); + private final ImportedModels importedModels = new ImportedModels(new OnnxImporter()); + @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { if (node instanceof ReferenceNode) @@ -48,7 +48,8 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); - ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, onnxImporter, new ConvertedModel.FeatureArguments(feature.getArguments()))); + // TODO: Increase scope of this instance to a rank profile: + ConvertedModel convertedModel = new ConvertedModel(modelPath, context, importedModels, new ConvertedModel.FeatureArguments(feature.getArguments())); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index b3778e2af84..102798d2b94 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -5,6 +5,7 @@ import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter; import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -25,10 +26,7 @@ import java.util.Map; */ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); - - /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, ConvertedModel> convertedModels = new HashMap<>(); + private final ImportedModels importedModels = new ImportedModels(new TensorFlowImporter()); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -45,7 +43,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); - ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter, new ConvertedModel.FeatureArguments(feature.getArguments()))); + // TODO: Increase scope of this instance to a rank profile: + ConvertedModel convertedModel = new ConvertedModel(modelPath, context, importedModels, new ConvertedModel.FeatureArguments(feature.getArguments())); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 29859817736..374ca864a36 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -259,8 +259,8 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException { - String rankProfile = + public void testImportingFromStoredExpressionsWithMacroOverridingConstantAndInheritance() throws IOException { + String rankProfiles = " rank-profile my_profile {\n" + " macro Placeholder() {\n" + " expression: tensor(d0[2],d1[784])(0.0)\n" + @@ -271,14 +271,17 @@ public class RankingExpressionWithTensorFlowTestCase { " first-phase {\n" + " expression: tensorflow('mnist_softmax/saved')" + " }\n" + + " }" + + " rank-profile my_profile_child inherits my_profile {\n" + " }"; - String vespaExpressionWithoutConstant = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_saved_layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"; - RankProfileSearchFixture search = fixtureWithUncompiled(rankProfile, new StoringApplicationPackage(applicationDir)); + RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir)); search.compileRankProfile("my_profile"); + search.compileRankProfile("my_profile_child"); search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", search.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read")); @@ -291,9 +294,11 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfile, storedApplication); + RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication); searchFromStored.compileRankProfile("my_profile"); + searchFromStored.compileRankProfile("my_profile_child"); searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile_child"); assertNull("Constant overridden by macro is not added", searchFromStored.search().getRankingConstants().get("mnist_softmax_saved_layer_Variable_read")); assertLargeConstant("mnist_softmax_saved_layer_Variable_1_read", searchFromStored, Optional.of(10L)); @@ -331,22 +336,33 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { + public void testImportingFromStoredExpressionsWithSmallConstantsAndInheritance() throws IOException { + final String rankProfiles = + " rank-profile my_profile {\n" + + " macro input() {\n" + + " expression: tensor(d0[1],d1[784])(0.0)\n" + + " }\n" + + " first-phase {\n" + + " expression: tensorflow('mnist/saved')" + + " }\n" + + " }" + + " rank-profile my_profile_child inherits my_profile {\n" + + " }"; + final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))"; - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist/saved')", - null, - null, - "input", - application); + RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir)); + search.compileRankProfile("my_profile"); + search.compileRankProfile("my_profile_child"); search.assertFirstPhaseExpression(expression, "my_profile"); + search.assertFirstPhaseExpression(expression, "my_profile_child"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); - search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); +// search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child"); +// search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child"); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -355,16 +371,16 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist/saved')", - null, - null, - "input", - storedApplication); + RankProfileSearchFixture searchFromStored = fixtureWithUncompiled(rankProfiles, storedApplication); + searchFromStored.compileRankProfile("my_profile"); + searchFromStored.compileRankProfile("my_profile_child"); searchFromStored.assertFirstPhaseExpression(expression, "my_profile"); + searchFromStored.assertFirstPhaseExpression(expression, "my_profile_child"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile_child"); searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); + searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child"); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); |