aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@vespa.ai>2024-06-19 12:49:29 +0000
committerArne Juul <arnej@vespa.ai>2024-06-19 12:49:29 +0000
commit2b4e4cffd59f2c06a9e6d402cd90c27d96917a97 (patch)
tree4c15ed853da4547eea27f2898ce45a8faac4098e
parent1f0e68e758e3779aab26b8389b142acf20239406 (diff)
accept just a hex string for dense tensors
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java7
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java12
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java25
3 files changed, 43 insertions, 1 deletions
diff --git a/document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java b/document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java
index c6eccdacf26..465d7da5e8e 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/SingleValueReader.java
@@ -6,7 +6,9 @@ import com.yahoo.document.DataType;
import com.yahoo.document.DocumentId;
import com.yahoo.document.PositionDataType;
import com.yahoo.document.ReferenceDataType;
+import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.FieldValue;
+import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
import com.yahoo.document.update.ValueUpdate;
@@ -41,6 +43,11 @@ public class SingleValueReader {
}
public static FieldValue readSingleValue(TokenBuffer buffer, DataType expectedType, boolean ignoreUndefinedFields) {
+ if (expectedType instanceof TensorDataType) {
+ FieldValue fieldValue = expectedType.createFieldValue();
+ TensorReader.fillTensor(buffer, (TensorFieldValue) fieldValue);
+ return fieldValue;
+ }
if (buffer.current().isScalarValue()) {
return readAtomic(buffer.currentText(), expectedType);
} else {
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index 3aa6dc96e56..82a67c08935 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -37,6 +37,18 @@ public class TensorReader {
// MUST be kept in sync with com.yahoo.tensor.serialization.JsonFormat.decode in vespajlib
static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) {
Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
+ if (buffer.current() == JsonToken.VALUE_STRING
+ && builder instanceof IndexedTensor.BoundBuilder indexedBuilder)
+ {
+ double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType());
+ if (decoded.length == 0)
+ throw new IllegalArgumentException("Bad string input for tensor");
+ for (int i = 0; i < decoded.length; i++) {
+ indexedBuilder.cellByDirectIndex(i, decoded[i]);
+ }
+ tensorFieldValue.assign(builder.build());
+ return;
+ }
expectOneOf(buffer.current(), JsonToken.START_OBJECT, JsonToken.START_ARRAY);
int initNesting = buffer.nesting();
while (true) {
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index e72d3720024..2ab7365ea20 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -175,6 +175,8 @@ public class JsonReaderTestCase {
new TensorDataType(new TensorType.Builder().indexed("x", 2).indexed("y", 3).build())));
x.addField(new Field("dense_int8_tensor",
new TensorDataType(TensorType.fromSpec("tensor<int8>(x[2],y[3])"))));
+ x.addField(new Field("dense_float_tensor",
+ new TensorDataType(TensorType.fromSpec("tensor<float>(y[3])"))));
x.addField(new Field("dense_unbound_tensor",
new TensorDataType(new TensorType.Builder().indexed("x").indexed("y").build())));
x.addField(new Field("mixed_tensor",
@@ -1780,7 +1782,7 @@ public class JsonReaderTestCase {
"remove": "id:unittest:smoke::whee",
"what is love": "baby, do not hurt me... much
}
- ]""";
+ ]"""; // "
new JsonReader(types, jsonToInputStream(jsonData), parserFactory).next();
}
@@ -1996,6 +1998,20 @@ public class JsonReaderTestCase {
"values": "020304050607"
}""", "dense_int8_tensor"), "dense_int8_tensor");
assertTrue(tensor instanceof IndexedTensor); // this matters for performance
+ tensor = assertTensorField(expected,
+ createPutWithTensor("""
+ "020304050607"
+ """, "dense_int8_tensor"), "dense_int8_tensor");
+ assertTrue(tensor instanceof IndexedTensor); // this matters for performance
+ builder = Tensor.Builder.of(TensorType.fromSpec("tensor<float>(y[3])"));
+ builder.cell().label("y", 0).value(42.0);
+ builder.cell().label("y", 1).value(-0.125);
+ builder.cell().label("y", 2).value(Double.POSITIVE_INFINITY);
+ expected = builder.build();
+ tensor = assertTensorField(expected,
+ createPutWithTensor("""
+ "42280000be0000007f800000"
+ """, "dense_float_tensor"), "dense_float_tensor");
}
@Test
@@ -2018,6 +2034,13 @@ public class JsonReaderTestCase {
""";
var put = createPutWithTensor(inputJson(mixedJson), "mixed_bfloat16_tensor");
Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor");
+ mixedJson = """
+ {
+ "blocks":{"foo":"400040404080", "bar":"40A040C040E0"}
+ }
+ """;
+ put = createPutWithTensor(inputJson(mixedJson), "mixed_bfloat16_tensor");
+ tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor");
}
/** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */