summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java96
1 files changed, 91 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index cb7539d8565..87157495485 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -11,12 +11,21 @@ import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.functions.ConstantTensor;
+import com.yahoo.tensor.functions.Slice;
+import java.util.ArrayList;
+import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
@@ -46,12 +55,33 @@ public class JsonFormat {
}
/** Serializes the given tensor type and value into a short-form JSON format */
- public static byte[] encodeShortForm(IndexedTensor tensor) {
+ public static byte[] encodeShortForm(Tensor tensor) {
Slime slime = new Slime();
Cursor root = slime.setObject();
root.setString("type", tensor.type().toString());
- Cursor value = root.setArray("value");
- encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0);
+
+ // Encode as nested lists if indexed tensor
+ if (tensor instanceof IndexedTensor) {
+ IndexedTensor denseTensor = (IndexedTensor) tensor;
+ encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0);
+ }
+
+ // Short form for a single mapped dimension
+ else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) {
+ encodeSingleDimensionCells((MappedTensor) tensor, root);
+ }
+
+ // Short form for a mixed tensor
+ else if (tensor instanceof MixedTensor &&
+ tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) {
+ encodeBlocks((MixedTensor) tensor, root);
+ }
+
+ // No other short forms exist: default to standard cell address output
+ else {
+ encodeCells(tensor, root);
+ }
+
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
}
@@ -65,22 +95,78 @@ public class JsonFormat {
}
}
+ private static void encodeSingleDimensionCells(MappedTensor tensor, Cursor cursor) {
+ Cursor cells = cursor.setObject("cells");
+ if (tensor.type().dimensions().size() > 1)
+ throw new IllegalStateException("JSON encode of mapped tensor can only contain a single dimension");
+ tensor.cells().forEach((k,v) -> cells.setDouble(k.label(0), v));
+ }
+
private static void encodeAddress(TensorType type, TensorAddress address, Cursor addressObject) {
for (int i = 0; i < address.size(); i++)
addressObject.setString(type.dimensions().get(i).name(), address.label(i));
}
- private static void encodeList(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
+ private static void encodeValues(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
DimensionSizes sizes = tensor.dimensionSizes();
for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) {
if (dimension < (sizes.dimensions() - 1)) {
- encodeList(tensor, cursor.addArray(), indexes, dimension + 1);
+ encodeValues(tensor, cursor.addArray(), indexes, dimension + 1);
} else {
cursor.addDouble(tensor.get(indexes));
}
}
}
+ private static void encodeBlocks(MixedTensor tensor, Cursor cursor) {
+ var mappedDimensions = tensor.type().dimensions().stream().filter(d -> d.isMapped())
+ .map(d -> TensorType.Dimension.mapped(d.name())).collect(Collectors.toList());
+ if (mappedDimensions.size() < 1) {
+ throw new IllegalArgumentException("Should be ensured by caller");
+ }
+ cursor = (mappedDimensions.size() == 1) ? cursor.setObject("blocks") : cursor.setArray("blocks");
+
+ // Create tensor type for mapped dimensions subtype
+ TensorType mappedSubType = new TensorType.Builder(mappedDimensions).build();
+
+ // Find all unique indices for the mapped dimensions
+ Set<TensorAddress> denseSubSpaceAddresses = new HashSet<>();
+ tensor.cellIterator().forEachRemaining((cell) -> {
+ denseSubSpaceAddresses.add(subAddress(cell.getKey(), mappedSubType, tensor.type()));
+ });
+
+ // Slice out dense subspace of each and encode dense subspace as a list
+ for (TensorAddress denseSubSpaceAddress : denseSubSpaceAddresses) {
+ IndexedTensor denseSubspace = (IndexedTensor) sliceSubAddress(tensor, denseSubSpaceAddress, mappedSubType);
+
+ if (mappedDimensions.size() == 1) {
+ encodeValues(denseSubspace, cursor.setArray(denseSubSpaceAddress.label(0)), new long[denseSubspace.dimensionSizes().dimensions()], 0);
+ } else {
+ Cursor block = cursor.addObject();
+ encodeAddress(mappedSubType, denseSubSpaceAddress, block.setObject("address"));
+ encodeValues(denseSubspace, block.setArray("values"), new long[denseSubspace.dimensionSizes().dimensions()], 0);
+ }
+
+ }
+ }
+
+ private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) {
+ TensorAddress.Builder builder = new TensorAddress.Builder(subType);
+ for (TensorType.Dimension dim : subType.dimensions()) {
+ builder.add(dim.name(), address.label(origType.indexOfDimension(dim.name()).
+ orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index"))));
+ }
+ return builder.build();
+ }
+
+ private static Tensor sliceSubAddress(Tensor tensor, TensorAddress subAddress, TensorType subType) {
+ List<Slice.DimensionValue<Name>> sliceDims = new ArrayList<>(subAddress.size());
+ for (int i = 0; i < subAddress.size(); ++i) {
+ sliceDims.add(new Slice.DimensionValue<>(subType.dimensions().get(i).name(), subAddress.label(i)));
+ }
+ return new Slice<>(new ConstantTensor<>(tensor), sliceDims).evaluate();
+ }
+
/** Deserializes the given tensor from JSON format */
// NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {