aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java160
1 files changed, 125 insertions, 35 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 45a9992c9ad..4d9bb258423 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -8,44 +8,59 @@ import java.util.Optional;
*/
class TensorParser {
- static Tensor tensorFrom(String tensorString, Optional<TensorType> type) {
+ static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) {
+ Optional<TensorType> type;
+ String valueString;
+
tensorString = tensorString.trim();
- try {
- if (tensorString.startsWith("tensor")) {
- int colonIndex = tensorString.indexOf(':');
- String typeString = tensorString.substring(0, colonIndex);
- String valueString = tensorString.substring(colonIndex + 1);
- TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
- if (type.isPresent() && ! type.get().equals(typeFromString))
- throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
- "passed type " + type.get());
- return tensorFromValueString(valueString, typeFromString);
- }
- else if (tensorString.startsWith("{")) {
- return tensorFromValueString(tensorString, type.orElse(typeFromValueString(tensorString)));
- }
- else {
- if (type.isPresent() && ! type.get().equals(TensorType.empty))
- throw new IllegalArgumentException("Got zero-dimensional tensor '" + tensorString +
- "' where type " + type.get() + " is required");
+ if (tensorString.startsWith("tensor")) {
+ int colonIndex = tensorString.indexOf(':');
+ String typeString = tensorString.substring(0, colonIndex);
+ TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
+ if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString))
+ throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
+ "passed type " + explicitType.get());
+ type = Optional.of(typeFromString);
+ valueString = tensorString.substring(colonIndex + 1);
+ }
+ else {
+ type = explicitType;
+ valueString = tensorString;
+ }
+
+ valueString = valueString.trim();
+ if (valueString.startsWith("{")) {
+ return tensorFromSparseValueString(valueString, type);
+ }
+ else if (valueString.startsWith("[")) {
+ return tensorFromDenseValueString(valueString, type);
+ }
+ else {
+ if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty))
+ throw new IllegalArgumentException("Got a zero-dimensional tensor value ('" + tensorString +
+ "') where type " + explicitType.get() + " is required");
+ try {
return Tensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build();
}
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" +
- tensorString + "'");
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" +
+ tensorString + "'");
+ }
}
}
- /** Derive the tensor type from the first address string in the given tensor string */
- private static TensorType typeFromValueString(String s) {
- s = s.substring(1).trim(); // remove tensor start
+ /** Derives the tensor type from the first address string in the given tensor string */
+ private static TensorType typeFromSparseValueString(String valueString) {
+ String s = valueString.substring(1).trim(); // remove tensor start
int firstKeyOrTensorEnd = s.indexOf('}');
+ if (firstKeyOrTensorEnd < 0)
+ throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" +
+ valueString + "'");
String addressBody = s.substring(0, firstKeyOrTensorEnd).trim();
if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor
if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor
- addressBody = addressBody.substring(1); // remove key start
+ addressBody = addressBody.substring(1, addressBody.length()); // remove key start
if (addressBody.isEmpty()) return TensorType.empty; // Empty key
TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE);
@@ -60,19 +75,94 @@ class TensorParser {
return builder.build();
}
- private static Tensor tensorFromValueString(String tensorValueString, TensorType type) {
- Tensor.Builder builder = Tensor.Builder.of(type);
- tensorValueString = tensorValueString.trim();
+ private static Tensor tensorFromSparseValueString(String valueString, Optional<TensorType> type) {
try {
- if (tensorValueString.startsWith("{"))
- return fromCellString(builder, tensorValueString);
- else
- return builder.cell(Double.parseDouble(tensorValueString)).build();
+ valueString = valueString.trim();
+ Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString)));
+ return fromCellString(builder, valueString);
}
catch (NumberFormatException e) {
throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" +
- tensorValueString + "'");
+ valueString + "'");
+ }
+ }
+
+ private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) {
+ if (type.isEmpty())
+ throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
+ "on the form 'tensor(dimensions):...");
+ if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty())))
+ throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " +
+ "only dense dimensions with a given size");
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get());
+
+ // Since we know the dimensions the brackets are just syntactic sugar
+ long[] indexes = new long[builder.type().rank()];
+ int currentChar;
+ int nextNumberEnd = 0;
+ while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) {
+ nextNumberEnd = nextStopCharIndex(currentChar, valueString);
+ if (currentChar == nextNumberEnd) return builder.build();
+
+ if (builder.type().valueType() == TensorType.Value.DOUBLE)
+ builder.cellByDirectIndex(nextCellIndex(indexes, builder), Double.parseDouble(valueString.substring(currentChar, nextNumberEnd)));
+ else if (builder.type().valueType() == TensorType.Value.FLOAT)
+ builder.cellByDirectIndex(nextCellIndex(indexes, builder), Float.parseFloat(valueString.substring(currentChar, nextNumberEnd)));
+ else
+ throw new IllegalArgumentException(builder.type().valueType() + " is not supported");
+ }
+ return builder.build();
+ }
+
+ // -----
+
+ /**
+ * Advance to the next cell in left-adjac ent order.
+ *
+ * On rightmost vs. leftmost adjacency:
+ * A dense tensor is laid out with the rightmost dimension as adjacent numbers,
+ * but when we parse a dense tensor we encounter numbers in the leftmost-adjacent order, since
+ * that is the most natural way to write it: tensor(x,y)[[1,2],[3,4]]
+ * should mean {{x:0, y:0}:1, {x:1, y:0}:2, {x:0, y:1}:3, {x:1, y:1}:4}.
+ * Therefore we need to convert the encounter order (numberIndex) from left-adjacent to right-adjacent.
+ */
+ private static long nextCellIndex(long[] indexes, IndexedTensor.BoundBuilder builder) {
+ long cellIndex = IndexedTensor.toValueIndex(indexes, builder.sizes());
+
+ // Find next dimension to advance
+ int nextInDimension = 0;
+ while (nextInDimension < indexes.length && indexes[nextInDimension] + 1 >= builder.sizes().size(nextInDimension)) {
+ indexes[nextInDimension] = 0;
+ nextInDimension++;
+ }
+ if (nextInDimension < indexes.length)
+ indexes[nextInDimension]++;
+ else // there is no next - become invalid
+ indexes[0]++;
+
+ return cellIndex;
+ }
+
+ /** Returns the position of the next character that should contain a number, or if none the string length */
+ private static int nextStartCharIndex(int charIndex, String valueString) {
+ for (; charIndex < valueString.length(); charIndex++) {
+ if (valueString.charAt(charIndex) == ']') continue;
+ if (valueString.charAt(charIndex) == '[') continue;
+ if (valueString.charAt(charIndex) == ',') continue;
+ if (valueString.charAt(charIndex) == ' ') continue;
+ return charIndex;
+ }
+ return valueString.length();
+ }
+
+ private static int nextStopCharIndex(int charIndex, String valueString) {
+ while (charIndex < valueString.length()) {
+ if (valueString.charAt(charIndex) == ',') return charIndex;
+ if (valueString.charAt(charIndex) == ']') return charIndex;
+ charIndex++;
}
+ throw new IllegalArgumentException("Malformed tensor value '" + valueString +
+ "': Expected a ',' or ']' after position " + charIndex);
}
private static Tensor fromCellString(Tensor.Builder builder, String s) {