summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-21 15:12:41 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-21 15:12:41 +0200
commit2936e1e221b2102b24c32493ee950a05958e5b84 (patch)
treefbfa0611039b52609eec4f04b1e9b440c730ccf6
parent3753d2d8b27f3941974a51dc0de5d07d879bacd2 (diff)
Don't mutate expressions in imported models
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java15
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java3
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java6
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java5
5 files changed, 27 insertions, 12 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
index 8766ff441bc..5ae04582de3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java
@@ -189,7 +189,10 @@ public class ConvertedModel {
}
// Transform and save macro - must come after reading expressions due to optimization transforms
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
+ // and must use the macro expression added to the profile, which may differ from the one saved in the model,
+ // after rewrite
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k,
+ profile.getMacros().get(k).getRankingExpression()));
return expressions;
}
@@ -259,7 +262,8 @@ public class ConvertedModel {
private void transformGeneratedMacro(ModelStore store,
Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
+ String macroName,
+ RankingExpression expression) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
store.writeMacro(macroName, expression);
@@ -322,7 +326,7 @@ public class ConvertedModel {
* Add the generated macros to the rank profile
*/
private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v.copy()));
}
/**
@@ -339,9 +343,8 @@ public class ConvertedModel {
Set<String> macroNames = new HashSet<>();
addMacroNamesIn(expression.getRoot(), macroNames, model);
for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
+ if ( ! model.macros().containsKey(macroName)) continue;
+
RankProfile.Macro macro = profile.getMacros().get(macroName);
if (macro == null) {
throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
index 72eac2282f4..f68f218a1eb 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ImportedModels.java
@@ -29,8 +29,7 @@ class ImportedModels {
* @throws IllegalArgumentException if the model cannot be loaded
*/
public ImportedModel imported(String modelName, File modelDir) {
- return modelImporter.importModel(modelName, modelDir);
- // return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelDir)); // TODO
+ return importedModels.computeIfAbsent(modelName, __ -> modelImporter.importModel(modelName, modelDir));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 374ca864a36..9a96555bb78 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -359,9 +359,9 @@ public class RankingExpressionWithTensorFlowTestCase {
search.assertFirstPhaseExpression(expression, "my_profile");
search.assertFirstPhaseExpression(expression, "my_profile_child");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
-// search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child");
-// search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile_child");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile_child");
// At this point the expression is stored - copy application to another location which do not have a models dir
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
index 01ed3b35d4c..34445a31ac3 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java
@@ -182,6 +182,16 @@ public class RankingExpression implements Serializable {
}
}
+ /** Returns a deep copy of this expression */
+ public RankingExpression copy() {
+ try {
+ return new RankingExpression(name, root.toString());
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Programming error: Could not parse serialized expression", e);
+ }
+ }
+
/**
* Returns the name of this ranking expression, or "" if no name is set.
*
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 4b49f17f74e..184e92781c3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -70,7 +70,10 @@ public class ImportedModel {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
- /** Returns an immutable map of macros that are part of this model */
+ /**
+ * Returns an immutable map of macros that are part of this model.
+ * Note that the macros themselves are *not* copies and *not* immutable - they must be copied before modification.
+ */
public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
/** Returns an immutable map of the macros that must be provided by the environment running this model */