summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-16 11:23:09 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-16 11:23:09 +0100
commitd9e17187fe49f662520d282c38e5cf779cbb8195 (patch)
tree19521bf836aa57fb4b9056ff12f77d8a0c957f60 /config-model
parentdd744223e7db5f14805b7e23dbe69f143b60f1a3 (diff)
Refactor (no functional changes)
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java242
1 files changed, 134 insertions, 108 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 9fd4199f833..7cefa9d9187 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -59,27 +59,52 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
- try {
- if ( ! feature.getName().equals("tensorflow")) return feature;
-
- if (feature.getArguments().isEmpty())
- throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
+ if ( ! feature.getName().equals("tensorflow")) return feature;
- // modelPath: The relative path to this model below the "models/" dir in the application package
- Path modelPath = Path.fromString(asString(feature.getArguments().expressions().get(0)));
- Optional<String> signatureArg = optionalArgument(1, feature.getArguments());
- Optional<String> outputArg = optionalArgument(2, feature.getArguments());
- if (new File(ApplicationPackage.MODELS_DIR.append(modelPath).getRelative()).getCanonicalFile().exists())
- return transformFromTensorFlowModel(modelPath, signatureArg, outputArg, context.rankProfile());
+ try {
+ FeatureArguments arguments = new FeatureArguments(feature.getArguments());
+ if (arguments.modelDir().exists())
+ return transformFromTensorFlowModel(arguments, context.rankProfile());
else
- return transformFromStoredConvertedModel(modelPath, signatureArg, outputArg);
+ return transformFromStoredConvertedModel(arguments);
}
- catch (IllegalArgumentException | IOException e) {
+ catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
}
}
+ private ExpressionNode transformFromTensorFlowModel(FeatureArguments arguments, RankProfile rankProfile) {
+ TensorFlowModel model = importedModels.computeIfAbsent(arguments.modelPath(),
+ k -> tensorFlowImporter.importModel(arguments.modelDir().toString()));
+
+ // Find the specified expression
+ Signature signature = chooseSignature(model, arguments.signature());
+ String output = chooseOutput(signature, arguments.output());
+ RankingExpression expression = model.expressions().get(output);
+ writeConverted(arguments, expression);
+
+ // Add all constants (after finding outputs to fail faster when the output is not found)
+ if (constantsInConfig)
+ model.constants().forEach((k, v) -> rankProfile.addConstantTensor(k, new TensorValue(v)));
+ else // correct way, disabled for now
+ model.constants().forEach((k, v) -> transformConstant(arguments, rankProfile, k, v));
+
+ return expression.getRoot();
+ }
+
+ private ExpressionNode transformFromStoredConvertedModel(FeatureArguments arguments) {
+ File expressionFile = null;
+ try {
+ return new RankingExpression(IOUtils.readFile(arguments.expressionFile())).getRoot();
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + expressionFile, e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + expressionFile, e);
+ }
+ }
+
/**
* Returns the specified, existing signature, or the only signature if none is specified.
* Throws IllegalArgumentException in all other cases.
@@ -133,77 +158,25 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
- private ExpressionNode transformFromTensorFlowModel(Path modelPath,
- Optional<String> signatureArg,
- Optional<String> outputArg,
- RankProfile rankProfile) {
- TensorFlowModel model = importedModels.computeIfAbsent(modelPath, k -> importModel(modelPath));
-
- // Find the specified expression
- Signature signature = chooseSignature(model, signatureArg);
- String output = chooseOutput(signature, outputArg);
- RankingExpression expression = model.expressions().get(output);
- writeConverted(modelPath, signatureArg, outputArg, expression);
-
- // Add all constants (after finding outputs to fail faster when the output is not found)
- if (constantsInConfig)
- model.constants().forEach((k, v) -> rankProfile.addConstantTensor(k, new TensorValue(v)));
- else // correct way, disabled for now
- model.constants().forEach((k, v) -> transformConstant(modelPath, rankProfile, k, v));
-
- return expression.getRoot();
- }
-
- private ExpressionNode transformFromStoredConvertedModel(Path modelPath,
- Optional<String> signatureArg,
- Optional<String> outputArg) {
- File expressionFile = null;
+ private void writeConverted(FeatureArguments arguments, RankingExpression expression) {
try {
- expressionFile = expressionFile(modelPath, signatureArg, outputArg);
- return new RankingExpression(IOUtils.readFile(expressionFile)).getRoot();
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + expressionFile, e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + expressionFile, e);
- }
- }
-
- private TensorFlowModel importModel(Path modelPath) {
- try {
- return tensorFlowImporter.importModel(new File(ApplicationPackage.MODELS_DIR.append(modelPath)
- .getRelative())
- .getCanonicalPath());
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- private void writeConverted(Path modelPath, Optional<String> signatureArg, Optional<String> outputArg, RankingExpression expression) {
- try {
- IOUtils.writeFile(expressionFile(modelPath, signatureArg, outputArg), expression.getRoot().toString(), false);
+ IOUtils.writeFile(arguments.expressionFile(), expression.getRoot().toString(), false);
}
catch (IOException e) {
throw new UncheckedIOException(e);
}
}
- private void transformConstant(Path modelPath, RankProfile profile, String constantName, Tensor constantValue) {
+ private void transformConstant(FeatureArguments arguments, RankProfile profile, String constantName, Tensor constantValue) {
try {
if (profile.getSearch().getRankingConstants().containsKey(constantName)) return;
- File constantFilePath = new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath)
- .append("constants")
- .getRelative())
- .getCanonicalFile();
- if ( ! constantFilePath.exists())
- if ( ! constantFilePath.mkdir())
- throw new IOException("Could not create directory " + constantFilePath);
+ if ( ! arguments.constantsDir().exists())
+ if ( ! arguments.constantsDir().mkdir())
+ throw new IOException("Could not create directory " + arguments.constantsDir());
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
- File constantFile = new File(constantFilePath, constantName + ".tbf");
+ File constantFile = new File(arguments.constantsDir(), constantName + ".tbf");
IOUtils.writeFile(constantFile, TypedBinaryFormat.encode(constantValue));
profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFile.getPath()));
}
@@ -219,47 +192,100 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
return b.toString();
}
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
+ /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */
+ private static class FeatureArguments {
- private String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
+ private final Path modelPath;
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
+ /** Optional arguments */
+ private final Optional<String> signature, output;
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
+ public FeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
+ "the tensorflow model directory under [application]/models");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
- private File expressionFile(Path modelPath, Optional<String> signatureArg, Optional<String> outputArg) {
- try {
- StringBuilder fileName = new StringBuilder();
- signatureArg.ifPresent(s -> fileName.append(s).append("."));
- outputArg.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
-
- return new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath)
- .append("expressions")
- .append(fileName.toString())
- .getRelative())
- .getCanonicalFile();
+ modelPath = Path.fromString(asString(arguments.expressions().get(0)));
+ signature = optionalArgument(1, arguments);
+ output = optionalArgument(2, arguments);
}
- catch (IOException e) {
- throw new UncheckedIOException(e);
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ /**
+ * Returns a File representing the actual location of the TensorFlow models given as part of the
+ * application package. This directory exists only when we are reading an application package supplied
+ * by a user.
+ */
+ public File modelDir() {
+ try {
+ return new File(ApplicationPackage.MODELS_DIR.append(modelPath).getRelative()).getCanonicalFile();
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ public File expressionFile() {
+ try {
+ StringBuilder fileName = new StringBuilder();
+ signature.ifPresent(s -> fileName.append(s).append("."));
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+
+ return new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath)
+ .append("expressions")
+ .append(fileName.toString())
+ .getRelative())
+ .getCanonicalFile();
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ public File constantsDir() {
+ try {
+ return new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath)
+ .append("constants")
+ .getRelative())
+ .getCanonicalFile();
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ private String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
}
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
}
}