summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-08-10 16:08:54 +0200
committerJon Bratseth <bratseth@oath.com>2018-08-10 16:08:54 +0200
commit66b8b332874136f95fff1290dbd7b7001e4a9398 (patch)
tree51034d282a1ce67e05616a4d545899cff22979cf /config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
parentbc41b0e6489e3002d75c400e4dda4f4218306554 (diff)
Refactor and remove duplication
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java20
1 files changed, 5 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 27e1ad51b33..fab5068ea6f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -43,28 +43,17 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments());
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
- if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
- return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
- else
- return transformFromStoredModel(store, context.rankProfile());
+ ConvertedModel.FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments());
+ ConvertedModel convertedModel = new ConvertedModel(arguments, context, tensorFlowImporter, importedModels);
+ return convertedModel.expression();
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
}
}
- private ExpressionNode transformFromTensorFlowModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
- ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
- k -> tensorFlowImporter.importModel(store.arguments().modelName(),
- store.modelDir()));
- return transformFromImportedModel(model, store, profile, queryProfiles);
- }
+ static class TensorFlowFeatureArguments extends ConvertedModel.FeatureArguments {
- static class TensorFlowFeatureArguments extends FeatureArguments {
public TensorFlowFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
@@ -76,6 +65,7 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
signature = optionalArgument(1, arguments);
output = optionalArgument(2, arguments);
}
+
}
}