diff options
author | Arne Juul <arnej@vespa.ai> | 2024-06-19 12:49:29 +0000 |
---|---|---|
committer | Arne Juul <arnej@vespa.ai> | 2024-06-19 12:49:29 +0000 |
commit | 2b4e4cffd59f2c06a9e6d402cd90c27d96917a97 (patch) | |
tree | 4c15ed853da4547eea27f2898ce45a8faac4098e | |
parent | 1f0e68e758e3779aab26b8389b142acf20239406 (diff) |
accept just a hex string for dense tensors
3 files changed, 43 insertions, 1 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java b/document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java index c6eccdacf26..465d7da5e8e 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java @@ -6,7 +6,9 @@ import com.yahoo.document.DataType; import com.yahoo.document.DocumentId; import com.yahoo.document.PositionDataType; import com.yahoo.document.ReferenceDataType; +import com.yahoo.document.TensorDataType; import com.yahoo.document.datatypes.FieldValue; +import com.yahoo.document.datatypes.TensorFieldValue; import com.yahoo.document.json.TokenBuffer; import com.yahoo.document.update.ValueUpdate; @@ -41,6 +43,11 @@ public class SingleValueReader { } public static FieldValue readSingleValue(TokenBuffer buffer, DataType expectedType, boolean ignoreUndefinedFields) { + if (expectedType instanceof TensorDataType) { + FieldValue fieldValue = expectedType.createFieldValue(); + TensorReader.fillTensor(buffer, (TensorFieldValue) fieldValue); + return fieldValue; + } if (buffer.current().isScalarValue()) { return readAtomic(buffer.currentText(), expectedType); } else { diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index 3aa6dc96e56..82a67c08935 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -37,6 +37,18 @@ public class TensorReader { // MUST be kept in sync with com.yahoo.tensor.serialization.JsonFormat.decode in vespajlib static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) { Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType()); + if (buffer.current() == JsonToken.VALUE_STRING + && builder instanceof IndexedTensor.BoundBuilder indexedBuilder) + { + double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType()); + if (decoded.length == 0) + throw new IllegalArgumentException("Bad string input for tensor"); + for (int i = 0; i < decoded.length; i++) { + indexedBuilder.cellByDirectIndex(i, decoded[i]); + } + tensorFieldValue.assign(builder.build()); + return; + } expectOneOf(buffer.current(), JsonToken.START_OBJECT, JsonToken.START_ARRAY); int initNesting = buffer.nesting(); while (true) { diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index e72d3720024..2ab7365ea20 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -175,6 +175,8 @@ public class JsonReaderTestCase { new TensorDataType(new TensorType.Builder().indexed("x", 2).indexed("y", 3).build()))); x.addField(new Field("dense_int8_tensor", new TensorDataType(TensorType.fromSpec("tensor<int8>(x[2],y[3])")))); + x.addField(new Field("dense_float_tensor", + new TensorDataType(TensorType.fromSpec("tensor<float>(y[3])")))); x.addField(new Field("dense_unbound_tensor", new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build()))); x.addField(new Field("mixed_tensor", @@ -1780,7 +1782,7 @@ public class JsonReaderTestCase { "remove": "id:unittest:smoke::whee", "what is love": "baby, do not hurt me... much } - ]"""; + ]"""; // " new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next(); } @@ -1996,6 +1998,20 @@ public class JsonReaderTestCase { "values": "020304050607" }""", "dense_int8_tensor"), "dense_int8_tensor"); assertTrue(tensor instanceof IndexedTensor); // this matters for performance + tensor = assertTensorField(expected, + createPutWithTensor(""" + "020304050607" + """, "dense_int8_tensor"), "dense_int8_tensor"); + assertTrue(tensor instanceof IndexedTensor); // this matters for performance + builder = Tensor.Builder.of(TensorType.fromSpec("tensor<float>(y[3])")); + builder.cell().label("y", 0).value(42.0); + builder.cell().label("y", 1).value(-0.125); + builder.cell().label("y", 2).value(Double.POSITIVE_INFINITY); + expected = builder.build(); + tensor = assertTensorField(expected, + createPutWithTensor(""" + "42280000be0000007f800000" + """, "dense_float_tensor"), "dense_float_tensor"); } @Test @@ -2018,6 +2034,13 @@ public class JsonReaderTestCase { """; var put = createPutWithTensor(inputJson(mixedJson), "mixed_bfloat16_tensor"); Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor"); + mixedJson = """ + { + "blocks":{"foo":"400040404080", "bar":"40A040C040E0"} + } + """; + put = createPutWithTensor(inputJson(mixedJson), "mixed_bfloat16_tensor"); + tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor"); } /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */ |