summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-08 11:06:53 +0100
committerJon Bratseth <bratseth@oath.com>2018-03-08 11:06:53 +0100
commit692d43c3c85352c8f8e40615ce37aa3d2f83b5d3 (patch)
treef661a10d076f6e8220e3ec46389b7adffefef678
parent2005e5dd57b2bafd1150917560da8066b39d64f7 (diff)
OrderedTensorType.to/from spec
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java18
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java39
5 files changed, 62 insertions, 22 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
index acc7875623b..3a6e5e6ebe9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorTypeParser;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;
@@ -171,6 +172,23 @@ public class OrderedTensorType {
return new OrderedTensorType(renamedDimensions);
}
+ /**
+ * Returns a string representation of this: A standard tensor type string where dimensions
+ * are listed in the order of this rather than in the natural order of their names.
+ */
+ @Override
+ public String toString() {
+ return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
+ }
+
+ /**
+ * Creates an instance from the string representation of this: A standard tensor type string
+ * where dimensions are listed in the order of this rather than the natural order of their names.
+ */
+ public static OrderedTensorType fromSpec(String typeSpec) {
+ return new OrderedTensorType(TensorTypeParser.fromSpec(typeSpec));
+ }
+
public static OrderedTensorType fromTensorFlowType(NodeDef node) {
return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
new file mode 100644
index 00000000000..beec2ab1ead
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
@@ -0,0 +1,21 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class OrderedTensorTypeTestCase {
+
+ @Test
+ public void testToFromSpec() {
+ String spec = "tensor(b[],c{},a[3])";
+ OrderedTensorType type = OrderedTensorType.fromSpec(spec);
+ assertEquals(spec, type.toString());
+ assertEquals("tensor(a[3],b[],c{})", type.type().toString());
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 9b3a9328f07..5590ccaad0a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -18,7 +18,7 @@ class TensorParser {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
String valueString = tensorString.substring(colonIndex + 1);
- TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
+ TensorType typeFromString = new TensorType(TensorTypeParser.fromSpec(typeString));
if (type.isPresent() && ! type.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
"passed type " + type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 0176dac6821..2483280817c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -32,7 +32,7 @@ public class TensorType {
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
- private TensorType(Collection<Dimension> dimensions) {
+ public TensorType(Collection<Dimension> dimensions) {
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = ImmutableList.copyOf(dimensionList);
@@ -50,7 +50,7 @@ public class TensorType {
* Example: <code>tensor(x[10],y[20])</code> (a matrix)
*/
public static TensorType fromSpec(String specString) {
- return TensorTypeParser.fromSpec(specString);
+ return new TensorType(TensorTypeParser.fromSpec(specString));
}
/** Returns the number of dimensions of this: dimensions().size() */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index 703101bc45b..6ed0b8202f1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -3,6 +3,9 @@ package com.yahoo.tensor;
import com.google.common.annotations.Beta;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -12,7 +15,7 @@ import java.util.regex.Pattern;
* @author geirst
*/
@Beta
-class TensorTypeParser {
+public class TensorTypeParser {
private final static String START_STRING = "tensor(";
private final static String END_STRING = ")";
@@ -20,48 +23,46 @@ class TensorTypeParser {
private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
- static TensorType fromSpec(String specString) {
- if (!specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
+ public static List<TensorType.Dimension> fromSpec(String specString) {
+ if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
- " and end with '" + END_STRING + "', but was '" + specString + "'");
+ " and end with '" + END_STRING + "', but was '" + specString + "'");
}
- TensorType.Builder builder = new TensorType.Builder();
String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- if (dimensionsSpec.isEmpty()) {
- return builder.build();
- }
+ if (dimensionsSpec.isEmpty()) return Collections.emptyList();
+
+ List<TensorType.Dimension> dimensions = new ArrayList<>();
for (String element : dimensionsSpec.split(",")) {
String trimmedElement = element.trim();
- if (tryParseIndexedDimension(trimmedElement, builder)) {
- } else if (tryParseMappedDimension(trimmedElement, builder)) {
- } else {
+ boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
+ tryParseMappedDimension(trimmedElement, dimensions);
+ if ( ! success)
throw new IllegalArgumentException("Failed parsing element '" + element +
- "' in type spec '" + specString + "'");
- }
+ "' in type spec '" + specString + "'");
}
- return builder.build();
+ return dimensions;
}
- private static boolean tryParseIndexedDimension(String element, TensorType.Builder builder) {
+ private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
Matcher matcher = indexedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
String dimensionSize = matcher.group(2);
if (dimensionSize.isEmpty()) {
- builder.indexed(dimensionName);
+ dimensions.add(TensorType.Dimension.indexed(dimensionName));
} else {
- builder.indexed(dimensionName, Integer.valueOf(dimensionSize));
+ dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize)));
}
return true;
}
return false;
}
- private static boolean tryParseMappedDimension(String element, TensorType.Builder builder) {
+ private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) {
Matcher matcher = mappedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
- builder.mapped(dimensionName);
+ dimensions.add(TensorType.Dimension.mapped(dimensionName));
return true;
}
return false;