diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java | 72 |
1 files changed, 60 insertions, 12 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index 32ad6171e57..d5f77be0dd0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -2,8 +2,10 @@ package com.yahoo.tensor; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Optional; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -11,26 +13,36 @@ import java.util.regex.Pattern; * Class for parsing a tensor type spec. * * @author geirst + * @author bratseth */ public class TensorTypeParser { - private final static String START_STRING = "tensor("; + private final static String START_STRING = "tensor"; private final static String END_STRING = ")"; private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]"); private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}"); public static TensorType fromSpec(String specString) { - return new TensorType.Builder(dimensionsFromSpec(specString)).build(); - } + if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING)) + throw formatException(specString); + String specBody = specString.substring(START_STRING.length(), specString.length() - END_STRING.length()); - public static List<TensorType.Dimension> dimensionsFromSpec(String specString) { - if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) { - throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" + - " and end with '" + END_STRING + "', but was '" + specString + "'"); + String dimensionsSpec; + TensorType.Value valueType; + if (specBody.startsWith("(")) { + valueType = TensorType.Value.DOUBLE; // no value type spec: Use default + dimensionsSpec = specBody.substring(1); + } + else { + int parenthesisIndex = specBody.indexOf("("); + if (parenthesisIndex < 0) + throw formatException(specString); + valueType = parseValueTypeSpec(specBody.substring(0, parenthesisIndex), specString); + dimensionsSpec = specBody.substring(parenthesisIndex + 1); } - String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length()); - if (dimensionsSpec.isEmpty()) return Collections.emptyList(); + + if (dimensionsSpec.isEmpty()) return new TensorType.Builder(valueType, Collections.emptyList()).build(); List<TensorType.Dimension> dimensions = new ArrayList<>(); for (String element : dimensionsSpec.split(",")) { @@ -38,10 +50,30 @@ public class TensorTypeParser { boolean success = tryParseIndexedDimension(trimmedElement, dimensions) || tryParseMappedDimension(trimmedElement, dimensions); if ( ! success) - throw new IllegalArgumentException("Failed parsing element '" + element + - "' in type spec '" + specString + "'"); + throw formatException(specString, "Dimension '" + element + "' is on the wrong format"); + } + return new TensorType.Builder(valueType, dimensions).build(); + } + + public static TensorType.Value toValueType(String valueTypeString) { + switch (valueTypeString) { + case "double" : return TensorType.Value.DOUBLE; + case "float" : return TensorType.Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + + private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) { + if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">")) + throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>")); + + try { + return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); + } + catch (IllegalArgumentException e) { + throw formatException(fullSpecString, e.getMessage()); } - return dimensions; } private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) { @@ -69,5 +101,21 @@ public class TensorTypeParser { return false; } + + private static IllegalArgumentException formatException(String spec) { + return formatException(spec, Optional.empty()); + } + + private static IllegalArgumentException formatException(String spec, String errorDetail) { + return formatException(spec, Optional.of(errorDetail)); + } + + private static IllegalArgumentException formatException(String spec, Optional<String> errorDetail) { + throw new IllegalArgumentException("A tensor type spec must be on the form " + + "tensor[<valuetype>]?(dimensionidentifier[{}|[length?]*), but was '" + spec + "'. " + + errorDetail.map(s -> s + ". ").orElse("") + + "Examples: tensor(x[]), tensor<float>(name{}, x[10])"); + } + } |