summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-10 11:39:39 -0800
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-10 11:39:39 -0800
commit4c46e1816d2cdfacd8435ad4d55e831929fc99ba (patch)
treed55a90aeeddcf9265a74e7f16129517e36f45375 /vespajlib
parentb8d2859a9fece15dac2b9260d71dea39f8ce19b3 (diff)
Tensor parsing improvements
- Mixed tensor format parsing (outside expressions) - Validate structure of dense tensor strings
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json34
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java74
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java49
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java265
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java33
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java20
10 files changed, 387 insertions, 111 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index e991173805f..d91b38a8a96 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -776,15 +776,14 @@
},
"com.yahoo.tensor.IndexedTensor$BoundBuilder": {
"superClass": "com.yahoo.tensor.IndexedTensor$Builder",
- "interfaces": [],
+ "interfaces": [
+ "com.yahoo.tensor.IndexedTensor$DirectIndexBuilder"
+ ],
"attributes": [
"public",
"abstract"
],
- "methods": [
- "public abstract void cellByDirectIndex(long, double)",
- "public abstract void cellByDirectIndex(long, float)"
- ],
+ "methods": [],
"fields": []
},
"com.yahoo.tensor.IndexedTensor$Builder": {
@@ -813,6 +812,21 @@
],
"fields": []
},
+ "com.yahoo.tensor.IndexedTensor$DirectIndexBuilder": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods": [
+ "public abstract com.yahoo.tensor.TensorType type()",
+ "public abstract void cellByDirectIndex(long, double)",
+ "public abstract void cellByDirectIndex(long, float)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.IndexedTensor$Indexes": {
"superClass": "java.lang.Object",
"interfaces": [],
@@ -829,7 +843,8 @@
"public java.util.List toList()",
"public java.lang.String toString()",
"public abstract long size()",
- "public abstract void next()"
+ "public abstract void next()",
+ "public abstract boolean hasNext()"
],
"fields": [
"protected final long[] indexes"
@@ -943,6 +958,7 @@
],
"methods": [
"public long denseSubspaceSize()",
+ "public com.yahoo.tensor.IndexedTensor$DirectIndexBuilder denseSubspaceBuilder(com.yahoo.tensor.TensorAddress)",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
"public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])",
@@ -1035,8 +1051,8 @@
],
"methods": [
"public void <init>(int)",
- "public void add(java.lang.String, long)",
- "public void add(java.lang.String, java.lang.String)",
+ "public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, long)",
+ "public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, java.lang.String)",
"public com.yahoo.tensor.PartialAddress build()"
],
"fields": []
@@ -1236,6 +1252,7 @@
"methods": [
"public void <init>()",
"public static com.yahoo.tensor.TensorAddress of(java.lang.String[])",
+ "public static varargs com.yahoo.tensor.TensorAddress ofLabels(java.lang.String[])",
"public static varargs com.yahoo.tensor.TensorAddress of(long[])",
"public abstract int size()",
"public abstract java.lang.String label(int)",
@@ -1395,6 +1412,7 @@
"public"
],
"methods": [
+ "public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)",
"public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])",
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
"public com.yahoo.tensor.TensorType$Value valueType()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 176ddfefc13..30923976fa5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor {
indexes.next();
// start brackets
- for (int i = 0; i < indexes.rightDimensionsWhichAreAtStart(); i++)
+ for (int i = 0; i < indexes.rightDimensionsAtStart(); i++)
b.append("[");
// value
@@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalStateException("Unexpected value type " + type.valueType());
// end bracket and comma
- for (int i = 0; i < indexes.rightDimensionsWhichAreAtEnd(); i++)
+ for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++)
b.append("]");
if (index < size() - 1)
b.append(", ");
@@ -375,8 +375,22 @@ public abstract class IndexedTensor implements Tensor {
}
+ public interface DirectIndexBuilder {
+
+ TensorType type();
+
+
+
+ /** Sets a value by its <i>standard value order</i> index */
+ void cellByDirectIndex(long index, double value);
+
+ /** Sets a value by its <i>standard value order</i> index */
+ void cellByDirectIndex(long index, float value);
+
+ }
+
/** A bound builder can create the double array directly */
- public static abstract class BoundBuilder extends Builder {
+ public static abstract class BoundBuilder extends Builder implements DirectIndexBuilder {
private DimensionSizes sizes;
@@ -393,14 +407,16 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
}
- BoundBuilder fill(float [] values) {
+
+ BoundBuilder fill(float[] values) {
long index = 0;
for (float value : values) {
cellByDirectIndex(index++, value);
}
return this;
}
- BoundBuilder fill(double [] values) {
+
+ BoundBuilder fill(double[] values) {
long index = 0;
for (double value : values) {
cellByDirectIndex(index++, value);
@@ -410,12 +426,6 @@ public abstract class IndexedTensor implements Tensor {
DimensionSizes sizes() { return sizes; }
- /** Sets a value by its <i>standard value order</i> index */
- public abstract void cellByDirectIndex(long index, double value);
-
- /** Sets a value by its <i>standard value order</i> index */
- public abstract void cellByDirectIndex(long index, float value);
-
}
/**
@@ -869,8 +879,11 @@ public abstract class IndexedTensor implements Tensor {
public abstract void next();
+ /** Returns whether further values are available by calling next() */
+ public abstract boolean hasNext();
+
/** Returns the number of dimensions from the right which are currently at the start position (0) */
- int rightDimensionsWhichAreAtStart() {
+ int rightDimensionsAtStart() {
int dimension = indexes.length - 1;
int atStartCount = 0;
while (dimension >= 0 && indexes[dimension] == 0) {
@@ -881,7 +894,7 @@ public abstract class IndexedTensor implements Tensor {
}
/** Returns the number of dimensions from the right which are currently at the end position */
- int rightDimensionsWhichAreAtEnd() {
+ int rightDimensionsAtEnd() {
int dimension = indexes.length - 1;
int atEndCount = 0;
while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) {
@@ -904,10 +917,15 @@ public abstract class IndexedTensor implements Tensor {
@Override
public void next() {}
+ @Override
+ public boolean hasNext() { return false; }
+
}
private final static class SingleValueIndexes extends Indexes {
+ private boolean exhausted = false;
+
private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@@ -916,7 +934,10 @@ public abstract class IndexedTensor implements Tensor {
public long size() { return 1; }
@Override
- public void next() {}
+ public void next() { exhausted = true; }
+
+ @Override
+ public boolean hasNext() { return ! exhausted; }
}
@@ -945,7 +966,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -957,6 +978,15 @@ public abstract class IndexedTensor implements Tensor {
indexes[iterateDimensions.get(iterateDimensionsIndex)]++;
}
+ @Override
+ public boolean hasNext() {
+ for (int iterateDimension : iterateDimensions) {
+ if (indexes[iterateDimension] + 1 < dimensionSizes().size(iterateDimension))
+ return true; // some dimension is not at the end
+ }
+ return false;
+ }
+
}
/** In this case we can reuse the source index computation for the iteration index */
@@ -1016,7 +1046,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -1031,6 +1061,11 @@ public abstract class IndexedTensor implements Tensor {
@Override
long toIterationValueIndex() { return currentIterationValueIndex; }
+ @Override
+ public boolean hasNext() {
+ return indexes[iterateDimension] + 1 < size;
+ }
+
}
/** In this case we only need to keep track of one index */
@@ -1068,7 +1103,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -1077,6 +1112,11 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
+ public boolean hasNext() {
+ return indexes[iterateDimension] + 1 < size;
+ }
+
+ @Override
long toSourceValueIndex() { return currentValueIndex; }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 1cde1fcdbb7..0c4efe78113 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -217,25 +217,34 @@ public class MixedTensor implements Tensor {
public static class BoundBuilder extends Builder {
/** For each sparse partial address, hold a dense subspace */
- final private Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>();
- final private Index.Builder indexBuilder;
- final private Index index;
+ private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>();
+ private final Index.Builder indexBuilder;
+ private final Index index;
+ private final TensorType denseSubtype;
private BoundBuilder(TensorType type) {
super(type);
indexBuilder = new Index.Builder(type);
index = indexBuilder.index();
+ denseSubtype = new TensorType(type.valueType(),
+ type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()));
}
public long denseSubspaceSize() {
return index.denseSubspaceSize();
}
- private double[] denseSubspace(TensorAddress sparsePartial) {
- if (!denseSubspaceMap.containsKey(sparsePartial)) {
- denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]);
+ private double[] denseSubspace(TensorAddress sparseAddress) {
+ if (!denseSubspaceMap.containsKey(sparseAddress)) {
+ denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]);
}
- return denseSubspaceMap.get(sparsePartial);
+ return denseSubspaceMap.get(sparseAddress);
+ }
+
+ public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) {
+ double[] values = new double[(int)denseSubspaceSize()];
+ denseSubspaceMap.put(sparseAddress, values);
+ return new DenseSubspaceBuilder(denseSubtype, values);
}
@Override
@@ -280,7 +289,6 @@ public class MixedTensor implements Tensor {
}
-
/**
* Temporarily stores all cells to find bounds of indexed dimensions,
* then creates a tensor using BoundBuilder. This is due to the
@@ -491,6 +499,31 @@ public class MixedTensor implements Tensor {
}
+ private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder {
+
+ private final TensorType type;
+ private final double[] values;
+
+ public DenseSubspaceBuilder(TensorType type, double[] values) {
+ this.type = type;
+ this.values = values;
+ }
+
+ @Override
+ public TensorType type() { return type; }
+
+ @Override
+ public void cellByDirectIndex(long index, double value) {
+ values[(int)index] = value;
+ }
+
+ @Override
+ public void cellByDirectIndex(long index, float value) {
+ values[(int)index] = value;
+ }
+
+ }
+
public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
TensorType.Builder builder = new TensorType.Builder(valueType);
for (TensorType.Dimension dimension : dimensions) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index 4eca9c47402..84f26d96725 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -122,16 +122,18 @@ public class PartialAddress {
labels = new Object[size];
}
- public void add(String dimensionName, long label) {
+ public Builder add(String dimensionName, long label) {
dimensionNames[index] = dimensionName;
labels[index] = label;
index++;
+ return this;
}
- public void add(String dimensionName, String label) {
+ public Builder add(String dimensionName, String label) {
dimensionNames[index] = dimensionName;
labels[index] = label;
index++;
+ return this;
}
public PartialAddress build() {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 52256293a5b..43d1bb0e468 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -18,6 +18,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return new StringTensorAddress(labels);
}
+ public static TensorAddress ofLabels(String ... labels) {
+ return new StringTensorAddress(labels);
+ }
+
public static TensorAddress of(long ... labels) {
return new NumericTensorAddress(labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 4d8b34b7dcf..04d3295795f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import java.util.List;
import java.util.Optional;
/**
@@ -9,6 +10,16 @@ import java.util.Optional;
class TensorParser {
static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) {
+ try {
+ return tensorFromBody(tensorString, explicitType);
+ } catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not parse '" + tensorString + "' as a tensor" +
+ (explicitType.isPresent() ? " of type " + explicitType.get() : ""),
+ e);
+ }
+ }
+
+ static Tensor tensorFromBody(String tensorString, Optional<TensorType> explicitType) {
Optional<TensorType> type;
String valueString;
@@ -29,9 +40,13 @@ class TensorParser {
}
valueString = valueString.trim();
- if (valueString.startsWith("{")) {
+ if (valueString.startsWith("{") &&
+ (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) {
return tensorFromSparseValueString(valueString, type);
}
+ else if (valueString.startsWith("{")) {
+ return tensorFromMixedValueString(valueString, type);
+ }
else if (valueString.startsWith("[")) {
return tensorFromDenseValueString(valueString, type);
}
@@ -54,8 +69,7 @@ class TensorParser {
String s = valueString.substring(1).trim(); // remove tensor start
int firstKeyOrTensorEnd = s.indexOf('}');
if (firstKeyOrTensorEnd < 0)
- throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" +
- valueString + "'");
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'");
String addressBody = s.substring(0, firstKeyOrTensorEnd).trim();
if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor
if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor
@@ -79,73 +93,51 @@ class TensorParser {
try {
valueString = valueString.trim();
Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString)));
- return fromCellString(builder, valueString);
+ return tensorFromSparseCellString(builder, valueString);
}
catch (NumberFormatException e) {
- throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" +
- valueString + "'");
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
}
}
- private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) {
+ private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) {
if (type.isEmpty())
- throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
+ throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
- if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty())))
- throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " +
- "only dense dimensions with a given size");
+ if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1)
+ throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " +
+ "but got " + type.get());
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get());
- long index = 0;
- int currentChar;
- int nextNumberEnd = 0;
- // Since we know the dimensions the brackets are just syntactic sugar:
- while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) {
- nextNumberEnd = nextStopCharIndex(currentChar, valueString);
- if (currentChar == nextNumberEnd) return builder.build();
- TensorType.Value cellValueType = builder.type().valueType();
- String cellValueString = valueString.substring(currentChar, nextNumberEnd);
- try {
- if (cellValueType == TensorType.Value.DOUBLE)
- builder.cellByDirectIndex(index, Double.parseDouble(cellValueString));
- else if (cellValueType == TensorType.Value.FLOAT)
- builder.cellByDirectIndex(index, Float.parseFloat(cellValueString));
- else
- throw new IllegalArgumentException(cellValueType + " is not supported");
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("At index " + index + ": '" +
- cellValueString + "' is not a valid " + cellValueType);
- }
- index++;
+ try {
+ valueString = valueString.trim();
+ if ( ! valueString.startsWith("{") && valueString.endsWith("}"))
+ throw new IllegalArgumentException("A mixed tensor must be enclosed in {}");
+ // TODO: Check if there is also at least one bound indexed dimension
+ MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get());
+ MixedParser parser = new MixedParser(valueString, builder);
+ parser.parse();
+ return builder.build();
}
- return builder.build();
- }
-
- /** Returns the position of the next character that should contain a number, or if none the string length */
- private static int nextStartCharIndex(int charIndex, String valueString) {
- for (; charIndex < valueString.length(); charIndex++) {
- if (valueString.charAt(charIndex) == ']') continue;
- if (valueString.charAt(charIndex) == '[') continue;
- if (valueString.charAt(charIndex) == ',') continue;
- if (valueString.charAt(charIndex) == ' ') continue;
- return charIndex;
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
}
- return valueString.length();
}
- private static int nextStopCharIndex(int charIndex, String valueString) {
- while (charIndex < valueString.length()) {
- if (valueString.charAt(charIndex) == ',') return charIndex;
- if (valueString.charAt(charIndex) == ']') return charIndex;
- charIndex++;
- }
- throw new IllegalArgumentException("Malformed tensor value '" + valueString +
- "': Expected a ',' or ']' after position " + charIndex);
+ private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) {
+ if (type.isEmpty())
+ throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
+ "on the form 'tensor(dimensions):...");
+ if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty())))
+ throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " +
+ "only dense dimensions with a given size");
+
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get());
+ new DenseParser(valueString, builder).parse();
+ return builder.build();
}
- private static Tensor fromCellString(Tensor.Builder builder, String s) {
+ private static Tensor tensorFromSparseCellString(Tensor.Builder builder, String s) {
int index = 1;
index = skipSpace(index, s);
while (index + 1 < s.length()) {
@@ -194,6 +186,16 @@ class TensorParser {
return index;
}
+ private static int nextStopCharIndex(int charIndex, String valueString) {
+ while (charIndex < valueString.length()) {
+ if (valueString.charAt(charIndex) == ',') return charIndex;
+ if (valueString.charAt(charIndex) == ']') return charIndex;
+ charIndex++;
+ }
+ throw new IllegalArgumentException("Malformed tensor value '" + valueString +
+ "': Expected a ',' or ']' after position " + charIndex);
+ }
+
/** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */
private static void addLabels(String mapAddressString, TensorAddress.Builder builder) {
mapAddressString = mapAddressString.trim();
@@ -213,4 +215,157 @@ class TensorParser {
}
}
+ private static abstract class ValueParser {
+
+ protected final String string;
+ protected int position = 0;
+
+ protected ValueParser(String string) {
+ this.string = string;
+ }
+
+ protected void skipSpace() {
+ while (position < string.length() && string.charAt(position) == ' ')
+ position++;
+ }
+
+ protected void consume(char character) {
+ skipSpace();
+
+ if (position >= string.length())
+ throw new IllegalArgumentException("At position " + position + ": Expected a '" + character +
+ "' but got the end of the string");
+ if ( string.charAt(position) != character)
+ throw new IllegalArgumentException("At position " + position + ": Expected a '" + character +
+ "' but got '" + string.charAt(position) + "'");
+ position++;
+ }
+
+ }
+
+ /** A single-use dense tensor string parser */
+ private static class DenseParser extends ValueParser {
+
+ private final IndexedTensor.DirectIndexBuilder builder;
+ private final IndexedTensor.Indexes indexes;
+ private final boolean hasInnerStructure;
+
+ private long tensorIndex = 0;
+
+ public DenseParser(String string, IndexedTensor.DirectIndexBuilder builder) {
+ super(string);
+ this.builder = builder;
+ indexes = IndexedTensor.Indexes.of(builder.type());
+ hasInnerStructure = hasInnerStructure(string);
+ }
+
+ public void parse() {
+ if (!hasInnerStructure)
+ consume('[');
+
+ while (indexes.hasNext()) {
+ indexes.next();
+
+ for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++)
+ consume('[');
+
+ consumeNumber();
+
+ for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++)
+ consume(']');
+
+ if (indexes.hasNext())
+ consume(',');
+ }
+
+ if (!hasInnerStructure)
+ consume(']');
+ }
+
+ public int position() { return position; }
+
+ /** Are there inner square brackets in this or is it just a flat list of numbers until ']'? */
+ private static boolean hasInnerStructure(String valueString) {
+ valueString = valueString.trim();
+ valueString = valueString.substring(1);
+ int firstLeftBracket = valueString.indexOf('[');
+ return firstLeftBracket >= 0 && firstLeftBracket < valueString.indexOf(']');
+ }
+
+ private void consumeNumber() {
+ skipSpace();
+
+ int nextNumberEnd = nextStopCharIndex(position, string);
+ TensorType.Value cellValueType = builder.type().valueType();
+ String cellValueString = string.substring(position, nextNumberEnd);
+ try {
+ if (cellValueType == TensorType.Value.DOUBLE)
+ builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString));
+ else if (cellValueType == TensorType.Value.FLOAT)
+ builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString));
+ else
+ throw new IllegalArgumentException(cellValueType + " is not supported");
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("At position " + position + ": '" +
+ cellValueString + "' is not a valid " + cellValueType);
+ }
+ position = nextNumberEnd;
+ }
+
+ }
+
+ private static class MixedParser extends ValueParser {
+
+ private final MixedTensor.BoundBuilder builder;
+
+ public MixedParser(String string, MixedTensor.BoundBuilder builder) {
+ super(string);
+ this.builder = builder;
+ }
+
+ private void parse() {
+ TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
+ TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension));
+
+ skipSpace();
+ consume('{');
+ skipSpace();
+ while (position + 1 < string.length()) {
+ int labelEnd = string.indexOf(':', position);
+ if (labelEnd <= position)
+ throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...} ");
+ String label = string.substring(position, labelEnd);
+ position = labelEnd + 1;
+ skipSpace();
+
+ TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build();
+ parseDenseSubspace(sparseAddress);
+ if ( ! consumeOptional(','))
+ consume('}');
+ skipSpace();
+ }
+ }
+
+ private void parseDenseSubspace(TensorAddress sparseAddress) {
+ DenseParser denseParser = new DenseParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress));
+ denseParser.parse();
+ position+= denseParser.position();
+ }
+
+ private boolean consumeOptional(char character) {
+ skipSpace();
+
+ if (position >= string.length())
+ return false;
+ if ( string.charAt(position) != character)
+ return false;
+
+ position++;
+ return true;
+ }
+
+
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 95cc70804e2..ca3f8ff28a4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -82,7 +82,7 @@ public class TensorType {
private final TensorType mappedSubtype;
- private TensorType(Value valueType, Collection<Dimension> dimensions) {
+ public TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
index 1928971820c..b2aba5b02eb 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -22,6 +22,12 @@ public class TensorParserTestCase {
}
@Test
+ public void testSingle() {
+ assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(),
+ "tensor(x[1]):[1.0]");
+ }
+
+ @Test
public void testDenseParsing() {
assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(),
"tensor():{0.0}");
@@ -55,18 +61,9 @@ public class TensorParserTestCase {
.cell(3.0, 1, 0, 0)
.cell(4.0, 1, 1, 0)
.cell(5.0, 2, 0, 0)
- .cell(6.0, 2, 1, 0).build(),
- "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [6.0]]]");
- assertEquals("Messy input",
- Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
- .cell( 1.0, 0, 0, 0)
- .cell( 2.0, 0, 1, 0)
- .cell( 3.0, 1, 0, 0)
- .cell( 4.0, 1, 1, 0)
- .cell( 5.0, 2, 0, 0)
.cell(-6.0, 2, 1, 0).build(),
- Tensor.from("tensor( x[3],y[2],z[1]) : [ [ [1.0, 2.0, 3.0] , [4.0, 5,-6.0] ] ]"));
- assertEquals("Skipping syntactic sugar",
+ "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [-6.0]]]");
+ assertEquals("Skipping structure",
Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
.cell( 1.0, 0, 0, 0)
.cell( 2.0, 0, 1, 0)
@@ -77,6 +74,16 @@ public class TensorParserTestCase {
Tensor.from("tensor( x[3],y[2],z[1]) : [1.0, 2.0, 3.0 , 4.0, 5, -6.0]"));
}
+ @Test
+ public void testMixedParsing() {
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])"))
+ .cell(TensorAddress.ofLabels("a", "0"), 1)
+ .cell(TensorAddress.ofLabels("a", "1"), 2)
+ .cell(TensorAddress.ofLabels("b", "0"), 3)
+ .cell(TensorAddress.ofLabels("b", "1"), 4).build(),
+ Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}"));
+ }
+
private void assertDense(Tensor expectedTensor, String denseFormat) {
assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat));
assertEquals(denseFormat, expectedTensor.toString());
@@ -92,7 +99,7 @@ public class TensorParserTestCase {
"{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}");
assertIllegal("At {x:0}: '1-.0' is not a valid double",
"{{x:0}:1-.0}");
- assertIllegal("At index 0: '1-.0' is not a valid double",
+ assertIllegal("At position 1: '1-.0' is not a valid double",
"tensor(x[1]):[1-.0]");
}
@@ -102,7 +109,7 @@ public class TensorParserTestCase {
fail("Expected an IllegalArgumentException when parsing " + tensor);
}
catch (IllegalArgumentException e) {
- assertEquals(message, e.getMessage());
+ assertEquals(message, e.getCause().getMessage());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 11365531019..9f077cb7b00 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -56,7 +56,8 @@ public class TensorTestCase {
fail("Expected parse error");
}
catch (IllegalArgumentException expected) {
- assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", expected.getMessage());
+ assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'",
+ expected.getCause().getMessage());
}
}
@@ -259,9 +260,9 @@ public class TensorTestCase {
assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0",
"tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l3}:5.0,{d1:l1,d2:l2}:6.0}");
assertLargest("{x:1,y:1}:4.0",
- "tensor(x[2],y[2]):[[1,2],[3,4]");
+ "tensor(x[2],y[2]):[[1,2],[3,4]]");
assertLargest("{x:0,y:0}:4.0, {x:1,y:1}:4.0",
- "tensor(x[2],y[2]):[[4,2],[3,4]");
+ "tensor(x[2],y[2]):[[4,2],[3,4]]");
}
@Test
@@ -273,9 +274,9 @@ public class TensorTestCase {
assertSmallest("{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:5.0",
"tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l3}:6.0,{d1:l1,d2:l2}:5.0}");
assertSmallest("{x:0,y:0}:1.0",
- "tensor(x[2],y[2]):[[1,2],[3,4]");
+ "tensor(x[2],y[2]):[[1,2],[3,4]]");
assertSmallest("{x:0,y:1}:2.0",
- "tensor(x[2],y[2]):[[4,2],[3,4]");
+ "tensor(x[2],y[2]):[[4,2],[3,4]]");
}
private void assertLargest(String expectedCells, String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index e16b7b90a1d..7cddeab1641 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.Name;
import org.junit.Test;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import static org.junit.Assert.assertEquals;
@@ -19,21 +20,36 @@ import static org.junit.Assert.assertEquals;
public class DynamicTensorTestCase {
@Test
- public void testDynamicTensorFunction() {
+ public void testDynamicIndexedRank1TensorFunction() {
TensorType dense = TensorType.fromSpec("tensor(x[3])");
DynamicTensor<Name> t1 = DynamicTensor.from(dense,
List.of(new Constant(1), new Constant(2), new Constant(3)));
assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate());
assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString());
+ }
+ @Test
+ public void testDynamicMappedRank1TensorFunction() {
TensorType sparse = TensorType.fromSpec("tensor(x{})");
DynamicTensor<Name> t2 = DynamicTensor.from(sparse,
Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
- new Constant(5)));
+ new Constant(5)));
assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
+ @Test
+ public void testDynamicMappedRank2TensorFunction() {
+ TensorType sparse = TensorType.fromSpec("tensor(x{},y{})");
+ HashMap<TensorAddress, ScalarFunction<Name>> values = new HashMap<>();
+ values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "b").build(),
+ new Constant(5));
+ values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "c").build(),
+ new Constant(7));
+ DynamicTensor<Name> t2 = DynamicTensor.from(sparse, values);
+ assertEquals(Tensor.from(sparse, "{{x:a,y:b}:5, {x:a,y:c}:7}"), t2.evaluate());
+ }
+
private static class Constant implements ScalarFunction<Name> {
private final double value;