diff options
Diffstat (limited to 'document/src/main/java/com/yahoo/document/json/readers/TensorReader.java')
-rw-r--r-- | document/src/main/java/com/yahoo/document/json/readers/TensorReader.java | 29 |
1 files changed, 21 insertions, 8 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index ad016a40fca..27426f584bd 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import static com.yahoo.document.json.readers.JsonParserHelpers.*; +import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString; /** * Reads the tensor format defined at @@ -41,7 +42,7 @@ public class TensorReader { else if (TENSOR_BLOCKS.equals(buffer.currentName())) readTensorBlocks(buffer, builder); else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks'"); + throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks', but got: "+buffer.currentName()); } expectObjectEnd(buffer.currentToken()); tensorFieldValue.assign(builder.build()); @@ -91,10 +92,18 @@ public class TensorReader { throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; + if (buffer.currentToken() == JsonToken.VALUE_STRING) { + double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType()); + for (int i = 0; i < decoded.length; i++) { + indexedBuilder.cellByDirectIndex(i, decoded[i]); + } + return; + } int index = 0; int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { indexedBuilder.cellByDirectIndex(index++, readDouble(buffer)); + } expectCompositeEnd(buffer.currentToken()); } @@ -167,17 +176,21 @@ public class TensorReader { * @return the values read */ private static double[] readValues(TokenBuffer buffer, int size, TensorAddress address, TensorType type) { - expectArrayStart(buffer.currentToken()); - int index = 0; - int initNesting = buffer.nesting(); double[] values = new double[size]; - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - values[index++] = readDouble(buffer); + if (buffer.currentToken() == JsonToken.VALUE_STRING) { + values = decodeHexString(buffer.currentText(), type.valueType()); + index = values.length; + } else { + expectArrayStart(buffer.currentToken()); + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + values[index++] = readDouble(buffer); + expectCompositeEnd(buffer.currentToken()); + } if (index != size) throw new IllegalArgumentException((address != null ? "At " + address.toString(type) + ": " : "") + "Expected " + size + " values, but got " + index); - expectCompositeEnd(buffer.currentToken()); return values; } |