diff options
3 files changed, 33 insertions, 3 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index ad016a40fca..c6722322982 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import static com.yahoo.document.json.readers.JsonParserHelpers.*; +import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString; /** * Reads the tensor format defined at @@ -41,7 +42,7 @@ public class TensorReader { else if (TENSOR_BLOCKS.equals(buffer.currentName())) readTensorBlocks(buffer, builder); else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks'"); + throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks', but got: "+buffer.currentName()); } expectObjectEnd(buffer.currentToken()); tensorFieldValue.assign(builder.build()); @@ -91,10 +92,18 @@ public class TensorReader { throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; + if (buffer.currentToken() == JsonToken.VALUE_STRING) { + double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType()); + for (int i = 0; i < decoded.length; i++) { + indexedBuilder.cellByDirectIndex(i, decoded[i]); + } + return; + } int index = 0; int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { indexedBuilder.cellByDirectIndex(index++, readDouble(buffer)); + } expectCompositeEnd(buffer.currentToken()); } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index da9ab4ea7bf..4fea220b2e8 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -164,6 +164,8 @@ public class JsonReaderTestCase { new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build()))); x.addField(new Field("dense_tensor", new TensorDataType(new TensorType.Builder().indexed("x", 2).indexed("y", 3).build()))); + x.addField(new Field("dense_int8_tensor", + new TensorDataType(TensorType.fromSpec("tensor<int8>(x[2],y[3])")))); x.addField(new Field("dense_unbound_tensor", new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build()))); x.addField(new Field("mixed_tensor", @@ -1324,6 +1326,25 @@ public class JsonReaderTestCase { } @Test + public void testParsingOfDenseTensorHexFormat() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + + Tensor tensor = assertTensorField(expected, + createPutWithTensor(inputJson("{", + " 'values': \"020304050607\"", + "}"), "dense_int8_tensor"), "dense_int8_tensor"); + assertTrue(tensor instanceof IndexedTensor); // this matters for performance + } + + + @Test public void testParsingOfMixedTensorOnMixedForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])")); builder.cell().label("x", 0).label("y", 0).value(2.0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 9eb9cb06666..461e73e3611 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -216,7 +216,7 @@ public class JsonFormat { return result; } - private static double[] decodeHexString(String input, TensorType.Value valueType) { + public static double[] decodeHexString(String input, TensorType.Value valueType) { switch(valueType) { case INT8: return decodeHexStringAsBytes(input); |