aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-22 09:56:52 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-22 09:56:52 +0200
commitb1cd814eedf509399adbe6da3160e81c12421a4e (patch)
tree414d433869ac86dbdd6f3762ebca2a4ab73b6630 /config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
parent6464389e9f980ee1a8d71a075262039939ae1094 (diff)
Scope transforms and converters to rank profile
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java11
1 files changed, 8 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 1343af1b6ec..774b166c45a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -10,6 +10,8 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import java.io.UncheckedIOException;
+import java.util.HashMap;
+import java.util.Map;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -20,7 +22,10 @@ import java.io.UncheckedIOException;
*/
public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- private final ImportedModels importedModels = new ImportedModels(new TensorFlowImporter());
+ private final ImportedModels importedTensorFlowModels = new ImportedModels(new TensorFlowImporter());
+
+ /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
+ private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -37,8 +42,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
- // TODO: Increase scope of this instance to a rank profile:
- ConvertedModel convertedModel = new ConvertedModel(modelPath, context, importedModels);
+ ConvertedModel convertedModel =
+ convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedTensorFlowModels));
return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {