summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java2
-rw-r--r--document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java9
-rw-r--r--document/src/main/java/com/yahoo/document/json/TokenBuffer.java20
-rw-r--r--document/src/main/java/com/yahoo/document/json/document/DocumentParser.java36
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java7
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java58
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java67
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java50
9 files changed, 198 insertions, 86 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java
index 54c46b34163..ead6ad53715 100644
--- a/container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java
+++ b/container-search/src/main/java/com/yahoo/prelude/searcher/FieldCollapsingSearcher.java
@@ -89,7 +89,7 @@ public class FieldCollapsingSearcher extends Searcher {
if (collapseField == null) return execution.search(query);
- int collapseSize = query.properties().getInteger(collapsesize,defaultCollapseSize);
+ int collapseSize = query.properties().getInteger(collapsesize, defaultCollapseSize);
query.properties().set(collapse, "0");
int hitsToRequest = query.getHits() != 0 ? (int) Math.ceil((query.getOffset() + query.getHits() + 1) * extraFactor) : 0;
diff --git a/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java b/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java
index d33cc8078dd..7f6ead528fe 100644
--- a/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java
+++ b/document/src/main/java/com/yahoo/document/json/JsonSerializationHelper.java
@@ -48,6 +48,7 @@ import java.util.Set;
* @author Vegard Sjonfjell
*/
public class JsonSerializationHelper {
+
private final static Base64.Encoder base64Encoder = Base64.getEncoder(); // Important: _basic_ format
static class JsonSerializationException extends RuntimeException {
@@ -99,14 +100,6 @@ public class JsonSerializationHelper {
});
}
- private static void serializeTensorDimensions(JsonGenerator generator, Set<String> dimensions) throws IOException {
- generator.writeArrayFieldStart(TensorReader.TENSOR_DIMENSIONS);
- for (String dimension : dimensions) {
- generator.writeString(dimension);
- }
- generator.writeEndArray();
- }
-
static void serializeTensorCells(JsonGenerator generator, Tensor tensor) throws IOException {
generator.writeArrayFieldStart(TensorReader.TENSOR_CELLS);
for (Map.Entry<TensorAddress, Double> cell : tensor.cells().entrySet()) {
diff --git a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java
index 6b2bdbd53d8..e1214920296 100644
--- a/document/src/main/java/com/yahoo/document/json/TokenBuffer.java
+++ b/document/src/main/java/com/yahoo/document/json/TokenBuffer.java
@@ -18,13 +18,15 @@ import com.google.common.base.Preconditions;
public class TokenBuffer {
private final Deque<Token> buffer;
+
private int nesting = 0;
+ private Token previousToken;
public TokenBuffer() {
this(new ArrayDeque<>());
}
- private TokenBuffer(Deque<Token> buffer) {
+ public TokenBuffer(Deque<Token> buffer) {
this.buffer = buffer;
if (buffer.size() > 0) {
updateNesting(buffer.peekFirst().token);
@@ -35,7 +37,7 @@ public class TokenBuffer {
public boolean isEmpty() { return size() == 0; }
public JsonToken next() {
- buffer.removeFirst();
+ previousToken = buffer.removeFirst();
Token t = buffer.peekFirst();
if (t == null) {
return null;
@@ -44,6 +46,16 @@ public class TokenBuffer {
return t.token;
}
+ /** Goes one token back. Repeated calls to this method will *not* go back further. */
+ public JsonToken previous() {
+ if (previousToken == null) return null;
+ updateNestingGoingBackwards(currentToken());
+ Token newCurrent = previousToken;
+ previousToken = null;
+ buffer.push(newCurrent);
+ return newCurrent.token;
+ }
+
/** Returns the current token without changing position, or null if none */
public JsonToken currentToken() {
Token token = buffer.peekFirst();
@@ -130,6 +142,10 @@ public class TokenBuffer {
nesting += nestingOffset(t);
}
+ private void updateNestingGoingBackwards(JsonToken t) {
+ nesting -= nestingOffset(t);
+ }
+
public int nesting() {
return nesting;
}
diff --git a/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java b/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java
index 5e1c1eb6ac4..b63a39f51c5 100644
--- a/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java
+++ b/document/src/main/java/com/yahoo/document/json/document/DocumentParser.java
@@ -76,18 +76,10 @@ public class DocumentParser {
throw new IllegalArgumentException("Could not read document, no document?");
}
switch (currentToken) {
- case START_OBJECT:
- indentLevel++;
- break;
- case END_OBJECT:
- indentLevel--;
- break;
- case START_ARRAY:
- indentLevel += 10000L;
- break;
- case END_ARRAY:
- indentLevel -= 10000L;
- break;
+ case START_OBJECT -> indentLevel++;
+ case END_OBJECT -> indentLevel--;
+ case START_ARRAY -> indentLevel += 10000L;
+ case END_ARRAY -> indentLevel -= 10000L;
}
}
@@ -133,18 +125,12 @@ public class DocumentParser {
}
private static DocumentOperationType operationNameToOperationType(String operationName) {
- switch (operationName) {
- case PUT:
- case ID:
- return DocumentOperationType.PUT;
- case REMOVE:
- return DocumentOperationType.REMOVE;
- case UPDATE:
- return DocumentOperationType.UPDATE;
- default:
- throw new IllegalArgumentException(
- "Got " + operationName + " as document operation, only \"put\", " +
- "\"remove\" and \"update\" are supported.");
- }
+ return switch (operationName) {
+ case PUT, ID -> DocumentOperationType.PUT;
+ case REMOVE -> DocumentOperationType.REMOVE;
+ case UPDATE -> DocumentOperationType.UPDATE;
+ default -> throw new IllegalArgumentException("Got " + operationName + " as document operation, only \"put\", " +
+ "\"remove\" and \"update\" are supported.");
+ };
}
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java b/document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java
index 1723df2bd54..594dfc5ab06 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/JsonParserHelpers.java
@@ -5,6 +5,8 @@ package com.yahoo.document.json.readers;
import com.fasterxml.jackson.core.JsonToken;
import com.google.common.base.Preconditions;
+import java.util.Arrays;
+
public class JsonParserHelpers {
public static void expectArrayStart(JsonToken token) {
@@ -61,4 +63,9 @@ public class JsonParserHelpers {
}
}
+ public static void expectOneOf(JsonToken token, JsonToken ... tokens) {
+ if (Arrays.stream(tokens).noneMatch(t -> t == token))
+ throw new IllegalArgumentException("Expected one of " + tokens + " but got " + token);
+ }
+
}
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index 193c9491e86..4b2da912c37 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -4,12 +4,18 @@ package com.yahoo.document.json.readers;
import com.fasterxml.jackson.core.JsonToken;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
+import com.yahoo.document.select.parser.Token;
+import com.yahoo.slime.Inspector;
+import com.yahoo.slime.Type;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayDeque;
+import java.util.Deque;
+
import static com.yahoo.document.json.readers.JsonParserHelpers.*;
import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString;
@@ -23,7 +29,6 @@ import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString;
public class TensorReader {
public static final String TENSOR_ADDRESS = "address";
- public static final String TENSOR_DIMENSIONS = "dimensions";
public static final String TENSOR_CELLS = "cells";
public static final String TENSOR_VALUES = "values";
public static final String TENSOR_BLOCKS = "blocks";
@@ -32,22 +37,39 @@ public class TensorReader {
// MUST be kept in sync with com.yahoo.tensor.serialization.JsonFormat.decode in vespajlib
static void fillTensor(TokenBuffer buffer, TensorFieldValue tensorFieldValue) {
Tensor.Builder builder = Tensor.Builder.of(tensorFieldValue.getDataType().getTensorType());
- expectObjectStart(buffer.currentToken());
+ expectOneOf(buffer.currentToken(), JsonToken.START_OBJECT, JsonToken.START_ARRAY);
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
- if (TENSOR_CELLS.equals(buffer.currentName()))
+ if (TENSOR_CELLS.equals(buffer.currentName()) && ! primitiveContent(new TokenBuffer(new ArrayDeque<>(buffer.rest())))) {
readTensorCells(buffer, builder);
- else if (TENSOR_VALUES.equals(buffer.currentName()))
+ }
+ else if (TENSOR_VALUES.equals(buffer.currentName()) && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) {
readTensorValues(buffer, builder);
- else if (TENSOR_BLOCKS.equals(buffer.currentName()))
+ }
+ else if (TENSOR_BLOCKS.equals(buffer.currentName())) {
readTensorBlocks(buffer, builder);
- else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
- throw new IllegalArgumentException("Expected a tensor value to contain either 'cells', 'values' or 'blocks', but got: "+buffer.currentName());
+ }
+ else {
+ buffer.previous(); // Back up to the start of the enclosing block
+ readDirectTensorValue(buffer, builder);
+ buffer.previous(); // ... and back up to the end of the enclosing block
+ }
}
- expectObjectEnd(buffer.currentToken());
+ expectOneOf(buffer.currentToken(), JsonToken.END_OBJECT, JsonToken.END_ARRAY);
tensorFieldValue.assign(builder.build());
}
+ static boolean primitiveContent(TokenBuffer buffer) {
+ JsonToken cellsValue = buffer.currentToken();
+ if (cellsValue.isScalarValue()) return true;
+ if (cellsValue == JsonToken.START_ARRAY) {
+ JsonToken firstArrayValue = buffer.next();
+ if (firstArrayValue == JsonToken.END_ARRAY) return false;
+ if (firstArrayValue.isScalarValue()) return true;
+ }
+ return false;
+ }
+
static void readTensorCells(TokenBuffer buffer, Tensor.Builder builder) {
if (buffer.currentToken() == JsonToken.START_ARRAY) {
int initNesting = buffer.nesting();
@@ -88,10 +110,9 @@ public class TensorReader {
}
private static void readTensorValues(TokenBuffer buffer, Tensor.Builder builder) {
- if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder))
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
"Use 'cells' or 'blocks' instead");
- IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
if (buffer.currentToken() == JsonToken.VALUE_STRING) {
double[] decoded = decodeHexString(buffer.currentText(), builder.type().valueType());
if (decoded.length == 0)
@@ -112,11 +133,9 @@ public class TensorReader {
}
static void readTensorBlocks(TokenBuffer buffer, Tensor.Builder builder) {
- if ( ! (builder instanceof MixedTensor.BoundBuilder))
+ if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder))
throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " +
"Use 'cells' or 'values' instead");
-
- MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder;
if (buffer.currentToken() == JsonToken.START_ARRAY) {
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next())
@@ -160,6 +179,19 @@ public class TensorReader {
mixedBuilder.block(address, values);
}
+ /** Reads a tensor value directly at the root, where the format is decided by the tensor type. */
+ private static void readDirectTensorValue(TokenBuffer buffer, Tensor.Builder builder) {
+ boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
+ boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
+
+ if ( ! hasMapped)
+ readTensorValues(buffer, builder);
+ else if (hasMapped && hasIndexed)
+ readTensorBlocks(buffer, builder);
+ else
+ readTensorCells(buffer, builder);
+ }
+
private static TensorAddress readAddress(TokenBuffer buffer, TensorType type) {
expectObjectStart(buffer.currentToken());
TensorAddress.Builder builder = new TensorAddress.Builder(type);
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 41d607b0d8e..c19094ff231 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1403,12 +1403,6 @@ public class JsonReaderTestCase {
}
@Test
- public void testParsingOfTensorWithEmptyDimensions() {
- assertSparseTensorField("tensor(x{},y{}):{}",
- createPutWithSparseTensor(inputJson("{ 'dimensions': [] }")));
- }
-
- @Test
public void testParsingOfTensorWithEmptyCells() {
assertSparseTensorField("tensor(x{},y{}):{}",
createPutWithSparseTensor(inputJson("{ 'cells': [] }")));
@@ -1511,6 +1505,32 @@ public class JsonReaderTestCase {
Tensor tensor = assertTensorField(expected, put, "mixed_bfloat16_tensor");
}
+ /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */
+ @Test
+ public void testDirectValue() {
+ assertTensorField("tensor(x{}):{a:2, b:3}", "sparse_single_dimension_tensor", "{'a':2.0, 'b':3.0}");
+ assertTensorField("tensor(x[2],y[3]):[2, 3, 4, 5, 6, 7]]", "dense_tensor", "[2, 3, 4, 5, 6, 7]");
+ assertTensorField("tensor(x{},y[3]):{a:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor", "{'a':[2, 3, 4], 'b':[4, 5, 6]}");
+ assertTensorField("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}", "sparse_tensor",
+ "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]");
+ }
+
+ @Test
+ public void testDirectValueReservedNameKeys() {
+ // Single-valued
+ assertTensorField("tensor(x{}):{cells:2, b:3}", "sparse_single_dimension_tensor", "{'cells':2.0, 'b':3.0}");
+ assertTensorField("tensor(x{}):{values:2, b:3}", "sparse_single_dimension_tensor", "{'values':2.0, 'b':3.0}");
+ assertTensorField("tensor(x{}):{block:2, b:3}", "sparse_single_dimension_tensor", "{'block':2.0, 'b':3.0}");
+
+ // Multi-valued
+ assertTensorField("tensor(x{},y[3]):{cells:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor",
+ "{'cells':[2, 3, 4], 'b':[4, 5, 6]}");
+ assertTensorField("tensor(x{},y[3]):{values:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor",
+ "{'values':[2, 3, 4], 'b':[4, 5, 6]}");
+ assertTensorField("tensor(x{},y[3]):{block:[2, 3, 4], b:[4, 5, 6]}", "mixed_tensor",
+ "{'block':[2, 3, 4], 'b':[4, 5, 6]}");
+ }
+
@Test
public void testParsingOfMixedTensorOnMixedForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y[3])"));
@@ -2070,6 +2090,9 @@ public class JsonReaderTestCase {
private static Tensor assertSparseTensorField(String expectedTensor, DocumentPut put) {
return assertTensorField(expectedTensor, put, "sparse_tensor");
}
+ private Tensor assertTensorField(String expectedTensor, String fieldName, String inputJson) {
+ return assertTensorField(expectedTensor, createPutWithTensor(inputJson, fieldName), fieldName);
+ }
private static Tensor assertTensorField(String expectedTensor, DocumentPut put, String tensorFieldName) {
return assertTensorField(Tensor.from(expectedTensor), put, tensorFieldName);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index a7afc1efc6d..0e8fbc30bb6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -60,25 +60,21 @@ public class JsonFormat {
Cursor root = slime.setObject();
root.setString("type", tensor.type().toString());
- // Encode as nested lists if indexed tensor
- if (tensor instanceof IndexedTensor) {
- IndexedTensor denseTensor = (IndexedTensor) tensor;
+ if (tensor instanceof IndexedTensor denseTensor) {
+ // Encode as nested lists if indexed tensor
encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0);
}
-
- // Short form for a single mapped dimension
else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) {
+ // Short form for a single mapped dimension
encodeSingleDimensionCells((MappedTensor) tensor, root);
}
-
- // Short form for a mixed tensor
else if (tensor instanceof MixedTensor &&
tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) {
+ // Short form for a mixed tensor
encodeBlocks((MixedTensor) tensor, root);
}
-
- // No other short forms exist: default to standard cell address output
else {
+ // No other short forms exist: default to standard cell address output
encodeCells(tensor, root);
}
@@ -177,17 +173,25 @@ public class JsonFormat {
Tensor.Builder builder = Tensor.Builder.of(type);
Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
- if (root.field("cells").valid())
+ if (root.field("cells").valid() && ! primitiveContent(root.field("cells")))
decodeCells(root.field("cells"), builder);
- else if (root.field("values").valid())
+ else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed()))
decodeValues(root.field("values"), builder);
else if (root.field("blocks").valid())
decodeBlocks(root.field("blocks"), builder);
- else if (builder.type().dimensions().stream().anyMatch(d -> d.isIndexed())) // sparse can be empty
- throw new IllegalArgumentException("Expected a tensor value to contain either 'cells' or 'values' or 'blocks'");
+ else
+ decodeDirectValue(root, builder);
return builder.build();
}
+ private static boolean primitiveContent(Inspector cellsValue) {
+ if (cellsValue.type() == Type.DOUBLE) return true;
+ if (cellsValue.type() == Type.LONG) return true;
+ if (cellsValue.type() == Type.ARRAY && cellsValue.entries() > 0 &&
+ ( cellsValue.entry(0).type() == Type.DOUBLE || cellsValue.entry(0).type() == Type.LONG)) return true;
+ return false;
+ }
+
private static void decodeCells(Inspector cells, Tensor.Builder builder) {
if (cells.type() == Type.ARRAY)
cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, builder));
@@ -212,10 +216,9 @@ public class JsonFormat {
}
private static void decodeValues(Inspector values, Tensor.Builder builder) {
- if ( ! (builder instanceof IndexedTensor.BoundBuilder))
+ if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder))
throw new IllegalArgumentException("The 'values' field can only be used with dense tensors. " +
"Use 'cells' or 'blocks' instead");
- IndexedTensor.BoundBuilder indexedBuilder = (IndexedTensor.BoundBuilder)builder;
if (values.type() == Type.STRING) {
double[] decoded = decodeHexString(values.asString(), builder.type().valueType());
if (decoded.length == 0)
@@ -240,10 +243,9 @@ public class JsonFormat {
}
private static void decodeBlocks(Inspector values, Tensor.Builder builder) {
- if ( ! (builder instanceof MixedTensor.BoundBuilder))
+ if ( ! (builder instanceof MixedTensor.BoundBuilder mixedBuilder))
throw new IllegalArgumentException("The 'blocks' field can only be used with mixed tensors with bound dimensions. " +
"Use 'cells' or 'values' instead");
- MixedTensor.BoundBuilder mixedBuilder = (MixedTensor.BoundBuilder) builder;
if (values.type() == Type.ARRAY)
values.traverse((ArrayTraverser) (__, value) -> decodeBlock(value, mixedBuilder));
@@ -260,6 +262,19 @@ public class JsonFormat {
decodeValues(block.field("values"), mixedBuilder));
}
+ /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */
+ private static void decodeDirectValue(Inspector root, Tensor.Builder builder) {
+ boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed);
+ boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped);
+
+ if ( ! hasMapped)
+ decodeValues(root, builder);
+ else if (hasMapped && hasIndexed)
+ decodeBlocks(root, builder);
+ else
+ decodeCells(root, builder);
+ }
+
private static void decodeSingleDimensionBlock(String key, Inspector value, MixedTensor.BoundBuilder mixedBuilder) {
if (value.type() != Type.ARRAY)
throw new IllegalArgumentException("Expected an item in a 'blocks' array to be an array, not " + value.type());
@@ -334,18 +349,12 @@ public class JsonFormat {
}
public static double[] decodeHexString(String input, TensorType.Value valueType) {
- switch(valueType) {
- case INT8:
- return decodeHexStringAsBytes(input);
- case BFLOAT16:
- return decodeHexStringAsBFloat16s(input);
- case FLOAT:
- return decodeHexStringAsFloats(input);
- case DOUBLE:
- return decodeHexStringAsDoubles(input);
- default:
- throw new IllegalArgumentException("Cannot handle value type: "+valueType);
- }
+ return switch (valueType) {
+ case INT8 -> decodeHexStringAsBytes(input);
+ case BFLOAT16 -> decodeHexStringAsBFloat16s(input);
+ case FLOAT -> decodeHexStringAsFloats(input);
+ case DOUBLE -> decodeHexStringAsDoubles(input);
+ };
}
private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 1c884186879..f71a68ec5ed 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -1,7 +1,6 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
-import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -16,6 +15,30 @@ import static org.junit.Assert.fail;
*/
public class JsonFormatTestCase {
+ /** Tests parsing of various tensor values set at the root, i.e. no 'cells', 'blocks' or 'values' */
+ @Test
+ public void testDirectValue() {
+ assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{a:2, b:3}", "{'a':2.0, 'b':3.0}");
+ assertDecoded("tensor(x[2]):[2, 3]]", "[2.0, 3.0]");
+ assertDecoded("tensor(x{},y[2]):{a:[2, 3], b:[4, 5]}", "{'a':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y{}):{{x:a,y:0}:2, {x:b,y:1}:3}",
+ "[{'address':{'x':'a','y':'0'},'value':2}, {'address':{'x':'b','y':'1'},'value':3}]");
+ }
+
+ @Test
+ public void testDirectValueReservedNameKeys() {
+ // Single-valued
+ assertDecoded("tensor(x{}):{cells:2, b:3}", "{'cells':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{values:2, b:3}", "{'values':2.0, 'b':3.0}");
+ assertDecoded("tensor(x{}):{block:2, b:3}", "{'block':2.0, 'b':3.0}");
+
+ // Multi-valued
+ assertDecoded("tensor(x{},y[2]):{cells:[2, 3], b:[4, 5]}", "{'cells':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y[2]):{values:[2, 3], b:[4, 5]}", "{'values':[2, 3], 'b':[4, 5]}");
+ assertDecoded("tensor(x{},y[2]):{block:[2, 3], b:[4, 5]}", "{'block':[2, 3], 'b':[4, 5]}");
+ }
+
@Test
public void testSparseTensor() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
@@ -33,6 +56,21 @@ public class JsonFormatTestCase {
}
@Test
+ public void testEmptySparseTensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
+ Tensor tensor = builder.build();
+ byte[] json = JsonFormat.encode(tensor);
+ assertEquals("{\"cells\":[]}",
+ new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
+
+ json = "{}".getBytes(); // short form variant of the above
+ decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
+ }
+
+ @Test
public void testSingleSparseDimensionShortForm() {
Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{})"));
builder.cell().label("x", "a").value(2.0);
@@ -327,7 +365,7 @@ public class JsonFormatTestCase {
"]}";
try {
JsonFormat.decode(x2, json.getBytes(StandardCharsets.UTF_8));
- fail("Excpected exception");
+ fail("Expected exception");
}
catch (IllegalArgumentException e) {
assertEquals("cell address (2) is not within the bounds of tensor(x[2])", e.getMessage());
@@ -354,6 +392,14 @@ public class JsonFormatTestCase {
assertEquals(expected, new String(json, StandardCharsets.UTF_8));
}
+ private void assertDecoded(String expected, String jsonToDecode) {
+ assertDecoded(Tensor.from(expected), jsonToDecode);
+ }
+
+ private void assertDecoded(Tensor expected, String jsonToDecode) {
+ assertEquals(expected, JsonFormat.decode(expected.type(), jsonToDecode.getBytes(StandardCharsets.UTF_8)));
+ }
+
private void assertDecodeFails(TensorType type, String format, String msg) {
try {
Tensor decoded = JsonFormat.decode(type, format.getBytes(StandardCharsets.UTF_8));