diff options
author | Lester Solbakken <lesters@oath.com> | 2020-04-21 15:26:58 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-04-21 15:26:58 +0200 |
commit | aad5c7184f37e1441c928efa77b434620742ff88 (patch) | |
tree | 34a92e7f954aa92e21d48816335771ff607fe404 | |
parent | 6f5ca49e45cdc8262fcf360b1c731a393385ffa8 (diff) |
Update model-integration for supporting BERT-type models
27 files changed, 752 insertions, 159 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index c7f320ed3b4..11abd3d24d8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -66,7 +66,7 @@ public class DimensionRenamer { void solve() { log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints)); - renames = solve(100000); + renames = solve(1000000); log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames)); } @@ -86,7 +86,7 @@ public class DimensionRenamer { private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) { Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations); - if ( solution == null) { + if (solution == null) { ListMap<Arc, Constraint> hardConstraints = new ListMap<>(); boolean anyRemoved = copyHard(constraints, hardConstraints); if (anyRemoved) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 14aa3ebf84e..ea981603481 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -81,28 +82,36 @@ public class IntermediateGraph { DimensionRenamer renamer = new DimensionRenamer(this); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - addDimensionNameConstraints(operations.get(output), renamer); + addDimensionNameConstraints(operations.get(output), renamer, new HashSet<>()); } } renamer.solve(); for (String signature : signatures()) { for (String output : outputs(signature).values()) { - renameDimensions(operations.get(output), renamer); + renameDimensions(operations.get(output), renamer, new HashSet<>()); } } } - private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { + private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> processed) { + if (processed.contains(operation.name())) { + return; + } if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, processed)); operation.addDimensionNameConstraints(renamer); + processed.add(operation.name()); } } - private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { + private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> processed) { + if (processed.contains(operation.name())) { + return; + } if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.inputs().forEach(input -> renameDimensions(input, renamer, processed)); operation.renameDimensions(renamer); + processed.add(operation.name()); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 3774e64c886..a9d71b7d9d5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -15,9 +15,11 @@ import com.yahoo.text.ExpressionFormatter; import com.yahoo.yolean.Exceptions; import java.io.File; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; @@ -108,6 +110,9 @@ public abstract class ModelImporter implements MlModelImporter { } private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) { + if (model.expressions().containsKey(operation.name())) { + return operation.function(); + } if (operation.type().isEmpty()) { return Optional.empty(); } @@ -206,18 +211,22 @@ public abstract class ModelImporter implements MlModelImporter { private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { for (ImportedModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { - reportWarnings(graph.get(outputName), model); + reportWarnings(graph.get(outputName), model, new HashSet<>()); } } } - private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { + private static void reportWarnings(IntermediateOperation operation, ImportedModel model, Set<String> processed) { + if (processed.contains(operation.name())) { + return; + } for (String warning : operation.warnings()) { // If we want to report warnings, that code goes here } for (IntermediateOperation input : operation.inputs()) { - reportWarnings(input, model); + reportWarnings(input, model, processed); } + processed.add(operation.name()); } /** diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java index 21cc6b27dad..9a7fcc85ee1 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java @@ -37,7 +37,8 @@ class NamingConstraintSolver { private static ListMap<String, Integer> allPossibilities(Set<String> dimensions) { ListMap<String, Integer> all = new ListMap<>(); for (String dimension : dimensions) { - for (int i = 0; i < dimensions.size(); ++i) + // 20 (different dimension names) should be enough for most problems. + for (int i = 0; i < Math.min(dimensions.size(), 20); ++i) all.put(dimension, i); } return all; @@ -89,6 +90,7 @@ class NamingConstraintSolver { workList.add(constraint); } } + if (iterations > maxIterations) return false; } return true; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index ffc64c38f16..d14ad033a69 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,7 +2,6 @@ package ai.vespa.rankingexpression.importer.onnx; -import ai.vespa.rankingexpression.importer.operations.ExpandDims; import ai.vespa.rankingexpression.importer.operations.Gather; import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; @@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Slice; import ai.vespa.rankingexpression.importer.operations.Softmax; +import ai.vespa.rankingexpression.importer.operations.Split; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import ai.vespa.rankingexpression.importer.operations.Tile; +import ai.vespa.rankingexpression.importer.operations.Transpose; import ai.vespa.rankingexpression.importer.operations.Unsqueeze; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; +import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -53,19 +57,21 @@ class GraphImporter { private static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs, - IntermediateGraph graph) { + IntermediateGraph graph, + int outputIndex) { String type = node.getOpType(); String modelName = graph.name(); String nodeName = getNodeName(node); AttributeConverter attributes = AttributeConverter.convert(node); - return mapOperation(type, inputs, modelName, nodeName, attributes); + return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex); } static IntermediateOperation mapOperation(String opType, List<IntermediateOperation> inputs, String modelName, String nodeName, - AttributeConverter attributes) { + AttributeConverter attributes, + int outputIndex) { switch (opType.toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); @@ -115,12 +121,15 @@ class GraphImporter { case "slice": return new Slice(modelName, nodeName, inputs, attributes); case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); + case "tile": return new Tile(modelName, nodeName, inputs); + case "transpose": return new Transpose(modelName, nodeName, inputs, attributes); case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes); } @@ -168,11 +177,11 @@ class GraphImporter { OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); - } else { Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); + int outputIndex = getOutputIndex(node, name); List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); - operation = mapOperation(node, inputs, intermediateGraph); + operation = mapOperation(node, inputs, intermediateGraph, outputIndex); // propagate constant values if all inputs are constant if (operation.isConstant()) { @@ -185,7 +194,7 @@ class GraphImporter { } } intermediateGraph.put(operation.name(), operation); - + intermediateGraph.put(name, operation); return operation; } @@ -296,6 +305,10 @@ class GraphImporter { return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst(); } + private static int getOutputIndex(Onnx.NodeProto node, String outputName) { + return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0); + } + private static String getNodeName(Onnx.NodeProto node) { String nodeName = node.getName(); if (nodeName.length() > 0) @@ -307,11 +320,14 @@ class GraphImporter { } private static Set<String> getWarnings(IntermediateOperation op) { - Set<String> warnings = new HashSet<>(op.warnings()); - for (IntermediateOperation input : op.inputs()) { - warnings.addAll(getWarnings(input)); - } - return warnings; + java.util.Map<String, Set<String>> warnings = new HashMap<>(); + getWarnings(op, warnings); + return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet()); } + private static void getWarnings(IntermediateOperation op, java.util.Map<String, Set<String>> warnings) { + if (warnings.containsKey(op.name())) return; + op.inputs().forEach(input -> getWarnings(input, warnings)); + warnings.put(op.name(), new HashSet<>(op.warnings())); + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 7c8038cea66..9354a346aaf 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -22,8 +22,11 @@ class TypeConverter { for (int onnxIndex = 0; onnxIndex < type.dimensions().size(); ++onnxIndex) { int vespaIndex = type.dimensionMap(onnxIndex); Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(onnxIndex); - TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); long onnxDimensionSize = onnxDimension.getDimValue() == 0 ? 1 : onnxDimension.getDimValue(); + if (onnxDimensionSize == -1) { + continue; // disregard batch dimensions + } + TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); if (onnxDimensionSize != vespaDimension.size().orElse(-1L)) { throw new IllegalArgumentException("Onnx dimensions of does not match Vespa dimensions"); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 01fd7ee55bd..bc883076b33 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -53,12 +53,6 @@ public class Const extends IntermediateOperation { return new TensorFunctionNode.ExpressionTensorFunction(expressionNode); } - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + super.vespaName(); - } - @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { addConstraintsFrom(type, renamer); @@ -77,7 +71,7 @@ public class Const extends IntermediateOperation { private Value value() { Optional<Value> value = attributeMap.get("value", type); - if ( ! value.isPresent()) { + if (value.isEmpty()) { throw new IllegalArgumentException("Node '" + name + "' of type " + "const has missing or non-recognized 'value' attribute"); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index ad56eefe5f2..8f6e2335b10 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -4,7 +4,6 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; import java.util.Collections; @@ -13,20 +12,11 @@ import java.util.Optional; public class Constant extends IntermediateOperation { - private final String modelName; - public Constant(String modelName, String nodeName, OrderedTensorType type) { super(modelName, nodeName, Collections.emptyList()); - this.modelName = modelName; this.type = type.rename(vespaName() + "_"); } - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + vespaName(name); - } - @Override protected OrderedTensorType lazyGetType() { return type; @@ -61,7 +51,9 @@ public class Constant extends IntermediateOperation { public Constant withInputs(List<IntermediateOperation> inputs) { if ( ! inputs.isEmpty()) throw new IllegalArgumentException("Constant cannot take inputs"); - return new Constant(modelName(), name(), type); + Constant constant = new Constant(modelName(), name(), type); + constant.setConstantValueFunction(constantValueFunction); + return constant; } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index 5463f645355..af192fcec38 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -12,12 +12,6 @@ public class Identity extends IntermediateOperation { super(modelName, nodeName, inputs); } - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + super.vespaName(); - } - @Override protected OrderedTensorType lazyGetType() { if (!allInputTypesPresent(1)) diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 2aa8b2a0d48..af134fac6cf 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; @@ -47,6 +49,8 @@ public abstract class IntermediateOperation { protected TensorFunction rankingExpressionFunction = null; protected boolean exportAsRankingFunction = false; + private boolean hasRenamedDimensions = false; + private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; private List<IntermediateOperation> controlInputs = Collections.emptyList(); @@ -121,7 +125,10 @@ public abstract class IntermediateOperation { } /** Performs dimension rename for this operation */ - public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } + public void renameDimensions(DimensionRenamer renamer) { + type = type.rename(renamer); + hasRenamedDimensions = true; + } /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ public boolean isInput() { return false; } @@ -153,12 +160,23 @@ public abstract class IntermediateOperation { public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return vespaName(name); } - public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; } + public String vespaName() { + if (isConstant()) + return modelName + "_" + vespaName(name); + return vespaName(name); + } + + public String vespaName(String name) { + return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; + } /** Retrieve the valid Vespa name of this node if it is a ranking expression function */ public String rankingExpressionFunctionName() { - return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; + String vespaName = vespaName(); + if (vespaName == null) { + return null; + } + return isConstant() ? "constant(" + vespaName + ")" : FUNCTION_PREFIX + modelName + "_" + vespaName; } /** Retrieve the list of warnings produced during its lifetime */ @@ -188,12 +206,51 @@ public abstract class IntermediateOperation { if ( ! isConstant() ) { throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant."); } - Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN)); - if (type != null && ! val.asTensor().type().equals(type.type()) ) { + Value val = evaluableCopy().evaluateAsConstant(new MapContext(DoubleValue.NaN)); + if (type == null) { + return val; + } + Tensor tensor = val.asTensor(); + checkIfRenameableTo(tensor, type); + setConstantValueFunction(t -> new TensorValue(tensor.withType(t.type()))); // so we don't have to re-evaluate + return new TensorValue(tensor.withType(type.type())); + } + + private void checkIfRenameableTo(Tensor tensor, OrderedTensorType type) { + if ( ! tensor.type().isRenamableTo(type.type()) ) { throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " + - "Expected: " + type.type() + " Got: " + val.asTensor().type()); + "Expected: " + type.type() + " Got: " + tensor.type()); + } + } + + private IntermediateOperation evaluableCopy() { + if (hasRenamedDimensions) { + return this; + } + IntermediateOperation copy = copyTree(); + + // Must have performed dimension renaming to properly evaluate as constant + IntermediateGraph graph = new IntermediateGraph(modelName); + graph.put(name, copy); + graph.outputs(graph.defaultSignature()).put(name, name); + graph.optimize(); + + return copy; + } + + private IntermediateOperation copyTree() { + List<IntermediateOperation> in = new ArrayList<>(); + if (constantValue != null) { + IntermediateOperation constant = new Constant(modelName, name, type); + constant.setConstantValueFunction(t -> new TensorValue(constantValue.asTensor().withType(t.type()))); + return constant; + } + inputs.forEach(i -> in.add(i.copyTree())); + IntermediateOperation copy = withInputs(in); + if (constantValueFunction != null) { + copy.constantValueFunction = constantValueFunction; } - return val; + return copy; } private Value evaluateAsConstant(Context context) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index adb54474812..3211a44fa68 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -82,6 +82,13 @@ public class Join extends IntermediateOperation { bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); } + // retain order of inputs + if (a == inputs.get(1)) { + TensorFunction temp = bReducedFunction; + bReducedFunction = aReducedFunction; + aReducedFunction = temp; + } + return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 6849e64641e..2b0af93fd8e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -3,13 +3,23 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.Slice; import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.text.ExpressionFormatter; +import java.util.ArrayList; import java.util.List; import java.util.Optional; +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + public class MatMul extends IntermediateOperation { public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) { @@ -20,62 +30,139 @@ public class MatMul extends IntermediateOperation { protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); + + if (typeA.type().rank() < 1 || typeB.type().rank() < 1) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1"); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); - typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); - typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB; + OrderedTensorType smallestRankType = typeA.rank() >= typeB.rank() ? typeB : typeA; + for (int i = 0; i < largestRankType.rank() - 2; ++i) { + TensorType.Dimension dim = largestRankType.dimensions().get(i); + // broadcasting + int j = smallestRankType.rank() - largestRankType.rank() + i; + if (j >= 0 && smallestRankType.dimensions().get(j).size().get() > dim.size().get()) { + dim = smallestRankType.dimensions().get(j); + } + typeBuilder.add(dim); + } + if (typeA.rank() >= 2) { + typeBuilder.add(typeA.dimensions().get(typeA.rank() - 2)); + } + if (typeB.rank() >= 2) { + typeBuilder.add(typeB.dimensions().get(typeB.rank() - 1)); + } return typeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; + + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); + + TensorFunction functionA = handleBroadcasting(inputs.get(0).function().get(), typeA, typeB); + TensorFunction functionB = handleBroadcasting(inputs.get(1).function().get(), typeB, typeA); + + return new com.yahoo.tensor.functions.Reduce( + new Join(functionA, functionB, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + typeA.dimensions().get(typeA.rank() - 1).name()); + } - OrderedTensorType aType = inputs.get(0).type().get(); - OrderedTensorType bType = inputs.get(1).type().get(); - if (aType.type().rank() < 2 || bType.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (aType.type().rank() != bType.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - Optional<TensorFunction> aFunction = inputs.get(0).function(); - Optional<TensorFunction> bFunction = inputs.get(1).function(); - if (!aFunction.isPresent() || !bFunction.isPresent()) { - return null; + private TensorFunction handleBroadcasting(TensorFunction tensorFunction, OrderedTensorType typeA, OrderedTensorType typeB) { + List<Slice.DimensionValue> slices = new ArrayList<>(); + for (int i = 0; i < typeA.rank() - 2; ++i) { + long dimSizeA = typeA.dimensions().get(i).size().get(); + String dimNameA = typeA.dimensionNames().get(i); + int j = typeB.rank() - typeA.rank() + i; + if (j >= 0) { + long dimSizeB = typeB.dimensions().get(j).size().get(); + if (dimSizeB > dimSizeA && dimSizeA == 1) { + ExpressionNode dimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); + slices.add(new Slice.DimensionValue(Optional.of(dimNameA), wrapScalar(dimensionExpression))); + } + } } - return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + return slices.size() == 0 ? tensorFunction : new Slice(tensorFunction, slices); } @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { if ( ! allInputTypesPresent(2)) return; - List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); - List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); - - assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); - assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); + String lastDimA = typeA.dimensions().get(typeA.rank()-1).name(); + String lastDimB = typeB.dimensions().get(typeB.rank()-1).name(); + String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name(); + String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name(); - // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); + // The last dimension of A should have the same name as the second-to-last dimension of B + renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this); - // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); + // The second-to-last dimension of a should have a different name than the last dimension of b + if (typeA.rank() >= 2 && typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this); + } // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); - } + if (typeA.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this); + } + if (typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this); + } + + // Handle different cases when at least one of the tensors have rank > 2 + for (int i = 0; i < typeA.rank() - 2; ++i) { + String iDim = typeA.dimensionNames().get(i); + + // a1 < a2 < a3 < a4 + for (int j = i+1; j < typeA.rank(); ++j) { + String jDim = typeA.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this); + } + // not equal to last 2 dimensions in B + for (int j = typeB.rank()-2; j < typeB.rank(); ++j) { + if (j < 0) continue; + String jDim = typeB.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this); + } + // equal to matching dimension in tensor B + int j = typeB.rank() - typeA.rank() + i; + if (j >= 0) { + String jDim = typeB.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this); + } + } - private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { - if (dimensions.size() >= 2) return; - throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + - " but got just " + dimensions + " from\n" + - ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); + for (int i = 0; i < typeB.rank() - 2; ++i) { + String iDim = typeB.dimensionNames().get(i); + + // b1 < b2 < b3 < b4 + for (int j = i+1; j < typeB.rank(); ++j) { + String jDim = typeB.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this); + } + // not equal to last 2 dimensions in A + for (int j = typeA.rank()-2; j < typeA.rank(); ++j) { + if (j < 0) continue; + String jDim = typeA.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.notEqual(false), this); + } + // equal to matching dimension in tensor A + int j = typeA.rank() - typeB.rank() + i; + if (j >= 0) { + String jDim = typeA.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.equal(false), this); + } + } } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java index e040ae62149..07ac457cca8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -54,7 +54,7 @@ public class Rename extends IntermediateOperation { } public void renameDimensions(DimensionRenamer renamer) { - type = type.rename(renamer); + super.renameDimensions(renamer); from = renamer.dimensionNameOf(from).orElse(from); to = renamer.dimensionNameOf(to).orElse(to); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index c88fc18e6c6..d8e806ae7e4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -2,30 +2,29 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.Function; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.stream.Collectors; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; public class Reshape extends IntermediateOperation { @@ -38,6 +37,10 @@ public class Reshape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + if (inputs.size() == 2) { return typeWithShapeAsInput(); } else if (inputs.size() == 1) { @@ -126,52 +129,53 @@ public class Reshape extends IntermediateOperation { return new Reshape(modelName(), name(), inputs, attributeMap); } - public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); + IntermediateOperation input = inputs.get(0); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, // then use the dimension order of the new shape to roll back into a tensor. - // Here we create a transformation tensor that is multiplied with the from tensor to map into - // the new shape. We have to introduce temporary dimension names and rename back if dimension names - // in the new and old tensor type overlap. - - // Todo: change this to use tensor generate when available - - List<String> from = new ArrayList<>(); - List<String> to = new ArrayList<>(); - boolean dimensionNamesOverlap = dimensionNamesOverlap(inputType, outputType); - if (dimensionNamesOverlap) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(outputType.type().valueType()); - for (int i = 0; i < outputType.rank(); ++i) { - TensorType.Dimension dim = outputType.dimensions().get(i); - from.add(dim.name()); - to.add("temp_" + dim.name()); - builder.add(dim.withName("temp_" + dim.name())); - } - outputType = builder.build(); - } - - ExpressionNode unrollFrom = unrollTensorExpression(inputType); - ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(new EmbracedNode(unrollFrom), TruthOperator.EQUAL, new EmbracedNode(unrollTo)); - TensorType transformationType = new TensorType.Builder(inputType.type(), outputType.type()).build(); - Generate transformTensor = new Generate(transformationType, - new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); + ExpressionNode unrolled = new EmbracedNode(unrollTensorExpression(outputType)); - TensorFunction result = new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); + long innerSize = 1; + for (int dim = 0; dim < inputType.rank(); ++dim) { + innerSize *= inputType.dimensions().get(dim).size().get(); + } - if (dimensionNamesOverlap) { - result = new Rename(result, to, from); + for (int dim = 0; dim < inputType.rank(); ++dim) { + String inputDimensionName = inputType.dimensions().get(dim).name(); + long inputDimensionSize = inputType.dimensions().get(dim).size().get(); + long previousInnerSize = innerSize; + innerSize /= inputDimensionSize; + + ExpressionNode inputDimensionExpression; + if (inputDimensionSize == 1) { + inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); + } else if (dim == (inputType.rank() - 1)) { + ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); + ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size); + inputDimensionExpression = new EmbracedNode(div); + } else { + ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); + ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); + ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize); + ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size); + inputDimensionExpression = new EmbracedNode(new FunctionNode(Function.floor, div)); + } + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); } - return result; - } - private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) { - return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent()); + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + return Generate.bound(outputType.type(), wrapScalar(sliceExpression)); } private static ExpressionNode unrollTensorExpression(OrderedTensorType type) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index e5463291ef8..372d70a9fda 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -176,13 +176,11 @@ public class Slice extends IntermediateOperation { com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); - TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); - return generate; + return Generate.bound(type.type(), wrapScalar(sliceExpression)); } @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - // Todo: what to do? for (int i = 0; i < type.dimensions().size(); i++) { renamer.addDimension(type.dimensions().get(i).name()); for (int j = i + 1; j < type.dimensions().size(); j++) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 83086926316..d03827f4c72 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -28,6 +28,10 @@ public class Softmax extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; + + // input is referenced twice due to overflow avoidance, so make this it's own function. + inputs.get(0).exportAsRankingFunction = true; + return inputs.get(0).type().get(); } @@ -50,7 +54,9 @@ public class Softmax extends IntermediateOperation { } TensorFunction input = inputs.get(0).function().get(); - TensorFunction exp = new Map(input, ScalarFunctions.exp()); + TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); + TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow + TensorFunction exp = new Map(cap, ScalarFunctions.exp()); TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java new file mode 100644 index 00000000000..3241732671d --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -0,0 +1,119 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class Split extends IntermediateOperation { + + private final AttributeMap attributes; + private final int output; + + private final int axis; + private int start; + private int end; + + public Split(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes, int output) { + super(modelName, nodeName, inputs); + this.attributes = attributes; + this.output = output; + axis = (int) attributes.get("axis").orElse(DoubleValue.zero).asDouble(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) + return null; + OrderedTensorType inputType = inputs.get(0).type().get(); + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + int axisSize = inputType.dimensions().get(axis).size().get().intValue(); + start = 0; + end = axisSize; + + if (attributes.getList("split").isPresent()) { + List<Value> splitList = attributes.getList("split").get(); + if (output > splitList.size()) { + throw new IllegalArgumentException("Split in " + name + ": output out of range of split list"); + } + for (int i = 0; i < output; ++i) { + start += (int) splitList.get(i).asDouble(); + } + if (output < splitList.size()) { + end = start + (int) splitList.get(output).asDouble(); + } + } else { + start = axisSize / 2 * output; + end = start + axisSize / 2; + } + + if (start >= axisSize || start < 0) { + throw new IllegalArgumentException("Split in " + name + ": split start index out of range (" + start + ")"); + } + if (end > axisSize || end < 0) { + throw new IllegalArgumentException("Split in " + name + ": split end index out of range (" + end + ")"); + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < inputType.rank(); ++i) { + TensorType.Dimension inputDimension = inputType.dimensions().get(i); + long dimSize = i == axis ? end - start : inputDimension.size().get(); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), dimSize)); + } + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) return null; + + IntermediateOperation input = inputs.get(0); + OrderedTensorType inputType = input.type().get(); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + for (int i = 0; i < inputType.rank(); ++i) { + String inputDimensionName = inputType.dimensions().get(i).name(); + ExpressionNode reference = new ReferenceNode(inputDimensionName); + ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset)))); + } + + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + return generate; + } + + @Override + public Split withInputs(List<IntermediateOperation> inputs) { + return new Split(modelName(), name(), inputs, attributes, output); + } + + @Override + public String operationName() { return "Split"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index a9e3fc6a43a..7ab01bf65c1 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -80,7 +80,7 @@ public class Squeeze extends IntermediateOperation { private OrderedTensorType reducedType(OrderedTensorType inputType) { OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); - for (TensorType.Dimension dimension: inputType.type().dimensions()) { + for (TensorType.Dimension dimension: inputType.dimensions()) { if ( ! squeezeDimensions.contains(dimension.name())) { builder.add(dimension); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java new file mode 100644 index 00000000000..a3c755bf1c0 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -0,0 +1,100 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/** + * Onnx tile operation. + */ +public class Tile extends IntermediateOperation { + + public Tile(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) return null; + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + IntermediateOperation repeats = inputs.get(1); + if (repeats.getConstantValue().isEmpty()) + throw new IllegalArgumentException("Tile " + name + ": repeats input must be a constant."); + + Tensor shape = repeats.getConstantValue().get().asTensor(); + if (shape.type().rank() != 1) + throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor."); + + OrderedTensorType inputType = inputs.get(0).type().get(); + if (shape.type().dimensions().get(0).size().get() != inputType.rank()) + throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank."); + + List<Integer> dimSizes = new ArrayList<>(inputType.rank()); + shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue())); + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dimSizes.size(); ++i) { + TensorType.Dimension inputDimension = inputType.dimensions().get(i); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get() * dimSizes.get(i))); + } + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(2)) return null; + + IntermediateOperation input = inputs.get(0); + OrderedTensorType inputType = input.type().get(); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + for (int axis = 0; axis < inputType.rank(); ++axis) { + String inputDimensionName = inputType.dimensions().get(axis).name(); + long inputDimensionSize = inputType.dimensions().get(axis).size().get(); + + ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); + ExpressionNode reference = new ReferenceNode(inputDimensionName); + ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size); + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod)))); + } + + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + return generate; + } + + @Override + public Tile withInputs(List<IntermediateOperation> inputs) { + return new Tile(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Tile"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java new file mode 100644 index 00000000000..5e7bc1a1f36 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java @@ -0,0 +1,53 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class Transpose extends IntermediateOperation { + + private final AttributeMap attributes; + + public Transpose(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) { + super(modelName, nodeName, inputs); + this.attributes = attributes; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().get(); + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < inputType.rank(); ++i) { + int inputIndex = inputType.rank() - 1 - i; + if (attributes.getList("perm").isPresent()) { + inputIndex = (int) attributes.getList("perm").get().get(i).asDouble(); + } + TensorType.Dimension inputDimension = inputType.dimensions().get(inputIndex); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get())); + } + OrderedTensorType result = typeBuilder.build(); + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) + return null; + return inputs.get(0).function().orElse(null); + } + + @Override + public Transpose withInputs(List<IntermediateOperation> inputs) { + return new Transpose(modelName(), name(), inputs, attributes); + } + + @Override + public String operationName() { return "Transpose"; } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 94c5577357b..d5dff7fb1b7 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -107,6 +107,18 @@ public class OnnxOperationsTestCase { assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y)); assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y)); assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y)); + + // broadcasting - opposite order + x = evaluate("random(d0[4]) + 1"); + y = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + assertEval("add", x, y, evaluate("rename(x, d0, d2) + y", x, y)); + assertEval("sub", x, y, evaluate("rename(x, d0, d2) - y", x, y)); + assertEval("mul", x, y, evaluate("rename(x, d0, d2) * y", x, y)); + assertEval("div", x, y, evaluate("rename(x, d0, d2) / y", x, y)); + assertEval("greater", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(pow(a,b)))", x, y)); } @Test @@ -185,9 +197,49 @@ public class OnnxOperationsTestCase { @Test public void testMatMul1() throws ParseException { - Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]"); - Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]"); - assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]")); + Tensor a = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + Tensor b = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("91")); + + a = evaluate("tensor(d0[3]):[1,2,3]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2]):[22, 28]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3]):[1,2,3]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2]):[14, 32]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[2],d1[3],d2[2]):[1,2,3,4,5,6,7,8,9,10,11,12]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2]):[22,28,49,64,58,64,139,154]")); + + a = evaluate("tensor(d0[2],d1[2],d2[3]):[1,2,3,4,5,6,7,8,9,10,11,12]"); + b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2]):[22,28,49,64,76,100,103,136]")); + + a = evaluate("tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]")); + + a = evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[2],d1[2],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2],d2[2],d3[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]")); } @Test @@ -217,6 +269,10 @@ public class OnnxOperationsTestCase { y = evaluate("tensor(d0[4]):[3,2,-1,1]"); assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]")); + + x = evaluate("tensor(d0[1],d1[2],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12]"); + y = evaluate("tensor(d0[2]):[2,6]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[6]):[1,2,3,4,5,6,7,8,9,10,11,12]")); } @Test @@ -435,6 +491,48 @@ public class OnnxOperationsTestCase { } + @Test + public void testTranspose1() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[3]):[[1,2,3],[4,5,6]]"); + assertEval("transpose", x, evaluate("tensor(d0[3],d1[2]):[[1,4],[2,5],[3,6]]")); + } + + @Test + public void testTile6() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + Tensor y = evaluate("tensor(d0[2]):[1,2]"); + assertEval("tile", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]")); + + x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + y = evaluate("tensor(d0[2]):[3,1]"); + assertEval("tile", x, y, evaluate("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]")); + + x = evaluate("tensor(d0[1],d1[1],d2[1]):[1]"); + y = evaluate("tensor(d0[3]):[1,6,1]"); + assertEval("tile", x, y, evaluate("tensor(d0[1],d1[6],d2[1]):[1,1,1,1,1,1]")); + } + + @Test + public void testSplit2() throws ParseException { + Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + assertEval("split", x, evaluate("tensor(d0[3]):[1,2,3]"), 0); + assertEval("split", x, evaluate("tensor(d0[3]):[4,5,6]"), 1); + assertEval("split", x, evaluate("tensor(d0[2]):[1,2]"), createAttribute("split", new int[] {2}), 0); + assertEval("split", x, evaluate("tensor(d0[4]):[3,4,5,6]"), createAttribute("split", new int[] {2}), 1); + assertEval("split", x, evaluate("tensor(d0[3]):[3,4,5]"), createAttribute("split", new int[] {2,3}), 1); + assertEval("split", x, evaluate("tensor(d0[1]):[6]"), createAttribute("split", new int[] {2,3}), 2); + + x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]")); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), 0); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), 1); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), createAttribute("split", new int[] {1}), 0); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), createAttribute("split", new int[] {1}), 1); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[1,4]"), createAttribute("axis", 1), 0); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[2,5]"), createAttribute("axis", 1), 1); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2); + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -461,41 +559,49 @@ public class OnnxOperationsTestCase { } private void assertEval(String opName, Tensor x, Tensor expected) { - assertEval(opName, x, null, null, null, null, expected, null); + assertEval(opName, x, null, null, null, null, expected, null, 0); + } + + private void assertEval(String opName, Tensor x, Tensor expected, int output) { + assertEval(opName, x, null, null, null, null, expected, null, output); } private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, null, null, null, null, expected, attr); + assertEval(opName, x, null, null, null, null, expected, attr, 0); + } + + private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr, int output) { + assertEval(opName, x, null, null, null, null, expected, attr, output); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, y, null, null, null, expected, attr); + assertEval(opName, x, y, null, null, null, expected, attr, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) { - assertEval(opName, x, y, null, null, null, expected, null); + assertEval(opName, x, y, null, null, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) { - assertEval(opName, x, y, z, null, null, expected, null); + assertEval(opName, x, y, z, null, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, y, z, null, null, expected, attr); + assertEval(opName, x, y, z, null, null, expected, attr, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) { - assertEval(opName, x, y, z, q, null, expected, null); + assertEval(opName, x, y, z, q, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) { - assertEval(opName, x, y, z, q, r, expected, null); + assertEval(opName, x, y, z, q, r, expected, null, 0); } - private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) { + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr, int output) { Context context = new MapContext(DoubleValue.NaN); List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r); - IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build()); + IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build(), output); optimizeAndRename(opName, op); Tensor result = evaluate(op); assertEquals(expected, result); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index 9631bddd93d..04db902073b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -11,7 +11,6 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -48,6 +47,19 @@ public class SimpleImportTestCase { assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]")); } + @Test + public void testConcat() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx"); + + MapContext context = new MapContext(); + context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]"))); + context.put("j", new TensorValue(Tensor.from("tensor(d0[1]):[2]"))); + context.put("k", new TensorValue(Tensor.from("tensor(d0[1]):[3]"))); + + Tensor result = model.expressions().get("y").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor(d0[3]):[1, 2, 3]")); + } + private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { RankingExpression e = RankingExpression.from(model.functions().get(functionName)); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java index b9d767774be..25f8acf1f6d 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java @@ -34,7 +34,7 @@ public class DropoutImportTestCase { ImportedMlFunction function = signature.outputFunction("y", "y"); assertNotNull(function); - assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(reduce(constant(test_outputs_Const), sum, d1), imported_ml_function_test_outputs_BiasAdd, f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", function.expression()); model.assertEqualResult("X", "outputs/Maximum"); assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); diff --git a/model-integration/src/test/models/onnx/simple/concat.onnx b/model-integration/src/test/models/onnx/simple/concat.onnx Binary files differnew file mode 100644 index 00000000000..945bc3c9445 --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/concat.onnx diff --git a/model-integration/src/test/models/onnx/simple/concat.py b/model-integration/src/test/models/onnx/simple/concat.py new file mode 100755 index 00000000000..186002c2abb --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/concat.py @@ -0,0 +1,25 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import onnx +from onnx import helper, TensorProto + +i_type = helper.make_tensor_value_info('i', TensorProto.FLOAT, [1]) +j_type = helper.make_tensor_value_info('j', TensorProto.FLOAT, [1]) +k_type = helper.make_tensor_value_info('k', TensorProto.FLOAT, [1]) + +output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + +node = onnx.helper.make_node( + 'Concat', + inputs=['i', 'j', 'k'], + outputs=['y'], + axis=0, +) +graph_def = onnx.helper.make_graph( + nodes = [node], + name = 'concat_test', + inputs = [i_type, j_type, k_type], + outputs = [output_type] +) +model_def = helper.make_model(graph_def, producer_name='concat.py') +onnx.save(model_def, 'concat.onnx') diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx Binary files differindex 62451ad953d..0647d86ed0f 100644 --- a/model-integration/src/test/models/onnx/simple/gather.onnx +++ b/model-integration/src/test/models/onnx/simple/gather.onnx diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx index 1c746c90efa..41b458451d0 100644 --- a/model-integration/src/test/models/onnx/simple/simple.onnx +++ b/model-integration/src/test/models/onnx/simple/simple.onnx @@ -1,4 +1,4 @@ - simple.py:ã + simple.py:ã 0 query_tensor attribute_tensormatmul"MatMul @@ -20,4 +20,4 @@ output -B +B
\ No newline at end of file |