aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-05-16 10:17:02 +0200
committerGitHub <noreply@github.com>2018-05-16 10:17:02 +0200
commit71332378a97fddfb28fe5dcdec07a0ee27e08c33 (patch)
treec19d9dcbdb1e3cf0842d69210ec4f3ba5479357b
parent8026fff81e084e02e6dcee34d663348ddabbcc7e (diff)
parenta66747f46c01d576436edb45a85f30b2b9cf7e28 (diff)
Merge pull request #5882 from vespa-engine/lesters/initial-onnx-import
Initial import of Onnx models
-rw-r--r--searchlib/pom.xml20
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java265
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java63
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java210
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java24
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java250
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java79
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java64
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java59
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java122
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java75
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java119
-rw-r--r--searchlib/src/main/protobuf/onnx.proto464
-rw-r--r--searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnxbin0 -> 31758 bytes
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java109
16 files changed, 1955 insertions, 0 deletions
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index 1615f248910..fb9a81b51a5 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -117,6 +117,26 @@
</execution>
</executions>
</plugin>
+ <plugin>
+ <groupId>com.github.os72</groupId>
+ <artifactId>protoc-jar-maven-plugin</artifactId>
+ <version>3.5.1.1</version>
+ <executions>
+ <execution>
+ <phase>generate-sources</phase>
+ <goals>
+ <goal>run</goal>
+ </goals>
+ <configuration>
+ <addSources>main</addSources>
+ <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory>
+ <inputDirectories>
+ <include>src/main/protobuf</include>
+ </inputDirectories>
+ </configuration>
+ </execution>
+ </executions>
+ </plugin>
</plugins>
</build>
</project>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
new file mode 100644
index 00000000000..047d1b187f5
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java
@@ -0,0 +1,265 @@
+package com.yahoo.searchlib.rankingexpression.integration.onnx;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.io.FileInputStream;
+import java.io.IOException;
+import java.util.Collection;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * Converts a ONNX model into a ranking expression and set of constants.
+ *
+ * @author lesters
+ */
+public class OnnxImporter {
+
+ public OnnxModel importModel(String modelPath, String outputNode) {
+ try (FileInputStream inputStream = new FileInputStream(modelPath)) {
+ Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream);
+ return importModel(model, outputNode);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e);
+ }
+ }
+
+ public OnnxModel importModel(Onnx.ModelProto model, String outputNode) {
+ return importGraph(model.getGraph(), outputNode);
+ }
+
+ private static OnnxModel importGraph(Onnx.GraphProto graph, String outputNode) {
+ OnnxModel model = new OnnxModel(outputNode);
+ OperationIndex index = new OperationIndex();
+
+ OnnxOperation output = importNode(outputNode, graph, index);
+ output.type().orElseThrow(() -> new IllegalArgumentException("Output of '" + outputNode + "' has no type."))
+ .verifyType(getOutputNode(outputNode, graph).getType());
+
+ findDimensionNames(output);
+ importExpressions(output, model);
+
+ return model;
+ }
+
+ private static OnnxOperation importNode(String nodeName, Onnx.GraphProto graph, OperationIndex index) {
+ if (index.alreadyImported(nodeName)) {
+ return index.get(nodeName);
+ }
+ OnnxOperation operation;
+ if (isArgumentTensor(nodeName, graph)) {
+ operation = new Argument(getArgumentTensor(nodeName, graph));
+ } else if (isConstantTensor(nodeName, graph)) {
+ operation = new Constant(getConstantTensor(nodeName, graph));
+ } else {
+ Onnx.NodeProto node = getNodeFromGraph(nodeName, graph);
+ List<OnnxOperation> inputs = importNodeInputs(node, graph, index);
+ operation = OperationMapper.get(node, inputs);
+ }
+ index.put(nodeName, operation);
+
+ return operation;
+ }
+
+ private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor == null;
+ }
+
+ private static boolean isConstantTensor(String name, Onnx.GraphProto graph) {
+ Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
+ Onnx.TensorProto tensor = getConstantTensor(name, graph);
+ return value != null && tensor != null;
+ }
+
+ private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) {
+ if (valueInfo.getName().equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) {
+ for (Onnx.TensorProto tensorProto : graph.getInitializerList()) {
+ if (tensorProto.getName().equals(name)) {
+ return tensorProto;
+ }
+ }
+ return null;
+ }
+
+ private static boolean isOutputNode(String name, Onnx.GraphProto graph) {
+ return getOutputNode(name, graph) != null;
+ }
+
+ private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) {
+ for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) {
+ Onnx.NodeProto node = getNodeFromGraph(valueInfo.getName(), graph);
+ if (node.getName().equals(name)) {
+ return valueInfo;
+ }
+ }
+ return null;
+ }
+
+ private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node,
+ Onnx.GraphProto graph,
+ OperationIndex index) {
+ return node.getInputList().stream()
+ .map(nodeName -> importNode(nodeName, graph, index))
+ .collect(Collectors.toList());
+ }
+
+ /** Find dimension names to avoid excessive renaming while evaluating the model. */
+ private static void findDimensionNames(OnnxOperation output) {
+ DimensionRenamer renamer = new DimensionRenamer();
+ addDimensionNameConstraints(output, renamer);
+ renamer.solve();
+ renameDimensions(output, renamer);
+ }
+
+ private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
+ }
+
+ private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
+
+ private static void importExpressions(OnnxOperation output, OnnxModel model) {
+ Optional<TensorFunction> function = importExpression(output, model);
+ if (!function.isPresent()) {
+ throw new IllegalArgumentException("No valid output function could be found.");
+ }
+ }
+
+ private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(operation, model);
+ }
+ importInputExpressions(operation, model);
+ importRankingExpression(operation, model);
+ importInputExpression(operation, model);
+
+ return operation.function();
+ }
+
+ private static void importInputExpressions(OnnxOperation operation, OnnxModel model) {
+ operation.inputs().forEach(input -> importExpression(input, model));
+ }
+
+ private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
+ }
+
+ Value value = operation.getConstantValue().orElseThrow(() ->
+ new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
+ "is constant but does not have a value."));
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+
+ Tensor tensor = value.asTensor();
+ if (tensor.type().rank() == 0) {
+ model.smallConstant(name, tensor);
+ } else {
+ model.largeConstant(name, tensor);
+ }
+ return operation.function();
+ }
+
+ private static void importRankingExpression(OnnxOperation operation, OnnxModel model) {
+ if (operation.function().isPresent()) {
+ String name = operation.vespaName();
+ if (!model.expressions().containsKey(name)) {
+ TensorFunction function = operation.function().get();
+
+ if (name.equals(model.output())) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
+
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced from the output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
+ }
+ }
+
+ private static void importInputExpression(OnnxOperation operation, OnnxModel model) {
+ if (operation.isInput()) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
+ model.argument(operation.vespaName(), standardNamingConvention.type());
+ model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
+ }
+ }
+
+
+ private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) {
+ boolean hasPortNumber = nodeName.contains(":");
+ for (Onnx.NodeProto node : graph.getNodeList()) {
+ if (hasPortNumber) {
+ for (String outputName : node.getOutputList()) {
+ if (outputName.equals(nodeName)) {
+ return node;
+ }
+ }
+ } else if (node.getName().equals(nodeName)) {
+ return node;
+ }
+ }
+ throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph");
+ }
+
+ private static class OperationIndex {
+ private final Map<String, OnnxOperation> index = new HashMap<>();
+ public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); }
+ public OnnxOperation get(String key) { return index.get(key); }
+ public boolean alreadyImported(String key) { return index.containsKey(key); }
+ public Collection<OnnxOperation> operations() { return index.values(); }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
new file mode 100644
index 00000000000..df108fcbbe7
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java
@@ -0,0 +1,63 @@
+package com.yahoo.searchlib.rankingexpression.integration.onnx;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.regex.Pattern;
+
+/**
+ * The result of importing an ONNX model into Vespa.
+ *
+ * @author lesters
+ */
+public class OnnxModel {
+
+ public OnnxModel(String outputNode) {
+ this.output = outputNode;
+ }
+
+ private final String output;
+ private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> smallConstants = new HashMap<>();
+ private final Map<String, Tensor> largeConstants = new HashMap<>();
+ private final Map<String, RankingExpression> expressions = new HashMap<>();
+ private final Map<String, TensorType> requiredMacros = new HashMap<>();
+
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
+ void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
+ void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
+ void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); }
+
+ /** Return the name of the output node for this model */
+ public String output() { return output; }
+
+ /** Returns an immutable map of the arguments (inputs) of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+
+ /**
+ * Returns an immutable map of the small constants of this.
+ */
+ public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
+
+ /**
+ * Returns an immutable map of the large constants of this.
+ */
+ public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
+
+ /**
+ * Returns an immutable map of the expressions of this - corresponding to ONNX nodes
+ * which are not inputs or constants.
+ */
+ public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+
+ /** Returns an immutable map of the macros that must be provided by the environment running this model */
+ public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java
new file mode 100644
index 00000000000..2524417cee0
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java
@@ -0,0 +1,210 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * A constraint satisfier to find suitable dimension names to reduce the
+ * amount of necessary renaming during evaluation of an imported model.
+ *
+ * @author lesters
+ */
+public class DimensionRenamer {
+
+ private final String dimensionPrefix;
+ private final Map<String, List<Integer>> variables = new HashMap<>();
+ private final Map<Arc, Constraint> constraints = new HashMap<>();
+ private final Map<String, Integer> renames = new HashMap<>();
+
+ private int iterations = 0;
+
+ public DimensionRenamer() {
+ this("d");
+ }
+
+ public DimensionRenamer(String dimensionPrefix) {
+ this.dimensionPrefix = dimensionPrefix;
+ }
+
+ /**
+ * Add a dimension name variable.
+ */
+ public void addDimension(String name) {
+ variables.computeIfAbsent(name, d -> new ArrayList<>());
+ }
+
+ /**
+ * Add a constraint between dimension names.
+ */
+ public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) {
+ Arc arc = new Arc(from, to, operation);
+ Arc opposite = arc.opposite();
+ constraints.put(arc, pred);
+ constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
+ }
+
+ /**
+ * Retrieve resulting name of dimension after solving for constraints.
+ */
+ public Optional<String> dimensionNameOf(String name) {
+ if (!renames.containsKey(name)) {
+ return Optional.empty();
+ }
+ return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
+ }
+
+ /**
+ * Perform iterative arc consistency until we have found a solution. After
+ * an initial iteration, the variables (dimensions) will have multiple
+ * valid values. Find a single valid assignment by iteratively locking one
+ * dimension after another, and running the arc consistency algorithm
+ * multiple times.
+ *
+ * This requires having constraints that result in an absolute ordering:
+ * equals, lesserThan and greaterThan do that, but adding notEquals does
+ * not typically result in a guaranteed ordering. If that is needed, the
+ * algorithm below needs to be adapted with a backtracking (tree) search
+ * to find solutions.
+ */
+ public void solve(int maxIterations) {
+ initialize();
+
+ // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
+
+ for (String dimension : variables.keySet()) {
+ List<Integer> values = variables.get(dimension);
+ if (values.size() > 1) {
+ if (!ac3()) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
+ }
+ values.sort(Integer::compare);
+ variables.put(dimension, Collections.singletonList(values.get(0)));
+ }
+ renames.put(dimension, variables.get(dimension).get(0));
+ if (iterations > maxIterations) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
+ maxIterations + " iterations");
+ }
+ }
+
+ // Todo: handle failure more gracefully:
+ // If a solution can't be found, look at the operation node in the arc
+ // with the most remaining constraints, and inject a rename operation.
+ // Then run this algorithm again.
+ }
+
+ public void solve() {
+ solve(100000);
+ }
+
+ private void initialize() {
+ for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
+ List<Integer> values = variable.getValue();
+ for (int i = 0; i < variables.size(); ++i) {
+ values.add(i); // invariant: values are in increasing order
+ }
+ }
+ }
+
+ private boolean ac3() {
+ Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
+ while (!workList.isEmpty()) {
+ Arc arc = workList.pop();
+ iterations += 1;
+ if (revise(arc)) {
+ if (variables.get(arc.from).size() == 0) {
+ return false; // no solution found
+ }
+ for (Arc constraint : constraints.keySet()) {
+ if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
+ workList.add(constraint);
+ }
+ }
+ }
+ }
+ return true;
+ }
+
+ private boolean revise(Arc arc) {
+ boolean revised = false;
+ for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
+ Integer from = fromIterator.next();
+ boolean satisfied = false;
+ for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
+ Integer to = toIterator.next();
+ if (constraints.get(arc).test(from, to)) {
+ satisfied = true;
+ }
+ }
+ if (!satisfied) {
+ fromIterator.remove();
+ revised = true;
+ }
+ }
+ return revised;
+ }
+
+ public interface Constraint {
+ boolean test(Integer x, Integer y);
+ }
+
+ public static boolean equals(Integer x, Integer y) {
+ return Objects.equals(x, y);
+ }
+
+ public static boolean lesserThan(Integer x, Integer y) {
+ return x < y;
+ }
+
+ public static boolean greaterThan(Integer x, Integer y) {
+ return x > y;
+ }
+
+ private static class Arc {
+
+ private final String from;
+ private final String to;
+ private final OnnxOperation operation;
+
+ Arc(String from, String to, OnnxOperation operation) {
+ this.from = from;
+ this.to = to;
+ this.operation = operation;
+ }
+
+ Arc opposite() {
+ return new Arc(to, from, operation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(from, to);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof Arc)) {
+ return false;
+ }
+ Arc other = (Arc) obj;
+ return Objects.equals(from, other.from) && Objects.equals(to, other.to);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("%s -> %s", from, to);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
new file mode 100644
index 00000000000..3ee3f6aa32e
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java
@@ -0,0 +1,24 @@
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import onnx.Onnx;
+
+import java.util.List;
+
+public class OperationMapper {
+
+ public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) {
+ switch (node.getOpType().toLowerCase()) {
+ case "add": return new Join(node, inputs, ScalarFunctions.add());
+ case "matmul": return new MatMul(node, inputs);
+ }
+
+ OnnxOperation op = new NoOp(node, inputs);
+ op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
+ return op;
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java
new file mode 100644
index 00000000000..f6e117bfd74
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java
@@ -0,0 +1,250 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+
+import com.yahoo.tensor.TensorType;
+import onnx.Onnx;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * A Vespa tensor type is ordered by the lexicographical ordering of dimension
+ * names. ONNX tensors have an explicit ordering of their dimensions.
+ * During import, we need to track the Vespa dimension that matches the
+ * corresponding ONNX dimension as the ordering can change after
+ * dimension renaming. That is the purpose of this class.
+ *
+ * @author lesters
+ */
+public class OrderedTensorType {
+
+ private final TensorType type;
+ private final List<TensorType.Dimension> dimensions;
+
+ private final long[] innerSizesOnnx;
+ private final long[] innerSizesVespa;
+ private final int[] dimensionMap;
+
+ private OrderedTensorType(List<TensorType.Dimension> dimensions) {
+ this.dimensions = Collections.unmodifiableList(dimensions);
+ this.type = new TensorType.Builder(dimensions).build();
+ this.innerSizesOnnx = new long[dimensions.size()];
+ this.innerSizesVespa = new long[dimensions.size()];
+ this.dimensionMap = createDimensionMap();
+ }
+
+ public TensorType type() { return this.type; }
+
+ public int rank() { return dimensions.size(); }
+
+ public List<TensorType.Dimension> dimensions() {
+ return dimensions;
+ }
+
+ public List<String> dimensionNames() {
+ return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
+ }
+
+ private int[] createDimensionMap() {
+ int numDimensions = dimensions.size();
+ if (numDimensions == 0) {
+ return null;
+ }
+ innerSizesOnnx[numDimensions - 1] = 1;
+ innerSizesVespa[numDimensions - 1] = 1;
+ for (int i = numDimensions - 1; --i >= 0; ) {
+ innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1];
+ innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
+ }
+ int[] mapping = new int[numDimensions];
+ for (int i = 0; i < numDimensions; ++i) {
+ TensorType.Dimension dim1 = dimensions().get(i);
+ for (int j = 0; j < numDimensions; ++j) {
+ TensorType.Dimension dim2 = type.dimensions().get(j);
+ if (dim1.equals(dim2)) {
+ mapping[i] = j;
+ break;
+ }
+ }
+ }
+ return mapping;
+ }
+
+ /**
+ * When dimension ordering between Vespa and Onnx differs, i.e.
+ * after dimension renaming, use the dimension map to read in values
+ * so that they are correctly laid out in memory for Vespa.
+ * Used when importing tensors from Onnx.
+ */
+ public int toDirectIndex(int index) {
+ if (dimensions.size() == 0) {
+ return 0;
+ }
+ if (dimensionMap == null) {
+ throw new IllegalArgumentException("Dimension map is not available");
+ }
+ int directIndex = 0;
+ long rest = index;
+ for (int i = 0; i < dimensions.size(); ++i) {
+ long address = rest / innerSizesOnnx[i];
+ directIndex += innerSizesVespa[dimensionMap[i]] * address;
+ rest %= innerSizesOnnx[i];
+ }
+ return directIndex;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof OrderedTensorType)) {
+ return false;
+ }
+ OrderedTensorType other = (OrderedTensorType) obj;
+ if (dimensions.size() != dimensions.size()) {
+ return false;
+ }
+ List<TensorType.Dimension> thisDimensions = this.dimensions();
+ List<TensorType.Dimension> otherDimensions = other.dimensions();
+ for (int i = 0; i < thisDimensions.size(); ++i) {
+ if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public void verifyType(Onnx.TypeProto typeProto) {
+ Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape();
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
+ throw new IllegalArgumentException("Onnx shape of does not match Vespa shape");
+ }
+ for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) {
+ int vespaIndex = dimensionMap[onnxIndex];
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex);
+ TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
+ if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+ public OrderedTensorType rename(DimensionRenamer renamer) {
+ List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
+ for (TensorType.Dimension dimension : dimensions) {
+ String oldName = dimension.name();
+ Optional<String> newName = renamer.dimensionNameOf(oldName);
+ if (!newName.isPresent())
+ return this; // presumably, already renamed
+ TensorType.Dimension.Type dimensionType = dimension.type();
+ if (dimensionType == TensorType.Dimension.Type.indexedBound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
+ } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
+ } else if (dimensionType == TensorType.Dimension.Type.mapped) {
+ renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
+ }
+ }
+ return new OrderedTensorType(renamedDimensions);
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) {
+ return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
+ Onnx.TensorShapeProto shape = type.getTensorType().getShape();
+ Builder builder = new Builder(shape);
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
+ if (onnxDimension.getDimValue() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) {
+ Builder builder = new Builder();
+ for (int i = 0; i < dims.size(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ Long dimSize = dims.get(i);
+ if (dimSize >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static OrderedTensorType standardType(OrderedTensorType type) {
+ Builder builder = new Builder();
+ for (int i = 0; i < type.dimensions().size(); ++ i) {
+ TensorType.Dimension dim = type.dimensions().get(i);
+ String dimensionName = "d" + i;
+ if (dim.size().isPresent() && dim.size().get() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static class Builder {
+
+ private final Onnx.TensorShapeProto shape;
+ private final List<TensorType.Dimension> dimensions;
+
+ public Builder(Onnx.TensorShapeProto shape) {
+ this.shape = shape;
+ this.dimensions = new ArrayList<>(shape.getDimCount());
+ }
+
+ public Builder() {
+ this.shape = null;
+ this.dimensions = new ArrayList<>();
+ }
+
+ public Builder add(TensorType.Dimension vespaDimension) {
+ if (shape != null) {
+ int index = dimensions.size();
+ Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index);
+ long size = onnxDimension.getDimValue();
+ if (size >= 0) {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
+ throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
+ "dimension types");
+ }
+ if (!vespaDimension.size().isPresent()) {
+ throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
+ "not have a size");
+ }
+ if (vespaDimension.size().get() != size) {
+ throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
+ "dimension sizes. TensorFlow: " + size + " Vespa: " +
+ vespaDimension.size().get());
+ }
+ } else {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
+ throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " +
+ "dimension types");
+ }
+ }
+ }
+ this.dimensions.add(vespaDimension);
+ return this;
+ }
+
+ public OrderedTensorType build() {
+ return new OrderedTensorType(dimensions);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java
new file mode 100644
index 00000000000..1c5fef456cb
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java
@@ -0,0 +1,79 @@
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer;
+
+import com.google.protobuf.ByteString;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import onnx.Onnx;
+
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.nio.FloatBuffer;
+import java.util.List;
+
+/**
+ * Converts Onnx tensors into Vespa tensors.
+ *
+ * @author lesters
+ */
+public class TensorConverter {
+
+ public static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) {
+ Values values = readValuesOf(tensorProto);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
+ for (int i = 0; i < values.size(); i++) {
+ builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i));
+ }
+ return builder.build();
+ }
+
+ /* todo: support more types */
+ private static Values readValuesOf(Onnx.TensorProto tensorProto) {
+ if (tensorProto.hasRawData()) {
+ switch (tensorProto.getDataType()) {
+ case FLOAT: return new RawFloatValues(tensorProto);
+ }
+ } else {
+ switch (tensorProto.getDataType()) {
+ case FLOAT: return new FloatValues(tensorProto);
+ }
+ }
+ throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
+ tensorProto.getDataType() + " to a Vespa tensor");
+ }
+
+ /** Allows reading values from buffers of various numeric types as bytes */
+ private static abstract class Values {
+ abstract double get(int i);
+ abstract int size();
+ }
+
+ private static abstract class RawValues extends Values {
+ ByteBuffer bytes(Onnx.TensorProto tensorProto) {
+ ByteString byteString = tensorProto.getRawData();
+ return byteString.asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN);
+ }
+ }
+
+ private static class RawFloatValues extends RawValues {
+ private final FloatBuffer values;
+ private final int size;
+ RawFloatValues(Onnx.TensorProto tensorProto) {
+ values = bytes(tensorProto).asFloatBuffer();
+ size = values.remaining();
+ }
+ @Override double get(int i) { return values.get(i); }
+ @Override int size() { return size; }
+ }
+
+ private static class FloatValues extends Values {
+ private final Onnx.TensorProto tensorProto;
+ FloatValues(Onnx.TensorProto tensorProto) {
+ this.tensorProto = tensorProto;
+ }
+ @Override double get(int i) { return tensorProto.getFloatData(i); }
+ @Override int size() { return tensorProto.getFloatDataCount(); }
+ }
+
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
new file mode 100644
index 00000000000..a8d8d63daf4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java
@@ -0,0 +1,64 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.Collections;
+import java.util.List;
+
+public class Argument extends OnnxOperation {
+
+ private Onnx.ValueInfoProto valueInfo;
+ private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
+
+ public Argument(Onnx.ValueInfoProto valueInfoProto) {
+ super(null, Collections.emptyList());
+ valueInfo = valueInfoProto;
+ standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType());
+ }
+
+ @Override
+ public String vespaName() {
+ return vespaName(valueInfo.getName());
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_");
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
+ if (!standardNamingType.equals(type)) {
+ List<String> renameFrom = standardNamingType.dimensionNames();
+ List<String> renameTo = type.dimensionNames();
+ output = new Rename(output, renameFrom, renameTo);
+ }
+ return output;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isInput() {
+ return true;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return false;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
new file mode 100644
index 00000000000..ab650bf8d77
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java
@@ -0,0 +1,59 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.Collections;
+import java.util.Optional;
+
+public class Constant extends OnnxOperation {
+
+ final Onnx.TensorProto tensorProto;
+
+ public Constant(Onnx.TensorProto tensorProto) {
+ super(null, Collections.emptyList());
+ this.tensorProto = tensorProto;
+ }
+
+ /** todo: Constant names are prefixed by "modelName_" to avoid name conflicts between models */
+ @Override
+ public String vespaName() {
+// return modelName() + "_" + super.vespaName();
+ return vespaName(tensorProto.getName());
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_");
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public Optional<Value> getConstantValue() {
+ return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java
new file mode 100644
index 00000000000..fe2004a528d
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java
@@ -0,0 +1,122 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+
+public class Join extends OnnxOperation {
+
+ private final DoubleBinaryOperator operator;
+
+ public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) {
+ super(node, inputs);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < a.rank(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ long size = aDim.size().orElse(-1L);
+
+ if (i - sizeDifference >= 0) {
+ TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
+ size = Math.max(size, bDim.size().orElse(-1L));
+ }
+
+ if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name(), size));
+ } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
+ builder.add(TensorType.Dimension.indexed(aDim.name()));
+ } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
+ builder.add(TensorType.Dimension.mapped(aDim.name()));
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+
+ OnnxOperation a = largestInput();
+ OnnxOperation b = smallestInput();
+
+ List<String> aDimensionsToReduce = new ArrayList<>();
+ List<String> bDimensionsToReduce = new ArrayList<>();
+ int sizeDifference = a.type().get().rank() - b.type().get().rank();
+ for (int i = 0; i < b.type().get().rank(); ++i) {
+ TensorType.Dimension bDim = b.type().get().dimensions().get(i);
+ TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
+ long bSize = bDim.size().orElse(-1L);
+ long aSize = aDim.size().orElse(-1L);
+ if (bSize == 1L && aSize != 1L) {
+ bDimensionsToReduce.add(bDim.name());
+ }
+ if (aSize == 1L && bSize != 1L) {
+ aDimensionsToReduce.add(bDim.name());
+ }
+ }
+
+ TensorFunction aReducedFunction = a.function().get();
+ if (aDimensionsToReduce.size() > 0) {
+ aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
+ }
+ TensorFunction bReducedFunction = b.function().get();
+ if (bDimensionsToReduce.size() > 0) {
+ bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
+ }
+
+ return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ OrderedTensorType a = largestInput().type().get();
+ OrderedTensorType b = smallestInput().type().get();
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < b.rank(); ++i) {
+ String bDim = b.dimensions().get(i).name();
+ String aDim = a.dimensions().get(i + sizeDifference).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ }
+ }
+
+ private OnnxOperation largestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+ private OnnxOperation smallestInput() {
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java
new file mode 100644
index 00000000000..1b388e2ae89
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java
@@ -0,0 +1,75 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.function.DoubleBinaryOperator;
+
+public class MatMul extends OnnxOperation {
+
+ public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) {
+ super(node, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
+ typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+ if (aType.type().rank() < 2 || bType.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (aType.type().rank() != bType.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (!aFunction.isPresent() || !bFunction.isPresent()) {
+ return null;
+ }
+ return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+
+ String aDim0 = aDimensions.get(0).name();
+ String aDim1 = aDimensions.get(1).name();
+ String bDim0 = bDimensions.get(0).name();
+ String bDim1 = bDimensions.get(1).name();
+
+ // The second dimension of a should have the same name as the first dimension of b
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
+
+ // The first dimension of a should have a different name than the second dimension of b
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
+
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
new file mode 100644
index 00000000000..b1136a0ce0a
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java
@@ -0,0 +1,32 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends OnnxOperation {
+
+ public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) {
+ super(node, Collections.emptyList()); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
new file mode 100644
index 00000000000..2c8003f5951
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java
@@ -0,0 +1,119 @@
+package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.TensorFunction;
+import onnx.Onnx;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * Wraps an ONNX node and produces the respective Vespa tensor operation.
+ * During import, a graph of these operations are constructed. Then, the
+ * types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper
+ * Vespa expressions can be extracted.
+ *
+ * @author lesters
+ */
+public abstract class OnnxOperation {
+
+ protected final Onnx.NodeProto node; // can be null for onnx inputs and constants
+ protected final List<OnnxOperation> inputs;
+ protected final List<OnnxOperation> outputs = new ArrayList<>();
+ protected final List<String> importWarnings = new ArrayList<>();
+
+ protected OrderedTensorType type;
+ protected TensorFunction function;
+ protected Value constantValue = null;
+
+ OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) {
+ this.node = node;
+ this.inputs = Collections.unmodifiableList(inputs);
+ this.inputs.forEach(i -> i.outputs.add(this));
+ }
+
+ protected abstract OrderedTensorType lazyGetType();
+ protected abstract TensorFunction lazyGetFunction();
+
+ /** Returns the Vespa tensor type of this operation if it exists */
+ public Optional<OrderedTensorType> type() {
+ if (type == null) {
+ type = lazyGetType();
+ }
+ return Optional.ofNullable(type);
+ }
+
+ /** Returns the Vespa tensor function implementing all operations from this node with inputs */
+ public Optional<TensorFunction> function() {
+ if (function == null) {
+ if (isConstant()) {
+ ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
+ function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
+ } else {
+ function = lazyGetFunction();
+ }
+ }
+ return Optional.ofNullable(function);
+ }
+
+ /** Return Onnx node */
+ public Onnx.NodeProto node() { return node; }
+
+ /** Return unmodifiable list of inputs */
+ public List<OnnxOperation> inputs() { return inputs; }
+
+ /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
+ public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); }
+
+ /** Add dimension name constraints for this operation */
+ public void addDimensionNameConstraints(DimensionRenamer renamer) { }
+
+ /** Performs dimension rename for this operation */
+ public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
+
+ /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
+ public boolean isInput() { return false; }
+
+ /** Return true if this node is constant */
+ public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); }
+
+ /** Gets the constant value if it exists */
+ public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
+
+ /** Retrieve the valid Vespa name of this node */
+ public String vespaName() { return vespaName(node.getName()); }
+ public String vespaName(String name) { return name != null ? name.replace('/', '_').replace(':','_') : null; }
+
+ /** Retrieve the list of warnings produced during its lifetime */
+ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
+
+ /** Set an input warning */
+ public void warning(String warning) { importWarnings.add(warning); }
+
+ boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) {
+ if (inputs.size() != expected) {
+ throw new IllegalArgumentException("Expected " + expected + " inputs " +
+ "for '" + node.getName() + "', got " + inputs.size());
+ }
+ return inputs.stream().map(func).allMatch(Optional::isPresent);
+ }
+
+ boolean allInputTypesPresent(int expected) {
+ return verifyInputs(expected, OnnxOperation::type);
+ }
+
+ boolean allInputFunctionsPresent(int expected) {
+ return verifyInputs(expected, OnnxOperation::function);
+ }
+
+}
diff --git a/searchlib/src/main/protobuf/onnx.proto b/searchlib/src/main/protobuf/onnx.proto
new file mode 100644
index 00000000000..dc6542867e0
--- /dev/null
+++ b/searchlib/src/main/protobuf/onnx.proto
@@ -0,0 +1,464 @@
+//
+// WARNING: This file is automatically generated! Please edit onnx.in.proto.
+//
+
+
+// Copyright (c) Facebook Inc. and Microsoft Corporation.
+// Licensed under the MIT license.
+
+syntax = "proto2";
+
+package onnx;
+
+// Overview
+//
+// ONNX is an open specification that is comprised of the following components:
+//
+// 1) A definition of an extensible computation graph model.
+// 2) Definitions of standard data types.
+// 3) Definitions of built-in operators.
+//
+// This document describes the syntax of models and their computation graphs,
+// as well as the standard data types. Together, they are referred to as the ONNX
+// Intermediate Representation, or 'IR' for short.
+//
+// The normative semantic specification of the ONNX IR is found in docs/IR.md.
+// Definitions of the built-in neural network operators may be found in docs/Operators.md.
+
+// Notes
+//
+// Release
+//
+// We are still in the very early stage of defining ONNX. The current
+// version of ONNX is a starting point. While we are actively working
+// towards a complete spec, we would like to get the community involved
+// by sharing our working version of ONNX.
+//
+// Protobuf compatibility
+//
+// To simplify framework compatibility, ONNX is defined using the subset of protobuf
+// that is compatible with both protobuf v2 and v3. This means that we do not use any
+// protobuf features that are only available in one of the two versions.
+//
+// Here are the most notable contortions we have to carry out to work around
+// these limitations:
+//
+// - No 'map' (added protobuf 3.0). We instead represent mappings as lists
+// of key-value pairs, where order does not matter and duplicates
+// are not allowed.
+
+
+// Versioning
+//
+// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md
+//
+// To be compatible with both proto2 and proto3, we will use a version number
+// that is not defined by the default value but an explicit enum number.
+enum Version {
+ // proto3 requires the first enum value to be zero.
+ // We add this just to appease the compiler.
+ _START_VERSION = 0;
+ // The version field is always serialized and we will use it to store the
+ // version that the graph is generated from. This helps us set up version
+ // control. We should use version as
+ // xx(major) - xx(minor) - xxxx(bugfix)
+ // and we are starting with 0x00000001 (0.0.1), which was the
+ // version we published on Oct 10, 2017.
+ IR_VERSION_2017_10_10 = 0x00000001;
+
+ // IR_VERSION 0.0.2 published on Oct 30, 2017
+ // - Added type discriminator to AttributeProto to support proto3 users
+ IR_VERSION_2017_10_30 = 0x00000002;
+
+ // IR VERSION 0.0.3 published on Nov 3, 2017
+ // - For operator versioning:
+ // - Added new message OperatorSetIdProto
+ // - Added opset_import in ModelProto
+ // - For vendor extensions, added domain in NodeProto
+ IR_VERSION = 0x00000003;
+}
+
+// Attributes
+//
+// A named attribute containing either singular float, integer, string, graph,
+// and tensor values, or repeated float, integer, string, graph, and tensor values.
+// An AttributeProto MUST contain the name field, and *only one* of the
+// following content fields, effectively enforcing a C/C++ union equivalent.
+message AttributeProto {
+
+ // Note: this enum is structurally identical to the OpSchema::AttrType
+ // enum defined in schema.h. If you rev one, you likely need to rev the other.
+ enum AttributeType {
+ UNDEFINED = 0;
+ FLOAT = 1;
+ INT = 2;
+ STRING = 3;
+ TENSOR = 4;
+ GRAPH = 5;
+
+ FLOATS = 6;
+ INTS = 7;
+ STRINGS = 8;
+ TENSORS = 9;
+ GRAPHS = 10;
+ }
+
+ // The name field MUST be present for this version of the IR.
+ optional string name = 1; // namespace Attribute
+
+ // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function.
+ // In this case, this AttributeProto does not contain data, and it's a reference of attribute
+ // in parent scope.
+ // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph.
+ optional string ref_attr_name = 21;
+
+ // A human-readable documentation for this attribute. Markdown is allowed.
+ optional string doc_string = 13;
+
+ // The type field MUST be present for this version of the IR.
+ // For 0.0.1 versions of the IR, this field was not defined, and
+ // implementations needed to use has_field hueristics to determine
+ // which value field was in use. For IR_VERSION 0.0.2 or later, this
+ // field MUST be set and match the f|i|s|t|... field in use. This
+ // change was made to accomodate proto3 implementations.
+ optional AttributeType type = 20; // discriminator that indicates which field below is in use
+
+ // Exactly ONE of the following fields must be present for this version of the IR
+ optional float f = 2; // float
+ optional int64 i = 3; // int
+ optional bytes s = 4; // UTF-8 string
+ optional TensorProto t = 5; // tensor value
+ optional GraphProto g = 6; // graph
+ // Do not use field below, it's deprecated.
+ // optional ValueProto v = 12; // value - subsumes everything but graph
+
+ repeated float floats = 7; // list of floats
+ repeated int64 ints = 8; // list of ints
+ repeated bytes strings = 9; // list of UTF-8 strings
+ repeated TensorProto tensors = 10; // list of tensors
+ repeated GraphProto graphs = 11; // list of graph
+}
+
+// Defines information on value, including the name, the type, and
+// the shape of the value.
+message ValueInfoProto {
+ // This field MUST be present in this version of the IR.
+ optional string name = 1; // namespace Value
+ // This field MUST be present in this version of the IR.
+ optional TypeProto type = 2;
+ // A human-readable documentation for this value. Markdown is allowed.
+ optional string doc_string = 3;
+}
+
+// Nodes
+//
+// Computation graphs are made up of a DAG of nodes, which represent what is
+// commonly called a "layer" or "pipeline stage" in machine learning frameworks.
+//
+// For example, it can be a node of type "Conv" that takes in an image, a filter
+// tensor and a bias tensor, and produces the convolved output.
+message NodeProto {
+ repeated string input = 1; // namespace Value
+ repeated string output = 2; // namespace Value
+
+ // An optional identifier for this node in a graph.
+ // This field MAY be absent in ths version of the IR.
+ optional string name = 3; // namespace Node
+
+ // The symbolic identifier of the Operator to execute.
+ optional string op_type = 4; // namespace Operator
+ // The domain of the OperatorSet that specifies the operator named by op_type.
+ optional string domain = 7; // namespace Domain
+
+ // Additional named attributes.
+ repeated AttributeProto attribute = 5;
+
+ // A human-readable documentation for this node. Markdown is allowed.
+ optional string doc_string = 6;
+}
+
+// Models
+//
+// ModelProto is a top-level file/container format for bundling a ML model and
+// associating its computation graph with metadata.
+//
+// The semantics of the model are described by the associated GraphProto.
+message ModelProto {
+ // The version of the IR this model targets. See Version enum above.
+ // This field MUST be present.
+ optional int64 ir_version = 1;
+
+ // The OperatorSets this model relies on.
+ // All ModelProtos MUST have at least one entry that
+ // specifies which version of the ONNX OperatorSet is
+ // being imported.
+ //
+ // All nodes in the ModelProto's graph will bind against the operator
+ // with the same-domain/same-op_type operator with the HIGHEST version
+ // in the referenced operator sets.
+ repeated OperatorSetIdProto opset_import = 8;
+
+ // The name of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ optional string producer_name = 2;
+
+ // The version of the framework or tool used to generate this model.
+ // This field SHOULD be present to indicate which implementation/tool/framework
+ // emitted the model.
+ optional string producer_version = 3;
+
+ // Domain name of the model.
+ // We use reverse domain names as name space indicators. For example:
+ // `com.facebook.fair` or `com.microsoft.cognitiveservices`
+ //
+ // Together with `model_version` and GraphProto.name, this forms the unique identity of
+ // the graph.
+ optional string domain = 4;
+
+ // The version of the graph encoded. See Version enum below.
+ optional int64 model_version = 5;
+
+ // A human-readable documentation for this model. Markdown is allowed.
+ optional string doc_string = 6;
+
+ // The parameterized graph that is evaluated to execute the model.
+ optional GraphProto graph = 7;
+
+ // Named metadata values; keys should be distinct.
+ repeated StringStringEntryProto metadata_props = 14;
+};
+
+// StringStringEntryProto follows the pattern for cross-proto-version maps.
+// See https://developers.google.com/protocol-buffers/docs/proto3#maps
+message StringStringEntryProto {
+ optional string key = 1;
+ optional string value= 2;
+};
+
+// Graphs
+//
+// A graph defines the computational logic of a model and is comprised of a parameterized
+// list of nodes that form a directed acyclic graph based on their inputs and outputs.
+// This is the equivalent of the "network" or "graph" in many deep learning
+// frameworks.
+message GraphProto {
+ // The nodes in the graph, sorted topologically.
+ repeated NodeProto node = 1;
+
+ // The name of the graph.
+ optional string name = 2; // namespace Graph
+
+ // A list of named tensor values, used to specify constant inputs of the graph.
+ // Each TensorProto entry must have a distinct name (within the list) that
+ // also appears in the input list.
+ repeated TensorProto initializer = 5;
+
+ // A human-readable documentation for this graph. Markdown is allowed.
+ optional string doc_string = 10;
+
+ // The inputs and outputs of the graph.
+ repeated ValueInfoProto input = 11;
+ repeated ValueInfoProto output = 12;
+
+ // Information for the values in the graph. The ValueInfoProto.name's
+ // must be distinct. It is optional for a value to appear in value_info list.
+ repeated ValueInfoProto value_info = 13;
+
+ // DO NOT USE the following fields, they were deprecated from earlier versions.
+ // repeated string input = 3;
+ // repeated string output = 4;
+ // optional int64 ir_version = 6;
+ // optional int64 producer_version = 7;
+ // optional string producer_tag = 8;
+ // optional string domain = 9;
+}
+
+// Tensors
+//
+// A serialized tensor value.
+message TensorProto {
+ enum DataType {
+ UNDEFINED = 0;
+ // Basic types.
+ FLOAT = 1; // float
+ UINT8 = 2; // uint8_t
+ INT8 = 3; // int8_t
+ UINT16 = 4; // uint16_t
+ INT16 = 5; // int16_t
+ INT32 = 6; // int32_t
+ INT64 = 7; // int64_t
+ STRING = 8; // string
+ BOOL = 9; // bool
+
+ // Advanced types
+ FLOAT16 = 10;
+ DOUBLE = 11;
+ UINT32 = 12;
+ UINT64 = 13;
+ COMPLEX64 = 14; // complex with float32 real and imaginary components
+ COMPLEX128 = 15; // complex with float64 real and imaginary components
+ // Future extensions go here.
+ }
+
+ // The shape of the tensor.
+ repeated int64 dims = 1;
+
+ // The data type of the tensor.
+ optional DataType data_type = 2;
+
+ // For very large tensors, we may want to store them in chunks, in which
+ // case the following fields will specify the segment that is stored in
+ // the current TensorProto.
+ message Segment {
+ optional int64 begin = 1;
+ optional int64 end = 2;
+ }
+ optional Segment segment = 3;
+
+ // Tensor content must be organized in row-major order.
+ //
+ // Depending on the data_type field, exactly one of the fields below with
+ // name ending in _data is used to store the elements of the tensor.
+
+ // For float and complex64 values
+ // Complex64 tensors are encoded as a single array of floats,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component apparing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be FLOAT or COMPLEX64.
+ repeated float float_data = 4 [packed = true];
+
+ // For int32, uint8, int8, uint16, int16, bool, and float16 values
+ // float16 values must be bit-wise converted to an uint16_t prior
+ // to writing to the buffer.
+ // When this field is present, the data_type field MUST be
+ // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32
+ repeated int32 int32_data = 5 [packed = true];
+
+ // For strings.
+ // Each element of string_data is a UTF-8 encoded Unicode
+ // string. No trailing null, no leading BOM. The protobuf "string"
+ // scalar type is not used to match ML community conventions.
+ // When this field is present, the data_type field MUST be STRING
+ repeated bytes string_data = 6;
+
+ // For int64.
+ // When this field is present, the data_type field MUST be INT64
+ repeated int64 int64_data = 7 [packed = true];
+
+ // Optionally, a name for the tensor.
+ optional string name = 8; // namespace Value
+
+ // A human-readable documentation for this tensor. Markdown is allowed.
+ optional string doc_string = 12;
+
+ // Serializations can either use one of the fields above, or use this
+ // raw bytes field. The only exception is the string case, where one is
+ // required to store the content in the repeated bytes string_data field.
+ //
+ // When this raw_data field is used to store tensor value, elements MUST
+ // be stored in as fixed-width, little-endian order.
+ // Floating-point data types MUST be stored in IEEE 754 format.
+ // Complex64 elements must be written as two consecutive FLOAT values, real component first.
+ // Complex128 elements must be written as two consecutive DOUBLE values, real component first.
+ // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false).
+ //
+ // Note: the advantage of specific field rather than the raw_data field is
+ // that in some cases (e.g. int data), protobuf does a better packing via
+ // variable length storage, and may lead to smaller binary footprint.
+ // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED
+ optional bytes raw_data = 9;
+
+ // For double
+ // Complex64 tensors are encoded as a single array of doubles,
+ // with the real components appearing in odd numbered positions,
+ // and the corresponding imaginary component apparing in the
+ // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i]
+ // is encoded as [1.0, 2.0 ,3.0 ,4.0]
+ // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128
+ repeated double double_data = 10 [packed = true];
+
+ // For uint64 and uint32 values
+ // When this field is present, the data_type field MUST be
+ // UINT32 or UINT64
+ repeated uint64 uint64_data = 11 [packed = true];
+}
+
+// Defines a tensor shape. A dimension can be either an integer value
+// or a symbolic variable. A symbolic variable represents an unknown
+// dimension.
+message TensorShapeProto {
+ message Dimension {
+ oneof value {
+ int64 dim_value = 1;
+ string dim_param = 2; // namespace Shape
+ };
+ // Standard denotation can optionally be used to denote tensor
+ // dimensions with standard semantic descriptions to ensure
+ // that operations are applied to the correct axis of a tensor.
+ optional string denotation = 3;
+ };
+ repeated Dimension dim = 1;
+}
+
+// A set of pre-defined constants to be used as values for
+// the standard denotation field in TensorShapeProto.Dimension
+// for semantic description of the tensor dimension.
+message DenotationConstProto {
+ // Describe a batch number dimension.
+ optional string DATA_BATCH = 1 [default = "DATA_BATCH"];
+ // Describe a channel dimension.
+ optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"];
+ // Describe a time dimension.
+ optional string DATA_TIME = 3 [default = "DATA_TIME"];
+ // Describe a feature dimension. This is typically a feature
+ // dimension in RNN and/or spatial dimension in CNN.
+ optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"];
+ // Describe a filter in-channel dimension. This is the dimension
+ // that is identical (in size) to the channel dimension of the input
+ // image feature maps.
+ optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"];
+ // Describe a filter out channel dimension. This is the dimension
+ // that is identical (int size) to the channel dimension of the output
+ // image feature maps.
+ optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"];
+ // Describe a filter spatial dimension.
+ optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"];
+}
+
+// Types
+//
+// The standard ONNX data types.
+message TypeProto {
+
+ message Tensor {
+ // This field MUST NOT have the value of UNDEFINED
+ // This field MUST be present for this version of the IR.
+ optional TensorProto.DataType elem_type = 1;
+ optional TensorShapeProto shape = 2;
+ }
+
+
+ oneof value {
+ // The type of a tensor.
+ Tensor tensor_type = 1;
+
+ }
+}
+
+// Operator Sets
+//
+// OperatorSets are uniquely identified by a (domain, opset_version) pair.
+message OperatorSetIdProto {
+ // The domain of the operator set being identified.
+ // The empty string ("") or absence of this field implies the operator
+ // set that is defined as part of the ONNX specification.
+ // This field MUST be present in this version of the IR when referring to any other operator set.
+ optional string domain = 1;
+
+ // The version of the operator set being identified.
+ // This field MUST be present in this version of the IR.
+ optional int64 version = 2;
+} \ No newline at end of file
diff --git a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx b/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx
new file mode 100644
index 00000000000..a86019bf53a
--- /dev/null
+++ b/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx
Binary files differ
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
new file mode 100644
index 00000000000..e118c2b885a
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -0,0 +1,109 @@
+package com.yahoo.searchlib.rankingexpression.integration.onnx;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+import org.tensorflow.SavedModelBundle;
+
+import java.io.IOException;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * @author lesters
+ */
+public class OnnxMnistSoftmaxImportTestCase {
+
+ @Test
+ public void testMnistSoftmaxImport() throws IOException {
+ OnnxModel model = new OnnxImporter().importModel("src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx", "add");
+
+ // Check constants
+ assertEquals(2, model.largeConstants().size());
+
+ Tensor constant0 = model.largeConstants().get("Variable_0");
+ assertNotNull(constant0);
+ assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
+ constant0.type());
+ assertEquals(7840, constant0.size());
+
+ Tensor constant1 = model.largeConstants().get("Variable_1_0");
+ assertNotNull(constant1);
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
+ constant1.type());
+ assertEquals(10, constant1.size());
+
+ // Check required macros (inputs)
+ assertEquals(1, model.requiredMacros().size());
+ assertTrue(model.requiredMacros().containsKey("Placeholder_0"));
+ assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
+ model.requiredMacros().get("Placeholder_0"));
+
+ // Check outputs
+ RankingExpression output = model.expressions().get("add");
+ assertNotNull(output);
+ assertEquals("add", output.getName());
+ assertEquals("join(reduce(join(rename(Placeholder_0, (d0, d1), (d0, d2)), constant(Variable_0), f(a,b)(a * b)), sum, d2), constant(Variable_1_0), f(a,b)(a + b))",
+ output.getRoot().toString());
+ }
+
+ @Test
+ public void testComparisonBetweenOnnxAndTensorflow() {
+ String tfModelPath = "src/test/files/integration/tensorflow/mnist_softmax/saved";
+ String onnxModelPath = "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx";
+
+ Tensor argument = placeholderArgument();
+ Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add");
+ Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder_0", "add");
+
+ assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult);
+ }
+
+ private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
+ SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve");
+ TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel);
+ return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
+ }
+
+ private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
+ OnnxModel model = new OnnxImporter().importModel(path, output);
+ return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
+ }
+
+ private Tensor evaluateExpression(RankingExpression expression, Context context, Tensor argument, String input) {
+ context.put(input, new TensorValue(argument));
+ return expression.evaluate(context).asTensor();
+ }
+
+ private Context contextFrom(TensorFlowModel result) {
+ MapContext context = new MapContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private Context contextFrom(OnnxModel result) {
+ MapContext context = new MapContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ return context;
+ }
+
+ private Tensor placeholderArgument() {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 784).build());
+ for (int d0 = 0; d0 < 1; d0++)
+ for (int d1 = 0; d1 < 784; d1++)
+ b.cell(d1 * 1.0 / 784, d0, d1);
+ return b.build();
+ }
+
+
+}