diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-11 10:04:14 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-11 10:04:14 +0200 |
commit | a7e6b478c536dee7abc14b62fa2700df2b9df93f (patch) | |
tree | a7427132e28e73293eacbd14dce1b4d1627ecfc5 | |
parent | 2014e93de206861200950343c5330ed5997d8770 (diff) |
Parse dense tensors in the rightmost adjacent order
4 files changed, 61 insertions, 69 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 04e68e60178..b2b895040bc 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1352,7 +1352,8 @@ "public static com.yahoo.tensor.TensorType$Value[] values()", "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)", "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)", - "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)" + "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)", + "public java.lang.String toString()" ], "fields": [ "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 4d9bb258423..4d8b34b7dcf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -94,55 +94,35 @@ class TensorParser { if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty()))) throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " + "only dense dimensions with a given size"); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get()); - // Since we know the dimensions the brackets are just syntactic sugar - long[] indexes = new long[builder.type().rank()]; + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get()); + long index = 0; int currentChar; int nextNumberEnd = 0; + // Since we know the dimensions the brackets are just syntactic sugar: while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) { nextNumberEnd = nextStopCharIndex(currentChar, valueString); if (currentChar == nextNumberEnd) return builder.build(); - if (builder.type().valueType() == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(nextCellIndex(indexes, builder), Double.parseDouble(valueString.substring(currentChar, nextNumberEnd))); - else if (builder.type().valueType() == TensorType.Value.FLOAT) - builder.cellByDirectIndex(nextCellIndex(indexes, builder), Float.parseFloat(valueString.substring(currentChar, nextNumberEnd))); - else - throw new IllegalArgumentException(builder.type().valueType() + " is not supported"); + TensorType.Value cellValueType = builder.type().valueType(); + String cellValueString = valueString.substring(currentChar, nextNumberEnd); + try { + if (cellValueType == TensorType.Value.DOUBLE) + builder.cellByDirectIndex(index, Double.parseDouble(cellValueString)); + else if (cellValueType == TensorType.Value.FLOAT) + builder.cellByDirectIndex(index, Float.parseFloat(cellValueString)); + else + throw new IllegalArgumentException(cellValueType + " is not supported"); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("At index " + index + ": '" + + cellValueString + "' is not a valid " + cellValueType); + } + index++; } return builder.build(); } - // ----- - - /** - * Advance to the next cell in left-adjac ent order. - * - * On rightmost vs. leftmost adjacency: - * A dense tensor is laid out with the rightmost dimension as adjacent numbers, - * but when we parse a dense tensor we encounter numbers in the leftmost-adjacent order, since - * that is the most natural way to write it: tensor(x,y)[[1,2],[3,4]] - * should mean {{x:0, y:0}:1, {x:1, y:0}:2, {x:0, y:1}:3, {x:1, y:1}:4}. - * Therefore we need to convert the encounter order (numberIndex) from left-adjacent to right-adjacent. - */ - private static long nextCellIndex(long[] indexes, IndexedTensor.BoundBuilder builder) { - long cellIndex = IndexedTensor.toValueIndex(indexes, builder.sizes()); - - // Find next dimension to advance - int nextInDimension = 0; - while (nextInDimension < indexes.length && indexes[nextInDimension] + 1 >= builder.sizes().size(nextInDimension)) { - indexes[nextInDimension] = 0; - nextInDimension++; - } - if (nextInDimension < indexes.length) - indexes[nextInDimension]++; - else // there is no next - become invalid - indexes[0]++; - - return cellIndex; - } - /** Returns the position of the next character that should contain a number, or if none the string length */ private static int nextStartCharIndex(int charIndex, String valueString) { for (; charIndex < valueString.length(); charIndex++) { @@ -187,8 +167,21 @@ class TensorParser { } TensorAddress address = addressBuilder.build(); - Double value = asDouble(address, s.substring(index, valueEnd).trim()); - builder.cell(address, value); + TensorType.Value cellValueType = builder.type().valueType(); + String cellValueString = s.substring(index, valueEnd).trim(); + try { + if (cellValueType == TensorType.Value.DOUBLE) + builder.cell(address, Double.parseDouble(cellValueString)); + else if (cellValueType == TensorType.Value.FLOAT) + builder.cell(address, Float.parseFloat(cellValueString)); + else + throw new IllegalArgumentException(cellValueType + " is not supported"); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" + + cellValueString + "' is not a valid " + cellValueType); + } + index = valueEnd+1; index = skipSpace(index, s); } @@ -220,13 +213,4 @@ class TensorParser { } } - private static Double asDouble(TensorAddress address, String s) { - try { - return Double.valueOf(s); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("At " + address + ": Expected a floating point number, got '" + s + "'"); - } - } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b1c7a2341c0..8e566fac0b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -48,6 +48,9 @@ public class TensorType { return FLOAT; } + @Override + public String toString() { return name().toLowerCase(); } + }; /** The empty tensor type - which is the same as a double */ diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index 313cca833f1..63fe40565bd 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -31,36 +31,36 @@ public class TensorParserTestCase { Tensor.from("tensor(x[2]):[1.0, 2.0]")); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])")) .cell(1.0, 0, 0) - .cell(2.0, 1, 0) - .cell(3.0, 0, 1) - .cell(4.0, 1, 1) - .cell(5.0, 0, 2) + .cell(2.0, 0, 1) + .cell(3.0, 0, 2) + .cell(4.0, 1, 0) + .cell(5.0, 1, 1) .cell(6.0, 1, 2).build(), Tensor.from("tensor(x[2],y[3]):[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]]")); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1],y[2],z[3])")) .cell(1.0, 0, 0, 0) - .cell(2.0, 0, 1, 0) - .cell(3.0, 0, 0, 1) - .cell(4.0, 0, 1, 1) - .cell(5.0, 0, 0, 2) + .cell(2.0, 0, 0, 1) + .cell(3.0, 0, 0, 2) + .cell(4.0, 0, 1, 0) + .cell(5.0, 0, 1, 1) .cell(6.0, 0, 1, 2).build(), Tensor.from("tensor(x[1],y[2],z[3]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [6.0]]]")); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])")) .cell(1.0, 0, 0, 0) - .cell(2.0, 1, 0, 0) - .cell(3.0, 2, 0, 0) - .cell(4.0, 0, 1, 0) - .cell(5.0, 1, 1, 0) + .cell(2.0, 0, 1, 0) + .cell(3.0, 1, 0, 0) + .cell(4.0, 1, 1, 0) + .cell(5.0, 2, 0, 0) .cell(6.0, 2, 1, 0).build(), Tensor.from("tensor(x[3],y[2],z[1]):[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]")); assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])")) - .cell(1.0, 0, 0, 0) - .cell(2.0, 1, 0, 0) - .cell(3.0, 2, 0, 0) - .cell(4.0, 0, 1, 0) - .cell(5.0, 1, 1, 0) - .cell(6.0, 2, 1, 0).build(), - Tensor.from("tensor( x[3],y[2],z[1]) : [ [ [1.0, 2.0, 3.0] , [4.0, 5,6.0] ] ]")); + .cell( 1.0, 0, 0, 0) + .cell( 2.0, 0, 1, 0) + .cell( 3.0, 1, 0, 0) + .cell( 4.0, 1, 1, 0) + .cell( 5.0, 2, 0, 0) + .cell(-6.0, 2, 1, 0).build(), + Tensor.from("tensor( x[3],y[2],z[1]) : [ [ [1.0, 2.0, 3.0] , [4.0, 5,-6.0] ] ]")); } @Test @@ -71,6 +71,10 @@ public class TensorParserTestCase { "{{'x':\"l0\"}:1.0}"); assertIllegal("dimension must be an identifier or integer, not '\"x\"'", "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}"); + assertIllegal("At {x:0}: '1-.0' is not a valid double", + "{{x:0}:1-.0}"); + assertIllegal("At index 0: '1-.0' is not a valid double", + "tensor(x[1]):[1-.0]"); } private void assertIllegal(String message, String tensor) { |