summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-08 11:06:53 +0100
committerJon Bratseth <bratseth@oath.com>2018-03-08 11:06:53 +0100
commit692d43c3c85352c8f8e40615ce37aa3d2f83b5d3 (patch)
treef661a10d076f6e8220e3ec46389b7adffefef678 /vespajlib
parent2005e5dd57b2bafd1150917560da8066b39d64f7 (diff)
OrderedTensorType.to/from spec
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java39
3 files changed, 23 insertions, 22 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 9b3a9328f07..5590ccaad0a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -18,7 +18,7 @@ class TensorParser {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
String valueString = tensorString.substring(colonIndex + 1);
- TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
+ TensorType typeFromString = new TensorType(TensorTypeParser.fromSpec(typeString));
if (type.isPresent() && ! type.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
"passed type " + type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 0176dac6821..2483280817c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -32,7 +32,7 @@ public class TensorType {
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
- private TensorType(Collection<Dimension> dimensions) {
+ public TensorType(Collection<Dimension> dimensions) {
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = ImmutableList.copyOf(dimensionList);
@@ -50,7 +50,7 @@ public class TensorType {
* Example: <code>tensor(x[10],y[20])</code> (a matrix)
*/
public static TensorType fromSpec(String specString) {
- return TensorTypeParser.fromSpec(specString);
+ return new TensorType(TensorTypeParser.fromSpec(specString));
}
/** Returns the number of dimensions of this: dimensions().size() */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index 703101bc45b..6ed0b8202f1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -3,6 +3,9 @@ package com.yahoo.tensor;
import com.google.common.annotations.Beta;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -12,7 +15,7 @@ import java.util.regex.Pattern;
* @author geirst
*/
@Beta
-class TensorTypeParser {
+public class TensorTypeParser {
private final static String START_STRING = "tensor(";
private final static String END_STRING = ")";
@@ -20,48 +23,46 @@ class TensorTypeParser {
private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
- static TensorType fromSpec(String specString) {
- if (!specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
+ public static List<TensorType.Dimension> fromSpec(String specString) {
+ if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
- " and end with '" + END_STRING + "', but was '" + specString + "'");
+ " and end with '" + END_STRING + "', but was '" + specString + "'");
}
- TensorType.Builder builder = new TensorType.Builder();
String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- if (dimensionsSpec.isEmpty()) {
- return builder.build();
- }
+ if (dimensionsSpec.isEmpty()) return Collections.emptyList();
+
+ List<TensorType.Dimension> dimensions = new ArrayList<>();
for (String element : dimensionsSpec.split(",")) {
String trimmedElement = element.trim();
- if (tryParseIndexedDimension(trimmedElement, builder)) {
- } else if (tryParseMappedDimension(trimmedElement, builder)) {
- } else {
+ boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
+ tryParseMappedDimension(trimmedElement, dimensions);
+ if ( ! success)
throw new IllegalArgumentException("Failed parsing element '" + element +
- "' in type spec '" + specString + "'");
- }
+ "' in type spec '" + specString + "'");
}
- return builder.build();
+ return dimensions;
}
- private static boolean tryParseIndexedDimension(String element, TensorType.Builder builder) {
+ private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
Matcher matcher = indexedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
String dimensionSize = matcher.group(2);
if (dimensionSize.isEmpty()) {
- builder.indexed(dimensionName);
+ dimensions.add(TensorType.Dimension.indexed(dimensionName));
} else {
- builder.indexed(dimensionName, Integer.valueOf(dimensionSize));
+ dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize)));
}
return true;
}
return false;
}
- private static boolean tryParseMappedDimension(String element, TensorType.Builder builder) {
+ private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) {
Matcher matcher = mappedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
- builder.mapped(dimensionName);
+ dimensions.add(TensorType.Dimension.mapped(dimensionName));
return true;
}
return false;