aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java11
1 files changed, 8 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 1343af1b6ec..774b166c45a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -10,6 +10,8 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import java.io.UncheckedIOException;
+import java.util.HashMap;
+import java.util.Map;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -20,7 +22,10 @@ import java.io.UncheckedIOException;
*/
public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- private final ImportedModels importedModels = new ImportedModels(new TensorFlowImporter());
+ private final ImportedModels importedTensorFlowModels = new ImportedModels(new TensorFlowImporter());
+
+ /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
+ private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -37,8 +42,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
- // TODO: Increase scope of this instance to a rank profile:
- ConvertedModel convertedModel = new ConvertedModel(modelPath, context, importedModels);
+ ConvertedModel convertedModel =
+ convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedTensorFlowModels));
return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {