diff options
16 files changed, 1955 insertions, 0 deletions
diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 1615f248910..fb9a81b51a5 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -117,6 +117,26 @@ </execution> </executions> </plugin> + <plugin> + <groupId>com.github.os72</groupId> + <artifactId>protoc-jar-maven-plugin</artifactId> + <version>3.5.1.1</version> + <executions> + <execution> + <phase>generate-sources</phase> + <goals> + <goal>run</goal> + </goals> + <configuration> + <addSources>main</addSources> + <outputDirectory>${project.build.directory}/generated-sources/protobuf</outputDirectory> + <inputDirectories> + <include>src/main/protobuf</include> + </inputDirectories> + </configuration> + </execution> + </executions> + </plugin> </plugins> </build> </project> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java new file mode 100644 index 00000000000..047d1b187f5 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxImporter.java @@ -0,0 +1,265 @@ +package com.yahoo.searchlib.rankingexpression.integration.onnx; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Argument; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OperationMapper; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.io.FileInputStream; +import java.io.IOException; +import java.util.Collection; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Converts a ONNX model into a ranking expression and set of constants. + * + * @author lesters + */ +public class OnnxImporter { + + public OnnxModel importModel(String modelPath, String outputNode) { + try (FileInputStream inputStream = new FileInputStream(modelPath)) { + Onnx.ModelProto model = Onnx.ModelProto.parseFrom(inputStream); + return importModel(model, outputNode); + } catch (IOException e) { + throw new IllegalArgumentException("Could not import ONNX model from '" + modelPath + "'", e); + } + } + + public OnnxModel importModel(Onnx.ModelProto model, String outputNode) { + return importGraph(model.getGraph(), outputNode); + } + + private static OnnxModel importGraph(Onnx.GraphProto graph, String outputNode) { + OnnxModel model = new OnnxModel(outputNode); + OperationIndex index = new OperationIndex(); + + OnnxOperation output = importNode(outputNode, graph, index); + output.type().orElseThrow(() -> new IllegalArgumentException("Output of '" + outputNode + "' has no type.")) + .verifyType(getOutputNode(outputNode, graph).getType()); + + findDimensionNames(output); + importExpressions(output, model); + + return model; + } + + private static OnnxOperation importNode(String nodeName, Onnx.GraphProto graph, OperationIndex index) { + if (index.alreadyImported(nodeName)) { + return index.get(nodeName); + } + OnnxOperation operation; + if (isArgumentTensor(nodeName, graph)) { + operation = new Argument(getArgumentTensor(nodeName, graph)); + } else if (isConstantTensor(nodeName, graph)) { + operation = new Constant(getConstantTensor(nodeName, graph)); + } else { + Onnx.NodeProto node = getNodeFromGraph(nodeName, graph); + List<OnnxOperation> inputs = importNodeInputs(node, graph, index); + operation = OperationMapper.get(node, inputs); + } + index.put(nodeName, operation); + + return operation; + } + + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { + Onnx.ValueInfoProto value = getArgumentTensor(name, graph); + Onnx.TensorProto tensor = getConstantTensor(name, graph); + return value != null && tensor == null; + } + + private static boolean isConstantTensor(String name, Onnx.GraphProto graph) { + Onnx.ValueInfoProto value = getArgumentTensor(name, graph); + Onnx.TensorProto tensor = getConstantTensor(name, graph); + return value != null && tensor != null; + } + + private static Onnx.ValueInfoProto getArgumentTensor(String name, Onnx.GraphProto graph) { + for (Onnx.ValueInfoProto valueInfo : graph.getInputList()) { + if (valueInfo.getName().equals(name)) { + return valueInfo; + } + } + return null; + } + + private static Onnx.TensorProto getConstantTensor(String name, Onnx.GraphProto graph) { + for (Onnx.TensorProto tensorProto : graph.getInitializerList()) { + if (tensorProto.getName().equals(name)) { + return tensorProto; + } + } + return null; + } + + private static boolean isOutputNode(String name, Onnx.GraphProto graph) { + return getOutputNode(name, graph) != null; + } + + private static Onnx.ValueInfoProto getOutputNode(String name, Onnx.GraphProto graph) { + for (Onnx.ValueInfoProto valueInfo : graph.getOutputList()) { + Onnx.NodeProto node = getNodeFromGraph(valueInfo.getName(), graph); + if (node.getName().equals(name)) { + return valueInfo; + } + } + return null; + } + + private static List<OnnxOperation> importNodeInputs(Onnx.NodeProto node, + Onnx.GraphProto graph, + OperationIndex index) { + return node.getInputList().stream() + .map(nodeName -> importNode(nodeName, graph, index)) + .collect(Collectors.toList()); + } + + /** Find dimension names to avoid excessive renaming while evaluating the model. */ + private static void findDimensionNames(OnnxOperation output) { + DimensionRenamer renamer = new DimensionRenamer(); + addDimensionNameConstraints(output, renamer); + renamer.solve(); + renameDimensions(output, renamer); + } + + private static void addDimensionNameConstraints(OnnxOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.addDimensionNameConstraints(renamer); + } + } + + private static void renameDimensions(OnnxOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.renameDimensions(renamer); + } + } + + private static void importExpressions(OnnxOperation output, OnnxModel model) { + Optional<TensorFunction> function = importExpression(output, model); + if (!function.isPresent()) { + throw new IllegalArgumentException("No valid output function could be found."); + } + } + + private static Optional<TensorFunction> importExpression(OnnxOperation operation, OnnxModel model) { + if (!operation.type().isPresent()) { + return Optional.empty(); + } + if (operation.isConstant()) { + return importConstant(operation, model); + } + importInputExpressions(operation, model); + importRankingExpression(operation, model); + importInputExpression(operation, model); + + return operation.function(); + } + + private static void importInputExpressions(OnnxOperation operation, OnnxModel model) { + operation.inputs().forEach(input -> importExpression(input, model)); + } + + private static Optional<TensorFunction> importConstant(OnnxOperation operation, OnnxModel model) { + String name = operation.vespaName(); + if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + return operation.function(); + } + + Value value = operation.getConstantValue().orElseThrow(() -> + new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + + "is constant but does not have a value.")); + if ( ! (value instanceof TensorValue)) { + return operation.function(); // scalar values are inserted directly into the expression + } + + Tensor tensor = value.asTensor(); + if (tensor.type().rank() == 0) { + model.smallConstant(name, tensor); + } else { + model.largeConstant(name, tensor); + } + return operation.function(); + } + + private static void importRankingExpression(OnnxOperation operation, OnnxModel model) { + if (operation.function().isPresent()) { + String name = operation.vespaName(); + if (!model.expressions().containsKey(name)) { + TensorFunction function = operation.function().get(); + + if (name.equals(model.output())) { + OrderedTensorType operationType = operation.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + function = new Rename(function, renameFrom, renameTo); + } + } + + try { + // We add all intermediate nodes imported as separate expressions. Only + // those referenced from the output will be used. We parse the + // TensorFunction here to convert it to a RankingExpression tree. + model.expression(name, new RankingExpression(name, function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + } + + private static void importInputExpression(OnnxOperation operation, OnnxModel model) { + if (operation.isInput()) { + // All inputs must have dimensions with standard naming convention: d0, d1, ... + OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); + model.argument(operation.vespaName(), standardNamingConvention.type()); + model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); + } + } + + + private static Onnx.NodeProto getNodeFromGraph(String nodeName, Onnx.GraphProto graph) { + boolean hasPortNumber = nodeName.contains(":"); + for (Onnx.NodeProto node : graph.getNodeList()) { + if (hasPortNumber) { + for (String outputName : node.getOutputList()) { + if (outputName.equals(nodeName)) { + return node; + } + } + } else if (node.getName().equals(nodeName)) { + return node; + } + } + throw new IllegalArgumentException("Node '" + nodeName + "' not found in ONNX graph"); + } + + private static class OperationIndex { + private final Map<String, OnnxOperation> index = new HashMap<>(); + public OnnxOperation put(String key, OnnxOperation operation) { return index.put(key, operation); } + public OnnxOperation get(String key) { return index.get(key); } + public boolean alreadyImported(String key) { return index.containsKey(key); } + public Collection<OnnxOperation> operations() { return index.values(); } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java new file mode 100644 index 00000000000..df108fcbbe7 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxModel.java @@ -0,0 +1,63 @@ +package com.yahoo.searchlib.rankingexpression.integration.onnx; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Pattern; + +/** + * The result of importing an ONNX model into Vespa. + * + * @author lesters + */ +public class OnnxModel { + + public OnnxModel(String outputNode) { + this.output = outputNode; + } + + private final String output; + private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, Tensor> smallConstants = new HashMap<>(); + private final Map<String, Tensor> largeConstants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); + private final Map<String, TensorType> requiredMacros = new HashMap<>(); + + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } + void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } + + /** Return the name of the output node for this model */ + public String output() { return output; } + + /** Returns an immutable map of the arguments (inputs) of this */ + public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } + + /** + * Returns an immutable map of the small constants of this. + */ + public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } + + /** + * Returns an immutable map of the large constants of this. + */ + public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } + + /** + * Returns an immutable map of the expressions of this - corresponding to ONNX nodes + * which are not inputs or constants. + */ + public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + + /** Returns an immutable map of the macros that must be provided by the environment running this model */ + public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java new file mode 100644 index 00000000000..2524417cee0 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/DimensionRenamer.java @@ -0,0 +1,210 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * A constraint satisfier to find suitable dimension names to reduce the + * amount of necessary renaming during evaluation of an imported model. + * + * @author lesters + */ +public class DimensionRenamer { + + private final String dimensionPrefix; + private final Map<String, List<Integer>> variables = new HashMap<>(); + private final Map<Arc, Constraint> constraints = new HashMap<>(); + private final Map<String, Integer> renames = new HashMap<>(); + + private int iterations = 0; + + public DimensionRenamer() { + this("d"); + } + + public DimensionRenamer(String dimensionPrefix) { + this.dimensionPrefix = dimensionPrefix; + } + + /** + * Add a dimension name variable. + */ + public void addDimension(String name) { + variables.computeIfAbsent(name, d -> new ArrayList<>()); + } + + /** + * Add a constraint between dimension names. + */ + public void addConstraint(String from, String to, Constraint pred, OnnxOperation operation) { + Arc arc = new Arc(from, to, operation); + Arc opposite = arc.opposite(); + constraints.put(arc, pred); + constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + } + + /** + * Retrieve resulting name of dimension after solving for constraints. + */ + public Optional<String> dimensionNameOf(String name) { + if (!renames.containsKey(name)) { + return Optional.empty(); + } + return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); + } + + /** + * Perform iterative arc consistency until we have found a solution. After + * an initial iteration, the variables (dimensions) will have multiple + * valid values. Find a single valid assignment by iteratively locking one + * dimension after another, and running the arc consistency algorithm + * multiple times. + * + * This requires having constraints that result in an absolute ordering: + * equals, lesserThan and greaterThan do that, but adding notEquals does + * not typically result in a guaranteed ordering. If that is needed, the + * algorithm below needs to be adapted with a backtracking (tree) search + * to find solutions. + */ + public void solve(int maxIterations) { + initialize(); + + // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + + for (String dimension : variables.keySet()) { + List<Integer> values = variables.get(dimension); + if (values.size() > 1) { + if (!ac3()) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution."); + } + values.sort(Integer::compare); + variables.put(dimension, Collections.singletonList(values.get(0))); + } + renames.put(dimension, variables.get(dimension).get(0)); + if (iterations > maxIterations) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + + maxIterations + " iterations"); + } + } + + // Todo: handle failure more gracefully: + // If a solution can't be found, look at the operation node in the arc + // with the most remaining constraints, and inject a rename operation. + // Then run this algorithm again. + } + + public void solve() { + solve(100000); + } + + private void initialize() { + for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { + List<Integer> values = variable.getValue(); + for (int i = 0; i < variables.size(); ++i) { + values.add(i); // invariant: values are in increasing order + } + } + } + + private boolean ac3() { + Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); + while (!workList.isEmpty()) { + Arc arc = workList.pop(); + iterations += 1; + if (revise(arc)) { + if (variables.get(arc.from).size() == 0) { + return false; // no solution found + } + for (Arc constraint : constraints.keySet()) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { + workList.add(constraint); + } + } + } + } + return true; + } + + private boolean revise(Arc arc) { + boolean revised = false; + for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { + Integer from = fromIterator.next(); + boolean satisfied = false; + for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { + Integer to = toIterator.next(); + if (constraints.get(arc).test(from, to)) { + satisfied = true; + } + } + if (!satisfied) { + fromIterator.remove(); + revised = true; + } + } + return revised; + } + + public interface Constraint { + boolean test(Integer x, Integer y); + } + + public static boolean equals(Integer x, Integer y) { + return Objects.equals(x, y); + } + + public static boolean lesserThan(Integer x, Integer y) { + return x < y; + } + + public static boolean greaterThan(Integer x, Integer y) { + return x > y; + } + + private static class Arc { + + private final String from; + private final String to; + private final OnnxOperation operation; + + Arc(String from, String to, OnnxOperation operation) { + this.from = from; + this.to = to; + this.operation = operation; + } + + Arc opposite() { + return new Arc(to, from, operation); + } + + @Override + public int hashCode() { + return Objects.hash(from, to); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof Arc)) { + return false; + } + Arc other = (Arc) obj; + return Objects.equals(from, other.from) && Objects.equals(to, other.to); + } + + @Override + public String toString() { + return String.format("%s -> %s", from, to); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java new file mode 100644 index 00000000000..3ee3f6aa32e --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OperationMapper.java @@ -0,0 +1,24 @@ +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.MatMul; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations.OnnxOperation; +import com.yahoo.tensor.functions.ScalarFunctions; +import onnx.Onnx; + +import java.util.List; + +public class OperationMapper { + + public static OnnxOperation get(Onnx.NodeProto node, List<OnnxOperation> inputs) { + switch (node.getOpType().toLowerCase()) { + case "add": return new Join(node, inputs, ScalarFunctions.add()); + case "matmul": return new MatMul(node, inputs); + } + + OnnxOperation op = new NoOp(node, inputs); + op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); + return op; + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java new file mode 100644 index 00000000000..f6e117bfd74 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/OrderedTensorType.java @@ -0,0 +1,250 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; + +import com.yahoo.tensor.TensorType; +import onnx.Onnx; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A Vespa tensor type is ordered by the lexicographical ordering of dimension + * names. ONNX tensors have an explicit ordering of their dimensions. + * During import, we need to track the Vespa dimension that matches the + * corresponding ONNX dimension as the ordering can change after + * dimension renaming. That is the purpose of this class. + * + * @author lesters + */ +public class OrderedTensorType { + + private final TensorType type; + private final List<TensorType.Dimension> dimensions; + + private final long[] innerSizesOnnx; + private final long[] innerSizesVespa; + private final int[] dimensionMap; + + private OrderedTensorType(List<TensorType.Dimension> dimensions) { + this.dimensions = Collections.unmodifiableList(dimensions); + this.type = new TensorType.Builder(dimensions).build(); + this.innerSizesOnnx = new long[dimensions.size()]; + this.innerSizesVespa = new long[dimensions.size()]; + this.dimensionMap = createDimensionMap(); + } + + public TensorType type() { return this.type; } + + public int rank() { return dimensions.size(); } + + public List<TensorType.Dimension> dimensions() { + return dimensions; + } + + public List<String> dimensionNames() { + return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList()); + } + + private int[] createDimensionMap() { + int numDimensions = dimensions.size(); + if (numDimensions == 0) { + return null; + } + innerSizesOnnx[numDimensions - 1] = 1; + innerSizesVespa[numDimensions - 1] = 1; + for (int i = numDimensions - 1; --i >= 0; ) { + innerSizesOnnx[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOnnx[i+1]; + innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; + } + int[] mapping = new int[numDimensions]; + for (int i = 0; i < numDimensions; ++i) { + TensorType.Dimension dim1 = dimensions().get(i); + for (int j = 0; j < numDimensions; ++j) { + TensorType.Dimension dim2 = type.dimensions().get(j); + if (dim1.equals(dim2)) { + mapping[i] = j; + break; + } + } + } + return mapping; + } + + /** + * When dimension ordering between Vespa and Onnx differs, i.e. + * after dimension renaming, use the dimension map to read in values + * so that they are correctly laid out in memory for Vespa. + * Used when importing tensors from Onnx. + */ + public int toDirectIndex(int index) { + if (dimensions.size() == 0) { + return 0; + } + if (dimensionMap == null) { + throw new IllegalArgumentException("Dimension map is not available"); + } + int directIndex = 0; + long rest = index; + for (int i = 0; i < dimensions.size(); ++i) { + long address = rest / innerSizesOnnx[i]; + directIndex += innerSizesVespa[dimensionMap[i]] * address; + rest %= innerSizesOnnx[i]; + } + return directIndex; + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof OrderedTensorType)) { + return false; + } + OrderedTensorType other = (OrderedTensorType) obj; + if (dimensions.size() != dimensions.size()) { + return false; + } + List<TensorType.Dimension> thisDimensions = this.dimensions(); + List<TensorType.Dimension> otherDimensions = other.dimensions(); + for (int i = 0; i < thisDimensions.size(); ++i) { + if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { + return false; + } + } + return true; + } + + public void verifyType(Onnx.TypeProto typeProto) { + Onnx.TensorShapeProto shape = typeProto.getTensorType().getShape(); + if (shape != null) { + if (shape.getDimCount() != type.rank()) { + throw new IllegalArgumentException("Onnx shape of does not match Vespa shape"); + } + for (int onnxIndex = 0; onnxIndex < dimensions.size(); ++onnxIndex) { + int vespaIndex = dimensionMap[onnxIndex]; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); + TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex); + if (onnxDimension.getDimValue() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("TensorFlow dimensions of does not match Vespa dimensions"); + } + } + } + } + public OrderedTensorType rename(DimensionRenamer renamer) { + List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); + for (TensorType.Dimension dimension : dimensions) { + String oldName = dimension.name(); + Optional<String> newName = renamer.dimensionNameOf(oldName); + if (!newName.isPresent()) + return this; // presumably, already renamed + TensorType.Dimension.Type dimensionType = dimension.type(); + if (dimensionType == TensorType.Dimension.Type.indexedBound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); + } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get())); + } else if (dimensionType == TensorType.Dimension.Type.mapped) { + renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); + } + } + return new OrderedTensorType(renamedDimensions); + } + + public static OrderedTensorType fromOnnxType(Onnx.TypeProto type) { + return fromOnnxType(type, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { + Onnx.TensorShapeProto shape = type.getTensorType().getShape(); + Builder builder = new Builder(shape); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); + if (onnxDimension.getDimValue() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, onnxDimension.getDimValue())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static OrderedTensorType fromOnnxType(List<Long> dims, String dimensionPrefix) { + Builder builder = new Builder(); + for (int i = 0; i < dims.size(); ++ i) { + String dimensionName = dimensionPrefix + i; + Long dimSize = dims.get(i); + if (dimSize >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static OrderedTensorType standardType(OrderedTensorType type) { + Builder builder = new Builder(); + for (int i = 0; i < type.dimensions().size(); ++ i) { + TensorType.Dimension dim = type.dimensions().get(i); + String dimensionName = "d" + i; + if (dim.size().isPresent() && dim.size().get() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static class Builder { + + private final Onnx.TensorShapeProto shape; + private final List<TensorType.Dimension> dimensions; + + public Builder(Onnx.TensorShapeProto shape) { + this.shape = shape; + this.dimensions = new ArrayList<>(shape.getDimCount()); + } + + public Builder() { + this.shape = null; + this.dimensions = new ArrayList<>(); + } + + public Builder add(TensorType.Dimension vespaDimension) { + if (shape != null) { + int index = dimensions.size(); + Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(index); + long size = onnxDimension.getDimValue(); + if (size >= 0) { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { + throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + + "dimension types"); + } + if (!vespaDimension.size().isPresent()) { + throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + + "not have a size"); + } + if (vespaDimension.size().get() != size) { + throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + + "dimension sizes. TensorFlow: " + size + " Vespa: " + + vespaDimension.size().get()); + } + } else { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { + throw new IllegalArgumentException("Non-agreement between Onnx and Vespa " + + "dimension types"); + } + } + } + this.dimensions.add(vespaDimension); + return this; + } + + public OrderedTensorType build() { + return new OrderedTensorType(dimensions); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java new file mode 100644 index 00000000000..1c5fef456cb --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/TensorConverter.java @@ -0,0 +1,79 @@ +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer; + +import com.google.protobuf.ByteString; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import onnx.Onnx; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.FloatBuffer; +import java.util.List; + +/** + * Converts Onnx tensors into Vespa tensors. + * + * @author lesters + */ +public class TensorConverter { + + public static Tensor toVespaTensor(Onnx.TensorProto tensorProto, OrderedTensorType type) { + Values values = readValuesOf(tensorProto); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); + for (int i = 0; i < values.size(); i++) { + builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i)); + } + return builder.build(); + } + + /* todo: support more types */ + private static Values readValuesOf(Onnx.TensorProto tensorProto) { + if (tensorProto.hasRawData()) { + switch (tensorProto.getDataType()) { + case FLOAT: return new RawFloatValues(tensorProto); + } + } else { + switch (tensorProto.getDataType()) { + case FLOAT: return new FloatValues(tensorProto); + } + } + throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + + tensorProto.getDataType() + " to a Vespa tensor"); + } + + /** Allows reading values from buffers of various numeric types as bytes */ + private static abstract class Values { + abstract double get(int i); + abstract int size(); + } + + private static abstract class RawValues extends Values { + ByteBuffer bytes(Onnx.TensorProto tensorProto) { + ByteString byteString = tensorProto.getRawData(); + return byteString.asReadOnlyByteBuffer().order(ByteOrder.LITTLE_ENDIAN); + } + } + + private static class RawFloatValues extends RawValues { + private final FloatBuffer values; + private final int size; + RawFloatValues(Onnx.TensorProto tensorProto) { + values = bytes(tensorProto).asFloatBuffer(); + size = values.remaining(); + } + @Override double get(int i) { return values.get(i); } + @Override int size() { return size; } + } + + private static class FloatValues extends Values { + private final Onnx.TensorProto tensorProto; + FloatValues(Onnx.TensorProto tensorProto) { + this.tensorProto = tensorProto; + } + @Override double get(int i) { return tensorProto.getFloatData(i); } + @Override int size() { return tensorProto.getFloatDataCount(); } + } + + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java new file mode 100644 index 00000000000..a8d8d63daf4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Argument.java @@ -0,0 +1,64 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.Collections; +import java.util.List; + +public class Argument extends OnnxOperation { + + private Onnx.ValueInfoProto valueInfo; + private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... + + public Argument(Onnx.ValueInfoProto valueInfoProto) { + super(null, Collections.emptyList()); + valueInfo = valueInfoProto; + standardNamingType = OrderedTensorType.fromOnnxType(valueInfo.getType()); + } + + @Override + public String vespaName() { + return vespaName(valueInfo.getName()); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromOnnxType(valueInfo.getType(), vespaName() + "_"); + } + + @Override + protected TensorFunction lazyGetFunction() { + TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); + if (!standardNamingType.equals(type)) { + List<String> renameFrom = standardNamingType.dimensionNames(); + List<String> renameTo = type.dimensionNames(); + output = new Rename(output, renameFrom, renameTo); + } + return output; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isInput() { + return true; + } + + @Override + public boolean isConstant() { + return false; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java new file mode 100644 index 00000000000..ab650bf8d77 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Constant.java @@ -0,0 +1,59 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.TensorConverter; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.Collections; +import java.util.Optional; + +public class Constant extends OnnxOperation { + + final Onnx.TensorProto tensorProto; + + public Constant(Onnx.TensorProto tensorProto) { + super(null, Collections.emptyList()); + this.tensorProto = tensorProto; + } + + /** todo: Constant names are prefixed by "modelName_" to avoid name conflicts between models */ + @Override + public String vespaName() { +// return modelName() + "_" + super.vespaName(); + return vespaName(tensorProto.getName()); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromOnnxType(tensorProto.getDimsList(), vespaName() + "_"); + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public Optional<Value> getConstantValue() { + return Optional.of(new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java new file mode 100644 index 00000000000..fe2004a528d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/Join.java @@ -0,0 +1,122 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.function.DoubleBinaryOperator; + +public class Join extends OnnxOperation { + + private final DoubleBinaryOperator operator; + + public Join(Onnx.NodeProto node, List<OnnxOperation> inputs, DoubleBinaryOperator operator) { + super(node, inputs); + this.operator = operator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType a = largestInput().type().get(); + OrderedTensorType b = smallestInput().type().get(); + + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + int sizeDifference = a.rank() - b.rank(); + for (int i = 0; i < a.rank(); ++i) { + TensorType.Dimension aDim = a.dimensions().get(i); + long size = aDim.size().orElse(-1L); + + if (i - sizeDifference >= 0) { + TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference); + size = Math.max(size, bDim.size().orElse(-1L)); + } + + if (aDim.type() == TensorType.Dimension.Type.indexedBound) { + builder.add(TensorType.Dimension.indexed(aDim.name(), size)); + } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) { + builder.add(TensorType.Dimension.indexed(aDim.name())); + } else if (aDim.type() == TensorType.Dimension.Type.mapped) { + builder.add(TensorType.Dimension.mapped(aDim.name())); + } + } + return builder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + if (!allInputFunctionsPresent(2)) { + return null; + } + + OnnxOperation a = largestInput(); + OnnxOperation b = smallestInput(); + + List<String> aDimensionsToReduce = new ArrayList<>(); + List<String> bDimensionsToReduce = new ArrayList<>(); + int sizeDifference = a.type().get().rank() - b.type().get().rank(); + for (int i = 0; i < b.type().get().rank(); ++i) { + TensorType.Dimension bDim = b.type().get().dimensions().get(i); + TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference); + long bSize = bDim.size().orElse(-1L); + long aSize = aDim.size().orElse(-1L); + if (bSize == 1L && aSize != 1L) { + bDimensionsToReduce.add(bDim.name()); + } + if (aSize == 1L && bSize != 1L) { + aDimensionsToReduce.add(bDim.name()); + } + } + + TensorFunction aReducedFunction = a.function().get(); + if (aDimensionsToReduce.size() > 0) { + aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); + } + TensorFunction bReducedFunction = b.function().get(); + if (bDimensionsToReduce.size() > 0) { + bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); + } + + return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + OrderedTensorType a = largestInput().type().get(); + OrderedTensorType b = smallestInput().type().get(); + int sizeDifference = a.rank() - b.rank(); + for (int i = 0; i < b.rank(); ++i) { + String bDim = b.dimensions().get(i).name(); + String aDim = a.dimensions().get(i + sizeDifference).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + } + } + + private OnnxOperation largestInput() { + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); + return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); + } + + private OnnxOperation smallestInput() { + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); + return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java new file mode 100644 index 00000000000..1b388e2ae89 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/MatMul.java @@ -0,0 +1,75 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.DoubleBinaryOperator; + +public class MatMul extends OnnxOperation { + + public MatMul(Onnx.NodeProto node, List<OnnxOperation> inputs) { + super(node, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); + typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + if (aType.type().rank() < 2 || bType.type().rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (aType.type().rank() != bType.type().rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + Optional<TensorFunction> aFunction = inputs.get(0).function(); + Optional<TensorFunction> bFunction = inputs.get(1).function(); + if (!aFunction.isPresent() || !bFunction.isPresent()) { + return null; + } + return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); + List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + + String aDim0 = aDimensions.get(0).name(); + String aDim1 = aDimensions.get(1).name(); + String bDim0 = bDimensions.get(0).name(); + String bDim1 = bDimensions.get(1).name(); + + // The second dimension of a should have the same name as the first dimension of b + renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); + + // The first dimension of a should have a different name than the second dimension of b + renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); + + // For efficiency, the dimensions to join over should be innermost - soft constraint + renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java new file mode 100644 index 00000000000..b1136a0ce0a --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/NoOp.java @@ -0,0 +1,32 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.Collections; +import java.util.List; + +public class NoOp extends OnnxOperation { + + public NoOp(Onnx.NodeProto node, List<OnnxOperation> inputs) { + super(node, Collections.emptyList()); // don't propagate inputs + } + + @Override + protected OrderedTensorType lazyGetType() { + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java new file mode 100644 index 00000000000..2c8003f5951 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/onnx/importer/operations/OnnxOperation.java @@ -0,0 +1,119 @@ +package com.yahoo.searchlib.rankingexpression.integration.onnx.importer.operations; + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.onnx.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.TensorFunction; +import onnx.Onnx; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +/** + * Wraps an ONNX node and produces the respective Vespa tensor operation. + * During import, a graph of these operations are constructed. Then, the + * types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper + * Vespa expressions can be extracted. + * + * @author lesters + */ +public abstract class OnnxOperation { + + protected final Onnx.NodeProto node; // can be null for onnx inputs and constants + protected final List<OnnxOperation> inputs; + protected final List<OnnxOperation> outputs = new ArrayList<>(); + protected final List<String> importWarnings = new ArrayList<>(); + + protected OrderedTensorType type; + protected TensorFunction function; + protected Value constantValue = null; + + OnnxOperation(Onnx.NodeProto node, List<OnnxOperation> inputs) { + this.node = node; + this.inputs = Collections.unmodifiableList(inputs); + this.inputs.forEach(i -> i.outputs.add(this)); + } + + protected abstract OrderedTensorType lazyGetType(); + protected abstract TensorFunction lazyGetFunction(); + + /** Returns the Vespa tensor type of this operation if it exists */ + public Optional<OrderedTensorType> type() { + if (type == null) { + type = lazyGetType(); + } + return Optional.ofNullable(type); + } + + /** Returns the Vespa tensor function implementing all operations from this node with inputs */ + public Optional<TensorFunction> function() { + if (function == null) { + if (isConstant()) { + ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); + function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); + } else { + function = lazyGetFunction(); + } + } + return Optional.ofNullable(function); + } + + /** Return Onnx node */ + public Onnx.NodeProto node() { return node; } + + /** Return unmodifiable list of inputs */ + public List<OnnxOperation> inputs() { return inputs; } + + /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ + public List<OnnxOperation> outputs() { return Collections.unmodifiableList(outputs); } + + /** Add dimension name constraints for this operation */ + public void addDimensionNameConstraints(DimensionRenamer renamer) { } + + /** Performs dimension rename for this operation */ + public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } + + /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ + public boolean isInput() { return false; } + + /** Return true if this node is constant */ + public boolean isConstant() { return inputs.stream().allMatch(OnnxOperation::isConstant); } + + /** Gets the constant value if it exists */ + public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } + + /** Retrieve the valid Vespa name of this node */ + public String vespaName() { return vespaName(node.getName()); } + public String vespaName(String name) { return name != null ? name.replace('/', '_').replace(':','_') : null; } + + /** Retrieve the list of warnings produced during its lifetime */ + public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } + + /** Set an input warning */ + public void warning(String warning) { importWarnings.add(warning); } + + boolean verifyInputs(int expected, Function<OnnxOperation, Optional<?>> func) { + if (inputs.size() != expected) { + throw new IllegalArgumentException("Expected " + expected + " inputs " + + "for '" + node.getName() + "', got " + inputs.size()); + } + return inputs.stream().map(func).allMatch(Optional::isPresent); + } + + boolean allInputTypesPresent(int expected) { + return verifyInputs(expected, OnnxOperation::type); + } + + boolean allInputFunctionsPresent(int expected) { + return verifyInputs(expected, OnnxOperation::function); + } + +} diff --git a/searchlib/src/main/protobuf/onnx.proto b/searchlib/src/main/protobuf/onnx.proto new file mode 100644 index 00000000000..dc6542867e0 --- /dev/null +++ b/searchlib/src/main/protobuf/onnx.proto @@ -0,0 +1,464 @@ +// +// WARNING: This file is automatically generated! Please edit onnx.in.proto. +// + + +// Copyright (c) Facebook Inc. and Microsoft Corporation. +// Licensed under the MIT license. + +syntax = "proto2"; + +package onnx; + +// Overview +// +// ONNX is an open specification that is comprised of the following components: +// +// 1) A definition of an extensible computation graph model. +// 2) Definitions of standard data types. +// 3) Definitions of built-in operators. +// +// This document describes the syntax of models and their computation graphs, +// as well as the standard data types. Together, they are referred to as the ONNX +// Intermediate Representation, or 'IR' for short. +// +// The normative semantic specification of the ONNX IR is found in docs/IR.md. +// Definitions of the built-in neural network operators may be found in docs/Operators.md. + +// Notes +// +// Release +// +// We are still in the very early stage of defining ONNX. The current +// version of ONNX is a starting point. While we are actively working +// towards a complete spec, we would like to get the community involved +// by sharing our working version of ONNX. +// +// Protobuf compatibility +// +// To simplify framework compatibility, ONNX is defined using the subset of protobuf +// that is compatible with both protobuf v2 and v3. This means that we do not use any +// protobuf features that are only available in one of the two versions. +// +// Here are the most notable contortions we have to carry out to work around +// these limitations: +// +// - No 'map' (added protobuf 3.0). We instead represent mappings as lists +// of key-value pairs, where order does not matter and duplicates +// are not allowed. + + +// Versioning +// +// ONNX versioning is specified in docs/IR.md and elaborated on in docs/Versioning.md +// +// To be compatible with both proto2 and proto3, we will use a version number +// that is not defined by the default value but an explicit enum number. +enum Version { + // proto3 requires the first enum value to be zero. + // We add this just to appease the compiler. + _START_VERSION = 0; + // The version field is always serialized and we will use it to store the + // version that the graph is generated from. This helps us set up version + // control. We should use version as + // xx(major) - xx(minor) - xxxx(bugfix) + // and we are starting with 0x00000001 (0.0.1), which was the + // version we published on Oct 10, 2017. + IR_VERSION_2017_10_10 = 0x00000001; + + // IR_VERSION 0.0.2 published on Oct 30, 2017 + // - Added type discriminator to AttributeProto to support proto3 users + IR_VERSION_2017_10_30 = 0x00000002; + + // IR VERSION 0.0.3 published on Nov 3, 2017 + // - For operator versioning: + // - Added new message OperatorSetIdProto + // - Added opset_import in ModelProto + // - For vendor extensions, added domain in NodeProto + IR_VERSION = 0x00000003; +} + +// Attributes +// +// A named attribute containing either singular float, integer, string, graph, +// and tensor values, or repeated float, integer, string, graph, and tensor values. +// An AttributeProto MUST contain the name field, and *only one* of the +// following content fields, effectively enforcing a C/C++ union equivalent. +message AttributeProto { + + // Note: this enum is structurally identical to the OpSchema::AttrType + // enum defined in schema.h. If you rev one, you likely need to rev the other. + enum AttributeType { + UNDEFINED = 0; + FLOAT = 1; + INT = 2; + STRING = 3; + TENSOR = 4; + GRAPH = 5; + + FLOATS = 6; + INTS = 7; + STRINGS = 8; + TENSORS = 9; + GRAPHS = 10; + } + + // The name field MUST be present for this version of the IR. + optional string name = 1; // namespace Attribute + + // if ref_attr_name is not empty, ref_attr_name is the attribute name in parent function. + // In this case, this AttributeProto does not contain data, and it's a reference of attribute + // in parent scope. + // NOTE: This should ONLY be used in function (sub-graph). It's invalid to be used in main graph. + optional string ref_attr_name = 21; + + // A human-readable documentation for this attribute. Markdown is allowed. + optional string doc_string = 13; + + // The type field MUST be present for this version of the IR. + // For 0.0.1 versions of the IR, this field was not defined, and + // implementations needed to use has_field hueristics to determine + // which value field was in use. For IR_VERSION 0.0.2 or later, this + // field MUST be set and match the f|i|s|t|... field in use. This + // change was made to accomodate proto3 implementations. + optional AttributeType type = 20; // discriminator that indicates which field below is in use + + // Exactly ONE of the following fields must be present for this version of the IR + optional float f = 2; // float + optional int64 i = 3; // int + optional bytes s = 4; // UTF-8 string + optional TensorProto t = 5; // tensor value + optional GraphProto g = 6; // graph + // Do not use field below, it's deprecated. + // optional ValueProto v = 12; // value - subsumes everything but graph + + repeated float floats = 7; // list of floats + repeated int64 ints = 8; // list of ints + repeated bytes strings = 9; // list of UTF-8 strings + repeated TensorProto tensors = 10; // list of tensors + repeated GraphProto graphs = 11; // list of graph +} + +// Defines information on value, including the name, the type, and +// the shape of the value. +message ValueInfoProto { + // This field MUST be present in this version of the IR. + optional string name = 1; // namespace Value + // This field MUST be present in this version of the IR. + optional TypeProto type = 2; + // A human-readable documentation for this value. Markdown is allowed. + optional string doc_string = 3; +} + +// Nodes +// +// Computation graphs are made up of a DAG of nodes, which represent what is +// commonly called a "layer" or "pipeline stage" in machine learning frameworks. +// +// For example, it can be a node of type "Conv" that takes in an image, a filter +// tensor and a bias tensor, and produces the convolved output. +message NodeProto { + repeated string input = 1; // namespace Value + repeated string output = 2; // namespace Value + + // An optional identifier for this node in a graph. + // This field MAY be absent in ths version of the IR. + optional string name = 3; // namespace Node + + // The symbolic identifier of the Operator to execute. + optional string op_type = 4; // namespace Operator + // The domain of the OperatorSet that specifies the operator named by op_type. + optional string domain = 7; // namespace Domain + + // Additional named attributes. + repeated AttributeProto attribute = 5; + + // A human-readable documentation for this node. Markdown is allowed. + optional string doc_string = 6; +} + +// Models +// +// ModelProto is a top-level file/container format for bundling a ML model and +// associating its computation graph with metadata. +// +// The semantics of the model are described by the associated GraphProto. +message ModelProto { + // The version of the IR this model targets. See Version enum above. + // This field MUST be present. + optional int64 ir_version = 1; + + // The OperatorSets this model relies on. + // All ModelProtos MUST have at least one entry that + // specifies which version of the ONNX OperatorSet is + // being imported. + // + // All nodes in the ModelProto's graph will bind against the operator + // with the same-domain/same-op_type operator with the HIGHEST version + // in the referenced operator sets. + repeated OperatorSetIdProto opset_import = 8; + + // The name of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_name = 2; + + // The version of the framework or tool used to generate this model. + // This field SHOULD be present to indicate which implementation/tool/framework + // emitted the model. + optional string producer_version = 3; + + // Domain name of the model. + // We use reverse domain names as name space indicators. For example: + // `com.facebook.fair` or `com.microsoft.cognitiveservices` + // + // Together with `model_version` and GraphProto.name, this forms the unique identity of + // the graph. + optional string domain = 4; + + // The version of the graph encoded. See Version enum below. + optional int64 model_version = 5; + + // A human-readable documentation for this model. Markdown is allowed. + optional string doc_string = 6; + + // The parameterized graph that is evaluated to execute the model. + optional GraphProto graph = 7; + + // Named metadata values; keys should be distinct. + repeated StringStringEntryProto metadata_props = 14; +}; + +// StringStringEntryProto follows the pattern for cross-proto-version maps. +// See https://developers.google.com/protocol-buffers/docs/proto3#maps +message StringStringEntryProto { + optional string key = 1; + optional string value= 2; +}; + +// Graphs +// +// A graph defines the computational logic of a model and is comprised of a parameterized +// list of nodes that form a directed acyclic graph based on their inputs and outputs. +// This is the equivalent of the "network" or "graph" in many deep learning +// frameworks. +message GraphProto { + // The nodes in the graph, sorted topologically. + repeated NodeProto node = 1; + + // The name of the graph. + optional string name = 2; // namespace Graph + + // A list of named tensor values, used to specify constant inputs of the graph. + // Each TensorProto entry must have a distinct name (within the list) that + // also appears in the input list. + repeated TensorProto initializer = 5; + + // A human-readable documentation for this graph. Markdown is allowed. + optional string doc_string = 10; + + // The inputs and outputs of the graph. + repeated ValueInfoProto input = 11; + repeated ValueInfoProto output = 12; + + // Information for the values in the graph. The ValueInfoProto.name's + // must be distinct. It is optional for a value to appear in value_info list. + repeated ValueInfoProto value_info = 13; + + // DO NOT USE the following fields, they were deprecated from earlier versions. + // repeated string input = 3; + // repeated string output = 4; + // optional int64 ir_version = 6; + // optional int64 producer_version = 7; + // optional string producer_tag = 8; + // optional string domain = 9; +} + +// Tensors +// +// A serialized tensor value. +message TensorProto { + enum DataType { + UNDEFINED = 0; + // Basic types. + FLOAT = 1; // float + UINT8 = 2; // uint8_t + INT8 = 3; // int8_t + UINT16 = 4; // uint16_t + INT16 = 5; // int16_t + INT32 = 6; // int32_t + INT64 = 7; // int64_t + STRING = 8; // string + BOOL = 9; // bool + + // Advanced types + FLOAT16 = 10; + DOUBLE = 11; + UINT32 = 12; + UINT64 = 13; + COMPLEX64 = 14; // complex with float32 real and imaginary components + COMPLEX128 = 15; // complex with float64 real and imaginary components + // Future extensions go here. + } + + // The shape of the tensor. + repeated int64 dims = 1; + + // The data type of the tensor. + optional DataType data_type = 2; + + // For very large tensors, we may want to store them in chunks, in which + // case the following fields will specify the segment that is stored in + // the current TensorProto. + message Segment { + optional int64 begin = 1; + optional int64 end = 2; + } + optional Segment segment = 3; + + // Tensor content must be organized in row-major order. + // + // Depending on the data_type field, exactly one of the fields below with + // name ending in _data is used to store the elements of the tensor. + + // For float and complex64 values + // Complex64 tensors are encoded as a single array of floats, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be FLOAT or COMPLEX64. + repeated float float_data = 4 [packed = true]; + + // For int32, uint8, int8, uint16, int16, bool, and float16 values + // float16 values must be bit-wise converted to an uint16_t prior + // to writing to the buffer. + // When this field is present, the data_type field MUST be + // INT32, INT16, INT8, UINT16, INT8, BOOL, or FLOAT32 + repeated int32 int32_data = 5 [packed = true]; + + // For strings. + // Each element of string_data is a UTF-8 encoded Unicode + // string. No trailing null, no leading BOM. The protobuf "string" + // scalar type is not used to match ML community conventions. + // When this field is present, the data_type field MUST be STRING + repeated bytes string_data = 6; + + // For int64. + // When this field is present, the data_type field MUST be INT64 + repeated int64 int64_data = 7 [packed = true]; + + // Optionally, a name for the tensor. + optional string name = 8; // namespace Value + + // A human-readable documentation for this tensor. Markdown is allowed. + optional string doc_string = 12; + + // Serializations can either use one of the fields above, or use this + // raw bytes field. The only exception is the string case, where one is + // required to store the content in the repeated bytes string_data field. + // + // When this raw_data field is used to store tensor value, elements MUST + // be stored in as fixed-width, little-endian order. + // Floating-point data types MUST be stored in IEEE 754 format. + // Complex64 elements must be written as two consecutive FLOAT values, real component first. + // Complex128 elements must be written as two consecutive DOUBLE values, real component first. + // Boolean type MUST be written one byte per tensor element (00000001 for true, 00000000 for false). + // + // Note: the advantage of specific field rather than the raw_data field is + // that in some cases (e.g. int data), protobuf does a better packing via + // variable length storage, and may lead to smaller binary footprint. + // When this field is present, the data_type field MUST NOT be STRING or UNDEFINED + optional bytes raw_data = 9; + + // For double + // Complex64 tensors are encoded as a single array of doubles, + // with the real components appearing in odd numbered positions, + // and the corresponding imaginary component apparing in the + // subsequent even numbered position. (e.g., [1.0 + 2.0i, 3.0 + 4.0i] + // is encoded as [1.0, 2.0 ,3.0 ,4.0] + // When this field is present, the data_type field MUST be DOUBLE or COMPLEX128 + repeated double double_data = 10 [packed = true]; + + // For uint64 and uint32 values + // When this field is present, the data_type field MUST be + // UINT32 or UINT64 + repeated uint64 uint64_data = 11 [packed = true]; +} + +// Defines a tensor shape. A dimension can be either an integer value +// or a symbolic variable. A symbolic variable represents an unknown +// dimension. +message TensorShapeProto { + message Dimension { + oneof value { + int64 dim_value = 1; + string dim_param = 2; // namespace Shape + }; + // Standard denotation can optionally be used to denote tensor + // dimensions with standard semantic descriptions to ensure + // that operations are applied to the correct axis of a tensor. + optional string denotation = 3; + }; + repeated Dimension dim = 1; +} + +// A set of pre-defined constants to be used as values for +// the standard denotation field in TensorShapeProto.Dimension +// for semantic description of the tensor dimension. +message DenotationConstProto { + // Describe a batch number dimension. + optional string DATA_BATCH = 1 [default = "DATA_BATCH"]; + // Describe a channel dimension. + optional string DATA_CHANNEL = 2 [default = "DATA_CHANNEL"]; + // Describe a time dimension. + optional string DATA_TIME = 3 [default = "DATA_TIME"]; + // Describe a feature dimension. This is typically a feature + // dimension in RNN and/or spatial dimension in CNN. + optional string DATA_FEATURE = 4 [default = "DATA_FEATURE"]; + // Describe a filter in-channel dimension. This is the dimension + // that is identical (in size) to the channel dimension of the input + // image feature maps. + optional string FILTER_IN_CHANNEL = 5 [default = "FILTER_IN_CHANNEL"]; + // Describe a filter out channel dimension. This is the dimension + // that is identical (int size) to the channel dimension of the output + // image feature maps. + optional string FILTER_OUT_CHANNEL = 6 [default = "FILTER_OUT_CHANNEL"]; + // Describe a filter spatial dimension. + optional string FILTER_SPATIAL = 7 [default = "FILTER_SPATIAL"]; +} + +// Types +// +// The standard ONNX data types. +message TypeProto { + + message Tensor { + // This field MUST NOT have the value of UNDEFINED + // This field MUST be present for this version of the IR. + optional TensorProto.DataType elem_type = 1; + optional TensorShapeProto shape = 2; + } + + + oneof value { + // The type of a tensor. + Tensor tensor_type = 1; + + } +} + +// Operator Sets +// +// OperatorSets are uniquely identified by a (domain, opset_version) pair. +message OperatorSetIdProto { + // The domain of the operator set being identified. + // The empty string ("") or absence of this field implies the operator + // set that is defined as part of the ONNX specification. + // This field MUST be present in this version of the IR when referring to any other operator set. + optional string domain = 1; + + // The version of the operator set being identified. + // This field MUST be present in this version of the IR. + optional int64 version = 2; +}
\ No newline at end of file diff --git a/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx b/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx Binary files differnew file mode 100644 index 00000000000..a86019bf53a --- /dev/null +++ b/searchlib/src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java new file mode 100644 index 00000000000..e118c2b885a --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -0,0 +1,109 @@ +package com.yahoo.searchlib.rankingexpression.integration.onnx; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; +import org.tensorflow.SavedModelBundle; + +import java.io.IOException; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * @author lesters + */ +public class OnnxMnistSoftmaxImportTestCase { + + @Test + public void testMnistSoftmaxImport() throws IOException { + OnnxModel model = new OnnxImporter().importModel("src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx", "add"); + + // Check constants + assertEquals(2, model.largeConstants().size()); + + Tensor constant0 = model.largeConstants().get("Variable_0"); + assertNotNull(constant0); + assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), + constant0.type()); + assertEquals(7840, constant0.size()); + + Tensor constant1 = model.largeConstants().get("Variable_1_0"); + assertNotNull(constant1); + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), + constant1.type()); + assertEquals(10, constant1.size()); + + // Check required macros (inputs) + assertEquals(1, model.requiredMacros().size()); + assertTrue(model.requiredMacros().containsKey("Placeholder_0")); + assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), + model.requiredMacros().get("Placeholder_0")); + + // Check outputs + RankingExpression output = model.expressions().get("add"); + assertNotNull(output); + assertEquals("add", output.getName()); + assertEquals("join(reduce(join(rename(Placeholder_0, (d0, d1), (d0, d2)), constant(Variable_0), f(a,b)(a * b)), sum, d2), constant(Variable_1_0), f(a,b)(a + b))", + output.getRoot().toString()); + } + + @Test + public void testComparisonBetweenOnnxAndTensorflow() { + String tfModelPath = "src/test/files/integration/tensorflow/mnist_softmax/saved"; + String onnxModelPath = "src/test/files/integration/onnx/mnist_softmax/mnist_softmax.onnx"; + + Tensor argument = placeholderArgument(); + Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add"); + Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder_0", "add"); + + assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult); + } + + private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) { + SavedModelBundle tensorFlowModel = SavedModelBundle.load(path, "serve"); + TensorFlowModel model = new TensorFlowImporter().importModel("test", tensorFlowModel); + return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); + } + + private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { + OnnxModel model = new OnnxImporter().importModel(path, output); + return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); + } + + private Tensor evaluateExpression(RankingExpression expression, Context context, Tensor argument, String input) { + context.put(input, new TensorValue(argument)); + return expression.evaluate(context).asTensor(); + } + + private Context contextFrom(TensorFlowModel result) { + MapContext context = new MapContext(); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + return context; + } + + private Context contextFrom(OnnxModel result) { + MapContext context = new MapContext(); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + return context; + } + + private Tensor placeholderArgument() { + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 784).build()); + for (int d0 = 0; d0 < 1; d0++) + for (int d1 = 0; d1 < 784; d1++) + b.cell(d1 * 1.0 / 784, d0, d1); + return b.build(); + } + + +} |