diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-21 15:12:41 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-21 15:12:41 +0200 |
commit | 2936e1e221b2102b24c32493ee950a05958e5b84 (patch) | |
tree | fbfa0611039b52609eec4f04b1e9b440c730ccf6 | |
parent | 3753d2d8b27f3941974a51dc0de5d07d879bacd2 (diff) |
Don't mutate expressions in imported models
5 files changed, 27 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 8766ff441bc..5ae04582de3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -189,7 +189,10 @@ public class ConvertedModel { } // Transform and save macro - must come after reading expressions due to optimization transforms - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); + // and must use the macro expression added to the profile, which may differ from the one saved in the model, + // after rewrite + model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, + profile.getMacros().get(k).getRankingExpression())); return expressions; } @@ -259,7 +262,8 @@ public class ConvertedModel { private void transformGeneratedMacro(ModelStore store, Set<String> constantsReplacedByMacros, - String macroName, RankingExpression expression) { + String macroName, + RankingExpression expression) { expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); store.writeMacro(macroName, expression); @@ -322,7 +326,7 @@ public class ConvertedModel { * Add the generated macros to the rank profile */ private void addGeneratedMacros(ImportedModel model, RankProfile profile) { - model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); + model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy())); } /** @@ -339,9 +343,8 @@ public class ConvertedModel { Set<String> macroNames = new HashSet<>(); addMacroNamesIn(expression.getRoot(), macroNames, model); for (String macroName : macroNames) { - if ( ! model.macros().containsKey(macroName)) { - continue; - } + if ( ! model.macros().containsKey(macroName)) continue; + RankProfile.Macro macro = profile.getMacros().get(macroName); if (macro == null) { throw new IllegalArgumentException("Model refers to generated macro '" + macroName + diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java index 72eac2282f4..f68f218a1eb 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java @@ -29,8 +29,7 @@ class ImportedModels { * @throws IllegalArgumentException if the model cannot be loaded */ public ImportedModel imported(String modelName, File modelDir) { - return modelImporter.importModel(modelName, modelDir); - // return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelDir)); // TODO + return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelDir)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 374ca864a36..9a96555bb78 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -359,9 +359,9 @@ public class RankingExpressionWithTensorFlowTestCase { search.assertFirstPhaseExpression(expression, "my_profile"); search.assertFirstPhaseExpression(expression, "my_profile_child"); assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search); -// search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); - search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child"); -// search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile_child"); + search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile"); search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child"); // At this point the expression is stored - copy application to another location which do not have a models dir diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index 01ed3b35d4c..34445a31ac3 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -182,6 +182,16 @@ public class RankingExpression implements Serializable { } } + /** Returns a deep copy of this expression */ + public RankingExpression copy() { + try { + return new RankingExpression(name, root.toString()); + } + catch (ParseException e) { + throw new RuntimeException("Programming error: Could not parse serialized expression", e); + } + } + /** * Returns the name of this ranking expression, or "" if no name is set. * diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 4b49f17f74e..184e92781c3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -70,7 +70,10 @@ public class ImportedModel { */ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - /** Returns an immutable map of macros that are part of this model */ + /** + * Returns an immutable map of macros that are part of this model. + * Note that the macros themselves are *not* copies and *not* immutable - they must be copied before modification. + */ public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } /** Returns an immutable map of the macros that must be provided by the environment running this model */ |