summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-05-28 11:33:17 +0200
committerLester Solbakken <lesters@oath.com>2018-05-28 11:33:17 +0200
commit88d06ec474f727d41963b6aa65c2382ccc01c3f5 (patch)
treed2a871f2e6870daadf674fca0f350692cbdc42a3
parent3c1334090cef6fb0891515040ad900702275ccea (diff)
Add ONNX pseudo ranking feature
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java693
-rw-r--r--config-model/src/test/integration/onnx/models/mnist_softmax.onnxbin0 -> 31758 bytes
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java333
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java127
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java57
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java20
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java18
10 files changed, 1206 insertions, 54 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
index 67d60b08ab0..6ca16c1559d 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java
@@ -21,6 +21,7 @@ public class ExpressionTransforms {
private final List<ExpressionTransformer> transforms =
ImmutableList.of(new TensorFlowFeatureConverter(),
+ new OnnxFeatureConverter(),
new ConstantDereferencer(),
new ConstantTensorTransformer(),
new MacroInliner(),
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
new file mode 100644
index 00000000000..6d44b130996
--- /dev/null
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -0,0 +1,693 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.expressiontransforms;
+
+import com.google.common.base.Joiner;
+import com.yahoo.collections.Pair;
+import com.yahoo.config.application.api.ApplicationFile;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.FilesApplicationPackage;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxImporter;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.OnnxModel;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.Arguments;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+
+import java.io.BufferedReader;
+import java.io.File;
+import java.io.IOException;
+import java.io.StringReader;
+import java.io.UncheckedIOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Replaces instances of the onnx(model-path, output)
+ * pseudofeature with the native Vespa ranking expression implementing
+ * the same computation.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
+
+ private final OnnxImporter onnxImporter = new OnnxImporter();
+
+ /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
+ private final Map<Path, OnnxModel> importedModels = new HashMap<>();
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
+ if (node instanceof ReferenceNode)
+ return transformFeature((ReferenceNode) node, context);
+ else if (node instanceof CompositeNode)
+ return super.transformChildren((CompositeNode) node, context);
+ else
+ return node;
+ }
+
+ private ExpressionNode transformFeature(ReferenceNode feature, RankProfileTransformContext context) {
+ if ( ! feature.getName().equals("onnx")) return feature;
+
+ try {
+ ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments());
+ if ( ! store.hasStoredModel()) // not converted yet - access Onnx model files
+ return transformFromOnnxModel(store, context.rankProfile(), context.queryProfiles());
+ else
+ return transformFromStoredModel(store, context.rankProfile());
+ }
+ catch (IllegalArgumentException | UncheckedIOException e) {
+ throw new IllegalArgumentException("Could not use Onnx model from " + feature, e);
+ }
+ }
+
+ private ExpressionNode transformFromOnnxModel(ModelStore store,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles) {
+ OnnxModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
+ k -> onnxImporter.importModel(store.arguments().modelName(),
+ store.onnxModelDir()));
+
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
+ // Find the specified expression
+ String output = chooseOutput(model, store.arguments().output());
+ if (model.skippedOutputs().containsKey(output)) {
+ String message = "Could not import Onnx model output '" + output + "'";
+ if (!model.skippedOutputs().get(output).isEmpty()) {
+ message += ": " + model.skippedOutputs().get(output);
+ }
+ if (!model.importWarnings().isEmpty()) {
+ message += ": " + String.join(", ", model.importWarnings());
+ }
+ throw new IllegalArgumentException(message);
+ }
+
+ RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
+ addGeneratedMacros(model, profile);
+ reduceBatchDimensions(expression, model, profile, queryProfiles);
+
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v));
+
+ store.writeConverted(expression);
+ return expression.getRoot();
+ }
+
+ private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (Pair<String, Tensor> constant : store.readSmallConstants())
+ profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+
+ for (RankingConstant constant : store.readLargeConstants()) {
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
+ }
+
+ for (Pair<String, RankingExpression> macro : store.readMacros()) {
+ addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ }
+
+ return store.readConverted().getRoot();
+ }
+
+ /**
+ * Returns the specified, existing output expression, or the only output expression if no output name is specified.
+ * Throws IllegalArgumentException in all other cases.
+ */
+ private String chooseOutput(OnnxModel model, Optional<String> outputName) {
+ if ( ! outputName.isPresent()) {
+ if (model.outputs().size() == 0)
+ throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(model));
+ if (model.outputs().size() > 1)
+ throw new IllegalArgumentException("Onnx model has multiple outputs (" +
+ Joiner.on(", ").join(model.outputs().keySet()) +
+ "), one must be specified " +
+ "as a second argument to onnx()");
+ return model.outputs().get(model.outputs().keySet().stream().findFirst().get());
+ }
+ else {
+ String output = model.outputs().get(outputName.get());
+ if (output == null) {
+ if (model.skippedOutputs().containsKey(outputName.get()))
+ throw new IllegalArgumentException("Could not use output '" + outputName.get() + "': " +
+ model.skippedOutputs().get(outputName.get()));
+ else
+ throw new IllegalArgumentException("Model does not have the specified output '" +
+ outputName.get() + "'");
+ }
+ return output;
+ }
+ }
+
+ private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ store.writeSmallConstant(constantName, constantValue);
+ profile.addConstant(constantName, asValue(constantValue));
+ }
+
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ "The required type of this is " + constantValue.type() +
+ ", but the macro returns " + macroType);
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
+ }
+ }
+
+ private void transformGeneratedMacro(ModelStore store, RankProfile profile,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
+ store.writeMacro(macroName, expression);
+ }
+
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ if (profile.getMacros().containsKey(macroName)) {
+ throw new IllegalArgumentException("Generated Onnx macro '" + macroName + "' already exists.");
+ }
+ profile.addMacro(macroName, false); // todo: inline if only used once
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ macro.setRankingExpression(expression);
+ macro.setTextualExpression(expression.getRoot().toString());
+ }
+
+ private String skippedOutputsDescription(OnnxModel model) {
+ if (model.skippedOutputs().isEmpty()) return "";
+ StringBuilder b = new StringBuilder(": ");
+ model.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));
+ return b.toString();
+ }
+
+ /**
+ * Verify that the macros referred in the given expression exists in the given rank profile,
+ * and return tensors of the types specified in requiredMacros.
+ */
+ private void verifyRequiredMacros(RankingExpression expression, OnnxModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ TensorType requiredType = model.requiredMacros().get(macroName);
+ if (requiredType == null) continue; // Not a required macro
+
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null)
+ throw new IllegalArgumentException("Model refers Placeholder '" + macroName +
+ "' of type " + requiredType + " but this macro is not present in " +
+ profile);
+ // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
+ // phase and summary features), as it may only resolve correctly given those bindings
+ // Or, probably better, annotate the macros with type constraints here and verify during general
+ // type verification
+ TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles));
+ if ( actualType == null)
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro references a feature which is not declared");
+ if ( ! actualType.isAssignableTo(requiredType))
+ throw new IllegalArgumentException("Model refers input '" + macroName +
+ "' of type " + requiredType +
+ " which must be produced by a macro in the rank profile, but " +
+ "this macro produces type " + actualType);
+ }
+ }
+
+ /**
+ * Add the generated macros to the rank profile
+ */
+ private void addGeneratedMacros(OnnxModel model, RankProfile profile) {
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ }
+
+ /**
+ * Check if batch dimensions of inputs can be reduced out. If the input
+ * macro specifies that a single exemplar should be evaluated, we can
+ * reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, OnnxModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated macros for inputs to reduce
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ if ( ! model.macros().containsKey(macroName)) {
+ continue;
+ }
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null) {
+ throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
+ "but this macro is not present in " + profile);
+ }
+ RankingExpression macroExpression = macro.getRankingExpression();
+ macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
+ }
+
+ // Check expression for inputs to reduce
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, model, typeContext);
+ TensorType typeAfterReducing = root.type(typeContext);
+ root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+ expression.setRoot(root);
+ }
+
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, OnnxModel model,
+ TypeContext<Reference> typeContext) {
+ if (node instanceof TensorFunctionNode) {
+ TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+ if (tensorFunction instanceof Rename) {
+ List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+ if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, typeContext);
+ }
+ }
+ }
+ }
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
+ }
+ }
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode)node).children();
+ List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children) {
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+ }
+ return ((CompositeNode)node).setChildren(transformedChildren);
+ }
+ return node;
+ }
+
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ TensorFunction result = function;
+ TensorType type = function.type(context);
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ }
+ }
+ return new TensorFunctionNode(result);
+ }
+
+ /**
+ * If batch dimensions have been reduced away above, bring them back here
+ * for any following computation of the tensor.
+ * Todo: determine when this is not necessary!
+ */
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ if (after.equals(before)) {
+ return node;
+ }
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : before.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+ typeBuilder.indexed(dimension.name(), 1);
+ }
+ }
+ TensorType expandDimensionsType = typeBuilder.build();
+ if (expandDimensionsType.dimensions().size() > 0) {
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+ Generate generatedFunction = new Generate(expandDimensionsType,
+ new GeneratorLambdaFunctionNode(expandDimensionsType,
+ generatedExpression)
+ .asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ return new TensorFunctionNode(expand);
+ }
+ return node;
+ }
+
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, OnnxModel model) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode)node;
+ if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
+ names.add(referenceNode.getName());
+ if (model.macros().containsKey(referenceNode.getName())) {
+ addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ }
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children())
+ addMacroNamesIn(child, names, model);
+ }
+ }
+
+ private Value asValue(Tensor tensor) {
+ if (tensor.type().rank() == 0)
+ return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors
+ else
+ return new TensorValue(tensor);
+ }
+
+ /**
+ * Provides read/write access to the correct directories of the application package given by the feature arguments
+ */
+ private static class ModelStore {
+
+ private final ApplicationPackage application;
+ private final FeatureArguments arguments;
+
+ public ModelStore(ApplicationPackage application, Arguments arguments) {
+ this.application = application;
+ this.arguments = new FeatureArguments(arguments);
+ }
+
+ public FeatureArguments arguments() { return arguments; }
+
+ public boolean hasStoredModel() {
+ try {
+ return application.getFile(arguments.expressionPath()).exists();
+ }
+ catch (UnsupportedOperationException e) {
+ return false;
+ }
+ }
+
+ /**
+ * Returns the directory which contains the source model to use for these arguments
+ */
+ public File onnxModelDir() {
+ return application.getFileReference(ApplicationPackage.MODELS_DIR.append(arguments.modelPath()));
+ }
+
+ /**
+ * Adds this expression to the application package, such that it can be read later.
+ */
+ public void writeConverted(RankingExpression expression) {
+ application.getFile(arguments.expressionPath())
+ .writeFile(new StringReader(expression.getRoot().toString()));
+ }
+
+ /** Reads the previously stored ranking expression for these arguments */
+ public RankingExpression readConverted() {
+ try {
+ return new RankingExpression(application.getFile(arguments.expressionPath()).createReader());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException("Could not read " + arguments.expressionPath(), e);
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+
+ /** Adds this macro expression to the application package to it can be read later. */
+ public void writeMacro(String name, RankingExpression expression) {
+ application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
+ }
+
+ /** Reads the previously stored macro expressions for these arguments */
+ public List<Pair<String, RankingExpression>> readMacros() {
+ try {
+ ApplicationFile file = application.getFile(arguments.macrosPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[1]);
+ macros.add(new Pair<>(name, expression));
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+ return macros;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Reads the information about all the large (aka ranking) constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ public List<RankingConstant> readLargeConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Adds this constant to the application package as a file,
+ * such that it can be distributed using file distribution.
+ *
+ * @return the path to the stored constant, relative to the application package root
+ */
+ public Path writeLargeConstant(String name, Tensor constant) {
+ Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
+
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
+ Path constantPath = constantsPath.append(name + ".tbf");
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.largeConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath)));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
+ IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
+ return correct(constantPath);
+ }
+
+ private List<Pair<String, Tensor>> readSmallConstants() {
+ try {
+ ApplicationFile file = application.getFile(arguments.smallConstantsPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, Tensor>> constants = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ TensorType type = TensorType.fromSpec(parts[1]);
+ Tensor tensor = Tensor.from(type, parts[2]);
+ constants.add(new Pair<>(name, tensor));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
+ * Append this constant to the single file used for small constants distributed as config
+ */
+ public void writeSmallConstant(String name, Tensor constant) {
+ // Secret file format for remembering constants:
+ application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" +
+ constant.type().toString() + "\t" +
+ constant.toString() + "\n");
+ }
+
+ /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
+ private Path correct(Path path) {
+ if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed)
+ && ! path.elements().contains(FilesApplicationPackage.preprocessed)) {
+ return Path.fromString(FilesApplicationPackage.preprocessed).append(path);
+ }
+ else {
+ return path;
+ }
+ }
+
+ private void createIfNeeded(Path path) {
+ File dir = application.getFileReference(path);
+ if ( ! dir.exists()) {
+ if (!dir.mkdirs())
+ throw new IllegalStateException("Could not create " + dir);
+ }
+ }
+
+ }
+
+ /** Encapsulates the 1, 2 or 3 arguments to a onnx feature */
+ private static class FeatureArguments {
+
+ private final Path modelPath;
+
+ /** Optional arguments */
+ private final Optional<String> output;
+
+ public FeatureArguments(Arguments arguments) {
+ if (arguments.isEmpty())
+ throw new IllegalArgumentException("An onnx node must take an argument pointing to " +
+ "the onnx model directory under [application]/models");
+ if (arguments.expressions().size() > 3)
+ throw new IllegalArgumentException("An onnx feature can have at most 2 arguments");
+
+ modelPath = Path.fromString(asString(arguments.expressions().get(0)));
+ output = optionalArgument(1, arguments);
+ }
+
+ /** Returns modelPath with slashes replaced by underscores */
+ public String modelName() { return modelPath.toString().replace('/', '_').replace('.', '_'); }
+
+ /** Returns relative path to this model below the "models/" dir in the application package */
+ public Path modelPath() { return modelPath; }
+ public Optional<String> output() { return output; }
+
+ /** Path to the small constants file */
+ public Path smallConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt");
+ }
+
+ /** Path to the large (ranking) constants directory */
+ public Path largeConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
+ /** Path to the macros file */
+ public Path macrosPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("macros.txt");
+ }
+
+ public Path expressionPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ .append(modelPath).append("expressions").append(expressionFileName());
+ }
+
+ private String expressionFileName() {
+ StringBuilder fileName = new StringBuilder();
+ output.ifPresent(s -> fileName.append(s).append("."));
+ if (fileName.length() == 0) // single signature and output
+ fileName.append("single.");
+ fileName.append("expression");
+ return fileName.toString();
+ }
+
+ private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
+ if (argumentIndex >= arguments.expressions().size())
+ return Optional.empty();
+ return Optional.of(asString(arguments.expressions().get(argumentIndex)));
+ }
+
+ private String asString(ExpressionNode node) {
+ if ( ! (node instanceof ConstantNode))
+ throw new IllegalArgumentException("Expected a constant string as onnx argument, but got '" + node);
+ return stripQuotes(((ConstantNode)node).sourceString());
+ }
+
+ private String stripQuotes(String s) {
+ if ( ! isQuoteSign(s.codePointAt(0))) return s;
+ if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
+ throw new IllegalArgumentException("onnx argument [" + s + "] is missing endquote");
+ return s.substring(1, s.length()-1);
+ }
+
+ private boolean isQuoteSign(int c) {
+ return c == '\'' || c == '"';
+ }
+
+ }
+
+}
diff --git a/config-model/src/test/integration/onnx/models/mnist_softmax.onnx b/config-model/src/test/integration/onnx/models/mnist_softmax.onnx
new file mode 100644
index 00000000000..a86019bf53a
--- /dev/null
+++ b/config-model/src/test/integration/onnx/models/mnist_softmax.onnx
Binary files differ
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
new file mode 100644
index 00000000000..5f2f9e9ffaa
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -0,0 +1,333 @@
+package com.yahoo.searchdefinition.processing;
+
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.yolean.Exceptions;
+import org.junit.After;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.io.UncheckedIOException;
+import java.util.Optional;
+
+import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
+
+import static junit.framework.TestCase.assertTrue;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+public class RankingExpressionWithOnnxTestCase {
+
+ private final Path applicationDir = Path.fromString("src/test/integration/onnx/");
+ private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_onnx_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))";
+
+ @After
+ public void removeGeneratedConstantTensorFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
+ @Test
+ public void testOnnxReference() throws ParseException {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
+ public void testOnnxReferenceWithConstantFeature() {
+ RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
+ "onnx('mnist_softmax.onnx')",
+ "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
+ null);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
+ public void testOnnxReferenceWithQueryFeature() {
+ String queryProfile = "<query-profile id='default' type='root'/>";
+ String queryProfileType = "<query-profile-type id='root'>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" +
+ "</query-profile-type>";
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
+ queryProfile,
+ queryProfileType);
+ RankProfileSearchFixture search = fixtureWith("query(mytensor)",
+ "onnx('mnist_softmax.onnx')",
+ null,
+ null,
+ "Placeholder",
+ application);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
+ public void testOnnxReferenceWithDocumentFeature() {
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
+ RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
+ "onnx('mnist_softmax.onnx')",
+ null,
+ "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
+ "Placeholder",
+ application);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+
+ @Test
+ public void testOnnxReferenceWithFeatureCombination() {
+ String queryProfile = "<query-profile id='default' type='root'/>";
+ String queryProfileType = "<query-profile-type id='root'>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" +
+ "</query-profile-type>";
+ StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
+ queryProfile,
+ queryProfileType);
+ RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)",
+ "onnx('mnist_softmax.onnx')",
+ "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
+ "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
+ "Placeholder",
+ application);
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+
+ @Test
+ public void testNestedOnnxReference() {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "5 + sum(onnx('mnist_softmax.onnx'))");
+ search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+ }
+
+ @Test
+ public void testOnnxReferenceSpecifyingOutput() {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'add')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ }
+
+ @Test
+ public void testOnnxReferenceMissingMacro() throws ParseException {
+ try {
+ RankProfileSearchFixture search = new RankProfileSearchFixture(
+ new StoringApplicationPackage(applicationDir),
+ new QueryProfileRegistry(),
+ " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: onnx('mnist_softmax.onnx')" +
+ " }\n" +
+ " }");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
+ "onnx('mnist_softmax.onnx'): " +
+ "Model refers Placeholder 'Placeholder' of type tensor(d0[],d1[784]) but this macro is " +
+ "not present in rank profile 'my_profile'",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+
+ @Test
+ public void testOnnxReferenceWithWrongMacroType() {
+ try {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)",
+ "onnx('mnist_softmax.onnx')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
+ "onnx('mnist_softmax.onnx'): " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) which must be produced " +
+ "by a macro in the rank profile, but this macro produces type tensor(d0[2],d5[10])",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testOnnxReferenceSpecifyingNonExistingOutput() {
+ try {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx', 'y')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ fail("Expecting exception");
+ }
+ catch (IllegalArgumentException expected) {
+ assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
+ "onnx('mnist_softmax.onnx','y'): " +
+ "Model does not have the specified output 'y'",
+ Exceptions.toMessageString(expected));
+ }
+ }
+
+ @Test
+ public void testImportingFromStoredExpressions() throws IOException {
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx')");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
+
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+ assertLargeConstant("mnist_softmax_onnx_Variable", search, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ "onnx('mnist_softmax.onnx')",
+ null,
+ null,
+ "Placeholder",
+ storedApplication);
+ searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ // Verify that the constants exists, but don't verify the content as we are not
+ // simulating file distribution in this test
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", searchFromStored, Optional.empty());
+ assertLargeConstant("mnist_softmax_onnx_Variable", searchFromStored, Optional.empty());
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
+ @Test
+ public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException {
+ String rankProfile =
+ " rank-profile my_profile {\n" +
+ " macro Placeholder() {\n" +
+ " expression: tensor(d0[2],d1[784])(0.0)\n" +
+ " }\n" +
+ " macro mnist_softmax_onnx_Variable() {\n" +
+ " expression: tensor(d1[10],d2[784])(0.0)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: onnx('mnist_softmax.onnx')" +
+ " }\n" +
+ " }";
+
+
+ String vespaExpressionWithoutConstant =
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), mnist_softmax_onnx_Variable, f(a,b)(a * b)), sum, d2), constant(mnist_softmax_onnx_Variable_1), f(a,b)(a + b))";
+ RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir));
+ search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+
+ assertNull("Constant overridden by macro is not added",
+ search.search().getRankingConstants().get("mnist_softmax_onnx_Variable"));
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", search, Optional.of(10L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
+ try {
+ storedApplicationDirectory.toFile().mkdirs();
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
+ RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication);
+ searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile");
+ assertNull("Constant overridden by macro is not added",
+ searchFromStored.search().getRankingConstants().get("mnist_softmax_onnx_Variable"));
+ assertLargeConstant("mnist_softmax_onnx_Variable_1", searchFromStored, Optional.of(10L));
+ } finally {
+ IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
+ }
+ }
+
+ /**
+ * Verifies that the constant with the given name exists, and - only if an expected size is given -
+ * that the content of the constant is available and has the expected size.
+ */
+ private void assertLargeConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) {
+ try {
+ Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax.onnx/constants").append(name + ".tbf");
+ RankingConstant rankingConstant = search.search().getRankingConstants().get(name);
+ assertEquals(name, rankingConstant.getName());
+ assertTrue(rankingConstant.getFileName().endsWith(constantApplicationPackagePath.toString()));
+
+ if (expectedSize.isPresent()) {
+ Path constantPath = applicationDir.append(constantApplicationPackagePath);
+ assertTrue("Constant file '" + constantPath + "' has been written",
+ constantPath.toFile().exists());
+ Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
+ assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
+ }
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, null, null, "Placeholder",
+ new StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression,
+ String constant, String field) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder",
+ new StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(application, application.getQueryProfiles(),
+ rankProfile, null, null);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ private RankProfileSearchFixture fixtureWith(String macroExpression,
+ String firstPhaseExpression,
+ String constant,
+ String field,
+ String macroName,
+ StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(
+ application,
+ application.getQueryProfiles(),
+ " rank-profile my_profile {\n" +
+ " macro " + macroName + "() {\n" +
+ " expression: " + macroExpression +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: " + firstPhaseExpression +
+ " }\n" +
+ " }",
+ constant,
+ field);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 55754605843..d288a396732 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -439,7 +439,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
- private static class StoringApplicationPackage extends MockApplicationPackage {
+ static class StoringApplicationPackage extends MockApplicationPackage {
private final File root;
@@ -465,7 +465,7 @@ public class RankingExpressionWithTensorFlowTestCase {
}
- private static class StoringApplicationPackageFile extends ApplicationFile {
+ public static class StoringApplicationPackageFile extends ApplicationFile {
/** The path to the application package root */
private final Path root;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
index 047d1b187f5..295f9228316 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
@@ -13,8 +13,10 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
import onnx.Onnx;
+import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Collection;
@@ -22,6 +24,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Logger;
import java.util.stream.Collectors;
/**
@@ -31,48 +34,64 @@ import java.util.stream.Collectors;
*/
public class OnnxImporter {
- public OnnxModel importModel(String modelPath, String outputNode) {
+ private static final Logger log = Logger.getLogger(OnnxImporter.class.getName());
+
+ public OnnxModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
+ }
+
+ public OnnxModel importModel(String modelName, String modelPath) {
try (FileInputStream inputStream = new FileInputStream(modelPath)) {
Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
- return importModel(model, outputNode);
+ return importModel(modelName, model);
} catch (IOException e) {
throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
}
}
- public OnnxModel importModel(Onnx.ModelProto model, String outputNode) {
- return importGraph(model.getGraph(), outputNode);
+ public OnnxModel importModel(String modelName, Onnx.ModelProto model) {
+ return importGraph(modelName, model.getGraph());
}
- private static OnnxModel importGraph(Onnx.GraphProto graph, String outputNode) {
- OnnxModel model = new OnnxModel(outputNode);
+ private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) {
+ OnnxModel model = new OnnxModel(modelName);
OperationIndex index = new OperationIndex();
- OnnxOperation output = importNode(outputNode, graph, index);
- output.type().orElseThrow(() -> new IllegalArgumentException("Output of '" + outputNode + "' has no type."))
- .verifyType(getOutputNode(outputNode, graph).getType());
+ importNodes(graph, model, index);
+ verifyOutputTypes(graph, model, index);
+ findDimensionNames(model, index);
+ importExpressions(model, index);
- findDimensionNames(output);
- importExpressions(output, model);
+ reportWarnings(model, index);
return model;
}
- private static OnnxOperation importNode(String nodeName, Onnx.GraphProto graph, OperationIndex index) {
- if (index.alreadyImported(nodeName)) {
- return index.get(nodeName);
+ private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
+ importNode(valueInfo.getName(), graph, model, index);
+ }
+ }
+
+ private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ if (index.alreadyImported(name)) {
+ return index.get(name);
}
OnnxOperation operation;
- if (isArgumentTensor(nodeName, graph)) {
- operation = new Argument(getArgumentTensor(nodeName, graph));
- } else if (isConstantTensor(nodeName, graph)) {
- operation = new Constant(getConstantTensor(nodeName, graph));
+ if (isArgumentTensor(name, graph)) {
+ operation = new Argument(getArgumentTensor(name, graph));
+ model.input(OnnxOperation.namePartOf(name), operation.vespaName());
+ } else if (isConstantTensor(name, graph)) {
+ operation = new Constant(model.name(), getConstantTensor(name, graph));
} else {
- Onnx.NodeProto node = getNodeFromGraph(nodeName, graph);
- List<OnnxOperation> inputs = importNodeInputs(node, graph, index);
+ Onnx.NodeProto node = getNodeFromGraph(name, graph);
+ List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index);
operation = OperationMapper.get(node, inputs);
+ if (isOutputNode(name, graph)) {
+ model.output(OnnxOperation.namePartOf(name), operation.vespaName());
+ }
}
- index.put(nodeName, operation);
+ index.put(operation.vespaName(), operation);
return operation;
}
@@ -113,8 +132,11 @@ public class OnnxImporter {
private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- Onnx.NodeProto node = getNodeFromGraph(valueInfo.getName(), graph);
- if (node.getName().equals(name)) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ String nodeName = OnnxOperation.namePartOf(valueInfo.getName());
+ if (nodeName.equals(name)) {
return valueInfo;
}
}
@@ -123,18 +145,34 @@ public class OnnxImporter {
private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node,
Onnx.GraphProto graph,
+ OnnxModel model,
OperationIndex index) {
return node.getInputList().stream()
- .map(nodeName -> importNode(nodeName, graph, index))
+ .map(nodeName -> importNode(nodeName, graph, model, index))
.collect(Collectors.toList());
}
+ private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ for (String outputName : model.outputs().values()) {
+ OnnxOperation operation = index.get(outputName);
+ Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph);
+ operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."))
+ .verifyType(onnxNode.getType());
+ }
+ }
+
+
/** Find dimension names to avoid excessive renaming while evaluating the model. */
- private static void findDimensionNames(OnnxOperation output) {
+ private static void findDimensionNames(OnnxModel model, OperationIndex index) {
DimensionRenamer renamer = new DimensionRenamer();
- addDimensionNameConstraints(output, renamer);
+ for (String output : model.outputs().values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
renamer.solve();
- renameDimensions(output, renamer);
+ for (String output : model.outputs().values()) {
+ renameDimensions(index.get(output), renamer);
+ }
}
private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) {
@@ -151,10 +189,17 @@ public class OnnxImporter {
}
}
- private static void importExpressions(OnnxOperation output, OnnxModel model) {
- Optional<TensorFunction> function = importExpression(output, model);
- if (!function.isPresent()) {
- throw new IllegalArgumentException("No valid output function could be found.");
+ private static void importExpressions(OnnxModel model, OperationIndex index) {
+ for (String outputName : model.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(index.get(outputName), model);
+ if (!function.isPresent()) {
+ model.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ model.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
}
}
@@ -167,7 +212,7 @@ public class OnnxImporter {
}
importInputExpressions(operation, model);
importRankingExpression(operation, model);
- importInputExpression(operation, model);
+ importArgumentExpression(operation, model);
return operation.function();
}
@@ -204,7 +249,7 @@ public class OnnxImporter {
if (!model.expressions().containsKey(name)) {
TensorFunction function = operation.function().get();
- if (name.equals(model.output())) {
+ if (model.outputs().containsKey(name)) {
OrderedTensorType operationType = operation.type().get();
OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
if ( ! operationType.equals(standardNamingType)) {
@@ -228,7 +273,7 @@ public class OnnxImporter {
}
}
- private static void importInputExpression(OnnxOperation operation, OnnxModel model) {
+ private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) {
if (operation.isInput()) {
// All inputs must have dimensions with standard naming convention: d0, d1, ...
OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
@@ -237,6 +282,20 @@ public class OnnxImporter {
}
}
+ private static void reportWarnings(OnnxModel model, OperationIndex index) {
+ for (String output : model.outputs().values()) {
+ reportWarnings(model, index.get(output));
+ }
+ }
+
+ private static void reportWarnings(OnnxModel model, OnnxOperation operation) {
+ for (String warning : operation.warnings()) {
+ model.importWarning(warning);
+ }
+ for (OnnxOperation input : operation.inputs()) {
+ reportWarnings(model, input);
+ }
+ }
private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
boolean hasPortNumber = nodeName.contains(":");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
index df108fcbbe7..027c1d7ff9d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
@@ -14,29 +14,73 @@ import java.util.regex.Pattern;
/**
* The result of importing an ONNX model into Vespa.
*
+ * @author bratseth
* @author lesters
*/
public class OnnxModel {
- public OnnxModel(String outputNode) {
- this.output = outputNode;
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+
+ private final String name;
+
+ public OnnxModel(String name) {
+ if ( ! nameRegexp.matcher(name).matches())
+ throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
+ this.name = name;
}
- private final String output;
+ /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
+ public String name() { return name; }
+
+ private final Map<String, String> inputs = new HashMap<>();
+ private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> skippedOutputs = new HashMap<>();
+ private final List<String> importWarnings = new ArrayList<>();
+
private final Map<String, TensorType> arguments = new HashMap<>();
private final Map<String, Tensor> smallConstants = new HashMap<>();
private final Map<String, Tensor> largeConstants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, RankingExpression> macros = new HashMap<>();
private final Map<String, TensorType> requiredMacros = new HashMap<>();
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void macro(String name, RankingExpression expression) { macros.put(name, expression); }
void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
- /** Return the name of the output node for this model */
- public String output() { return output; }
+ /**
+ * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
+ * to argument (Placeholder) name in the owner of this
+ */
+ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expression names of this */
+ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+
+ /**
+ * Returns an immutable list of the outputs of this which could not be imported,
+ * with a string detailing the reason for each
+ */
+ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+
+ /**
+ * Returns an immutable list of possibly non-fatal warnings encountered during import.
+ */
+ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
+
+ /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); }
/** Returns an immutable map of the arguments (inputs) of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
@@ -57,6 +101,9 @@ public class OnnxModel {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+ /** Returns an immutable map of macros that are part of this model */
+ public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
+
/** Returns an immutable map of the macros that must be provided by the environment running this model */
public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
index ab650bf8d77..b5494477227 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
@@ -15,18 +15,19 @@ import java.util.Optional;
public class Constant extends OnnxOperation {
+ final String modelName;
final Onnx.TensorProto tensorProto;
- public Constant(Onnx.TensorProto tensorProto) {
+ public Constant(String modelName, Onnx.TensorProto tensorProto) {
super(null, Collections.emptyList());
+ this.modelName = modelName;
this.tensorProto = tensorProto;
}
/** todo: Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
-// return modelName() + "_" + super.vespaName();
- return vespaName(tensorProto.getName());
+ return modelName + "_" + vespaName(tensorProto.getName());
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
index 2c8003f5951..3c9f01c5e1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
@@ -92,7 +92,7 @@ public abstract class OnnxOperation {
/** Retrieve the valid Vespa name of this node */
public String vespaName() { return vespaName(node.getName()); }
- public String vespaName(String name) { return name != null ? name.replace('/', '_').replace(':','_') : null; }
+ public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
/** Retrieve the list of warnings produced during its lifetime */
public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
@@ -116,4 +116,22 @@ public abstract class OnnxOperation {
return verifyInputs(expected, OnnxOperation::function);
}
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ public static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output index part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ public static int indexPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
index e118c2b885a..4b68cd40a08 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -24,18 +24,18 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() throws IOException {
- OnnxModel model = new OnnxImporter().importModel("src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx", "add");
+ OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
- Tensor constant0 = model.largeConstants().get("Variable_0");
+ Tensor constant0 = model.largeConstants().get("test_Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.largeConstants().get("Variable_1_0");
+ Tensor constant1 = model.largeConstants().get("test_Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
@@ -43,15 +43,15 @@ public class OnnxMnistSoftmaxImportTestCase {
// Check required macros (inputs)
assertEquals(1, model.requiredMacros().size());
- assertTrue(model.requiredMacros().containsKey("Placeholder_0"));
+ assertTrue(model.requiredMacros().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.requiredMacros().get("Placeholder_0"));
+ model.requiredMacros().get("Placeholder"));
// Check outputs
- RankingExpression output = model.expressions().get("add");
+ RankingExpression output = model.outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(reduce(join(rename(Placeholder_0, (d0, d1), (d0, d2)), constant(Variable_0), f(a,b)(a * b)), sum, d2), constant(Variable_1_0), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
output.getRoot().toString());
}
@@ -62,7 +62,7 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor argument = placeholderArgument();
Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add");
- Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder_0", "add");
+ Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add");
assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult);
}
@@ -74,7 +74,7 @@ public class OnnxMnistSoftmaxImportTestCase {
}
private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- OnnxModel model = new OnnxImporter().importModel(path, output);
+ OnnxModel model = new OnnxImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}