aboutsummaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-06-06 13:38:37 +0200
committerGitHub <noreply@github.com>2018-06-06 13:38:37 +0200
commit240176d60c44507f4e6733c7512620e80554c8de (patch)
tree7b1f54a9acd169da88a524f4899ddcf76d02db28 /config-model
parente4f626c587cf1cc4d5c05da5e15523f4162107f0 (diff)
parente4626398c7e9c1b4b0fa5dbd974e1696c377dd77 (diff)
Merge pull request #6046 from vespa-engine/lesters/refactor-onnx-tensorflow-import
Refactor ONNX and TF import to use same code base
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java674
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java636
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java677
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java22
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java26
5 files changed, 718 insertions, 1317 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
new file mode 100644
index 00000000000..effa261be3b
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MLImportFeatureConverter.java
@@ -0,0 +1,674 @@
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.google.common.base.Joiner;
+import com.yahoo.collections.Pair;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Base class for replacing instances of a pseudofeature for imported ML
+ * ranking models with native Vespa ranking expressions.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+abstract class MLImportFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+
+ ExpressionNode transformFromImportedModel(ImportedModel model,
+ ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
+ // Find the specified expression
+ ImportedModel.Signature signature = chooseSignature(model, store.arguments().signature());
+ String output = chooseOutput(signature, store.arguments().output());
+ if (signature.skippedOutputs().containsKey(output)) {
+ String message = "Could not import model output '" + output + "'";
+ if (!signature.skippedOutputs().get(output).isEmpty()) {
+ message += ": " + signature.skippedOutputs().get(output);
+ }
+ if (!signature.importWarnings().isEmpty()) {
+ message += ": " + String.join(", ", signature.importWarnings());
+ }
+ throw new IllegalArgumentException(message);
+ }
+
+ RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
+ addGeneratedMacros(model, profile);
+ reduceBatchDimensions(expression, model, profile, queryProfiles);
+
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
+
+ store.writeConverted(expression);
+ return expression.getRoot();
+ }
+
+ ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (Pair<String, Tensor> constant : store.readSmallConstants())
+ profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+
+ for (RankingConstant constant : store.readLargeConstants()) {
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
+ }
+
+ for (Pair<String, RankingExpression> macro : store.readMacros()) {
+ addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ }
+
+ return store.readConverted().getRoot();
+ }
+
+ /**
+ * Returns the specified, existing signature, or the only signature if none is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private ImportedModel.Signature chooseSignature(ImportedModel importResult, Optional<String> signatureName) {
+ if ( ! signatureName.isPresent()) {
+ if (importResult.signatures().size() == 0)
+ throw new IllegalArgumentException("No signatures are available");
+ if (importResult.signatures().size() > 1)
+ throw new IllegalArgumentException("Model has multiple signatures (" +
+ Joiner.on(", ").join(importResult.signatures().keySet()) +
+ "), one must be specified " +
+ "as a second argument to tensorflow()");
+ return importResult.signatures().values().stream().findFirst().get();
+ }
+ else {
+ ImportedModel.Signature signature = importResult.signatures().get(signatureName.get());
+ if (signature == null)
+ throw new IllegalArgumentException("Model does not have the specified signature '" +
+ signatureName.get() + "'");
+ return signature;
+ }
+ }
+
+ /**
+ * Returns the specified, existing output expression, or the only output expression if no output name is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private String chooseOutput(ImportedModel.Signature signature, Optional<String> outputName) {
+ if ( ! outputName.isPresent()) {
+ if (signature.outputs().size() == 0)
+ throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
+ if (signature.outputs().size() > 1)
+ throw new IllegalArgumentException(signature + " has multiple outputs (" +
+ Joiner.on(", ").join(signature.outputs().keySet()) +
+ "), one must be specified " +
+ "as a third argument to tensorflow()");
+ return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
+ }
+ else {
+ String output = signature.outputs().get(outputName.get());
+ if (output == null) {
+ if (signature.skippedOutputs().containsKey(outputName.get()))
+ throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
+ signature.skippedOutputs().get(outputName.get()));
+ else
+ throw new IllegalArgumentException("Model does not have the specified output '" +
+ outputName.get() + "'");
+ }
+ return output;
+ }
+ }
+
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ store.writeSmallConstant(constantName, constantValue);
+ profile.addConstant(constantName, asValue(constantValue));
+ }
+
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ typeMismatchExplanation(constantValue.type(), macroType));
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
+ }
+ }
+
+ private void transformGeneratedMacro(ModelStore store,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ store.writeMacro(macroName, expression);
+ }
+
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ if (profile.getMacros().containsKey(macroName)) {
+ throw new IllegalArgumentException("Generated macro '" + macroName + "' already exists.");
+ }
+ profile.addMacro(macroName, false); // todo: inline if only used once
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ macro.setRankingExpression(expression);
+ macro.setTextualExpression(expression.getRoot().toString());
+ }
+
+ private String skippedOutputsDescription(ImportedModel.Signature signature) {
+ if (signature.skippedOutputs().isEmpty()) return "";
+ StringBuilder b = new StringBuilder(": ");
+ signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
+ return b.toString();
+ }
+
+ /**
+ * Verify that the macros referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredMacros.
+ */
+ private void verifyRequiredMacros(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ TensorType requiredType = model.requiredMacros().get(macroName);
+ if (requiredType == null) continue; // Not a required macro
+
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType + " but this macro is not present in " +
+ profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
+ TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
+ if ( actualType == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro references a feature which is not declared");
+ if ( ! actualType.isAssignableTo(requiredType))
+ throw new IllegalArgumentException("Model refers input '" + macroName + "'. " +
+ typeMismatchExplanation(requiredType, actualType));
+ }
+ }
+
+ private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
+ return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
+ (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
+ "in query profile types - see the documentation."
+ : "");
+ }
+
+ /**
+ * Add the generated macros to the rank profile
+ */
+ private void addGeneratedMacros(ImportedModel model, RankProfile profile) {
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ }
+
+ /**
+ * Check if batch dimensions of inputs can be reduced out. If the input
+ * macro specifies that a single exemplar should be evaluated, we can
+ * reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, ImportedModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated macros for inputs to reduce
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ if ( ! model.macros().containsKey(macroName)) {
+ continue;
+ }
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null) {
+ throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
+ "but this macro is not present in " + profile);
+ }
+ RankingExpression macroExpression = macro.getRankingExpression();
+ macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
+ }
+
+ // Check expression for inputs to reduce
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, model, typeContext);
+ TensorType typeAfterReducing = root.type(typeContext);
+ root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+ expression.setRoot(root);
+ }
+
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedModel model,
+ TypeContext<Reference> typeContext) {
+ if (node instanceof TensorFunctionNode) {
+ TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+ if (tensorFunction instanceof Rename) {
+ List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+ if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, typeContext);
+ }
+ }
+ }
+ }
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
+ }
+ }
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode)node).children();
+ List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children) {
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+ }
+ return ((CompositeNode)node).setChildren(transformedChildren);
+ }
+ return node;
+ }
+
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ TensorFunction result = function;
+ TensorType type = function.type(context);
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ }
+ }
+ return new TensorFunctionNode(result);
+ }
+
+ /**
+ * If batch dimensions have been reduced away above, bring them back here
+ * for any following computation of the tensor.
+ * Todo: determine when this is not necessary!
+ */
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ if (after.equals(before)) {
+ return node;
+ }
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : before.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+ typeBuilder.indexed(dimension.name(), 1);
+ }
+ }
+ TensorType expandDimensionsType = typeBuilder.build();
+ if (expandDimensionsType.dimensions().size() > 0) {
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+ Generate generatedFunction = new Generate(expandDimensionsType,
+ new GeneratorLambdaFunctionNode(expandDimensionsType,
+ generatedExpression)
+ .asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ return new TensorFunctionNode(expand);
+ }
+ return node;
+ }
+
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, ImportedModel model) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode)node;
+ if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
+ names.add(referenceNode.getName());
+ if (model.macros().containsKey(referenceNode.getName())) {
+ addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ }
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children())
+ addMacroNamesIn(child, names, model);
+ }
+ }
+
+ private Value asValue(Tensor tensor) {
+ if (tensor.type().rank() == 0)
+ return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
+ else
+ return new TensorValue(tensor);
+ }
+
+ /**
+ * Provides read/write access to the correct directories of the application package given by the feature arguments
+ */
+ static class ModelStore {
+
+ private final ApplicationPackage application;
+ private final FeatureArguments arguments;
+
+ ModelStore(ApplicationPackage application, FeatureArguments arguments) {
+ this.application = application;
+ this.arguments = arguments;
+ }
+
+ public FeatureArguments arguments() { return arguments; }
+
+ public boolean hasStoredModel() {
+ try {
+ return application.getFile(arguments.expressionPath()).exists();
+ }
+ catch (UnsupportedOperationException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public File modelDir() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ }
+
+ /**
+ * Adds this expression to the application package, such that it can be read later.
+ */
+ void writeConverted(RankingExpression expression) {
+ application.getFile(arguments.expressionPath())
+ .writeFile(new StringReader(expression.getRoot().toString()));
+ }
+
+ /** Reads the previously stored ranking expression for these arguments */
+ RankingExpression readConverted() {
+ try {
+ return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+
+ /** Adds this macro expression to the application package to it can be read later. */
+ void writeMacro(String name, RankingExpression expression) {
+ application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
+ }
+
+ /** Reads the previously stored macro expressions for these arguments */
+ List<Pair<String, RankingExpression>> readMacros() {
+ try {
+ ApplicationFile file = application.getFile(arguments.macrosPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[1]);
+ macros.add(new Pair<>(name, expression));
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+ return macros;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Reads the information about all the large (aka ranking) constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ List<RankingConstant> readLargeConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Adds this constant to the application package as a file,
+ * such that it can be distributed using file distribution.
+ *
+ * @return the path to the stored constant, relative to the application package root
+ */
+ Path writeLargeConstant(String name, Tensor constant) {
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
+ Path constantPath = constantsPath.append(name + ".tbf");
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ return correct(constantPath);
+ }
+
+ private List<Pair<String, Tensor>> readSmallConstants() {
+ try {
+ ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, Tensor>> constants = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ TensorType type = TensorType.fromSpec(parts[1]);
+ Tensor tensor = Tensor.from(type, parts[2]);
+ constants.add(new Pair<>(name, tensor));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Append this constant to the single file used for small constants distributed as config
+ */
+ public void writeSmallConstant(String name, Tensor constant) {
+ // Secret file format for remembering constants:
+ application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
+ }
+
+ /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
+ private Path correct(Path path) {
+ if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
+ && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
+ return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
+ }
+ else {
+ return path;
+ }
+ }
+
+ private void createIfNeeded(Path path) {
+ File dir = application.getFileReference(path);
+ if ( ! dir.exists()) {
+ if (!dir.mkdirs())
+ throw new IllegalStateException("Could not create " + dir);
+ }
+ }
+
+ }
+
+ /** Encapsulates the arguments to the import feature */
+ static abstract class FeatureArguments {
+
+ Path modelPath;
+
+ /** Optional arguments */
+ Optional<String> signature, output;
+
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> signature() { return signature; }
+ public Optional<String> output() { return output; }
+
+ /** Path to the small constants file */
+ public Path smallConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
+ /** Path to the macros file */
+ public Path macrosPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ }
+
+ public Path expressionPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ private String expressionFileName() {
+ StringBuilder fileName = new StringBuilder();
+ signature.ifPresent(s -> fileName.append(s).append("."));
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+ return fileName.toString();
+ }
+
+ Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
+ }
+}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 1c41ad8284e..44eeb364603 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -2,58 +2,20 @@
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.base.Joiner;
-import com.yahoo.collections.Pair;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.application.provider.FilesApplicationPackage;
-import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter;
-import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.OnnxImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.StringReader;
import java.io.UncheckedIOException;
-import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
import java.util.Map;
import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
/**
* Replaces instances of the onnx(model-path, output)
@@ -63,12 +25,12 @@ import java.util.stream.Collectors;
* @author bratseth
* @author lesters
*/
-public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+public class OnnxFeatureConverter extends MLImportFeatureConverter {
private final OnnxImporter onnxImporter = new OnnxImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, OnnxModel> importedModels = new HashMap<>();
+ private final Map<Path, ImportedModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -84,7 +46,8 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
if ( ! feature.getName().equals("onnx")) return feature;
try {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ FeatureArguments arguments = new OnnxFeatureArguments(feature.getArguments());
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files
return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -98,597 +61,24 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
private ExpressionNode transformFromOnnxModel(ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles) {
- OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
k -> onnxImporter.importModel(store.arguments().modelName(),
- store.onnxModelDir()));
-
- // Add constants
- Set<String> constantsReplacedByMacros = new HashSet<>();
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
- constantsReplacedByMacros, k, v));
-
- // Find the specified expression
- String output = chooseOutput(model, store.arguments().output());
- if (model.skippedOutputs().containsKey(output)) {
- String message = "Could not import Onnx model output '" + output + "'";
- if (!model.skippedOutputs().get(output).isEmpty()) {
- message += ": " + model.skippedOutputs().get(output);
- }
- if (!model.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", model.importWarnings());
- }
- throw new IllegalArgumentException(message);
- }
-
- RankingExpression expression = model.expressions().get(output);
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
- }
-
- private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
- profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
-
- for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
- profile.getSearch().addRankingConstant(constant);
- }
-
- for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
- }
-
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(OnnxModel model, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (model.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model));
- if (model.outputs().size() > 1)
- throw new IllegalArgumentException("Onnx model has multiple outputs (" +
- Joiner.on(", ").join(model.outputs().keySet()) +
- "), one must be specified " +
- "as a second argument to onnx()");
- return model.outputs().get(model.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = model.outputs().get(outputName.get());
- if (output == null) {
- if (model.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- model.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
+ store.modelDir()));
+ return transformFromImportedModel(model, store, profile, queryProfiles);
}
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- store.writeSmallConstant(constantName, constantValue);
- profile.addConstant(constantName, asValue(constantValue));
- }
-
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
- String constantName, Tensor constantValue) {
- RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
- if (macroOverridingConstant != null) {
- TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
- if ( ! macroType.equals(constantValue.type()))
- throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
- "The required type of this is " + constantValue.type() +
- ", but the macro returns " + macroType);
- constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
- }
- else {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
- }
- }
- }
-
- private void transformGeneratedMacro(ModelStore store, RankProfile profile,
- Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
-
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- store.writeMacro(macroName, expression);
- }
-
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
- throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists.");
- }
- profile.addMacro(macroName, false); // todo: inline if only used once
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- macro.setRankingExpression(expression);
- macro.setTextualExpression(expression.getRoot().toString());
- }
-
- private String skippedOutputsDescription(OnnxModel model) {
- if (model.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
- }
-
- /**
- * Verify that the macros referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredMacros.
- */
- private void verifyRequiredMacros(RankingExpression expression, OnnxModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- TensorType requiredType = model.requiredMacros().get(macroName);
- if (requiredType == null) continue; // Not a required macro
-
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null)
- throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
- "' of type " + requiredType + " but this macro is not present in " +
- profile);
- // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
- // phase and summary features), as it may only resolve correctly given those bindings
- // Or, probably better, annotate the macros with type constraints here and verify during general
- // type verification
- TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
- if ( actualType == null)
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
- throw new IllegalArgumentException("Model refers input '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro produces type " + actualType);
- }
- }
-
- /**
- * Add the generated macros to the rank profile
- */
- private void addGeneratedMacros(OnnxModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
- }
-
- /**
- * Check if batch dimensions of inputs can be reduced out. If the input
- * macro specifies that a single exemplar should be evaluated, we can
- * reduce the batch dimension out.
- */
- private void reduceBatchDimensions(RankingExpression expression, OnnxModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated macros for inputs to reduce
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null) {
- throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
- "but this macro is not present in " + profile);
- }
- RankingExpression macroExpression = macro.getRankingExpression();
- macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model,
- TypeContext<Reference> typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- }
- }
- return new TensorFunctionNode(result);
- }
-
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- * Todo: determine when this is not necessary!
- */
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
-
- /**
- * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
- * This method does that for the given expression and returns the result.
- */
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
- Set<String> constantsReplacedByMacros) {
- if (constantsReplacedByMacros.isEmpty()) return expression;
- return new RankingExpression(expression.getName(),
- replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
- }
-
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
- if (node instanceof ReferenceNode) {
- Reference reference = ((ReferenceNode)node).reference();
- if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
- String argument = reference.simpleArgument().get();
- if (constantsReplacedByMacros.contains(argument))
- return new ReferenceNode(argument);
- }
- }
- if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
- CompositeNode composite = (CompositeNode)node;
- return composite.setChildren(composite.children().stream()
- .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
- .collect(Collectors.toList()));
- }
- return node;
- }
-
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) {
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
- names.add(referenceNode.getName());
- if (model.macros().containsKey(referenceNode.getName())) {
- addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
- }
- }
- }
- else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names, model);
- }
- }
-
- private Value asValue(Tensor tensor) {
- if (tensor.type().rank() == 0)
- return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
- else
- return new TensorValue(tensor);
- }
-
- /**
- * Provides read/write access to the correct directories of the application package given by the feature arguments
- */
- private static class ModelStore {
-
- private final ApplicationPackage application;
- private final FeatureArguments arguments;
-
- public ModelStore(ApplicationPackage application, Arguments arguments) {
- this.application = application;
- this.arguments = new FeatureArguments(arguments);
- }
-
- public FeatureArguments arguments() { return arguments; }
-
- public boolean hasStoredModel() {
- try {
- return application.getFile(arguments.expressionPath()).exists();
- }
- catch (UnsupportedOperationException e) {
- return false;
- }
- }
-
- /**
- * Returns the directory which contains the source model to use for these arguments
- */
- public File onnxModelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
- }
-
- /**
- * Adds this expression to the application package, such that it can be read later.
- */
- public void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
- .writeFile(new StringReader(expression.getRoot().toString()));
- }
-
- /** Reads the previously stored ranking expression for these arguments */
- public RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
-
- /** Adds this macro expression to the application package to it can be read later. */
- public void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
- expression.getRoot().toString() + "\n");
- }
-
- /** Reads the previously stored macro expressions for these arguments */
- public List<Pair<String, RankingExpression>> readMacros() {
- try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, RankingExpression>> macros = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[1]);
- macros.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
- return macros;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Reads the information about all the large (aka ranking) constants stored in the application package
- * (the constant value itself is replicated with file distribution).
- */
- public List<RankingConstant> readLargeConstants() {
- try {
- List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
- String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
- constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Adds this constant to the application package as a file,
- * such that it can be distributed using file distribution.
- *
- * @return the path to the stored constant, relative to the application package root
- */
- public Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
-
- // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
- Path constantPath = constantsPath.append(name + ".tbf");
-
- // Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
- .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
-
- // Write content explicitly as a file on the file system as this is distributed using file distribution
- createIfNeeded(constantsPath);
- IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
- return correct(constantPath);
- }
-
- private List<Pair<String, Tensor>> readSmallConstants() {
- try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, Tensor>> constants = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- TensorType type = TensorType.fromSpec(parts[1]);
- Tensor tensor = Tensor.from(type, parts[2]);
- constants.add(new Pair<>(name, tensor));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Append this constant to the single file used for small constants distributed as config
- */
- public void writeSmallConstant(String name, Tensor constant) {
- // Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
- }
-
- /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
- private Path correct(Path path) {
- if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
- && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
- return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
- }
- else {
- return path;
- }
- }
-
- private void createIfNeeded(Path path) {
- File dir = application.getFileReference(path);
- if ( ! dir.exists()) {
- if (!dir.mkdirs())
- throw new IllegalStateException("Could not create " + dir);
- }
- }
-
- }
-
- /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */
- private static class FeatureArguments {
-
- private final Path modelPath;
-
- /** Optional arguments */
- private final Optional<String> output;
-
- public FeatureArguments(Arguments arguments) {
+ static class OnnxFeatureArguments extends FeatureArguments {
+ public OnnxFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
- "the onnx model directory under [application]/models");
+ "the tensorflow model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
modelPath = Path.fromString(asString(arguments.expressions().get(0)));
output = optionalArgument(1, arguments);
+ signature = Optional.of("default");
}
-
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
- public Optional<String> output() { return output; }
-
- /** Path to the small constants file */
- public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
- }
-
- /** Path to the macros file */
- public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
- }
-
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
- }
-
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
- }
-
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
-
- private String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
-
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
-
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
-
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 41da32f64c3..27e1ad51b33 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -1,59 +1,19 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.expressiontransforms;
-import com.google.common.base.Joiner;
-import com.yahoo.collections.Pair;
-import com.yahoo.config.application.api.ApplicationFile;
-import com.yahoo.config.application.api.ApplicationPackage;
-import com.yahoo.config.model.application.provider.FilesApplicationPackage;
-import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
-import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
-import com.yahoo.searchdefinition.RankingConstant;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel;
+import com.yahoo.searchlib.rankingexpression.integration.ml.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.tensor.serialization.TypedBinaryFormat;
-import java.io.BufferedReader;
-import java.io.File;
-import java.io.IOException;
-import java.io.StringReader;
import java.io.UncheckedIOException;
-import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.stream.Collectors;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -62,12 +22,12 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+public class TensorFlowFeatureConverter extends MLImportFeatureConverter {
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<Path, TensorFlowModel> importedModels = new HashMap<>();
+ private final Map<Path, ImportedModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -83,7 +43,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if ( ! feature.getName().equals("tensorflow")) return feature;
try {
- ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ FeatureArguments arguments = new TensorFlowFeatureArguments(feature.getArguments());
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), arguments);
if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
else
@@ -95,565 +56,19 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
private ExpressionNode transformFromTensorFlowModel(ModelStore store,
- RankProfile profile,
- QueryProfileRegistry queryProfiles) {
- TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
- k -> tensorFlowImporter.importModel(store.arguments().modelName(),
- store.tensorFlowModelDir()));
-
- // Add constants
- Set<String> constantsReplacedByMacros = new HashSet<>();
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
- constantsReplacedByMacros, k, v));
-
- // Find the specified expression
- Signature signature = chooseSignature(model, store.arguments().signature());
- String output = chooseOutput(signature, store.arguments().output());
- if (signature.skippedOutputs().containsKey(output)) {
- String message = "Could not import TensorFlow model output '" + output + "'";
- if (!signature.skippedOutputs().get(output).isEmpty()) {
- message += ": " + signature.skippedOutputs().get(output);
- }
- if (!signature.importWarnings().isEmpty()) {
- message += ": " + String.join(", ", signature.importWarnings());
- }
- throw new IllegalArgumentException(message);
- }
-
- RankingExpression expression = model.expressions().get(output);
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model, profile, queryProfiles);
- addGeneratedMacros(model, profile);
- reduceBatchDimensions(expression, model, profile, queryProfiles);
-
- model.macros().forEach((k, v) -> transformGeneratedMacro(store, constantsReplacedByMacros, k, v));
-
- store.writeConverted(expression);
- return expression.getRoot();
- }
-
- private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
- profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
-
- for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
- profile.getSearch().addRankingConstant(constant);
- }
-
- for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
- }
-
- return store.readConverted().getRoot();
- }
-
- /**
- * Returns the specified, existing signature, or the only signature if none is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
- if ( ! signatureName.isPresent()) {
- if (importResult.signatures().size() == 0)
- throw new IllegalArgumentException("No signatures are available");
- if (importResult.signatures().size() > 1)
- throw new IllegalArgumentException("Model has multiple signatures (" +
- Joiner.on(", ").join(importResult.signatures().keySet()) +
- "), one must be specified " +
- "as a second argument to tensorflow()");
- return importResult.signatures().values().stream().findFirst().get();
- }
- else {
- Signature signature = importResult.signatures().get(signatureName.get());
- if (signature == null)
- throw new IllegalArgumentException("Model does not have the specified signature '" +
- signatureName.get() + "'");
- return signature;
- }
- }
-
- /**
- * Returns the specified, existing output expression, or the only output expression if no output name is specified.
- * Throws IllegalArgumentException in all other cases.
- */
- private String chooseOutput(Signature signature, Optional<String> outputName) {
- if ( ! outputName.isPresent()) {
- if (signature.outputs().size() == 0)
- throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
- if (signature.outputs().size() > 1)
- throw new IllegalArgumentException(signature + " has multiple outputs (" +
- Joiner.on(", ").join(signature.outputs().keySet()) +
- "), one must be specified " +
- "as a third argument to tensorflow()");
- return signature.outputs().get(signature.outputs().keySet().stream().findFirst().get());
- }
- else {
- String output = signature.outputs().get(outputName.get());
- if (output == null) {
- if (signature.skippedOutputs().containsKey(outputName.get()))
- throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
- signature.skippedOutputs().get(outputName.get()));
- else
- throw new IllegalArgumentException("Model does not have the specified output '" +
- outputName.get() + "'");
- }
- return output;
- }
- }
-
- private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- store.writeSmallConstant(constantName, constantValue);
- profile.addConstant(constantName, asValue(constantValue));
- }
-
- private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
- Set<String> constantsReplacedByMacros,
- String constantName, Tensor constantValue) {
- RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
- if (macroOverridingConstant != null) {
- TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
- if ( ! macroType.equals(constantValue.type()))
- throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
- typeMismatchExplanation(constantValue.type(), macroType));
- constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
- }
- else {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
- }
- }
- }
-
- private void transformGeneratedMacro(ModelStore store,
- Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
-
- expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- store.writeMacro(macroName, expression);
- }
-
- private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
- if (profile.getMacros().containsKey(macroName)) {
- throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists.");
- }
- profile.addMacro(macroName, false); // todo: inline if only used once
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- macro.setRankingExpression(expression);
- macro.setTextualExpression(expression.getRoot().toString());
- }
-
- private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
- if (signature.skippedOutputs().isEmpty()) return "";
- StringBuilder b = new StringBuilder(": ");
- signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
- return b.toString();
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ ImportedModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ k -> tensorFlowImporter.importModel(store.arguments().modelName(),
+ store.modelDir()));
+ return transformFromImportedModel(model, store, profile, queryProfiles);
}
- /**
- * Verify that the macros referred in the given expression exists in the given rank profile,
- * and return tensors of the types specified in requiredMacros.
- */
- private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- TensorType requiredType = model.requiredMacros().get(macroName);
- if (requiredType == null) continue; // Not a required macro
-
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null)
- throw new IllegalArgumentException("Model refers placeholder '" + macroName +
- "' of type " + requiredType + " but this macro is not present in " +
- profile);
- // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
- // phase and summary features), as it may only resolve correctly given those bindings
- // Or, probably better, annotate the macros with type constraints here and verify during general
- // type verification
- TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
- if ( actualType == null)
- throw new IllegalArgumentException("Model refers placeholder '" + macroName +
- "' of type " + requiredType +
- " which must be produced by a macro in the rank profile, but " +
- "this macro references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
- throw new IllegalArgumentException("Model refers placeholder '" + macroName + "'. " +
- typeMismatchExplanation(requiredType, actualType));
- }
- }
-
- private String typeMismatchExplanation(TensorType requiredType, TensorType actualType) {
- return "The required type of this is " + requiredType + ", but this macro returns " + actualType +
- (actualType.rank() == 0 ? ". This is often due to missing declaration of query tensor features " +
- "in query profile types - see the documentation."
- : "");
- }
-
- /**
- * Add the generated macros to the rank profile
- */
- private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) {
- model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
- }
-
- /**
- * Check if batch dimensions of inputs can be reduced out. If the input
- * macro specifies that a single exemplar should be evaluated, we can
- * reduce the batch dimension out.
- */
- private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated macros for inputs to reduce
- Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames, model);
- for (String macroName : macroNames) {
- if ( ! model.macros().containsKey(macroName)) {
- continue;
- }
- RankProfile.Macro macro = profile.getMacros().get(macroName);
- if (macro == null) {
- throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
- "but this macro is not present in " + profile);
- }
- RankingExpression macroExpression = macro.getRankingExpression();
- macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model,
- TypeContext<Reference> typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.requiredMacros().containsKey(referenceNode.getName())) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- }
- }
- return new TensorFunctionNode(result);
- }
-
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- * Todo: determine when this is not necessary!
- */
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) {
- return node;
- }
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
-
- /**
- * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
- * This method does that for the given expression and returns the result.
- */
- private RankingExpression replaceConstantsByMacros(RankingExpression expression,
- Set<String> constantsReplacedByMacros) {
- if (constantsReplacedByMacros.isEmpty()) return expression;
- return new RankingExpression(expression.getName(),
- replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
- }
-
- private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
- if (node instanceof ReferenceNode) {
- Reference reference = ((ReferenceNode)node).reference();
- if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
- String argument = reference.simpleArgument().get();
- if (constantsReplacedByMacros.contains(argument))
- return new ReferenceNode(argument);
- }
- }
- if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
- CompositeNode composite = (CompositeNode)node;
- return composite.setChildren(composite.children().stream()
- .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
- .collect(Collectors.toList()));
- }
- return node;
- }
-
- private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) {
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
- names.add(referenceNode.getName());
- if (model.macros().containsKey(referenceNode.getName())) {
- addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
- }
- }
- }
- else if (node instanceof CompositeNode) {
- for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names, model);
- }
- }
-
- private Value asValue(Tensor tensor) {
- if (tensor.type().rank() == 0)
- return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
- else
- return new TensorValue(tensor);
- }
-
- /**
- * Provides read/write access to the correct directories of the application package given by the feature arguments
- */
- private static class ModelStore {
-
- private final ApplicationPackage application;
- private final FeatureArguments arguments;
-
- public ModelStore(ApplicationPackage application, Arguments arguments) {
- this.application = application;
- this.arguments = new FeatureArguments(arguments);
- }
-
-
-
- public FeatureArguments arguments() { return arguments; }
-
- public boolean hasStoredModel() {
- try {
- return application.getFile(arguments.expressionPath()).exists();
- }
- catch (UnsupportedOperationException e) {
- return false;
- }
- }
-
- /**
- * Returns the directory which (if hasTensorFlowModels is true)
- * contains the source model to use for these arguments
- */
- public File tensorFlowModelDir() {
- return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
- }
-
- /**
- * Adds this expression to the application package, such that it can be read later.
- */
- public void writeConverted(RankingExpression expression) {
- application.getFile(arguments.expressionPath())
- .writeFile(new StringReader(expression.getRoot().toString()));
- }
-
- /** Reads the previously stored ranking expression for these arguments */
- public RankingExpression readConverted() {
- try {
- return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
- }
- catch (IOException e) {
- throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
-
- /** Adds this macro expression to the application package to it can be read later. */
- public void writeMacro(String name, RankingExpression expression) {
- application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
- expression.getRoot().toString() + "\n");
- }
-
- /** Reads the previously stored macro expressions for these arguments */
- public List<Pair<String, RankingExpression>> readMacros() {
- try {
- ApplicationFile file = application.getFile(arguments.macrosPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, RankingExpression>> macros = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- try {
- RankingExpression expression = new RankingExpression(parts[1]);
- macros.add(new Pair<>(name, expression));
- }
- catch (ParseException e) {
- throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
- }
- }
- return macros;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Reads the information about all the large (aka ranking) constants stored in the application package
- * (the constant value itself is replicated with file distribution).
- */
- public List<RankingConstant> readLargeConstants() {
- try {
- List<RankingConstant> constants = new ArrayList<>();
- for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
- String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
- constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Adds this constant to the application package as a file,
- * such that it can be distributed using file distribution.
- *
- * @return the path to the stored constant, relative to the application package root
- */
- public Path writeLargeConstant(String name, Tensor constant) {
- Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
-
- // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
- Path constantPath = constantsPath.append(name + ".tbf");
-
- // Remember the constant in a file we replicate in ZooKeeper
- application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
- .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
-
- // Write content explicitly as a file on the file system as this is distributed using file distribution
- createIfNeeded(constantsPath);
- IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
- return correct(constantPath);
- }
-
- private List<Pair<String, Tensor>> readSmallConstants() {
- try {
- ApplicationFile file = application.getFile(arguments.smallConstantsPath());
- if (!file.exists()) return Collections.emptyList();
-
- List<Pair<String, Tensor>> constants = new ArrayList<>();
- BufferedReader reader = new BufferedReader(file.createReader());
- String line;
- while (null != (line = reader.readLine())) {
- String[] parts = line.split("\t");
- String name = parts[0];
- TensorType type = TensorType.fromSpec(parts[1]);
- Tensor tensor = Tensor.from(type, parts[2]);
- constants.add(new Pair<>(name, tensor));
- }
- return constants;
- }
- catch (IOException e) {
- throw new UncheckedIOException(e);
- }
- }
-
- /**
- * Append this constant to the single file used for small constants distributed as config
- */
- public void writeSmallConstant(String name, Tensor constant) {
- // Secret file format for remembering constants:
- application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
- constant.type().toString() + "\t" +
- constant.toString() + "\n");
- }
-
- /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
- private Path correct(Path path) {
- if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
- && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
- return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
- }
- else {
- return path;
- }
- }
-
- private void createIfNeeded(Path path) {
- File dir = application.getFileReference(path);
- if ( ! dir.exists()) {
- if (!dir.mkdirs())
- throw new IllegalStateException("Could not create " + dir);
- }
- }
-
- }
-
- /** Encapsulates the 1, 2 or 3 arguments to a tensorflow feature */
- private static class FeatureArguments {
-
- private final Path modelPath;
-
- /** Optional arguments */
- private final Optional<String> signature, output;
-
- public FeatureArguments(Arguments arguments) {
+ static class TensorFlowFeatureArguments extends FeatureArguments {
+ public TensorFlowFeatureArguments(Arguments arguments) {
if (arguments.isEmpty())
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
- "the tensorflow model directory under [application]/models");
+ "the tensorflow model directory under [application]/models");
if (arguments.expressions().size() > 3)
throw new IllegalArgumentException("A tensorflow feature can have at most 3 arguments");
@@ -661,68 +76,6 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
signature = optionalArgument(1, arguments);
output = optionalArgument(2, arguments);
}
-
- /** Returns modelPath with slashes replaced by underscores */
- public String modelName() { return modelPath.toString().replace('/', '_'); }
-
- /** Returns relative path to this model below the "models/" dir in the application package */
- public Path modelPath() { return modelPath; }
- public Optional<String> signature() { return signature; }
- public Optional<String> output() { return output; }
-
- /** Path to the small constants file */
- public Path smallConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
- }
-
- /** Path to the large (ranking) constants directory */
- public Path largeConstantsPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
- }
-
- /** Path to the macros file */
- public Path macrosPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
- }
-
- public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
- .append(modelPath).append("expressions").append(expressionFileName());
- }
-
- private String expressionFileName() {
- StringBuilder fileName = new StringBuilder();
- signature.ifPresent(s -> fileName.append(s).append("."));
- output.ifPresent(s -> fileName.append(s).append("."));
- if (fileName.length() == 0) // single signature and output
- fileName.append("single.");
- fileName.append("expression");
- return fileName.toString();
- }
-
- private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
- if (argumentIndex >= arguments.expressions().size())
- return Optional.empty();
- return Optional.of(asString(arguments.expressions().get(argumentIndex)));
- }
-
- private String asString(ExpressionNode node) {
- if ( ! (node instanceof ConstantNode))
- throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
- return stripQuotes(((ConstantNode)node).sourceString());
- }
-
- private String stripQuotes(String s) {
- if ( ! isQuoteSign(s.codePointAt(0))) return s;
- if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
- throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
- return s.substring(1, s.length()-1);
- }
-
- private boolean isQuoteSign(int c) {
- return c == '\'' || c == '"';
- }
-
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 1c54d12d8b3..d9beab6e2f2 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -37,15 +37,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReference() throws ParseException {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx')");
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
- assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
- }
-
- @Test
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
@@ -122,13 +113,6 @@ public class RankingExpressionWithOnnxTestCase {
}
@Test
- public void testOnnxReferenceSpecifyingOutput() {
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "onnx('mnist_softmax.onnx', 'add')");
- search.assertFirstPhaseExpression(vespaExpression, "my_profile");
- }
-
- @Test
public void testOnnxReferenceMissingMacro() throws ParseException {
try {
RankProfileSearchFixture search = new RankProfileSearchFixture(
@@ -145,7 +129,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -163,8 +147,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " +
- "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])",
+ "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "but this macro returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index d288a396732..7228af2b0de 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -162,7 +162,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -179,7 +179,7 @@ public class RankingExpressionWithTensorFlowTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use tensorflow model from " +
"tensorflow('mnist_softmax/saved'): " +
- "Model refers placeholder 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
"but this macro returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
@@ -305,9 +305,9 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testMacroGeneration() {
- final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')",
@@ -316,15 +316,15 @@ public class RankingExpressionWithTensorFlowTestCase {
"input",
new StoringApplicationPackage(applicationDir));
search.assertFirstPhaseExpression(expression, "my_profile");
- search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden2_add, reduce(constant(mnist_saved_dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden2_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(tf_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(imported_ml_macro_mnist_saved_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), imported_ml_macro_mnist_saved_dnn_hidden1_add, f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
@@ -335,8 +335,8 @@ public class RankingExpressionWithTensorFlowTestCase {
application);
search.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- search.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- search.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ search.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
// At this point the expression is stored - copy application to another location which do not have a models dir
Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
@@ -353,8 +353,8 @@ public class RankingExpressionWithTensorFlowTestCase {
storedApplication);
searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
assertSmallConstant("mnist_saved_dnn_hidden1_mul_x", TensorType.fromSpec("tensor()"), search);
- searchFromStored.assertMacro(macroExpression1, "tf_macro_mnist_saved_dnn_hidden1_add", "my_profile");
- searchFromStored.assertMacro(macroExpression2, "tf_macro_mnist_saved_dnn_hidden2_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression1, "imported_ml_macro_mnist_saved_dnn_hidden1_add", "my_profile");
+ searchFromStored.assertMacro(macroExpression2, "imported_ml_macro_mnist_saved_dnn_hidden2_add", "my_profile");
}
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
@@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
- public static class StoringApplicationPackageFile extends ApplicationFile {
+ static class StoringApplicationPackageFile extends ApplicationFile {
/** The path to the application package root */
private final Path root;