diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java | 265 |
1 files changed, 210 insertions, 55 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 4d8b34b7dcf..04d3295795f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.List; import java.util.Optional; /** @@ -9,6 +10,16 @@ import java.util.Optional; class TensorParser { static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) { + try { + return tensorFromBody(tensorString, explicitType); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not parse '" + tensorString + "' as a tensor" + + (explicitType.isPresent() ? " of type " + explicitType.get() : ""), + e); + } + } + + static Tensor tensorFromBody(String tensorString, Optional<TensorType> explicitType) { Optional<TensorType> type; String valueString; @@ -29,9 +40,13 @@ class TensorParser { } valueString = valueString.trim(); - if (valueString.startsWith("{")) { + if (valueString.startsWith("{") && + (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) { return tensorFromSparseValueString(valueString, type); } + else if (valueString.startsWith("{")) { + return tensorFromMixedValueString(valueString, type); + } else if (valueString.startsWith("[")) { return tensorFromDenseValueString(valueString, type); } @@ -54,8 +69,7 @@ class TensorParser { String s = valueString.substring(1).trim(); // remove tensor start int firstKeyOrTensorEnd = s.indexOf('}'); if (firstKeyOrTensorEnd < 0) - throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" + - valueString + "'"); + throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'"); String addressBody = s.substring(0, firstKeyOrTensorEnd).trim(); if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor @@ -79,73 +93,51 @@ class TensorParser { try { valueString = valueString.trim(); Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString))); - return fromCellString(builder, valueString); + return tensorFromSparseCellString(builder, valueString); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + - valueString + "'"); + throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('"); } } - private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) { if (type.isEmpty()) - throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + + throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); - if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty()))) - throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " + - "only dense dimensions with a given size"); + if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1) + throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " + + "but got " + type.get()); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get()); - long index = 0; - int currentChar; - int nextNumberEnd = 0; - // Since we know the dimensions the brackets are just syntactic sugar: - while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) { - nextNumberEnd = nextStopCharIndex(currentChar, valueString); - if (currentChar == nextNumberEnd) return builder.build(); - TensorType.Value cellValueType = builder.type().valueType(); - String cellValueString = valueString.substring(currentChar, nextNumberEnd); - try { - if (cellValueType == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(index, Double.parseDouble(cellValueString)); - else if (cellValueType == TensorType.Value.FLOAT) - builder.cellByDirectIndex(index, Float.parseFloat(cellValueString)); - else - throw new IllegalArgumentException(cellValueType + " is not supported"); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("At index " + index + ": '" + - cellValueString + "' is not a valid " + cellValueType); - } - index++; + try { + valueString = valueString.trim(); + if ( ! valueString.startsWith("{") && valueString.endsWith("}")) + throw new IllegalArgumentException("A mixed tensor must be enclosed in {}"); + // TODO: Check if there is also at least one bound indexed dimension + MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get()); + MixedParser parser = new MixedParser(valueString, builder); + parser.parse(); + return builder.build(); } - return builder.build(); - } - - /** Returns the position of the next character that should contain a number, or if none the string length */ - private static int nextStartCharIndex(int charIndex, String valueString) { - for (; charIndex < valueString.length(); charIndex++) { - if (valueString.charAt(charIndex) == ']') continue; - if (valueString.charAt(charIndex) == '[') continue; - if (valueString.charAt(charIndex) == ',') continue; - if (valueString.charAt(charIndex) == ' ') continue; - return charIndex; + catch (NumberFormatException e) { + throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('"); } - return valueString.length(); } - private static int nextStopCharIndex(int charIndex, String valueString) { - while (charIndex < valueString.length()) { - if (valueString.charAt(charIndex) == ',') return charIndex; - if (valueString.charAt(charIndex) == ']') return charIndex; - charIndex++; - } - throw new IllegalArgumentException("Malformed tensor value '" + valueString + - "': Expected a ',' or ']' after position " + charIndex); + private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + if (type.isEmpty()) + throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + + "on the form 'tensor(dimensions):..."); + if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty()))) + throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " + + "only dense dimensions with a given size"); + + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get()); + new DenseParser(valueString, builder).parse(); + return builder.build(); } - private static Tensor fromCellString(Tensor.Builder builder, String s) { + private static Tensor tensorFromSparseCellString(Tensor.Builder builder, String s) { int index = 1; index = skipSpace(index, s); while (index + 1 < s.length()) { @@ -194,6 +186,16 @@ class TensorParser { return index; } + private static int nextStopCharIndex(int charIndex, String valueString) { + while (charIndex < valueString.length()) { + if (valueString.charAt(charIndex) == ',') return charIndex; + if (valueString.charAt(charIndex) == ']') return charIndex; + charIndex++; + } + throw new IllegalArgumentException("Malformed tensor value '" + valueString + + "': Expected a ',' or ']' after position " + charIndex); + } + /** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */ private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { mapAddressString = mapAddressString.trim(); @@ -213,4 +215,157 @@ class TensorParser { } } + private static abstract class ValueParser { + + protected final String string; + protected int position = 0; + + protected ValueParser(String string) { + this.string = string; + } + + protected void skipSpace() { + while (position < string.length() && string.charAt(position) == ' ') + position++; + } + + protected void consume(char character) { + skipSpace(); + + if (position >= string.length()) + throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + "' but got the end of the string"); + if ( string.charAt(position) != character) + throw new IllegalArgumentException("At position " + position + ": Expected a '" + character + + "' but got '" + string.charAt(position) + "'"); + position++; + } + + } + + /** A single-use dense tensor string parser */ + private static class DenseParser extends ValueParser { + + private final IndexedTensor.DirectIndexBuilder builder; + private final IndexedTensor.Indexes indexes; + private final boolean hasInnerStructure; + + private long tensorIndex = 0; + + public DenseParser(String string, IndexedTensor.DirectIndexBuilder builder) { + super(string); + this.builder = builder; + indexes = IndexedTensor.Indexes.of(builder.type()); + hasInnerStructure = hasInnerStructure(string); + } + + public void parse() { + if (!hasInnerStructure) + consume('['); + + while (indexes.hasNext()) { + indexes.next(); + + for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++) + consume('['); + + consumeNumber(); + + for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++) + consume(']'); + + if (indexes.hasNext()) + consume(','); + } + + if (!hasInnerStructure) + consume(']'); + } + + public int position() { return position; } + + /** Are there inner square brackets in this or is it just a flat list of numbers until ']'? */ + private static boolean hasInnerStructure(String valueString) { + valueString = valueString.trim(); + valueString = valueString.substring(1); + int firstLeftBracket = valueString.indexOf('['); + return firstLeftBracket >= 0 && firstLeftBracket < valueString.indexOf(']'); + } + + private void consumeNumber() { + skipSpace(); + + int nextNumberEnd = nextStopCharIndex(position, string); + TensorType.Value cellValueType = builder.type().valueType(); + String cellValueString = string.substring(position, nextNumberEnd); + try { + if (cellValueType == TensorType.Value.DOUBLE) + builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString)); + else if (cellValueType == TensorType.Value.FLOAT) + builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString)); + else + throw new IllegalArgumentException(cellValueType + " is not supported"); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("At position " + position + ": '" + + cellValueString + "' is not a valid " + cellValueType); + } + position = nextNumberEnd; + } + + } + + private static class MixedParser extends ValueParser { + + private final MixedTensor.BoundBuilder builder; + + public MixedParser(String string, MixedTensor.BoundBuilder builder) { + super(string); + this.builder = builder; + } + + private void parse() { + TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); + TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension)); + + skipSpace(); + consume('{'); + skipSpace(); + while (position + 1 < string.length()) { + int labelEnd = string.indexOf(':', position); + if (labelEnd <= position) + throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...} "); + String label = string.substring(position, labelEnd); + position = labelEnd + 1; + skipSpace(); + + TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build(); + parseDenseSubspace(sparseAddress); + if ( ! consumeOptional(',')) + consume('}'); + skipSpace(); + } + } + + private void parseDenseSubspace(TensorAddress sparseAddress) { + DenseParser denseParser = new DenseParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress)); + denseParser.parse(); + position+= denseParser.position(); + } + + private boolean consumeOptional(char character) { + skipSpace(); + + if (position >= string.length()) + return false; + if ( string.charAt(position) != character) + return false; + + position++; + return true; + } + + + } + } |