summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-21 15:12:41 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-21 15:12:41 +0200
commit2936e1e221b2102b24c32493ee950a05958e5b84 (patch)
treefbfa0611039b52609eec4f04b1e9b440c730ccf6 /config-model
parent3753d2d8b27f3941974a51dc0de5d07d879bacd2 (diff)
Don't mutate expressions in imported models
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java15
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java6
3 files changed, 13 insertions, 11 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 8766ff441bc..5ae04582de3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -189,7 +189,10 @@ public class ConvertedModel {
}
// Transform and save macro - must come after reading expressions due to optimization transforms
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
+ // and must use the macro expression added to the profile, which may differ from the one saved in the model,
+ // after rewrite
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k,
+ profile.getMacros().get(k).getRankingExpression()));
return expressions;
}
@@ -259,7 +262,8 @@ public class ConvertedModel {
private void transformGeneratedMacro(ModelStore store,
Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
+ String macroName,
+ RankingExpression expression) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
store.writeMacro(macroName, expression);
@@ -322,7 +326,7 @@ public class ConvertedModel {
* Add the generated macros to the rank profile
*/
private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy()));
}
/**
@@ -339,9 +343,8 @@ public class ConvertedModel {
Set<String> macroNames = new HashSet<>();
addMacroNamesIn(expression.getRoot(), macroNames, model);
for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
+ if ( ! model.macros().containsKey(macroName)) continue;
+
RankProfile.Macro macro = profile.getMacros().get(macroName);
if (macro == null) {
throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
index 72eac2282f4..f68f218a1eb 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
@@ -29,8 +29,7 @@ class ImportedModels {
* @throws IllegalArgumentException if the model cannot be loaded
*/
public ImportedModel imported(String modelName, File modelDir) {
- return modelImporter.importModel(modelName, modelDir);
- // return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelDir)); // TODO
+ return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelDir));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 374ca864a36..9a96555bb78 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -359,9 +359,9 @@ public class RankingExpressionWithTensorFlowTestCase {
search.assertFirstPhaseExpression(expression, "my_profile");
search.assertFirstPhaseExpression(expression, "my_profile_child");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
-// search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child");
-// search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile_child");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child");
// At this point the expression is stored - copy application to another location which do not have a models dir