diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 1343af1b6ec..774b166c45a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -10,6 +10,8 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import java.io.UncheckedIOException; +import java.util.HashMap; +import java.util.Map; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -20,7 +22,10 @@ import java.io.UncheckedIOException; */ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - private final ImportedModels importedModels = new ImportedModels(new TensorFlowImporter()); + private final ImportedModels importedTensorFlowModels = new ImportedModels(new TensorFlowImporter()); + + /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ + private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -37,8 +42,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0))); - // TODO: Increase scope of this instance to a rank profile: - ConvertedModel convertedModel = new ConvertedModel(modelPath, context, importedModels); + ConvertedModel convertedModel = + convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> new ConvertedModel(modelPath, context, importedTensorFlowModels)); return convertedModel.expression(asFeatureArguments(feature.getArguments())); } catch (IllegalArgumentException | UncheckedIOException e) { |