diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-28 21:35:59 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-28 21:35:59 +0100 |
commit | b5ffe229474223844c150e99d24ca618e5e9f8dd (patch) | |
tree | 7c9ac3da58ff567fae79019bc688ba2a4e4d904c | |
parent | 1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (diff) |
Complete prototype TensorFlow mapping
11 files changed, 275 insertions, 49 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 160af794faf..8dcd31b270e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -1,15 +1,24 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +import com.google.common.collect.ImmutableList; import com.google.protobuf.ProtocolStringList; import com.google.protobuf.TextFormat; import com.yahoo.io.IOUtils; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; +import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; @@ -17,11 +26,14 @@ import org.tensorflow.framework.OpDef; import org.tensorflow.framework.SavedModel; import org.tensorflow.framework.SignatureDef; import org.tensorflow.framework.TensorInfo; +import org.tensorflow.framework.TensorShapeProto; import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.DoubleBinaryOperator; import java.util.stream.Collectors; /** @@ -31,17 +43,31 @@ import java.util.stream.Collectors; */ public class TensorFlowImporter { + /* + A note on conversion from implicitly numbered to explicitly named dimensions: + Vespa tensor dimensions are explicitly named and thus have an explicit notion of being + 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation + comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation + around dimension renaming operations which mirrors those built into the TF operation definitions. + + To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' + dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation + and the result is then renamed again (if necessary) to recover this convention across a full nested + computation. + + This requires us to track tensor types throughout the conversion. + */ + /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a pbtxt file. * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending). */ - public void importModel(String modelDir) { + public List<RankingExpression> importModel(String modelDir) { try { SavedModel.Builder builder = SavedModel.newBuilder(); TextFormat.getParser().merge(IOUtils.createReader(modelDir + "/saved_model.pbtxt"), builder); - //System.out.println("Read " + builder); - importModel(builder.build()); + return importModel(builder.build()); // TODO: Support binary reading: //SavedModel.parseFrom(new FileInputStream(modelDir + "/saved_model.pbtxt")); @@ -52,53 +78,161 @@ public class TensorFlowImporter { } - private void importModel(SavedModel model) { - model.getMetaGraphsList().forEach(this::importGraph); + /** Import all declared inputs in all the graphs in the given model */ + private List<RankingExpression> importModel(SavedModel model) { + // TODO: Handle name conflicts between output keys in different graphs? + return model.getMetaGraphsList().stream() + .flatMap(graph -> importGraph(graph).stream()) + .collect(Collectors.toList()); } - - private void importGraph(MetaGraphDef graph) { + + private List<RankingExpression> importGraph(MetaGraphDef graph) { System.out.println("Importing graph"); + List<RankingExpression> expressions = new ArrayList<>(); for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { System.out.println(" Importing signature def " + signatureEntry.getKey() + " with method name " + signatureEntry.getValue().getMethodName()); - signatureEntry.getValue().getOutputsMap().values() - .forEach(output -> importOutput(output, signatureEntry.getValue().getMethodName(), graph.getGraphDef())); + Map<String, TensorType> inputs = importInputs(signatureEntry.getValue().getInputsMap()); + for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { + try { + ExpressionNode result = importOutput(output.getValue(), + inputs, + graph.getGraphDef()); + expressions.add(new RankingExpression(output.getKey(), result)); + } + catch (IllegalArgumentException e) { + System.err.println("Skipping output '" + output.getValue().getName() + "' of signature '" + // TODO: Log, or ... + signatureEntry.getValue().getMethodName() + + "': " + Exceptions.toMessageString(e)); + } + } } + return expressions; + } + + private Map<String, TensorType> importInputs(Map<String, TensorInfo> inputInfoMap) { + Map<String, TensorType> inputs = new HashMap<>(); + inputInfoMap.forEach((key, value) -> inputs.put(nameOf(value.getName()), + importTensorType(value.getTensorShape()))); + return inputs; } - private void importOutput(TensorInfo output, String signatureName, GraphDef graph) { - try { - System.out.println(" Importing output " + output.getName()); - NodeDef node = getNode(nameOf(output.getName()), graph); - // System.out.println("Ops:-------------"); - // graph.getStrippedOpList().getOpList().stream().forEach(s -> System.out.println(s.getName())); - // System.out.println("-----------------"); - importNode(node, graph, ""); - } - catch (IllegalArgumentException e) { - System.err.println("Skipping output '" + output.getName() + "' of signature '" + signatureName + "': " + Exceptions.toMessageString(e)); + private TensorType importTensorType(TensorShapeProto tensorShape) { + TensorType.Builder b = new TensorType.Builder(); + for (int i = 0; i < tensorShape.getDimCount(); i++) { + int dimensionSize = (int) tensorShape.getDim(i).getSize(); + if (dimensionSize >= 0) + b.indexed("d" + i, dimensionSize); + else + b.indexed("d" + i); // unbound size } + return b.build(); } - private ExpressionNode importNode(NodeDef tfNode, GraphDef graph, String indent) { - System.out.println(" " + indent + "Importing node " + tfNode.getName()); - List<ExpressionNode> arguments = new ArrayList<>(); - for (String input : tfNode.getInputList()) - arguments.add(importNode(getNode(nameOf(input), graph), graph, indent + " ")); - ExpressionNode node = expressionNodeOf(tfNode.getName(), arguments); + private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph) { + System.out.println(" Importing output " + output.getName()); + NodeDef node = getNode(nameOf(output.getName()), graph); + return new TensorFunctionNode(importNode(node, inputs, graph, "").function()); + } + + /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ + private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) { + System.out.println(" " + indent + "Importing node " + tfNode.getName() + " with operation " + tfNode.getOp()); + return tensorFunctionOf(tfNode, inputs, graph, indent); + } + + private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, + Map<String, TensorType> inputs, + GraphDef graph, + String indent) { + // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops + switch (tfNode.getOp()) { + case "Identity" : return identity(tfNode, inputs); + case "Add" : return join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add()); + case "MatMul" : return matmul(importArguments(tfNode, inputs, graph, indent)); + case "Softmax" : return softmax(importArguments(tfNode, inputs, graph, indent)); + default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); + } } - private ExpressionNode expressionNodeOf(String node, List<ExpressionNode> arguments) { - return new TensorFunctionNode(tensorFunctionOf(node, arguments.stream() - .map(TensorFunctionNode.TensorFunctionExpressionNode::new) - .collect(Collectors.toList()))); + private List<TypedTensorFunction> importArguments(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) { + return tfNode.getInputList().stream() + .map(argNode -> importNode(getNode(nameOf(argNode), graph), inputs, graph, indent + " ")) + .collect(Collectors.toList()); } - private TensorFunction tensorFunctionOf(String node, List<TensorFunction> arguments) { - switch (node) { - case "add" : return new Join(arguments.get(0), arguments.get(1), ScalarFunctions.add()); - case "MatMul" : return new Matmul(arguments.get(0), arguments.get(1), ScalarFunctions.add()); + private TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { + ensureArguments(2, arguments, "join"); + TypedTensorFunction a = arguments.get(0); + TypedTensorFunction b = arguments.get(0); + // TODO: Verify with TF doc + TensorType resultType = Join.resultType(a.type(), b.type()); + Join function = new Join(a.function(), b.function(), doubleFunction); + return new TypedTensorFunction(resultType, function); + } + + private TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { + ensureArguments(2, arguments, "matmul"); + TypedTensorFunction a = arguments.get(0); + TypedTensorFunction b = arguments.get(0); + if (a.type().rank() < 2 || b.type.rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (a.type().rank() != b.type.rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first + // and the last dimension of the second argument be not present in the first argument, while leaving the + // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. + + // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly + + String beforeLastDim = "d" + (a.type().rank() - 1); + String lastDim = "d" + a.type().rank(); + String afterLastDim = "d" + (a.type().rank() + 1); + + Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim), + ImmutableList.of(lastDim, afterLastDim)); + Matmul matmul = new Matmul(a.function(), renamedB, lastDim); + return new TypedTensorFunction(Matmul.resultType(a.type(), b.type(), lastDim), + new Rename(matmul, afterLastDim, lastDim)); + } + + private TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { + ensureArguments(1, arguments, "softmax"); + TypedTensorFunction a = arguments.get(0); + String dimension = "d0"; // TODO: Verify with TF doc + Softmax softmax = new Softmax(a.function(), dimension); + return new TypedTensorFunction(Softmax.resultType(a.type(), dimension), softmax); + } + + private TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) { + // TODO: Verify with TF documentation + String name; + TensorType inputType; + if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants + if (tfNode.getInputList().size() != 1) + throw new IllegalArgumentException("A Variable/read node must have one input but has " + + tfNode.getInputList().size()); + name = tfNode.getInput(0); + AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape"); + inputType = importTensorType(shapes.getList().getShape(0)); } + else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name + name = tfNode.getName(); + inputType = inputs.get(name); + if (inputType == null) + throw new IllegalArgumentException("An identity operation node is referencing input '" + name + + "', but there is no such input"); + } + return new TypedTensorFunction(inputType, new VariableTensor(name)); + } + + private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) { + if ( arguments.size() != count) + throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + + ", but got " + arguments.size()); } private NodeDef getNode(String name, GraphDef graph) { @@ -120,15 +254,31 @@ public class TensorFlowImporter { } /** - * An output has the form name:index. + * A method signature input and output has the form name:index. * This returns the name part without the index. */ - private String nameOf(String outputName) { - return outputName.split(":")[0]; + private String nameOf(String name) { + return name.split(":")[0]; } private boolean contains(String string, ProtocolStringList strings) { return strings.asByteStringList().stream().anyMatch(s -> s.toStringUtf8().equals(string)); } + + /** A tensor function returning a specific tensor type */ + private static final class TypedTensorFunction { + + private final TensorType type; + private final TensorFunction function; + + public TypedTensorFunction(TensorType type, TensorFunction function) { + this.type = type; + this.function = function; + } + + public TensorType type() { return type; } + public TensorFunction function() { return function; } + + } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 1f8db6e036c..ba765d07094 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -17,7 +17,7 @@ import java.util.Map; * @author bratseth */ public class SerializationContext { - + /** Expression functions indexed by name */ private final ImmutableMap<String, ExpressionFunction> functions; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ce21e132980..ab5f1e7191d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -30,6 +30,9 @@ public class TensorFunctionNode extends CompositeNode { this.function = function; } + /** Returns the tensor function wrapped by this */ + public TensorFunction function() { return function; } + @Override public List<ExpressionNode> children() { return function.functionArguments().stream() diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java index 4c511047118..30328c3d9fe 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -1,7 +1,13 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import org.junit.Test; +import java.util.List; + +import static org.junit.Assert.assertEquals; + /** * @author bratseth */ @@ -9,7 +15,21 @@ public class TensorFlowImporterTestCase { @Test public void testModel1() { - new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/"); + List<RankingExpression> expressions = + new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/"); + assertEquals(1, expressions.size()); + assertEquals("scores", expressions.get(0).getName()); + assertEquals("" + + "softmax(join(rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + + "rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + + "f(a,b)(a + b)), " + + "d0)", + toNonPrimitiveString(expressions.get(0))); + } + + private String toNonPrimitiveString(RankingExpression expression) { + // toString on the wrapping expression will map to primitives, which is harder to read + return ((TensorFunctionNode)expression.getRoot()).function().toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index c05c35d6df3..c27ac57415d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -52,6 +52,9 @@ public class TensorType { public static TensorType fromSpec(String specString) { return TensorTypeParser.fromSpec(specString); } + + /** Returns the number of dimensions of this: dimensions().size() */ + public int rank() { return dimensions.size(); } /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 8c4dbfb0acb..c89f63c0395 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction { this.combinator = combinator; } + /** Returns the type resulting from applying Join to the two given types */ + public static TensorType resultType(TensorType a, TensorType b) { + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (int i = 0; i < a.dimensions().size(); ++i) { + TensorType.Dimension aDim = a.dimensions().get(i); + for (int j = 0; j < b.dimensions().size(); ++j) { + TensorType.Dimension bDim = b.dimensions().get(j); + if (aDim.name().equals(bDim.name())) { // include + if (aDim.isIndexed() && bDim.isIndexed()) { + if (aDim.size().isPresent() || bDim.size().isPresent()) + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), + bDim.size().orElse(Integer.MAX_VALUE))); + else + typeBuilder.indexed(aDim.name()); + } + else { + typeBuilder.mapped(aDim.name()); + } + } + } + } + return typeBuilder.build(); + } + public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index bb27e937699..cbb3f159623 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; import java.util.List; @@ -20,6 +21,10 @@ public class Matmul extends CompositeTensorFunction { this.argument2 = argument2; this.dimension = dimension; } + + public static TensorType resultType(TensorType a, TensorType b, String dimension) { + return Reduce.resultType(Join.resultType(a, b), ImmutableList.of(dimension)); + } @Override public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index cfc78be7e0c..aa28a26deb2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,6 +61,15 @@ public class Reduce extends PrimitiveTensorFunction { this.dimensions = ImmutableList.copyOf(dimensions); } + public static TensorType resultType(TensorType type, List<String> reduceDimensions) { + TensorType.Builder b = new TensorType.Builder(); + for (TensorType.Dimension dimension : type.dimensions()) { + if ( ! reduceDimensions.contains(dimension.name())) + b.dimension(dimension); + } + return b.build(); + } + public TensorFunction argument() { return argument; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 6b0daf1b49a..6e52760424e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -28,6 +28,10 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List<String> fromDimensions; private final List<String> toDimensions; + + public Rename(TensorFunction argument, String fromDimension, String toDimension) { + this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); + } public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index bf279eb24d8..45f78389c16 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -2,6 +2,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; import java.util.Collections; import java.util.List; @@ -19,6 +21,10 @@ public class Softmax extends CompositeTensorFunction { this.argument = argument; this.dimension = dimension; } + + public static TensorType resultType(TensorType type, String dimension) { + return Reduce.resultType(type, ImmutableList.of(dimension)); + } @Override public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java index 6606e278102..9643c0a56e7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java @@ -14,8 +14,8 @@ public class MatmulTestCase { @Test public void testMatmul2d() { - // Convention: a is the 'outermost' dimension, etc. - Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3])")); + // d0 is the 'outermost' dimension, etc. + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])")); ab.cell( 1,0, 0); ab.cell( 2,0, 1); ab.cell( 3,0, 2); @@ -24,7 +24,7 @@ public class MatmulTestCase { ab.cell( 6,1, 2); Tensor a = ab.build(); - Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[3],b[2])")); + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])")); bb.cell( 7,0, 0); bb.cell( 8,0, 1); bb.cell( 9,1, 0); @@ -33,21 +33,22 @@ public class MatmulTestCase { bb.cell(12,2, 1); Tensor b = bb.build(); - Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],c[2])")); + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])")); rb.cell( 58,0, 0); rb.cell( 64,0, 1); rb.cell(139,1, 0); rb.cell(154,1, 1); Tensor r = rb.build(); - Tensor result = a.matmul(b.rename(ImmutableList.of("a","b"),ImmutableList.of("b","c")), "b"); + Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1") + .rename("d2","d1"); assertEquals(r, result); } @Test public void testMatmul3d() { // Convention: a is the 'outermost' dimension, etc. - Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],c[3])")); + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])")); ab.cell( 1,0, 0, 0); ab.cell( 2,0, 0, 1); ab.cell( 3,0, 0, 2); @@ -62,7 +63,7 @@ public class MatmulTestCase { ab.cell(12,1, 1, 2); Tensor a = ab.build(); - Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3],c[2])")); + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])")); bb.cell(13,0, 0, 0); bb.cell(14,0, 0, 1); bb.cell(15,0, 1, 0); @@ -77,7 +78,7 @@ public class MatmulTestCase { bb.cell(24,1, 2, 1); Tensor b = bb.build(); - Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],d[2])")); + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])")); rb.cell( 94,0, 0, 0); rb.cell(100,0, 0, 1); rb.cell(229,0, 1, 0); @@ -88,8 +89,9 @@ public class MatmulTestCase { rb.cell(730,1, 1, 1); Tensor r = rb.build(); - Tensor result = a.matmul(b.rename(ImmutableList.of("b","c"),ImmutableList.of("c","d")), "c"); - System.out.println(result); + Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2") + .rename("d3","d2"); + assertEquals(r, result); } } |