summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java72
1 files changed, 60 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index 32ad6171e57..d5f77be0dd0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -2,8 +2,10 @@
package com.yahoo.tensor;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.Optional;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -11,26 +13,36 @@ import java.util.regex.Pattern;
* Class for parsing a tensor type spec.
*
* @author geirst
+ * @author bratseth
*/
public class TensorTypeParser {
- private final static String START_STRING = "tensor(";
+ private final static String START_STRING = "tensor";
private final static String END_STRING = ")";
private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
public static TensorType fromSpec(String specString) {
- return new TensorType.Builder(dimensionsFromSpec(specString)).build();
- }
+ if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING))
+ throw formatException(specString);
+ String specBody = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- public static List<TensorType.Dimension> dimensionsFromSpec(String specString) {
- if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
- throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
- " and end with '" + END_STRING + "', but was '" + specString + "'");
+ String dimensionsSpec;
+ TensorType.Value valueType;
+ if (specBody.startsWith("(")) {
+ valueType = TensorType.Value.DOUBLE; // no value type spec: Use default
+ dimensionsSpec = specBody.substring(1);
+ }
+ else {
+ int parenthesisIndex = specBody.indexOf("(");
+ if (parenthesisIndex < 0)
+ throw formatException(specString);
+ valueType = parseValueTypeSpec(specBody.substring(0, parenthesisIndex), specString);
+ dimensionsSpec = specBody.substring(parenthesisIndex + 1);
}
- String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
- if (dimensionsSpec.isEmpty()) return Collections.emptyList();
+
+ if (dimensionsSpec.isEmpty()) return new TensorType.Builder(valueType, Collections.emptyList()).build();
List<TensorType.Dimension> dimensions = new ArrayList<>();
for (String element : dimensionsSpec.split(",")) {
@@ -38,10 +50,30 @@ public class TensorTypeParser {
boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
tryParseMappedDimension(trimmedElement, dimensions);
if ( ! success)
- throw new IllegalArgumentException("Failed parsing element '" + element +
- "' in type spec '" + specString + "'");
+ throw formatException(specString, "Dimension '" + element + "' is on the wrong format");
+ }
+ return new TensorType.Builder(valueType, dimensions).build();
+ }
+
+ public static TensorType.Value toValueType(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
+ private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
+ if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
+ throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
+
+ try {
+ return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ }
+ catch (IllegalArgumentException e) {
+ throw formatException(fullSpecString, e.getMessage());
}
- return dimensions;
}
private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
@@ -69,5 +101,21 @@ public class TensorTypeParser {
return false;
}
+
+ private static IllegalArgumentException formatException(String spec) {
+ return formatException(spec, Optional.empty());
+ }
+
+ private static IllegalArgumentException formatException(String spec, String errorDetail) {
+ return formatException(spec, Optional.of(errorDetail));
+ }
+
+ private static IllegalArgumentException formatException(String spec, Optional<String> errorDetail) {
+ throw new IllegalArgumentException("A tensor type spec must be on the form " +
+ "tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was '" + spec + "'. " +
+ errorDetail.map(s -> s + ". ").orElse("") +
+ "Examples: tensor(x[]), tensor<float>(name{}, x[10])");
+ }
+
}