summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-08 11:06:53 +0100
committerJon Bratseth <bratseth@oath.com>2018-03-08 11:06:53 +0100
commit692d43c3c85352c8f8e40615ce37aa3d2f83b5d3 (patch)
treef661a10d076f6e8220e3ec46389b7adffefef678 /searchlib
parent2005e5dd57b2bafd1150917560da8066b39d64f7 (diff)
OrderedTensorType.to/from spec
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java18
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java21
2 files changed, 39 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
index acc7875623b..3a6e5e6ebe9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorTypeParser;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;
@@ -171,6 +172,23 @@ public class OrderedTensorType {
return new OrderedTensorType(renamedDimensions);
}
+ /**
+ * Returns a string representation of this: A standard tensor type string where dimensions
+ * are listed in the order of this rather than in the natural order of their names.
+ */
+ @Override
+ public String toString() {
+ return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
+ }
+
+ /**
+ * Creates an instance from the string representation of this: A standard tensor type string
+ * where dimensions are listed in the order of this rather than the natural order of their names.
+ */
+ public static OrderedTensorType fromSpec(String typeSpec) {
+ return new OrderedTensorType(TensorTypeParser.fromSpec(typeSpec));
+ }
+
public static OrderedTensorType fromTensorFlowType(NodeDef node) {
return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
new file mode 100644
index 00000000000..beec2ab1ead
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
@@ -0,0 +1,21 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class OrderedTensorTypeTestCase {
+
+ @Test
+ public void testToFromSpec() {
+ String spec = "tensor(b[],c{},a[3])";
+ OrderedTensorType type = OrderedTensorType.fromSpec(spec);
+ assertEquals(spec, type.toString());
+ assertEquals("tensor(a[3],b[],c{})", type.type().toString());
+ }
+
+}