diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 24 |
1 files changed, 9 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index b2c096d4e95..d28299b1d30 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -44,8 +44,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - ConvertedModel.FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments()); - ConvertedModel convertedModel = new ConvertedModel(arguments, context, tensorFlowImporter, importedModels); + ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()), + context, tensorFlowImporter, importedModels); return convertedModel.expression(); } catch (IllegalArgumentException | UncheckedIOException e) { @@ -53,20 +53,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - static class TensorFlowFeatureArguments extends ConvertedModel.FeatureArguments { - - public TensorFlowFeatureArguments(Arguments arguments) { - if (arguments.isEmpty()) - throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); - if (arguments.expressions().size() > 3) - throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - - modelPath = Path.fromString(asString(arguments.expressions().get(0))); - signature = optionalArgument(1, arguments); - output = optionalArgument(2, arguments); - } + private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + + "the tensorflow model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); + return new ConvertedModel.FeatureArguments(arguments); } } |