summaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-05-28 11:33:17 +0200
committerLester Solbakken <lesters@oath.com>2018-05-28 11:33:17 +0200
commit88d06ec474f727d41963b6aa65c2382ccc01c3f5 (patch)
treed2a871f2e6870daadf674fca0f350692cbdc42a3 /config-model/src
parent3c1334090cef6fb0891515040ad900702275ccea (diff)
Add ONNX pseudo ranking feature
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java693
-rw-r--r--config-model/src/test/integration/onnx/models/mnist_softmax.onnxbin0 -> 31758 bytes
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java333
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java4
5 files changed, 1029 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
index 67d60b08ab0..6ca16c1559d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
@@ -21,6 +21,7 @@ public class ExpressionTransforms {
private final List<ExpressionTransformer> transforms =
ImmutableList.of(new TensorFlowFeatureConverter(),
+ new OnnxFeatureConverter(),
new ConstantDereferencer(),
new ConstantTensorTransformer(),
new MacroInliner(),
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
new file mode 100644
index 00000000000..6d44b130996
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -0,0 +1,693 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.google.common.base.Joiner;
+import com.yahoo.collections.Pair;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Replaces instances of the onnx(model-path, output)
+ * pseudofeature with the native Vespa ranking expression implementing
+ * the same computation.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+
+ private final OnnxImporter onnxImporter = new OnnxImporter();
+
+ /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
+ private final Map<Path, OnnxModel> importedModels = new HashMap<>();
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof ReferenceNode)
+ return transformFeature((ReferenceNode) node, context);
+ else if (node instanceof CompositeNode)
+ return super.transformChildren((CompositeNode) node, context);
+ else
+ return node;
+ }
+
+ private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
+ if ( ! feature.getName().equals("onnx")) return feature;
+
+ try {
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files
+ return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles());
+ else
+ return transformFromStoredModel(store, context.rankProfile());
+ }
+ catch (IllegalArgumentException | UncheckedIOException e) {
+ throw new IllegalArgumentException("Could not use Onnx model from " + feature, e);
+ }
+ }
+
+ private ExpressionNode transformFromOnnxModel(ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ k -> onnxImporter.importModel(store.arguments().modelName(),
+ store.onnxModelDir()));
+
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
+ // Find the specified expression
+ String output = chooseOutput(model, store.arguments().output());
+ if (model.skippedOutputs().containsKey(output)) {
+ String message = "Could not import Onnx model output '" + output + "'";
+ if (!model.skippedOutputs().get(output).isEmpty()) {
+ message += ": " + model.skippedOutputs().get(output);
+ }
+ if (!model.importWarnings().isEmpty()) {
+ message += ": " + String.join(", ", model.importWarnings());
+ }
+ throw new IllegalArgumentException(message);
+ }
+
+ RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
+ addGeneratedMacros(model, profile);
+ reduceBatchDimensions(expression, model, profile, queryProfiles);
+
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v));
+
+ store.writeConverted(expression);
+ return expression.getRoot();
+ }
+
+ private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (Pair<String, Tensor> constant : store.readSmallConstants())
+ profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+
+ for (RankingConstant constant : store.readLargeConstants()) {
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
+ }
+
+ for (Pair<String, RankingExpression> macro : store.readMacros()) {
+ addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ }
+
+ return store.readConverted().getRoot();
+ }
+
+ /**
+ * Returns the specified, existing output expression, or the only output expression if no output name is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private String chooseOutput(OnnxModel model, Optional<String> outputName) {
+ if ( ! outputName.isPresent()) {
+ if (model.outputs().size() == 0)
+ throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model));
+ if (model.outputs().size() > 1)
+ throw new IllegalArgumentException("Onnx model has multiple outputs (" +
+ Joiner.on(", ").join(model.outputs().keySet()) +
+ "), one must be specified " +
+ "as a second argument to onnx()");
+ return model.outputs().get(model.outputs().keySet().stream().findFirst().get());
+ }
+ else {
+ String output = model.outputs().get(outputName.get());
+ if (output == null) {
+ if (model.skippedOutputs().containsKey(outputName.get()))
+ throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
+ model.skippedOutputs().get(outputName.get()));
+ else
+ throw new IllegalArgumentException("Model does not have the specified output '" +
+ outputName.get() + "'");
+ }
+ return output;
+ }
+ }
+
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ store.writeSmallConstant(constantName, constantValue);
+ profile.addConstant(constantName, asValue(constantValue));
+ }
+
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ "The required type of this is " + constantValue.type() +
+ ", but the macro returns " + macroType);
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
+ }
+ }
+
+ private void transformGeneratedMacro(ModelStore store, RankProfile profile,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ store.writeMacro(macroName, expression);
+ }
+
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ if (profile.getMacros().containsKey(macroName)) {
+ throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists.");
+ }
+ profile.addMacro(macroName, false); // todo: inline if only used once
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ macro.setRankingExpression(expression);
+ macro.setTextualExpression(expression.getRoot().toString());
+ }
+
+ private String skippedOutputsDescription(OnnxModel model) {
+ if (model.skippedOutputs().isEmpty()) return "";
+ StringBuilder b = new StringBuilder(": ");
+ model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
+ return b.toString();
+ }
+
+ /**
+ * Verify that the macros referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredMacros.
+ */
+ private void verifyRequiredMacros(RankingExpression expression, OnnxModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ TensorType requiredType = model.requiredMacros().get(macroName);
+ if (requiredType == null) continue; // Not a required macro
+
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null)
+ throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
+ "' of type " + requiredType + " but this macro is not present in " +
+ profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
+ TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
+ if ( actualType == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro references a feature which is not declared");
+ if ( ! actualType.isAssignableTo(requiredType))
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro produces type " + actualType);
+ }
+ }
+
+ /**
+ * Add the generated macros to the rank profile
+ */
+ private void addGeneratedMacros(OnnxModel model, RankProfile profile) {
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ }
+
+ /**
+ * Check if batch dimensions of inputs can be reduced out. If the input
+ * macro specifies that a single exemplar should be evaluated, we can
+ * reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, OnnxModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated macros for inputs to reduce
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ if ( ! model.macros().containsKey(macroName)) {
+ continue;
+ }
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null) {
+ throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
+ "but this macro is not present in " + profile);
+ }
+ RankingExpression macroExpression = macro.getRankingExpression();
+ macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
+ }
+
+ // Check expression for inputs to reduce
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, model, typeContext);
+ TensorType typeAfterReducing = root.type(typeContext);
+ root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+ expression.setRoot(root);
+ }
+
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model,
+ TypeContext<Reference> typeContext) {
+ if (node instanceof TensorFunctionNode) {
+ TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+ if (tensorFunction instanceof Rename) {
+ List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+ if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, typeContext);
+ }
+ }
+ }
+ }
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
+ }
+ }
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode)node).children();
+ List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children) {
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+ }
+ return ((CompositeNode)node).setChildren(transformedChildren);
+ }
+ return node;
+ }
+
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ TensorFunction result = function;
+ TensorType type = function.type(context);
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ }
+ }
+ return new TensorFunctionNode(result);
+ }
+
+ /**
+ * If batch dimensions have been reduced away above, bring them back here
+ * for any following computation of the tensor.
+ * Todo: determine when this is not necessary!
+ */
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ if (after.equals(before)) {
+ return node;
+ }
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : before.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+ typeBuilder.indexed(dimension.name(), 1);
+ }
+ }
+ TensorType expandDimensionsType = typeBuilder.build();
+ if (expandDimensionsType.dimensions().size() > 0) {
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+ Generate generatedFunction = new Generate(expandDimensionsType,
+ new GeneratorLambdaFunctionNode(expandDimensionsType,
+ generatedExpression)
+ .asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ return new TensorFunctionNode(expand);
+ }
+ return node;
+ }
+
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode)node;
+ if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
+ names.add(referenceNode.getName());
+ if (model.macros().containsKey(referenceNode.getName())) {
+ addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ }
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children())
+ addMacroNamesIn(child, names, model);
+ }
+ }
+
+ private Value asValue(Tensor tensor) {
+ if (tensor.type().rank() == 0)
+ return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
+ else
+ return new TensorValue(tensor);
+ }
+
+ /**
+ * Provides read/write access to the correct directories of the application package given by the feature arguments
+ */
+ private static class ModelStore {
+
+ private final ApplicationPackage application;
+ private final FeatureArguments arguments;
+
+ public ModelStore(ApplicationPackage application, Arguments arguments) {
+ this.application = application;
+ this.arguments = new FeatureArguments(arguments);
+ }
+
+ public FeatureArguments arguments() { return arguments; }
+
+ public boolean hasStoredModel() {
+ try {
+ return application.getFile(arguments.expressionPath()).exists();
+ }
+ catch (UnsupportedOperationException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public File onnxModelDir() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ }
+
+ /**
+ * Adds this expression to the application package, such that it can be read later.
+ */
+ public void writeConverted(RankingExpression expression) {
+ application.getFile(arguments.expressionPath())
+ .writeFile(new StringReader(expression.getRoot().toString()));
+ }
+
+ /** Reads the previously stored ranking expression for these arguments */
+ public RankingExpression readConverted() {
+ try {
+ return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+
+ /** Adds this macro expression to the application package to it can be read later. */
+ public void writeMacro(String name, RankingExpression expression) {
+ application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
+ }
+
+ /** Reads the previously stored macro expressions for these arguments */
+ public List<Pair<String, RankingExpression>> readMacros() {
+ try {
+ ApplicationFile file = application.getFile(arguments.macrosPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[1]);
+ macros.add(new Pair<>(name, expression));
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+ return macros;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Reads the information about all the large (aka ranking) constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ public List<RankingConstant> readLargeConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Adds this constant to the application package as a file,
+ * such that it can be distributed using file distribution.
+ *
+ * @return the path to the stored constant, relative to the application package root
+ */
+ public Path writeLargeConstant(String name, Tensor constant) {
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
+ Path constantPath = constantsPath.append(name + ".tbf");
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ return correct(constantPath);
+ }
+
+ private List<Pair<String, Tensor>> readSmallConstants() {
+ try {
+ ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, Tensor>> constants = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ TensorType type = TensorType.fromSpec(parts[1]);
+ Tensor tensor = Tensor.from(type, parts[2]);
+ constants.add(new Pair<>(name, tensor));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Append this constant to the single file used for small constants distributed as config
+ */
+ public void writeSmallConstant(String name, Tensor constant) {
+ // Secret file format for remembering constants:
+ application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
+ }
+
+ /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
+ private Path correct(Path path) {
+ if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
+ && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
+ return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
+ }
+ else {
+ return path;
+ }
+ }
+
+ private void createIfNeeded(Path path) {
+ File dir = application.getFileReference(path);
+ if ( ! dir.exists()) {
+ if (!dir.mkdirs())
+ throw new IllegalStateException("Could not create " + dir);
+ }
+ }
+
+ }
+
+ /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */
+ private static class FeatureArguments {
+
+ private final Path modelPath;
+
+ /** Optional arguments */
+ private final Optional<String> output;
+
+ public FeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
+ "the onnx model directory under [application]/models");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
+
+ modelPath = Path.fromString(asString(arguments.expressions().get(0)));
+ output = optionalArgument(1, arguments);
+ }
+
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> output() { return output; }
+
+ /** Path to the small constants file */
+ public Path smallConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
+ /** Path to the macros file */
+ public Path macrosPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ }
+
+ public Path expressionPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ private String expressionFileName() {
+ StringBuilder fileName = new StringBuilder();
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+ return fileName.toString();
+ }
+
+ private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ private String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
+ }
+
+}
diff --git a/config-model/src/test/integration/onnx/models/mnist_softmax.onnx b/config-model/src/test/integration/onnx/models/mnist_softmax.onnx
new file mode 100644
index 00000000000..a86019bf53a
--- /dev/null
+++ b/config-model/src/test/integration/onnx/models/mnist_softmax.onnx
Binary files differ
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
new file mode 100644
index 00000000000..5f2f9e9ffaa
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -0,0 +1,333 @@
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.yolean.Exceptions;
+import org.junit.After;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Optional;
+
+import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
+
+import static junit.framework.TestCase.assertTrue;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+public class RankingExpressionWithOnnxTestCase {
+
+ private final Path applicationDir = Path.fromString("src/test/integration/onnx/");
+ private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_onnx_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))";
+
+ @After
+ public void removeGeneratedConstantTensorFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
+ @Test
+ public void testOnnxReference() throws ParseException {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
+ public void testOnnxReferenceWithConstantFeature() {
+ RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
+ "onnx('mnist_softmax.onnx')",
+ "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
+ null);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
+ public void testOnnxReferenceWithQueryFeature() {
+ String queryProfile = "<query-profile id='default' type='root'/>";
+ String queryProfileType = "<query-profile-type id='root'>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" +
+ "</query-profile-type>";
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
+ queryProfile,
+ queryProfileType);
+ RankProfileSearchFixture search = fixtureWith("query(mytensor)",
+ "onnx('mnist_softmax.onnx')",
+ null,
+ null,
+ "Placeholder",
+ application);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
+ public void testOnnxReferenceWithDocumentFeature() {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
+ RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
+ "onnx('mnist_softmax.onnx')",
+ null,
+ "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
+ "Placeholder",
+ application);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+
+ @Test
+ public void testOnnxReferenceWithFeatureCombination() {
+ String queryProfile = "<query-profile id='default' type='root'/>";
+ String queryProfileType = "<query-profile-type id='root'>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" +
+ "</query-profile-type>";
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
+ queryProfile,
+ queryProfileType);
+ RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)",
+ "onnx('mnist_softmax.onnx')",
+ "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
+ "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
+ "Placeholder",
+ application);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+
+ @Test
+ public void testNestedOnnxReference() {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "5 + sum(onnx('mnist_softmax.onnx'))");
+ search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
+ public void testOnnxReferenceSpecifyingOutput() {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'add')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ }
+
+ @Test
+ public void testOnnxReferenceMissingMacro() throws ParseException {
+ try {
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ new StoringApplicationPackage(applicationDir),
+ new QueryProfileRegistry(),
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: onnx('mnist_softmax.onnx')" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
+ "onnx('mnist_softmax.onnx'): " +
+ "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "not present in rank profile 'my_profile'",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+
+ @Test
+ public void testOnnxReferenceWithWrongMacroType() {
+ try {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)",
+ "onnx('mnist_softmax.onnx')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
+ "onnx('mnist_softmax.onnx'): " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " +
+ "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testOnnxReferenceSpecifyingNonExistingOutput() {
+ try {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'y')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
+ "onnx('mnist_softmax.onnx','y'): " +
+ "Model does not have the specified output 'y'",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testImportingFromStoredExpressions() throws IOException {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx')",
+ null,
+ null,
+ "Placeholder",
+ storedApplication);
+ searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ // Verify that the constants exists, but don't verify the content as we are not
+ // simulating file distribution in this test
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", searchFromStored, Optional.empty());
+ assertLargeConstant("mnist_softmax_onnx_Variable", searchFromStored, Optional.empty());
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
+ @Test
+ public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException {
+ String rankProfile =
+ " rank-profile my_profile {\n" +
+ " macro Placeholder() {\n" +
+ " expression: tensor(d0[2],d1[784])(0.0)\n" +
+ " }\n" +
+ " macro mnist_softmax_onnx_Variable() {\n" +
+ " expression: tensor(d1[10],d2[784])(0.0)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: onnx('mnist_softmax.onnx')" +
+ " }\n" +
+ " }";
+
+
+ String vespaExpressionWithoutConstant =
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_onnx_Variable, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))";
+ RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
+ search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+
+ assertNull("Constant overridden by macro is not added",
+ search.search().getRankingConstants().get("mnist_softmax_onnx_Variable"));
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication);
+ searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+ assertNull("Constant overridden by macro is not added",
+ searchFromStored.search().getRankingConstants().get("mnist_softmax_onnx_Variable"));
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", searchFromStored, Optional.of(10L));
+ } finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
+ /**
+ * Verifies that the constant with the given name exists, and - only if an expected size is given -
+ * that the content of the constant is available and has the expected size.
+ */
+ private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
+ try {
+ Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax.onnx/constants").append(name + ".tbf");
+ RankingConstant rankingConstant = search.search().getRankingConstants().get(name);
+ assertEquals(name, rankingConstant.getName());
+ assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
+
+ if (expectedSize.isPresent()) {
+ Path constantPath = applicationDir.append(constantApplicationPackagePath);
+ assertTrue("Constant file '" + constantPath + "' has been written",
+ constantPath.toFile().exists());
+ Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
+ assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
+ }
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
+ new StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression,
+ String constant, String field) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder",
+ new StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(application, application.getQueryProfiles(),
+ rankProfile, null, null);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ private RankProfileSearchFixture fixtureWith(String macroExpression,
+ String firstPhaseExpression,
+ String constant,
+ String field,
+ String macroName,
+ StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(
+ application,
+ application.getQueryProfiles(),
+ " rank-profile my_profile {\n" +
+ " macro " + macroName + "() {\n" +
+ " expression: " + macroExpression +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: " + firstPhaseExpression +
+ " }\n" +
+ " }",
+ constant,
+ field);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 55754605843..d288a396732 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -439,7 +439,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
- private static class StoringApplicationPackage extends MockApplicationPackage {
+ static class StoringApplicationPackage extends MockApplicationPackage {
private final File root;
@@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
- private static class StoringApplicationPackageFile extends ApplicationFile {
+ public static class StoringApplicationPackageFile extends ApplicationFile {
/** The path to the application package root */
private final Path root;