diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-08 11:06:53 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-03-08 11:06:53 +0100 |
commit | 692d43c3c85352c8f8e40615ce37aa3d2f83b5d3 (patch) | |
tree | f661a10d076f6e8220e3ec46389b7adffefef678 /searchlib | |
parent | 2005e5dd57b2bafd1150917560da8066b39d64f7 (diff) |
OrderedTensorType.to/from spec
Diffstat (limited to 'searchlib')
2 files changed, 39 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java index acc7875623b..3a6e5e6ebe9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorTypeParser; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.TensorShapeProto; @@ -171,6 +172,23 @@ public class OrderedTensorType { return new OrderedTensorType(renamedDimensions); } + /** + * Returns a string representation of this: A standard tensor type string where dimensions + * are listed in the order of this rather than in the natural order of their names. + */ + @Override + public String toString() { + return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; + } + + /** + * Creates an instance from the string representation of this: A standard tensor type string + * where dimensions are listed in the order of this rather than the natural order of their names. + */ + public static OrderedTensorType fromSpec(String typeSpec) { + return new OrderedTensorType(TensorTypeParser.fromSpec(typeSpec)); + } + public static OrderedTensorType fromTensorFlowType(NodeDef node) { return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java new file mode 100644 index 00000000000..beec2ab1ead --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java @@ -0,0 +1,21 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class OrderedTensorTypeTestCase { + + @Test + public void testToFromSpec() { + String spec = "tensor(b[],c{},a[3])"; + OrderedTensorType type = OrderedTensorType.fromSpec(spec); + assertEquals(spec, type.toString()); + assertEquals("tensor(a[3],b[],c{})", type.type().toString()); + } + +} |