aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java265
1 files changed, 210 insertions, 55 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 4d8b34b7dcf..04d3295795f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import java.util.List;
import java.util.Optional;
/**
@@ -9,6 +10,16 @@ import java.util.Optional;
class TensorParser {
static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) {
+ try {
+ return tensorFromBody(tensorString, explicitType);
+ } catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not parse '" + tensorString + "' as a tensor" +
+ (explicitType.isPresent() ? " of type " + explicitType.get() : ""),
+ e);
+ }
+ }
+
+ static Tensor tensorFromBody(String tensorString, Optional<TensorType> explicitType) {
Optional<TensorType> type;
String valueString;
@@ -29,9 +40,13 @@ class TensorParser {
}
valueString = valueString.trim();
- if (valueString.startsWith("{")) {
+ if (valueString.startsWith("{") &&
+ (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) {
return tensorFromSparseValueString(valueString, type);
}
+ else if (valueString.startsWith("{")) {
+ return tensorFromMixedValueString(valueString, type);
+ }
else if (valueString.startsWith("[")) {
return tensorFromDenseValueString(valueString, type);
}
@@ -54,8 +69,7 @@ class TensorParser {
String s = valueString.substring(1).trim(); // remove tensor start
int firstKeyOrTensorEnd = s.indexOf('}');
if (firstKeyOrTensorEnd < 0)
- throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" +
- valueString + "'");
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'");
String addressBody = s.substring(0, firstKeyOrTensorEnd).trim();
if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor
if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor
@@ -79,73 +93,51 @@ class TensorParser {
try {
valueString = valueString.trim();
Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString)));
- return fromCellString(builder, valueString);
+ return tensorFromSparseCellString(builder, valueString);
}
catch (NumberFormatException e) {
- throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" +
- valueString + "'");
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
}
}
- private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) {
+ private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) {
if (type.isEmpty())
- throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
+ throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
- if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty())))
- throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " +
- "only dense dimensions with a given size");
+ if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1)
+ throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " +
+ "but got " + type.get());
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get());
- long index = 0;
- int currentChar;
- int nextNumberEnd = 0;
- // Since we know the dimensions the brackets are just syntactic sugar:
- while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) {
- nextNumberEnd = nextStopCharIndex(currentChar, valueString);
- if (currentChar == nextNumberEnd) return builder.build();
- TensorType.Value cellValueType = builder.type().valueType();
- String cellValueString = valueString.substring(currentChar, nextNumberEnd);
- try {
- if (cellValueType == TensorType.Value.DOUBLE)
- builder.cellByDirectIndex(index, Double.parseDouble(cellValueString));
- else if (cellValueType == TensorType.Value.FLOAT)
- builder.cellByDirectIndex(index, Float.parseFloat(cellValueString));
- else
- throw new IllegalArgumentException(cellValueType + " is not supported");
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("At index " + index + ": '" +
- cellValueString + "' is not a valid " + cellValueType);
- }
- index++;
+ try {
+ valueString = valueString.trim();
+ if ( ! valueString.startsWith("{") && valueString.endsWith("}"))
+ throw new IllegalArgumentException("A mixed tensor must be enclosed in {}");
+ // TODO: Check if there is also at least one bound indexed dimension
+ MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get());
+ MixedParser parser = new MixedParser(valueString, builder);
+ parser.parse();
+ return builder.build();
}
- return builder.build();
- }
-
- /** Returns the position of the next character that should contain a number, or if none the string length */
- private static int nextStartCharIndex(int charIndex, String valueString) {
- for (; charIndex < valueString.length(); charIndex++) {
- if (valueString.charAt(charIndex) == ']') continue;
- if (valueString.charAt(charIndex) == '[') continue;
- if (valueString.charAt(charIndex) == ',') continue;
- if (valueString.charAt(charIndex) == ' ') continue;
- return charIndex;
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
}
- return valueString.length();
}
- private static int nextStopCharIndex(int charIndex, String valueString) {
- while (charIndex < valueString.length()) {
- if (valueString.charAt(charIndex) == ',') return charIndex;
- if (valueString.charAt(charIndex) == ']') return charIndex;
- charIndex++;
- }
- throw new IllegalArgumentException("Malformed tensor value '" + valueString +
- "': Expected a ',' or ']' after position " + charIndex);
+ private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) {
+ if (type.isEmpty())
+ throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
+ "on the form 'tensor(dimensions):...");
+ if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty())))
+ throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " +
+ "only dense dimensions with a given size");
+
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get());
+ new DenseParser(valueString, builder).parse();
+ return builder.build();
}
- private static Tensor fromCellString(Tensor.Builder builder, String s) {
+ private static Tensor tensorFromSparseCellString(Tensor.Builder builder, String s) {
int index = 1;
index = skipSpace(index, s);
while (index + 1 < s.length()) {
@@ -194,6 +186,16 @@ class TensorParser {
return index;
}
+ private static int nextStopCharIndex(int charIndex, String valueString) {
+ while (charIndex < valueString.length()) {
+ if (valueString.charAt(charIndex) == ',') return charIndex;
+ if (valueString.charAt(charIndex) == ']') return charIndex;
+ charIndex++;
+ }
+ throw new IllegalArgumentException("Malformed tensor value '" + valueString +
+ "': Expected a ',' or ']' after position " + charIndex);
+ }
+
/** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */
private static void addLabels(String mapAddressString, TensorAddress.Builder builder) {
mapAddressString = mapAddressString.trim();
@@ -213,4 +215,157 @@ class TensorParser {
}
}
+ private static abstract class ValueParser {
+
+ protected final String string;
+ protected int position = 0;
+
+ protected ValueParser(String string) {
+ this.string = string;
+ }
+
+ protected void skipSpace() {
+ while (position < string.length() && string.charAt(position) == ' ')
+ position++;
+ }
+
+ protected void consume(char character) {
+ skipSpace();
+
+ if (position >= string.length())
+ throw new IllegalArgumentException("At position " + position + ": Expected a '" + character +
+ "' but got the end of the string");
+ if ( string.charAt(position) != character)
+ throw new IllegalArgumentException("At position " + position + ": Expected a '" + character +
+ "' but got '" + string.charAt(position) + "'");
+ position++;
+ }
+
+ }
+
+ /** A single-use dense tensor string parser */
+ private static class DenseParser extends ValueParser {
+
+ private final IndexedTensor.DirectIndexBuilder builder;
+ private final IndexedTensor.Indexes indexes;
+ private final boolean hasInnerStructure;
+
+ private long tensorIndex = 0;
+
+ public DenseParser(String string, IndexedTensor.DirectIndexBuilder builder) {
+ super(string);
+ this.builder = builder;
+ indexes = IndexedTensor.Indexes.of(builder.type());
+ hasInnerStructure = hasInnerStructure(string);
+ }
+
+ public void parse() {
+ if (!hasInnerStructure)
+ consume('[');
+
+ while (indexes.hasNext()) {
+ indexes.next();
+
+ for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++)
+ consume('[');
+
+ consumeNumber();
+
+ for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++)
+ consume(']');
+
+ if (indexes.hasNext())
+ consume(',');
+ }
+
+ if (!hasInnerStructure)
+ consume(']');
+ }
+
+ public int position() { return position; }
+
+ /** Are there inner square brackets in this or is it just a flat list of numbers until ']'? */
+ private static boolean hasInnerStructure(String valueString) {
+ valueString = valueString.trim();
+ valueString = valueString.substring(1);
+ int firstLeftBracket = valueString.indexOf('[');
+ return firstLeftBracket >= 0 && firstLeftBracket < valueString.indexOf(']');
+ }
+
+ private void consumeNumber() {
+ skipSpace();
+
+ int nextNumberEnd = nextStopCharIndex(position, string);
+ TensorType.Value cellValueType = builder.type().valueType();
+ String cellValueString = string.substring(position, nextNumberEnd);
+ try {
+ if (cellValueType == TensorType.Value.DOUBLE)
+ builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString));
+ else if (cellValueType == TensorType.Value.FLOAT)
+ builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString));
+ else
+ throw new IllegalArgumentException(cellValueType + " is not supported");
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("At position " + position + ": '" +
+ cellValueString + "' is not a valid " + cellValueType);
+ }
+ position = nextNumberEnd;
+ }
+
+ }
+
+ private static class MixedParser extends ValueParser {
+
+ private final MixedTensor.BoundBuilder builder;
+
+ public MixedParser(String string, MixedTensor.BoundBuilder builder) {
+ super(string);
+ this.builder = builder;
+ }
+
+ private void parse() {
+ TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
+ TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension));
+
+ skipSpace();
+ consume('{');
+ skipSpace();
+ while (position + 1 < string.length()) {
+ int labelEnd = string.indexOf(':', position);
+ if (labelEnd <= position)
+ throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...} ");
+ String label = string.substring(position, labelEnd);
+ position = labelEnd + 1;
+ skipSpace();
+
+ TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build();
+ parseDenseSubspace(sparseAddress);
+ if ( ! consumeOptional(','))
+ consume('}');
+ skipSpace();
+ }
+ }
+
+ private void parseDenseSubspace(TensorAddress sparseAddress) {
+ DenseParser denseParser = new DenseParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress));
+ denseParser.parse();
+ position+= denseParser.position();
+ }
+
+ private boolean consumeOptional(char character) {
+ skipSpace();
+
+ if (position >= string.length())
+ return false;
+ if ( string.charAt(position) != character)
+ return false;
+
+ position++;
+ return true;
+ }
+
+
+ }
+
}