diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-08 11:06:53 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-03-08 11:06:53 +0100 |
commit | 692d43c3c85352c8f8e40615ce37aa3d2f83b5d3 (patch) | |
tree | f661a10d076f6e8220e3ec46389b7adffefef678 /vespajlib/src | |
parent | 2005e5dd57b2bafd1150917560da8066b39d64f7 (diff) |
OrderedTensorType.to/from spec
Diffstat (limited to 'vespajlib/src')
3 files changed, 23 insertions, 22 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 9b3a9328f07..5590ccaad0a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -18,7 +18,7 @@ class TensorParser { int colonIndex = tensorString.indexOf(':'); String typeString = tensorString.substring(0, colonIndex); String valueString = tensorString.substring(colonIndex + 1); - TensorType typeFromString = TensorTypeParser.fromSpec(typeString); + TensorType typeFromString = new TensorType(TensorTypeParser.fromSpec(typeString)); if (type.isPresent() && ! type.get().equals(typeFromString)) throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + "passed type " + type); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 0176dac6821..2483280817c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -32,7 +32,7 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final ImmutableList<Dimension> dimensions; - private TensorType(Collection<Dimension> dimensions) { + public TensorType(Collection<Dimension> dimensions) { List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = ImmutableList.copyOf(dimensionList); @@ -50,7 +50,7 @@ public class TensorType { * Example: <code>tensor(x[10],y[20])</code> (a matrix) */ public static TensorType fromSpec(String specString) { - return TensorTypeParser.fromSpec(specString); + return new TensorType(TensorTypeParser.fromSpec(specString)); } /** Returns the number of dimensions of this: dimensions().size() */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index 703101bc45b..6ed0b8202f1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -3,6 +3,9 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -12,7 +15,7 @@ import java.util.regex.Pattern; * @author geirst */ @Beta -class TensorTypeParser { +public class TensorTypeParser { private final static String START_STRING = "tensor("; private final static String END_STRING = ")"; @@ -20,48 +23,46 @@ class TensorTypeParser { private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]"); private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}"); - static TensorType fromSpec(String specString) { - if (!specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) { + public static List<TensorType.Dimension> fromSpec(String specString) { + if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) { throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" + - " and end with '" + END_STRING + "', but was '" + specString + "'"); + " and end with '" + END_STRING + "', but was '" + specString + "'"); } - TensorType.Builder builder = new TensorType.Builder(); String dimensionsSpec = specString.substring(START_STRING.length(), specString.length() - END_STRING.length()); - if (dimensionsSpec.isEmpty()) { - return builder.build(); - } + if (dimensionsSpec.isEmpty()) return Collections.emptyList(); + + List<TensorType.Dimension> dimensions = new ArrayList<>(); for (String element : dimensionsSpec.split(",")) { String trimmedElement = element.trim(); - if (tryParseIndexedDimension(trimmedElement, builder)) { - } else if (tryParseMappedDimension(trimmedElement, builder)) { - } else { + boolean success = tryParseIndexedDimension(trimmedElement, dimensions) || + tryParseMappedDimension(trimmedElement, dimensions); + if ( ! success) throw new IllegalArgumentException("Failed parsing element '" + element + - "' in type spec '" + specString + "'"); - } + "' in type spec '" + specString + "'"); } - return builder.build(); + return dimensions; } - private static boolean tryParseIndexedDimension(String element, TensorType.Builder builder) { + private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) { Matcher matcher = indexedPattern.matcher(element); if (matcher.matches()) { String dimensionName = matcher.group(1); String dimensionSize = matcher.group(2); if (dimensionSize.isEmpty()) { - builder.indexed(dimensionName); + dimensions.add(TensorType.Dimension.indexed(dimensionName)); } else { - builder.indexed(dimensionName, Integer.valueOf(dimensionSize)); + dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize))); } return true; } return false; } - private static boolean tryParseMappedDimension(String element, TensorType.Builder builder) { + private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) { Matcher matcher = mappedPattern.matcher(element); if (matcher.matches()) { String dimensionName = matcher.group(1); - builder.mapped(dimensionName); + dimensions.add(TensorType.Dimension.mapped(dimensionName)); return true; } return false; |