blob: f7b71c2149a4393c790618e52672c4bc72b094f3 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
|
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
import com.google.common.annotations.Beta;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* Class for parsing a tensor type spec.
*
* @author geirst
*/
@Beta
class TensorTypeParser {
private final static String START_STRING = "tensor(";
private final static String END_STRING = ")";
private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
static TensorType fromSpec(String specString) {
if (!specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
" and end with '" + END_STRING + "', but was '" + specString + "'");
}
TensorType.Builder builder = new TensorType.Builder();
String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
if (dimensionsSpec.isEmpty()) {
return builder.build();
}
for (String element : dimensionsSpec.split(",")) {
String trimmedElement = element.trim();
if (tryParseIndexedDimension(trimmedElement, builder)) {
} else if (tryParseMappedDimension(trimmedElement, builder)) {
} else {
throw new IllegalArgumentException("Failed parsing element '" + element +
"' in type spec '" + specString + "'");
}
}
return builder.build();
}
private static boolean tryParseIndexedDimension(String element, TensorType.Builder builder) {
Matcher matcher = indexedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
String dimensionSize = matcher.group(2);
if (dimensionSize.isEmpty()) {
builder.indexed(dimensionName);
} else {
builder.indexed(dimensionName, Integer.valueOf(dimensionSize));
}
return true;
}
return false;
}
private static boolean tryParseMappedDimension(String element, TensorType.Builder builder) {
Matcher matcher = mappedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
builder.mapped(dimensionName);
return true;
}
return false;
}
}
|