aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-15 13:55:49 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-15 13:55:49 +0200
commitffcf49b0690147d9fdc47c543da3609981045fd5 (patch)
treea07b7a2403b96e0308e376e7fcc3eec157400c33 /config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
parent482bc20e6fb7f84cbb54a6194f0e129b2f4f7029 (diff)
Convert all outputs
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java8
1 files changed, 4 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index d28299b1d30..585adc0c0d4 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -28,7 +28,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, ImportedModel> importedModels = new HashMap<>();
+ private final Map<Path, ConvertedModel> convertedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -44,9 +44,9 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()),
- context, tensorFlowImporter, importedModels);
- return convertedModel.expression();
+ Path modelPath = Path.fromString(ConvertedModel.FeatureArguments.asString(feature.getArguments().expressions().get(0)));
+ ConvertedModel convertedModel = convertedModels.computeIfAbsent(modelPath, path -> new ConvertedModel(path, context, tensorFlowImporter));
+ return convertedModel.expression(asFeatureArguments(feature.getArguments()));
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);