diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 683 |
1 files changed, 10 insertions, 673 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 41da32f64c3..75d72111e9a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,59 +1,18 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.expressiontransforms; -import com.google.common.base.Joiner; -import com.yahoo.collections.Pair; -import com.yahoo.config.application.api.ApplicationFile; -import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.application.provider.FilesApplicationPackage; -import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.tensor.serialization.TypedBinaryFormat; -import java.io.BufferedReader; -import java.io.File; -import java.io.IOException; -import java.io.StringReader; import java.io.UncheckedIOException; -import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; -import java.util.HashSet; -import java.util.List; import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.stream.Collectors; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -62,12 +21,12 @@ import java.util.stream.Collectors; * * @author bratseth */ -public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { +public class TensorFlowFeatureConverter extends MLImportFeatureConverter { private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<Path, TensorFlowModel> importedModels = new HashMap<>(); + private final Map<Path, ImportedModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -95,634 +54,12 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } private ExpressionNode transformFromTensorFlowModel(ModelStore store, - RankProfile profile, - QueryProfileRegistry queryProfiles) { - TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), - k -> tensorFlowImporter.importModel(store.arguments().modelName(), - store.tensorFlowModelDir())); - - // Add constants - Set<String> constantsReplacedByMacros = new HashSet<>(); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, - constantsReplacedByMacros, k, v)); - - // Find the specified expression - Signature signature = chooseSignature(model, store.arguments().signature()); - String output = chooseOutput(signature, store.arguments().output()); - if (signature.skippedOutputs().containsKey(output)) { - String message = "Could not import TensorFlow model output '" + output + "'"; - if (!signature.skippedOutputs().get(output).isEmpty()) { - message += ": " + signature.skippedOutputs().get(output); - } - if (!signature.importWarnings().isEmpty()) { - message += ": " + String.join(", ", signature.importWarnings()); - } - throw new IllegalArgumentException(message); - } - - RankingExpression expression = model.expressions().get(output); - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - verifyRequiredMacros(expression, model, profile, queryProfiles); - addGeneratedMacros(model, profile); - reduceBatchDimensions(expression, model, profile, queryProfiles); - - model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v)); - - store.writeConverted(expression); - return expression.getRoot(); - } - - private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { - for (Pair<String, Tensor> constant : store.readSmallConstants()) - profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); - - for (RankingConstant constant : store.readLargeConstants()) { - if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) - profile.getSearch().addRankingConstant(constant); - } - - for (Pair<String, RankingExpression> macro : store.readMacros()) { - addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); - } - - return store.readConverted().getRoot(); - } - - /** - * Returns the specified, existing signature, or the only signature if none is specified. - * Throws IllegalArgumentException in all other cases. - */ - private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) { - if ( ! signatureName.isPresent()) { - if (importResult.signatures().size() == 0) - throw new IllegalArgumentException("No signatures are available"); - if (importResult.signatures().size() > 1) - throw new IllegalArgumentException("Model has multiple signatures (" + - Joiner.on(", ").join(importResult.signatures().keySet()) + - "), one must be specified " + - "as a second argument to tensorflow()"); - return importResult.signatures().values().stream().findFirst().get(); - } - else { - Signature signature = importResult.signatures().get(signatureName.get()); - if (signature == null) - throw new IllegalArgumentException("Model does not have the specified signature '" + - signatureName.get() + "'"); - return signature; - } - } - - /** - * Returns the specified, existing output expression, or the only output expression if no output name is specified. - * Throws IllegalArgumentException in all other cases. - */ - private String chooseOutput(Signature signature, Optional<String> outputName) { - if ( ! outputName.isPresent()) { - if (signature.outputs().size() == 0) - throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); - if (signature.outputs().size() > 1) - throw new IllegalArgumentException(signature + " has multiple outputs (" + - Joiner.on(", ").join(signature.outputs().keySet()) + - "), one must be specified " + - "as a third argument to tensorflow()"); - return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get()); - } - else { - String output = signature.outputs().get(outputName.get()); - if (output == null) { - if (signature.skippedOutputs().containsKey(outputName.get())) - throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " + - signature.skippedOutputs().get(outputName.get())); - else - throw new IllegalArgumentException("Model does not have the specified output '" + - outputName.get() + "'"); - } - return output; - } - } - - private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - store.writeSmallConstant(constantName, constantValue); - profile.addConstant(constantName, asValue(constantValue)); - } - - private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Set<String> constantsReplacedByMacros, - String constantName, Tensor constantValue) { - RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); - if (macroOverridingConstant != null) { - TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); - if ( ! macroType.equals(constantValue.type())) - throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + - typeMismatchExplanation(constantValue.type(), macroType)); - constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later - } - else { - Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); - } - } - } - - private void transformGeneratedMacro(ModelStore store, - Set<String> constantsReplacedByMacros, - String macroName, RankingExpression expression) { - - expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - store.writeMacro(macroName, expression); - } - - private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { - if (profile.getMacros().containsKey(macroName)) { - throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists."); - } - profile.addMacro(macroName, false); // todo: inline if only used once - RankProfile.Macro macro = profile.getMacros().get(macroName); - macro.setRankingExpression(expression); - macro.setTextualExpression(expression.getRoot().toString()); - } - - private String skippedOutputsDescription(TensorFlowModel.Signature signature) { - if (signature.skippedOutputs().isEmpty()) return ""; - StringBuilder b = new StringBuilder(": "); - signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); - return b.toString(); - } - - /** - * Verify that the macros referred in the given expression exists in the given rank profile, - * and return tensors of the types specified in requiredMacros. - */ - private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - TensorType requiredType = model.requiredMacros().get(macroName); - if (requiredType == null) continue; // Not a required macro - - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + - "' of type " + requiredType + " but this macro is not present in " + - profile); - // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second - // phase and summary features), as it may only resolve correctly given those bindings - // Or, probably better, annotate the macros with type constraints here and verify during general - // type verification - TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); - if ( actualType == null) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + - "' of type " + requiredType + - " which must be produced by a macro in the rank profile, but " + - "this macro references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) - throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " + - typeMismatchExplanation(requiredType, actualType)); - } - } - - private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) { - return "The required type of this is " + requiredType + ", but this macro returns " + actualType + - (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " + - "in query profile types - see the documentation." - : ""); - } - - /** - * Add the generated macros to the rank profile - */ - private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) { - model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); - } - - /** - * Check if batch dimensions of inputs can be reduced out. If the input - * macro specifies that a single exemplar should be evaluated, we can - * reduce the batch dimension out. - */ - private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated macros for inputs to reduce - Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames, model); - for (String macroName : macroNames) { - if ( ! model.macros().containsKey(macroName)) { - continue; - } - RankProfile.Macro macro = profile.getMacros().get(macroName); - if (macro == null) { - throw new IllegalArgumentException("Model refers to generated macro '" + macroName + - "but this macro is not present in " + profile); - } - RankingExpression macroExpression = macro.getRankingExpression(); - macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, - TypeContext<Reference> typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.requiredMacros().containsKey(referenceNode.getName())) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - } - } - return new TensorFunctionNode(result); - } - - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - * Todo: determine when this is not necessary! - */ - private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) { - return node; - } - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } - - /** - * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. - * This method does that for the given expression and returns the result. - */ - private RankingExpression replaceConstantsByMacros(RankingExpression expression, - Set<String> constantsReplacedByMacros) { - if (constantsReplacedByMacros.isEmpty()) return expression; - return new RankingExpression(expression.getName(), - replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); - } - - private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { - if (node instanceof ReferenceNode) { - Reference reference = ((ReferenceNode)node).reference(); - if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { - String argument = reference.simpleArgument().get(); - if (constantsReplacedByMacros.contains(argument)) - return new ReferenceNode(argument); - } - } - if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above - CompositeNode composite = (CompositeNode)node; - return composite.setChildren(composite.children().stream() - .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) - .collect(Collectors.toList())); - } - return node; - } - - private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) { - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode)node; - if (referenceNode.getOutput() == null) { // macro references cannot specify outputs - names.add(referenceNode.getName()); - if (model.macros().containsKey(referenceNode.getName())) { - addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); - } - } - } - else if (node instanceof CompositeNode) { - for (ExpressionNode child : ((CompositeNode)node).children()) - addMacroNamesIn(child, names, model); - } - } - - private Value asValue(Tensor tensor) { - if (tensor.type().rank() == 0) - return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors - else - return new TensorValue(tensor); - } - - /** - * Provides read/write access to the correct directories of the application package given by the feature arguments - */ - private static class ModelStore { - - private final ApplicationPackage application; - private final FeatureArguments arguments; - - public ModelStore(ApplicationPackage application, Arguments arguments) { - this.application = application; - this.arguments = new FeatureArguments(arguments); - } - - - - public FeatureArguments arguments() { return arguments; } - - public boolean hasStoredModel() { - try { - return application.getFile(arguments.expressionPath()).exists(); - } - catch (UnsupportedOperationException e) { - return false; - } - } - - /** - * Returns the directory which (if hasTensorFlowModels is true) - * contains the source model to use for these arguments - */ - public File tensorFlowModelDir() { - return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath())); - } - - /** - * Adds this expression to the application package, such that it can be read later. - */ - public void writeConverted(RankingExpression expression) { - application.getFile(arguments.expressionPath()) - .writeFile(new StringReader(expression.getRoot().toString())); - } - - /** Reads the previously stored ranking expression for these arguments */ - public RankingExpression readConverted() { - try { - return new RankingExpression(application.getFile(arguments.expressionPath()).createReader()); - } - catch (IOException e) { - throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - - /** Adds this macro expression to the application package to it can be read later. */ - public void writeMacro(String name, RankingExpression expression) { - application.getFile(arguments.macrosPath()).appendFile(name + "\t" + - expression.getRoot().toString() + "\n"); - } - - /** Reads the previously stored macro expressions for these arguments */ - public List<Pair<String, RankingExpression>> readMacros() { - try { - ApplicationFile file = application.getFile(arguments.macrosPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, RankingExpression>> macros = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - try { - RankingExpression expression = new RankingExpression(parts[1]); - macros.add(new Pair<>(name, expression)); - } - catch (ParseException e) { - throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); - } - } - return macros; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Reads the information about all the large (aka ranking) constants stored in the application package - * (the constant value itself is replicated with file distribution). - */ - public List<RankingConstant> readLargeConstants() { - try { - List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { - String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); - constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Adds this constant to the application package as a file, - * such that it can be distributed using file distribution. - * - * @return the path to the stored constant, relative to the application package root - */ - public Path writeLargeConstant(String name, Tensor constant) { - Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); - - // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: - Path constantPath = constantsPath.append(name + ".tbf"); - - // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.largeConstantsPath().append(name + ".constant")) - .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); - - // Write content explicitly as a file on the file system as this is distributed using file distribution - createIfNeeded(constantsPath); - IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); - return correct(constantPath); - } - - private List<Pair<String, Tensor>> readSmallConstants() { - try { - ApplicationFile file = application.getFile(arguments.smallConstantsPath()); - if (!file.exists()) return Collections.emptyList(); - - List<Pair<String, Tensor>> constants = new ArrayList<>(); - BufferedReader reader = new BufferedReader(file.createReader()); - String line; - while (null != (line = reader.readLine())) { - String[] parts = line.split("\t"); - String name = parts[0]; - TensorType type = TensorType.fromSpec(parts[1]); - Tensor tensor = Tensor.from(type, parts[2]); - constants.add(new Pair<>(name, tensor)); - } - return constants; - } - catch (IOException e) { - throw new UncheckedIOException(e); - } - } - - /** - * Append this constant to the single file used for small constants distributed as config - */ - public void writeSmallConstant(String name, Tensor constant) { - // Secret file format for remembering constants: - application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + - constant.type().toString() + "\t" + - constant.toString() + "\n"); - } - - /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ - private Path correct(Path path) { - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) - && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { - return Path.fromString(FilesApplicationPackage.preprocessed).append(path); - } - else { - return path; - } - } - - private void createIfNeeded(Path path) { - File dir = application.getFileReference(path); - if ( ! dir.exists()) { - if (!dir.mkdirs()) - throw new IllegalStateException("Could not create " + dir); - } - } - - } - - /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */ - private static class FeatureArguments { - - private final Path modelPath; - - /** Optional arguments */ - private final Optional<String> signature, output; - - public FeatureArguments(Arguments arguments) { - if (arguments.isEmpty()) - throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + - "the tensorflow model directory under [application]/models"); - if (arguments.expressions().size() > 3) - throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments"); - - modelPath = Path.fromString(asString(arguments.expressions().get(0))); - signature = optionalArgument(1, arguments); - output = optionalArgument(2, arguments); - } - - /** Returns modelPath with slashes replaced by underscores */ - public String modelName() { return modelPath.toString().replace('/', '_'); } - - /** Returns relative path to this model below the "models/" dir in the application package */ - public Path modelPath() { return modelPath; } - public Optional<String> signature() { return signature; } - public Optional<String> output() { return output; } - - /** Path to the small constants file */ - public Path smallConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); - } - - /** Path to the large (ranking) constants directory */ - public Path largeConstantsPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); - } - - /** Path to the macros file */ - public Path macrosPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt"); - } - - public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR - .append(modelPath).append("expressions").append(expressionFileName()); - } - - private String expressionFileName() { - StringBuilder fileName = new StringBuilder(); - signature.ifPresent(s -> fileName.append(s).append(".")); - output.ifPresent(s -> fileName.append(s).append(".")); - if (fileName.length() == 0) // single signature and output - fileName.append("single."); - fileName.append("expression"); - return fileName.toString(); - } - - private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { - if (argumentIndex >= arguments.expressions().size()) - return Optional.empty(); - return Optional.of(asString(arguments.expressions().get(argumentIndex))); - } - - private String asString(ExpressionNode node) { - if ( ! (node instanceof ConstantNode)) - throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); - return stripQuotes(((ConstantNode)node).sourceString()); - } - - private String stripQuotes(String s) { - if ( ! isQuoteSign(s.codePointAt(0))) return s; - if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) - throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); - return s.substring(1, s.length()-1); - } - - private boolean isQuoteSign(int c) { - return c == '\'' || c == '"'; - } - + RankProfile profile, + QueryProfileRegistry queryProfiles) { + ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), + k -> tensorFlowImporter.importModel(store.arguments().modelName(), + store.modelDir())); + return transformFromImportedModel(model, store, profile, queryProfiles); } } |