summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-05-28 11:33:17 +0200
committerLester Solbakken <lesters@oath.com>2018-05-28 11:33:17 +0200
commit88d06ec474f727d41963b6aa65c2382ccc01c3f5 (patch)
treed2a871f2e6870daadf674fca0f350692cbdc42a3 /searchlib
parent3c1334090cef6fb0891515040ad900702275ccea (diff)
Add ONNX pseudo ranking feature
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java127
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java57
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java20
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java18
5 files changed, 177 insertions, 52 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
index 047d1b187f5..295f9228316 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
@@ -13,8 +13,10 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.yolean.Exceptions;
import onnx.Onnx;
+import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.util.Collection;
@@ -22,6 +24,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Logger;
import java.util.stream.Collectors;
/**
@@ -31,48 +34,64 @@ import java.util.stream.Collectors;
*/
public class OnnxImporter {
- public OnnxModel importModel(String modelPath, String outputNode) {
+ private static final Logger log = Logger.getLogger(OnnxImporter.class.getName());
+
+ public OnnxModel importModel(String modelName, File modelDir) {
+ return importModel(modelName, modelDir.toString());
+ }
+
+ public OnnxModel importModel(String modelName, String modelPath) {
try (FileInputStream inputStream = new FileInputStream(modelPath)) {
Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
- return importModel(model, outputNode);
+ return importModel(modelName, model);
} catch (IOException e) {
throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
}
}
- public OnnxModel importModel(Onnx.ModelProto model, String outputNode) {
- return importGraph(model.getGraph(), outputNode);
+ public OnnxModel importModel(String modelName, Onnx.ModelProto model) {
+ return importGraph(modelName, model.getGraph());
}
- private static OnnxModel importGraph(Onnx.GraphProto graph, String outputNode) {
- OnnxModel model = new OnnxModel(outputNode);
+ private static OnnxModel importGraph(String modelName, Onnx.GraphProto graph) {
+ OnnxModel model = new OnnxModel(modelName);
OperationIndex index = new OperationIndex();
- OnnxOperation output = importNode(outputNode, graph, index);
- output.type().orElseThrow(() -> new IllegalArgumentException("Output of '" + outputNode + "' has no type."))
- .verifyType(getOutputNode(outputNode, graph).getType());
+ importNodes(graph, model, index);
+ verifyOutputTypes(graph, model, index);
+ findDimensionNames(model, index);
+ importExpressions(model, index);
- findDimensionNames(output);
- importExpressions(output, model);
+ reportWarnings(model, index);
return model;
}
- private static OnnxOperation importNode(String nodeName, Onnx.GraphProto graph, OperationIndex index) {
- if (index.alreadyImported(nodeName)) {
- return index.get(nodeName);
+ private static void importNodes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
+ importNode(valueInfo.getName(), graph, model, index);
+ }
+ }
+
+ private static OnnxOperation importNode(String name, Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ if (index.alreadyImported(name)) {
+ return index.get(name);
}
OnnxOperation operation;
- if (isArgumentTensor(nodeName, graph)) {
- operation = new Argument(getArgumentTensor(nodeName, graph));
- } else if (isConstantTensor(nodeName, graph)) {
- operation = new Constant(getConstantTensor(nodeName, graph));
+ if (isArgumentTensor(name, graph)) {
+ operation = new Argument(getArgumentTensor(name, graph));
+ model.input(OnnxOperation.namePartOf(name), operation.vespaName());
+ } else if (isConstantTensor(name, graph)) {
+ operation = new Constant(model.name(), getConstantTensor(name, graph));
} else {
- Onnx.NodeProto node = getNodeFromGraph(nodeName, graph);
- List<OnnxOperation> inputs = importNodeInputs(node, graph, index);
+ Onnx.NodeProto node = getNodeFromGraph(name, graph);
+ List<OnnxOperation> inputs = importNodeInputs(node, graph, model, index);
operation = OperationMapper.get(node, inputs);
+ if (isOutputNode(name, graph)) {
+ model.output(OnnxOperation.namePartOf(name), operation.vespaName());
+ }
}
- index.put(nodeName, operation);
+ index.put(operation.vespaName(), operation);
return operation;
}
@@ -113,8 +132,11 @@ public class OnnxImporter {
private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
- Onnx.NodeProto node = getNodeFromGraph(valueInfo.getName(), graph);
- if (node.getName().equals(name)) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ String nodeName = OnnxOperation.namePartOf(valueInfo.getName());
+ if (nodeName.equals(name)) {
return valueInfo;
}
}
@@ -123,18 +145,34 @@ public class OnnxImporter {
private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node,
Onnx.GraphProto graph,
+ OnnxModel model,
OperationIndex index) {
return node.getInputList().stream()
- .map(nodeName -> importNode(nodeName, graph, index))
+ .map(nodeName -> importNode(nodeName, graph, model, index))
.collect(Collectors.toList());
}
+ private static void verifyOutputTypes(Onnx.GraphProto graph, OnnxModel model, OperationIndex index) {
+ for (String outputName : model.outputs().values()) {
+ OnnxOperation operation = index.get(outputName);
+ Onnx.ValueInfoProto onnxNode = getOutputNode(outputName, graph);
+ operation.type().orElseThrow(
+ () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."))
+ .verifyType(onnxNode.getType());
+ }
+ }
+
+
/** Find dimension names to avoid excessive renaming while evaluating the model. */
- private static void findDimensionNames(OnnxOperation output) {
+ private static void findDimensionNames(OnnxModel model, OperationIndex index) {
DimensionRenamer renamer = new DimensionRenamer();
- addDimensionNameConstraints(output, renamer);
+ for (String output : model.outputs().values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
renamer.solve();
- renameDimensions(output, renamer);
+ for (String output : model.outputs().values()) {
+ renameDimensions(index.get(output), renamer);
+ }
}
private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) {
@@ -151,10 +189,17 @@ public class OnnxImporter {
}
}
- private static void importExpressions(OnnxOperation output, OnnxModel model) {
- Optional<TensorFunction> function = importExpression(output, model);
- if (!function.isPresent()) {
- throw new IllegalArgumentException("No valid output function could be found.");
+ private static void importExpressions(OnnxModel model, OperationIndex index) {
+ for (String outputName : model.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(index.get(outputName), model);
+ if (!function.isPresent()) {
+ model.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ model.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
}
}
@@ -167,7 +212,7 @@ public class OnnxImporter {
}
importInputExpressions(operation, model);
importRankingExpression(operation, model);
- importInputExpression(operation, model);
+ importArgumentExpression(operation, model);
return operation.function();
}
@@ -204,7 +249,7 @@ public class OnnxImporter {
if (!model.expressions().containsKey(name)) {
TensorFunction function = operation.function().get();
- if (name.equals(model.output())) {
+ if (model.outputs().containsKey(name)) {
OrderedTensorType operationType = operation.type().get();
OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
if ( ! operationType.equals(standardNamingType)) {
@@ -228,7 +273,7 @@ public class OnnxImporter {
}
}
- private static void importInputExpression(OnnxOperation operation, OnnxModel model) {
+ private static void importArgumentExpression(OnnxOperation operation, OnnxModel model) {
if (operation.isInput()) {
// All inputs must have dimensions with standard naming convention: d0, d1, ...
OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
@@ -237,6 +282,20 @@ public class OnnxImporter {
}
}
+ private static void reportWarnings(OnnxModel model, OperationIndex index) {
+ for (String output : model.outputs().values()) {
+ reportWarnings(model, index.get(output));
+ }
+ }
+
+ private static void reportWarnings(OnnxModel model, OnnxOperation operation) {
+ for (String warning : operation.warnings()) {
+ model.importWarning(warning);
+ }
+ for (OnnxOperation input : operation.inputs()) {
+ reportWarnings(model, input);
+ }
+ }
private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
boolean hasPortNumber = nodeName.contains(":");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
index df108fcbbe7..027c1d7ff9d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
@@ -14,29 +14,73 @@ import java.util.regex.Pattern;
/**
* The result of importing an ONNX model into Vespa.
*
+ * @author bratseth
* @author lesters
*/
public class OnnxModel {
- public OnnxModel(String outputNode) {
- this.output = outputNode;
+ private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
+
+ private final String name;
+
+ public OnnxModel(String name) {
+ if ( ! nameRegexp.matcher(name).matches())
+ throw new IllegalArgumentException("A TensorFlow model name can only contain [A-Za-z0-9_], but is '" +
+ name + "'");
+ this.name = name;
}
- private final String output;
+ /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
+ public String name() { return name; }
+
+ private final Map<String, String> inputs = new HashMap<>();
+ private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> skippedOutputs = new HashMap<>();
+ private final List<String> importWarnings = new ArrayList<>();
+
private final Map<String, TensorType> arguments = new HashMap<>();
private final Map<String, Tensor> smallConstants = new HashMap<>();
private final Map<String, Tensor> largeConstants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, RankingExpression> macros = new HashMap<>();
private final Map<String, TensorType> requiredMacros = new HashMap<>();
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, String expressionName) { outputs.put(name, expressionName); }
+ void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void macro(String name, RankingExpression expression) { macros.put(name, expression); }
void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
- /** Return the name of the output node for this model */
- public String output() { return output; }
+ /**
+ * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
+ * to argument (Placeholder) name in the owner of this
+ */
+ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /** Returns arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ public TensorType inputArgument(String inputName) { return arguments().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expression names of this */
+ public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+
+ /**
+ * Returns an immutable list of the outputs of this which could not be imported,
+ * with a string detailing the reason for each
+ */
+ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+
+ /**
+ * Returns an immutable list of possibly non-fatal warnings encountered during import.
+ */
+ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
+
+ /** Returns expressions().get(outputs.get(outputName)), e.g the expression this output references */
+ public RankingExpression outputExpression(String outputName) { return expressions().get(outputs.get(outputName)); }
/** Returns an immutable map of the arguments (inputs) of this */
public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
@@ -57,6 +101,9 @@ public class OnnxModel {
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+ /** Returns an immutable map of macros that are part of this model */
+ public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); }
+
/** Returns an immutable map of the macros that must be provided by the environment running this model */
public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
index ab650bf8d77..b5494477227 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
@@ -15,18 +15,19 @@ import java.util.Optional;
public class Constant extends OnnxOperation {
+ final String modelName;
final Onnx.TensorProto tensorProto;
- public Constant(Onnx.TensorProto tensorProto) {
+ public Constant(String modelName, Onnx.TensorProto tensorProto) {
super(null, Collections.emptyList());
+ this.modelName = modelName;
this.tensorProto = tensorProto;
}
/** todo: Constant names are prefixed by "modelName_" to avoid name conflicts between models */
@Override
public String vespaName() {
-// return modelName() + "_" + super.vespaName();
- return vespaName(tensorProto.getName());
+ return modelName + "_" + vespaName(tensorProto.getName());
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
index 2c8003f5951..3c9f01c5e1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
@@ -92,7 +92,7 @@ public abstract class OnnxOperation {
/** Retrieve the valid Vespa name of this node */
public String vespaName() { return vespaName(node.getName()); }
- public String vespaName(String name) { return name != null ? name.replace('/', '_').replace(':','_') : null; }
+ public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
/** Retrieve the list of warnings produced during its lifetime */
public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
@@ -116,4 +116,22 @@ public abstract class OnnxOperation {
return verifyInputs(expected, OnnxOperation::function);
}
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ public static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output index part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ public static int indexPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
index e118c2b885a..4b68cd40a08 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -24,18 +24,18 @@ public class OnnxMnistSoftmaxImportTestCase {
@Test
public void testMnistSoftmaxImport() throws IOException {
- OnnxModel model = new OnnxImporter().importModel("src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx", "add");
+ OnnxModel model = new OnnxImporter().importModel("test", "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx");
// Check constants
assertEquals(2, model.largeConstants().size());
- Tensor constant0 = model.largeConstants().get("Variable_0");
+ Tensor constant0 = model.largeConstants().get("test_Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.largeConstants().get("Variable_1_0");
+ Tensor constant1 = model.largeConstants().get("test_Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
@@ -43,15 +43,15 @@ public class OnnxMnistSoftmaxImportTestCase {
// Check required macros (inputs)
assertEquals(1, model.requiredMacros().size());
- assertTrue(model.requiredMacros().containsKey("Placeholder_0"));
+ assertTrue(model.requiredMacros().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.requiredMacros().get("Placeholder_0"));
+ model.requiredMacros().get("Placeholder"));
// Check outputs
- RankingExpression output = model.expressions().get("add");
+ RankingExpression output = model.outputExpression("add");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(reduce(join(rename(Placeholder_0, (d0, d1), (d0, d2)), constant(Variable_0), f(a,b)(a * b)), sum, d2), constant(Variable_1_0), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
output.getRoot().toString());
}
@@ -62,7 +62,7 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor argument = placeholderArgument();
Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add");
- Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder_0", "add");
+ Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add");
assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult);
}
@@ -74,7 +74,7 @@ public class OnnxMnistSoftmaxImportTestCase {
}
private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- OnnxModel model = new OnnxImporter().importModel(path, output);
+ OnnxModel model = new OnnxImporter().importModel("test", path);
return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
}