diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms')
3 files changed, 40 insertions, 38 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java index 828d8228db1..867740c7912 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConvertedModel.java @@ -69,7 +69,7 @@ public class ConvertedModel { ModelImporter modelImporter, Map<Path, ImportedModel> importedModels) { ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments); - if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files + if ( ! store.hasStoredModel()) // not converted yet - access from models/ directory convertedExpression = importModel(store, context.rankProfile(), context.queryProfiles(), modelImporter, importedModels); else convertedExpression = transformFromStoredModel(store, context.rankProfile()); @@ -88,10 +88,10 @@ public class ConvertedModel { public ExpressionNode expression() { return convertedExpression; } - ExpressionNode transformFromImportedModel(ImportedModel model, - ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { + private ExpressionNode transformFromImportedModel(ImportedModel model, + ModelStore store, + RankProfile profile, + QueryProfileRegistry queryProfiles) { // Add constants Set<String> constantsReplacedByMacros = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -631,13 +631,25 @@ public class ConvertedModel { } /** Encapsulates the arguments to the import feature */ - static abstract class FeatureArguments { + static class FeatureArguments { Path modelPath; /** Optional arguments */ Optional<String> signature, output; + public FeatureArguments(Arguments arguments) { + this(Path.fromString(asString(arguments.expressions().get(0))), + optionalArgument(1, arguments), + optionalArgument(2, arguments)); + } + + public FeatureArguments(Path modelPath, Optional<String> signature, Optional<String> output) { + this.modelPath = modelPath; + this.signature = signature; + this.output = output; + } + /** Returns modelPath with slashes replaced by underscores */ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); } @@ -676,26 +688,26 @@ public class ConvertedModel { return fileName.toString(); } - Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { if (argumentIndex >= arguments.expressions().size()) return Optional.empty(); return Optional.of(asString(arguments.expressions().get(argumentIndex))); } - String asString(ExpressionNode node) { + private static String asString(ExpressionNode node) { if ( ! (node instanceof ConstantNode)) throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); return stripQuotes(((ConstantNode)node).sourceString()); } - private String stripQuotes(String s) { + private static String stripQuotes(String s) { if ( ! isQuoteSign(s.codePointAt(0))) return s; if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) throw new IllegalArgumentException("argument [" + s + "] is missing endquote"); return s.substring(1, s.length()-1); } - private boolean isQuoteSign(int c) { + private static boolean isQuoteSign(int c) { return c == '\'' || c == '"'; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index ce656b74b54..d31ffefde65 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -47,8 +47,8 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans if ( ! feature.getName().equals("onnx")) return feature; try { - ConvertedModel.FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments()); - ConvertedModel convertedModel = new ConvertedModel(arguments, context, onnxImporter, importedModels); + ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()), + context, onnxImporter, importedModels); return convertedModel.expression(); } catch (IllegalArgumentException | UncheckedIOException e) { @@ -56,18 +56,14 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans } } - static class OnnxFeatureArguments extends ConvertedModel.FeatureArguments { - public OnnxFeatureArguments(Arguments arguments) { - if (arguments.isEmpty()) - throw new IllegalArgumentException("An onnx node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); - if (arguments.expressions().size() > 3) - throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); + private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("An onnx node must take an argument pointing to " + + "the onnx model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("An onnx feature can have at most 2 arguments"); - modelPath = Path.fromString(asString(arguments.expressions().get(0))); - output = optionalArgument(1, arguments); - signature = Optional.of("default"); - } + return new ConvertedModel.FeatureArguments(arguments); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index b2c096d4e95..d28299b1d30 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -44,8 +44,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if ( ! feature.getName().equals("tensorflow")) return feature; try { - ConvertedModel.FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments()); - ConvertedModel convertedModel = new ConvertedModel(arguments, context, tensorFlowImporter, importedModels); + ConvertedModel convertedModel = new ConvertedModel(asFeatureArguments(feature.getArguments()), + context, tensorFlowImporter, importedModels); return convertedModel.expression(); } catch (IllegalArgumentException | UncheckedIOException e) { @@ -53,20 +53,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - static class TensorFlowFeatureArguments extends ConvertedModel.FeatureArguments { - - public TensorFlowFeatureArguments(Arguments arguments) { - if (arguments.isEmpty()) - throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); - if (arguments.expressions().size() > 3) - throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - - modelPath = Path.fromString(asString(arguments.expressions().get(0))); - signature = optionalArgument(1, arguments); - output = optionalArgument(2, arguments); - } + private ConvertedModel.FeatureArguments asFeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + + "the tensorflow model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); + return new ConvertedModel.FeatureArguments(arguments); } } |