summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
blob: 6ed0b8202f10bacbcdfbb6e862cd8895402f0323 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;

import com.google.common.annotations.Beta;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Class for parsing a tensor type spec.
 *
 * @author geirst
 */
@Beta
public class TensorTypeParser {

    private final static String START_STRING = "tensor(";
    private final static String END_STRING = ")";

    private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
    private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");

    public static List<TensorType.Dimension> fromSpec(String specString) {
        if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
            throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
                                               " and end with '" + END_STRING + "', but was '" + specString + "'");
        }
        String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length());
        if (dimensionsSpec.isEmpty()) return Collections.emptyList();

        List<TensorType.Dimension> dimensions = new ArrayList<>();
        for (String element : dimensionsSpec.split(",")) {
            String trimmedElement = element.trim();
            boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
                              tryParseMappedDimension(trimmedElement, dimensions);
            if ( ! success)
                throw new IllegalArgumentException("Failed parsing element '" + element +
                                                   "' in type spec '" + specString + "'");
        }
        return dimensions;
    }

    private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
        Matcher matcher = indexedPattern.matcher(element);
        if (matcher.matches()) {
            String dimensionName = matcher.group(1);
            String dimensionSize = matcher.group(2);
            if (dimensionSize.isEmpty()) {
                dimensions.add(TensorType.Dimension.indexed(dimensionName));
            } else {
                dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize)));
            }
            return true;
        }
        return false;
    }

    private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) {
        Matcher matcher = mappedPattern.matcher(element);
        if (matcher.matches()) {
            String dimensionName = matcher.group(1);
            dimensions.add(TensorType.Dimension.mapped(dimensionName));
            return true;
        }
        return false;
    }

}