diff options
author | Lester Solbakken <lesters@oath.com> | 2021-09-27 13:45:10 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-09-27 13:45:10 +0200 |
commit | e02be90cd8ea302cb23444a7dd321c9ef774913a (patch) | |
tree | 609821fc539839867fb652c55709b424fa1127a5 /vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | |
parent | 9377da84086392e118d69b467006e73fe9ae3f70 (diff) |
Stateless REST API: short forms for sparse and mixed tensors
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 61 |
1 files changed, 58 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index cb7539d8565..bebd706f815 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -11,12 +11,19 @@ import com.yahoo.slime.Slime; import com.yahoo.slime.Type; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.Slice; +import java.util.HashSet; import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -46,12 +53,33 @@ public class JsonFormat { } /** Serializes the given tensor type and value into a short-form JSON format */ - public static byte[] encodeShortForm(IndexedTensor tensor) { + public static byte[] encodeShortForm(Tensor tensor) { Slime slime = new Slime(); Cursor root = slime.setObject(); root.setString("type", tensor.type().toString()); - Cursor value = root.setArray("value"); - encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0); + + // Encode as nested lists if indexed tensor + if (tensor instanceof IndexedTensor) { + IndexedTensor denseTensor = (IndexedTensor) tensor; + encodeList(denseTensor, root.setArray("value"), new long[denseTensor.dimensionSizes().dimensions()], 0); + } + + // Short form for a single mapped dimension + else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) { + encodeMap((MappedTensor) tensor, root.setObject("value")); + } + + // Short form for a mixed tensor with a single mapped dimension + else if (tensor instanceof MixedTensor && + tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1) { + encodeMapBlocks((MixedTensor) tensor, root.setObject("value")); + } + + // No other short forms exist: default to standard cell address output + else { + encodeCells(tensor, root.setObject("value")); + } + return com.yahoo.slime.JsonFormat.toJsonBytes(slime); } @@ -81,6 +109,33 @@ public class JsonFormat { } } + private static void encodeMap(MappedTensor tensor, Cursor cursor) { + if (tensor.type().dimensions().size() > 1) + throw new IllegalStateException("JSON encode of mapped tensor can only contain a single dimension"); + tensor.cells().forEach((k,v) -> cursor.setDouble(k.label(0), v)); + } + + private static void encodeMapBlocks(MixedTensor tensor, Cursor cursor) { + var mappedDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); + if (mappedDimensions.size() != 1) { + throw new IllegalArgumentException("Should be ensured by caller"); + } + String mappedDimensionName = mappedDimensions.get(0).name(); + int mappedDimensionIndex = tensor.type().indexOfDimension(mappedDimensionName). + orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index")); + + // Find all unique indices for the mapped dimension + Set<String> mappedIndices = new HashSet<>(); + tensor.cellIterator().forEachRemaining((cell) -> mappedIndices.add(cell.getKey().label(mappedDimensionIndex))); + + // Slice out dense subspace of each and encode dense subspace as a list + for (String mappedIndex : mappedIndices) { + IndexedTensor denseSubspace = (IndexedTensor) new Slice<>(new ConstantTensor<>(tensor), + List.of(new Slice.DimensionValue<>(mappedDimensionName, mappedIndex))).evaluate(); + encodeList(denseSubspace, cursor.setArray(mappedIndex), new long[denseSubspace.dimensionSizes().dimensions()], 0); + } + } + /** Deserializes the given tensor from JSON format */ // NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module public static Tensor decode(TensorType type, byte[] jsonTensorValue) { |