diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java | 160 |
1 files changed, 125 insertions, 35 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 45a9992c9ad..4d9bb258423 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -8,44 +8,59 @@ import java.util.Optional; */ class TensorParser { - static Tensor tensorFrom(String tensorString, Optional<TensorType> type) { + static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) { + Optional<TensorType> type; + String valueString; + tensorString = tensorString.trim(); - try { - if (tensorString.startsWith("tensor")) { - int colonIndex = tensorString.indexOf(':'); - String typeString = tensorString.substring(0, colonIndex); - String valueString = tensorString.substring(colonIndex + 1); - TensorType typeFromString = TensorTypeParser.fromSpec(typeString); - if (type.isPresent() && ! type.get().equals(typeFromString)) - throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + - "passed type " + type.get()); - return tensorFromValueString(valueString, typeFromString); - } - else if (tensorString.startsWith("{")) { - return tensorFromValueString(tensorString, type.orElse(typeFromValueString(tensorString))); - } - else { - if (type.isPresent() && ! type.get().equals(TensorType.empty)) - throw new IllegalArgumentException("Got zero-dimensional tensor '" + tensorString + - "' where type " + type.get() + " is required"); + if (tensorString.startsWith("tensor")) { + int colonIndex = tensorString.indexOf(':'); + String typeString = tensorString.substring(0, colonIndex); + TensorType typeFromString = TensorTypeParser.fromSpec(typeString); + if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString)) + throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + + "passed type " + explicitType.get()); + type = Optional.of(typeFromString); + valueString = tensorString.substring(colonIndex + 1); + } + else { + type = explicitType; + valueString = tensorString; + } + + valueString = valueString.trim(); + if (valueString.startsWith("{")) { + return tensorFromSparseValueString(valueString, type); + } + else if (valueString.startsWith("[")) { + return tensorFromDenseValueString(valueString, type); + } + else { + if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty)) + throw new IllegalArgumentException("Got a zero-dimensional tensor value ('" + tensorString + + "') where type " + explicitType.get() + " is required"); + try { return Tensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build(); } - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + - tensorString + "'"); + catch (NumberFormatException e) { + throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" + + tensorString + "'"); + } } } - /** Derive the tensor type from the first address string in the given tensor string */ - private static TensorType typeFromValueString(String s) { - s = s.substring(1).trim(); // remove tensor start + /** Derives the tensor type from the first address string in the given tensor string */ + private static TensorType typeFromSparseValueString(String valueString) { + String s = valueString.substring(1).trim(); // remove tensor start int firstKeyOrTensorEnd = s.indexOf('}'); + if (firstKeyOrTensorEnd < 0) + throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" + + valueString + "'"); String addressBody = s.substring(0, firstKeyOrTensorEnd).trim(); if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor - addressBody = addressBody.substring(1); // remove key start + addressBody = addressBody.substring(1, addressBody.length()); // remove key start if (addressBody.isEmpty()) return TensorType.empty; // Empty key TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE); @@ -60,19 +75,94 @@ class TensorParser { return builder.build(); } - private static Tensor tensorFromValueString(String tensorValueString, TensorType type) { - Tensor.Builder builder = Tensor.Builder.of(type); - tensorValueString = tensorValueString.trim(); + private static Tensor tensorFromSparseValueString(String valueString, Optional<TensorType> type) { try { - if (tensorValueString.startsWith("{")) - return fromCellString(builder, tensorValueString); - else - return builder.cell(Double.parseDouble(tensorValueString)).build(); + valueString = valueString.trim(); + Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString))); + return fromCellString(builder, valueString); } catch (NumberFormatException e) { throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + - tensorValueString + "'"); + valueString + "'"); + } + } + + private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + if (type.isEmpty()) + throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + + "on the form 'tensor(dimensions):..."); + if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty()))) + throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " + + "only dense dimensions with a given size"); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get()); + + // Since we know the dimensions the brackets are just syntactic sugar + long[] indexes = new long[builder.type().rank()]; + int currentChar; + int nextNumberEnd = 0; + while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) { + nextNumberEnd = nextStopCharIndex(currentChar, valueString); + if (currentChar == nextNumberEnd) return builder.build(); + + if (builder.type().valueType() == TensorType.Value.DOUBLE) + builder.cellByDirectIndex(nextCellIndex(indexes, builder), Double.parseDouble(valueString.substring(currentChar, nextNumberEnd))); + else if (builder.type().valueType() == TensorType.Value.FLOAT) + builder.cellByDirectIndex(nextCellIndex(indexes, builder), Float.parseFloat(valueString.substring(currentChar, nextNumberEnd))); + else + throw new IllegalArgumentException(builder.type().valueType() + " is not supported"); + } + return builder.build(); + } + + // ----- + + /** + * Advance to the next cell in left-adjac ent order. + * + * On rightmost vs. leftmost adjacency: + * A dense tensor is laid out with the rightmost dimension as adjacent numbers, + * but when we parse a dense tensor we encounter numbers in the leftmost-adjacent order, since + * that is the most natural way to write it: tensor(x,y)[[1,2],[3,4]] + * should mean {{x:0, y:0}:1, {x:1, y:0}:2, {x:0, y:1}:3, {x:1, y:1}:4}. + * Therefore we need to convert the encounter order (numberIndex) from left-adjacent to right-adjacent. + */ + private static long nextCellIndex(long[] indexes, IndexedTensor.BoundBuilder builder) { + long cellIndex = IndexedTensor.toValueIndex(indexes, builder.sizes()); + + // Find next dimension to advance + int nextInDimension = 0; + while (nextInDimension < indexes.length && indexes[nextInDimension] + 1 >= builder.sizes().size(nextInDimension)) { + indexes[nextInDimension] = 0; + nextInDimension++; + } + if (nextInDimension < indexes.length) + indexes[nextInDimension]++; + else // there is no next - become invalid + indexes[0]++; + + return cellIndex; + } + + /** Returns the position of the next character that should contain a number, or if none the string length */ + private static int nextStartCharIndex(int charIndex, String valueString) { + for (; charIndex < valueString.length(); charIndex++) { + if (valueString.charAt(charIndex) == ']') continue; + if (valueString.charAt(charIndex) == '[') continue; + if (valueString.charAt(charIndex) == ',') continue; + if (valueString.charAt(charIndex) == ' ') continue; + return charIndex; + } + return valueString.length(); + } + + private static int nextStopCharIndex(int charIndex, String valueString) { + while (charIndex < valueString.length()) { + if (valueString.charAt(charIndex) == ',') return charIndex; + if (valueString.charAt(charIndex) == ']') return charIndex; + charIndex++; } + throw new IllegalArgumentException("Malformed tensor value '" + valueString + + "': Expected a ',' or ']' after position " + charIndex); } private static Tensor fromCellString(Tensor.Builder builder, String s) { |