summaryrefslogtreecommitdiffstats
path: root/document/src
diff options
context:
space:
mode:
Diffstat (limited to 'document/src')
-rw-r--r--document/src/main/java/com/yahoo/document/DataType.java2
-rw-r--r--document/src/main/java/com/yahoo/document/json/TokenBuffer.java2
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java29
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java39
4 files changed, 62 insertions, 10 deletions
diff --git a/document/src/main/java/com/yahoo/document/DataType.java b/document/src/main/java/com/yahoo/document/DataType.java
index 104d63cae96..fd7ccfc5e96 100644
--- a/document/src/main/java/com/yahoo/document/DataType.java
+++ b/document/src/main/java/com/yahoo/document/DataType.java
@@ -54,7 +54,7 @@ public abstract class DataType extends Identifiable implements Serializable, Com
public final static NumericDataType BYTE = new NumericDataType("byte", 16, ByteFieldValue.class, ByteFieldValue.getFactory());
public final static PrimitiveDataType PREDICATE = new PrimitiveDataType("predicate", 20, PredicateFieldValue.class, PredicateFieldValue.getFactory());
public final static int tensorDataTypeCode = 21; // All TensorDataType instances have id=21 but carries additional type information serialized separately
- // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor
+ // ADDITIONAL parametrized types added at runtime: map, struct, array, weighted set, annotation reference, tensor
// Tags are converted to weightedset<string> when reading the search definition TODO: Remove it
public final static WeightedSetDataType TAG = new WeightedSetDataType(DataType.STRING, true, true);
diff --git a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java
index 88353139b0f..9db80f3972b 100644
--- a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java
+++ b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java
@@ -29,7 +29,7 @@ public class TokenBuffer {
}
}
- private Deque<Token> buffer;
+ private final Deque<Token> buffer;
private int nesting = 0;
public TokenBuffer() {
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index ad016a40fca..27426f584bd 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -11,6 +11,7 @@ import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import static com.yahoo.document.json.readers.JsonParserHelpers.*;
+import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString;
/**
* Reads the tensor format defined at
@@ -41,7 +42,7 @@ public class TensorReader {
else if (TENSOR_BLOCKS.equals(buffer.currentName()))
readTensorBlocks(buffer, builder);
else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
- throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks'");
+ throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks', but got: "+buffer.currentName());
}
expectObjectEnd(buffer.currentToken());
tensorFieldValue.assign(builder.build());
@@ -91,10 +92,18 @@ public class TensorReader {
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
"Use 'cells' or 'blocks' instead");
IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
+ if (buffer.currentToken() == JsonToken.VALUE_STRING) {
+ double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType());
+ for (int i = 0; i < decoded.length; i++) {
+ indexedBuilder.cellByDirectIndex(i, decoded[i]);
+ }
+ return;
+ }
int index = 0;
int initNesting = buffer.nesting();
- for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
indexedBuilder.cellByDirectIndex(index++, readDouble(buffer));
+ }
expectCompositeEnd(buffer.currentToken());
}
@@ -167,17 +176,21 @@ public class TensorReader {
* @return the values read
*/
private static double[] readValues(TokenBuffer buffer, int size, TensorAddress address, TensorType type) {
- expectArrayStart(buffer.currentToken());
-
int index = 0;
- int initNesting = buffer.nesting();
double[] values = new double[size];
- for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
- values[index++] = readDouble(buffer);
+ if (buffer.currentToken() == JsonToken.VALUE_STRING) {
+ values = decodeHexString(buffer.currentText(), type.valueType());
+ index = values.length;
+ } else {
+ expectArrayStart(buffer.currentToken());
+ int initNesting = buffer.nesting();
+ for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
+ values[index++] = readDouble(buffer);
+ expectCompositeEnd(buffer.currentToken());
+ }
if (index != size)
throw new IllegalArgumentException((address != null ? "At " + address.toString(type) + ": " : "") +
"Expected " + size + " values, but got " + index);
- expectCompositeEnd(buffer.currentToken());
return values;
}
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index da9ab4ea7bf..e50fd9734f7 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -164,10 +164,14 @@ public class JsonReaderTestCase {
new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").build())));
x.addField(new Field("dense_tensor",
new TensorDataType(new TensorType.Builder().indexed("x", 2).indexed("y", 3).build())));
+ x.addField(new Field("dense_int8_tensor",
+ new TensorDataType(TensorType.fromSpec("tensor<int8>(x[2],y[3])"))));
x.addField(new Field("dense_unbound_tensor",
new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build())));
x.addField(new Field("mixed_tensor",
new TensorDataType(new TensorType.Builder().mapped("x").indexed("y", 3).build())));
+ x.addField(new Field("mixed_bfloat16_tensor",
+ new TensorDataType(TensorType.fromSpec("tensor<bfloat16>(x{},y[3])"))));
x.addField(new Field("mixed_tensor_adv",
new TensorDataType(new TensorType.Builder().mapped("x").mapped("y").mapped("z").indexed("a", 3).build())));
types.registerDocumentType(x);
@@ -1324,6 +1328,41 @@ public class JsonReaderTestCase {
}
@Test
+ public void testParsingOfDenseTensorHexFormat() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[2],y[3])"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 0).label("y", 2).value(4.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(6.0);
+ builder.cell().label("x", 1).label("y", 2).value(7.0);
+ Tensor expected = builder.build();
+ Tensor tensor = assertTensorField(expected,
+ createPutWithTensor(inputJson("{",
+ " 'values': \"020304050607\"",
+ "}"), "dense_int8_tensor"), "dense_int8_tensor");
+ assertTrue(tensor instanceof IndexedTensor); // this matters for performance
+ }
+
+ @Test
+ public void testParsingOfMixedTensorHexFormat() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x{},y[3])"));
+ builder.cell().label("x", "foo").label("y", 0).value(2.0);
+ builder.cell().label("x", "foo").label("y", 1).value(3.0);
+ builder.cell().label("x", "foo").label("y", 2).value(4.0);
+ builder.cell().label("x", "bar").label("y", 0).value(5.0);
+ builder.cell().label("x", "bar").label("y", 1).value(6.0);
+ builder.cell().label("x", "bar").label("y", 2).value(7.0);
+ Tensor expected = builder.build();
+ String mixedJson = "{\"blocks\":[" +
+ "{\"address\":{\"x\":\"foo\"},\"values\":\"400040404080\"}," +
+ "{\"address\":{\"x\":\"bar\"},\"values\":\"40A040C040E0\"}" +
+ "]}";
+ var put = createPutWithTensor(inputJson(mixedJson), "mixed_bfloat16_tensor");
+ Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor");
+ }
+
+ @Test
public void testParsingOfMixedTensorOnMixedForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])"));
builder.cell().label("x", 0).label("y", 0).value(2.0);