diff options
Diffstat (limited to 'document/src')
4 files changed, 62 insertions, 10 deletions
diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java index 104d63cae96..fd7ccfc5e96 100644 --- a/document/src/main/java/com/yahoo/document/DataType.java +++ b/document/src/main/java/com/yahoo/document/DataType.java @@ -54,7 +54,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com public final static NumericDataType BYTE = new NumericDataType("byte", 16, ByteFieldValue.class, ByteFieldValue.getFactory()); public final static PrimitiveDataType PREDICATE = new PrimitiveDataType("predicate", 20, PredicateFieldValue.class, PredicateFieldValue.getFactory()); public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately - // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor + // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor // Tags are converted to weightedset<string> when reading the search definition TODO: Remove it public final static WeightedSetDataType TAG = new WeightedSetDataType(DataType.STRING, true, true); diff --git a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java index 88353139b0f..9db80f3972b 100644 --- a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java +++ b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java @@ -29,7 +29,7 @@ public class TokenBuffer { } } - private Deque<Token> buffer; + private final Deque<Token> buffer; private int nesting = 0; public TokenBuffer() { diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java index ad016a40fca..27426f584bd 100644 --- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java +++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java @@ -11,6 +11,7 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import static com.yahoo.document.json.readers.JsonParserHelpers.*; +import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString; /** * Reads the tensor format defined at @@ -41,7 +42,7 @@ public class TensorReader { else if (TENSOR_BLOCKS.equals(buffer.currentName())) readTensorBlocks(buffer, builder); else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty - throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks'"); + throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks', but got: "+buffer.currentName()); } expectObjectEnd(buffer.currentToken()); tensorFieldValue.assign(builder.build()); @@ -91,10 +92,18 @@ public class TensorReader { throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " + "Use 'cells' or 'blocks' instead"); IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder; + if (buffer.currentToken() == JsonToken.VALUE_STRING) { + double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType()); + for (int i = 0; i < decoded.length; i++) { + indexedBuilder.cellByDirectIndex(i, decoded[i]); + } + return; + } int index = 0; int initNesting = buffer.nesting(); - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) { indexedBuilder.cellByDirectIndex(index++, readDouble(buffer)); + } expectCompositeEnd(buffer.currentToken()); } @@ -167,17 +176,21 @@ public class TensorReader { * @return the values read */ private static double[] readValues(TokenBuffer buffer, int size, TensorAddress address, TensorType type) { - expectArrayStart(buffer.currentToken()); - int index = 0; - int initNesting = buffer.nesting(); double[] values = new double[size]; - for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) - values[index++] = readDouble(buffer); + if (buffer.currentToken() == JsonToken.VALUE_STRING) { + values = decodeHexString(buffer.currentText(), type.valueType()); + index = values.length; + } else { + expectArrayStart(buffer.currentToken()); + int initNesting = buffer.nesting(); + for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) + values[index++] = readDouble(buffer); + expectCompositeEnd(buffer.currentToken()); + } if (index != size) throw new IllegalArgumentException((address != null ? "At " + address.toString(type) + ": " : "") + "Expected " + size + " values, but got " + index); - expectCompositeEnd(buffer.currentToken()); return values; } diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java index da9ab4ea7bf..e50fd9734f7 100644 --- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java +++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java @@ -164,10 +164,14 @@ public class JsonReaderTestCase { new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build()))); x.addField(new Field("dense_tensor", new TensorDataType(new TensorType.Builder().indexed("x", 2).indexed("y", 3).build()))); + x.addField(new Field("dense_int8_tensor", + new TensorDataType(TensorType.fromSpec("tensor<int8>(x[2],y[3])")))); x.addField(new Field("dense_unbound_tensor", new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build()))); x.addField(new Field("mixed_tensor", new TensorDataType(new TensorType.Builder().mapped("x").indexed("y", 3).build()))); + x.addField(new Field("mixed_bfloat16_tensor", + new TensorDataType(TensorType.fromSpec("tensor<bfloat16>(x{},y[3])")))); x.addField(new Field("mixed_tensor_adv", new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").mapped("z").indexed("a", 3).build()))); types.registerDocumentType(x); @@ -1324,6 +1328,41 @@ public class JsonReaderTestCase { } @Test + public void testParsingOfDenseTensorHexFormat() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 0).label("y", 2).value(4.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(6.0); + builder.cell().label("x", 1).label("y", 2).value(7.0); + Tensor expected = builder.build(); + Tensor tensor = assertTensorField(expected, + createPutWithTensor(inputJson("{", + " 'values': \"020304050607\"", + "}"), "dense_int8_tensor"), "dense_int8_tensor"); + assertTrue(tensor instanceof IndexedTensor); // this matters for performance + } + + @Test + public void testParsingOfMixedTensorHexFormat() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x{},y[3])")); + builder.cell().label("x", "foo").label("y", 0).value(2.0); + builder.cell().label("x", "foo").label("y", 1).value(3.0); + builder.cell().label("x", "foo").label("y", 2).value(4.0); + builder.cell().label("x", "bar").label("y", 0).value(5.0); + builder.cell().label("x", "bar").label("y", 1).value(6.0); + builder.cell().label("x", "bar").label("y", 2).value(7.0); + Tensor expected = builder.build(); + String mixedJson = "{\"blocks\":[" + + "{\"address\":{\"x\":\"foo\"},\"values\":\"400040404080\"}," + + "{\"address\":{\"x\":\"bar\"},\"values\":\"40A040C040E0\"}" + + "]}"; + var put = createPutWithTensor(inputJson(mixedJson), "mixed_bfloat16_tensor"); + Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor"); + } + + @Test public void testParsingOfMixedTensorOnMixedForm() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])")); builder.cell().label("x", 0).label("y", 0).value(2.0); |