diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-16 11:23:09 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-16 11:23:09 +0100 |
commit | d9e17187fe49f662520d282c38e5cf779cbb8195 (patch) | |
tree | 19521bf836aa57fb4b9056ff12f77d8a0c957f60 /config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | |
parent | dd744223e7db5f14805b7e23dbe69f143b60f1a3 (diff) |
Refactor (no functional changes)
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 242 |
1 files changed, 134 insertions, 108 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 9fd4199f833..7cefa9d9187 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -59,27 +59,52 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) { - try { - if ( ! feature.getName().equals("tensorflow")) return feature; - - if (feature.getArguments().isEmpty()) - throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); + if ( ! feature.getName().equals("tensorflow")) return feature; - // modelPath: The relative path to this model below the "models/" dir in the application package - Path modelPath = Path.fromString(asString(feature.getArguments().expressions().get(0))); - Optional<String> signatureArg = optionalArgument(1, feature.getArguments()); - Optional<String> outputArg = optionalArgument(2, feature.getArguments()); - if (new File(ApplicationPackage.MODELS_DIR.append(modelPath).getRelative()).getCanonicalFile().exists()) - return transformFromTensorFlowModel(modelPath, signatureArg, outputArg, context.rankProfile()); + try { + FeatureArguments arguments = new FeatureArguments(feature.getArguments()); + if (arguments.modelDir().exists()) + return transformFromTensorFlowModel(arguments, context.rankProfile()); else - return transformFromStoredConvertedModel(modelPath, signatureArg, outputArg); + return transformFromStoredConvertedModel(arguments); } - catch (IllegalArgumentException | IOException e) { + catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); } } + private ExpressionNode transformFromTensorFlowModel(FeatureArguments arguments, RankProfile rankProfile) { + TensorFlowModel model = importedModels.computeIfAbsent(arguments.modelPath(), + k -> tensorFlowImporter.importModel(arguments.modelDir().toString())); + + // Find the specified expression + Signature signature = chooseSignature(model, arguments.signature()); + String output = chooseOutput(signature, arguments.output()); + RankingExpression expression = model.expressions().get(output); + writeConverted(arguments, expression); + + // Add all constants (after finding outputs to fail faster when the output is not found) + if (constantsInConfig) + model.constants().forEach((k, v) -> rankProfile.addConstantTensor(k, new TensorValue(v))); + else // correct way, disabled for now + model.constants().forEach((k, v) -> transformConstant(arguments, rankProfile, k, v)); + + return expression.getRoot(); + } + + private ExpressionNode transformFromStoredConvertedModel(FeatureArguments arguments) { + File expressionFile = null; + try { + return new RankingExpression(IOUtils.readFile(arguments.expressionFile())).getRoot(); + } + catch (IOException e) { + throw new UncheckedIOException("Could not read " + expressionFile, e); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + expressionFile, e); + } + } + /** * Returns the specified, existing signature, or the only signature if none is specified. * Throws IllegalArgumentException in all other cases. @@ -133,77 +158,25 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private ExpressionNode transformFromTensorFlowModel(Path modelPath, - Optional<String> signatureArg, - Optional<String> outputArg, - RankProfile rankProfile) { - TensorFlowModel model = importedModels.computeIfAbsent(modelPath, k -> importModel(modelPath)); - - // Find the specified expression - Signature signature = chooseSignature(model, signatureArg); - String output = chooseOutput(signature, outputArg); - RankingExpression expression = model.expressions().get(output); - writeConverted(modelPath, signatureArg, outputArg, expression); - - // Add all constants (after finding outputs to fail faster when the output is not found) - if (constantsInConfig) - model.constants().forEach((k, v) -> rankProfile.addConstantTensor(k, new TensorValue(v))); - else // correct way, disabled for now - model.constants().forEach((k, v) -> transformConstant(modelPath, rankProfile, k, v)); - - return expression.getRoot(); - } - - private ExpressionNode transformFromStoredConvertedModel(Path modelPath, - Optional<String> signatureArg, - Optional<String> outputArg) { - File expressionFile = null; + private void writeConverted(FeatureArguments arguments, RankingExpression expression) { try { - expressionFile = expressionFile(modelPath, signatureArg, outputArg); - return new RankingExpression(IOUtils.readFile(expressionFile)).getRoot(); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + expressionFile, e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + expressionFile, e); - } - } - - private TensorFlowModel importModel(Path modelPath) { - try { - return tensorFlowImporter.importModel(new File(ApplicationPackage.MODELS_DIR.append(modelPath) - .getRelative()) - .getCanonicalPath()); - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - private void writeConverted(Path modelPath, Optional<String> signatureArg, Optional<String> outputArg, RankingExpression expression) { - try { - IOUtils.writeFile(expressionFile(modelPath, signatureArg, outputArg), expression.getRoot().toString(), false); + IOUtils.writeFile(arguments.expressionFile(), expression.getRoot().toString(), false); } catch (IOException e) { throw new UncheckedIOException(e); } } - private void transformConstant(Path modelPath, RankProfile profile, String constantName, Tensor constantValue) { + private void transformConstant(FeatureArguments arguments, RankProfile profile, String constantName, Tensor constantValue) { try { if (profile.getSearch().getRankingConstants().containsKey(constantName)) return; - File constantFilePath = new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath) - .append("constants") - .getRelative()) - .getCanonicalFile(); - if ( ! constantFilePath.exists()) - if ( ! constantFilePath.mkdir()) - throw new IOException("Could not create directory " + constantFilePath); + if ( ! arguments.constantsDir().exists()) + if ( ! arguments.constantsDir().mkdir()) + throw new IOException("Could not create directory " + arguments.constantsDir()); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: - File constantFile = new File(constantFilePath, constantName + ".tbf"); + File constantFile = new File(arguments.constantsDir(), constantName + ".tbf"); IOUtils.writeFile(constantFile, TypedBinaryFormat.encode(constantValue)); profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFile.getPath())); } @@ -219,47 +192,100 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return b.toString(); } - private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } + /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */ + private static class FeatureArguments { - private String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } + private final Path modelPath; - private String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } + /** Optional arguments */ + private final Optional<String> signature, output; - private boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } + public FeatureArguments(Arguments arguments) { + if (arguments.isEmpty()) + throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + + "the tensorflow model directory under [application]/models"); + if (arguments.expressions().size() > 3) + throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - private File expressionFile(Path modelPath, Optional<String> signatureArg, Optional<String> outputArg) { - try { - StringBuilder fileName = new StringBuilder(); - signatureArg.ifPresent(s -> fileName.append(s).append(".")); - outputArg.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - - return new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath) - .append("expressions") - .append(fileName.toString()) - .getRelative()) - .getCanonicalFile(); + modelPath = Path.fromString(asString(arguments.expressions().get(0))); + signature = optionalArgument(1, arguments); + output = optionalArgument(2, arguments); } - catch (IOException e) { - throw new UncheckedIOException(e); + + /** Returns relative path to this model below the "models/" dir in the application package */ + public Path modelPath() { return modelPath; } + public Optional<String> signature() { return signature; } + public Optional<String> output() { return output; } + + /** + * Returns a File representing the actual location of the TensorFlow models given as part of the + * application package. This directory exists only when we are reading an application package supplied + * by a user. + */ + public File modelDir() { + try { + return new File(ApplicationPackage.MODELS_DIR.append(modelPath).getRelative()).getCanonicalFile(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public File expressionFile() { + try { + StringBuilder fileName = new StringBuilder(); + signature.ifPresent(s -> fileName.append(s).append(".")); + output.ifPresent(s -> fileName.append(s).append(".")); + if (fileName.length() == 0) // single signature and output + fileName.append("single."); + fileName.append("expression"); + + return new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath) + .append("expressions") + .append(fileName.toString()) + .getRelative()) + .getCanonicalFile(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + public File constantsDir() { + try { + return new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath) + .append("constants") + .getRelative()) + .getCanonicalFile(); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + private String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + } } |