diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-08-10 16:08:54 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-08-10 16:08:54 +0200 |
commit | 66b8b332874136f95fff1290dbd7b7001e4a9398 (patch) | |
tree | 51034d282a1ce67e05616a4d545899cff22979cf /config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | |
parent | bc41b0e6489e3002d75c400e4dda4f4218306554 (diff) |
Refactor and remove duplication
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 20 |
1 files changed, 5 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 27e1ad51b33..fab5068ea6f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -43,28 +43,17 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter { if ( ! feature.getName().equals("tensorflow")) return feature; try { - FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments()); - ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); - if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files - return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles()); - else - return transformFromStoredModel(store, context.rankProfile()); + ConvertedModel.FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments()); + ConvertedModel convertedModel = new ConvertedModel(arguments, context, tensorFlowImporter, importedModels); + return convertedModel.expression(); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); } } - private ExpressionNode transformFromTensorFlowModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { - ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), - k -> tensorFlowImporter.importModel(store.arguments().modelName(), - store.modelDir())); - return transformFromImportedModel(model, store, profile, queryProfiles); - } + static class TensorFlowFeatureArguments extends ConvertedModel.FeatureArguments { - static class TensorFlowFeatureArguments extends FeatureArguments { public TensorFlowFeatureArguments(Arguments arguments) { if (arguments.isEmpty()) throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + @@ -76,6 +65,7 @@ public class TensorFlowFeatureConverter extends MLImportFeatureConverter { signature = optionalArgument(1, arguments); output = optionalArgument(2, arguments); } + } } |