diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2021-09-29 09:58:13 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-09-29 09:58:13 +0200 |
commit | 8923accf7e72d147d6d57185eecc4faf2b4adeb7 (patch) | |
tree | 0f856be32d11455e89547c98507a2a2d315e3225 /vespajlib/src/main | |
parent | a50c3b478de99e23ee5dd1af12efd3ace03d5b28 (diff) | |
parent | ac28a2c925e90d0b1c651d8019e113ae4aa5cad9 (diff) |
Merge pull request #19304 from vespa-engine/lesters/additional-short-forms-stateless-rest-api
Stateless REST API: short forms for sparse and mixed tensors
Diffstat (limited to 'vespajlib/src/main')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java | 2 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 96 |
2 files changed, 92 insertions, 6 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 71ed347219e..33dcd458980 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -91,7 +91,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return b.toString(); } - /** Returns a label as a string with approriate quoting/escaping when necessary */ + /** Returns a label as a string with appropriate quoting/escaping when necessary */ public static String labelToString(String label) { if (TensorType.labelMatcher.matches(label)) return label; // no quoting if (label.contains("'")) return "\"" + label + "\""; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index cb7539d8565..87157495485 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -11,12 +11,21 @@ import com.yahoo.slime.Slime; import com.yahoo.slime.Type; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.Slice; +import java.util.ArrayList; +import java.util.HashSet; import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -46,12 +55,33 @@ public class JsonFormat { } /** Serializes the given tensor type and value into a short-form JSON format */ - public static byte[] encodeShortForm(IndexedTensor tensor) { + public static byte[] encodeShortForm(Tensor tensor) { Slime slime = new Slime(); Cursor root = slime.setObject(); root.setString("type", tensor.type().toString()); - Cursor value = root.setArray("value"); - encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0); + + // Encode as nested lists if indexed tensor + if (tensor instanceof IndexedTensor) { + IndexedTensor denseTensor = (IndexedTensor) tensor; + encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0); + } + + // Short form for a single mapped dimension + else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) { + encodeSingleDimensionCells((MappedTensor) tensor, root); + } + + // Short form for a mixed tensor + else if (tensor instanceof MixedTensor && + tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) { + encodeBlocks((MixedTensor) tensor, root); + } + + // No other short forms exist: default to standard cell address output + else { + encodeCells(tensor, root); + } + return com.yahoo.slime.JsonFormat.toJsonBytes(slime); } @@ -65,22 +95,78 @@ public class JsonFormat { } } + private static void encodeSingleDimensionCells(MappedTensor tensor, Cursor cursor) { + Cursor cells = cursor.setObject("cells"); + if (tensor.type().dimensions().size() > 1) + throw new IllegalStateException("JSON encode of mapped tensor can only contain a single dimension"); + tensor.cells().forEach((k,v) -> cells.setDouble(k.label(0), v)); + } + private static void encodeAddress(TensorType type, TensorAddress address, Cursor addressObject) { for (int i = 0; i < address.size(); i++) addressObject.setString(type.dimensions().get(i).name(), address.label(i)); } - private static void encodeList(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) { + private static void encodeValues(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) { DimensionSizes sizes = tensor.dimensionSizes(); for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) { if (dimension < (sizes.dimensions() - 1)) { - encodeList(tensor, cursor.addArray(), indexes, dimension + 1); + encodeValues(tensor, cursor.addArray(), indexes, dimension + 1); } else { cursor.addDouble(tensor.get(indexes)); } } } + private static void encodeBlocks(MixedTensor tensor, Cursor cursor) { + var mappedDimensions = tensor.type().dimensions().stream().filter(d -> d.isMapped()) + .map(d -> TensorType.Dimension.mapped(d.name())).collect(Collectors.toList()); + if (mappedDimensions.size() < 1) { + throw new IllegalArgumentException("Should be ensured by caller"); + } + cursor = (mappedDimensions.size() == 1) ? cursor.setObject("blocks") : cursor.setArray("blocks"); + + // Create tensor type for mapped dimensions subtype + TensorType mappedSubType = new TensorType.Builder(mappedDimensions).build(); + + // Find all unique indices for the mapped dimensions + Set<TensorAddress> denseSubSpaceAddresses = new HashSet<>(); + tensor.cellIterator().forEachRemaining((cell) -> { + denseSubSpaceAddresses.add(subAddress(cell.getKey(), mappedSubType, tensor.type())); + }); + + // Slice out dense subspace of each and encode dense subspace as a list + for (TensorAddress denseSubSpaceAddress : denseSubSpaceAddresses) { + IndexedTensor denseSubspace = (IndexedTensor) sliceSubAddress(tensor, denseSubSpaceAddress, mappedSubType); + + if (mappedDimensions.size() == 1) { + encodeValues(denseSubspace, cursor.setArray(denseSubSpaceAddress.label(0)), new long[denseSubspace.dimensionSizes().dimensions()], 0); + } else { + Cursor block = cursor.addObject(); + encodeAddress(mappedSubType, denseSubSpaceAddress, block.setObject("address")); + encodeValues(denseSubspace, block.setArray("values"), new long[denseSubspace.dimensionSizes().dimensions()], 0); + } + + } + } + + private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) { + TensorAddress.Builder builder = new TensorAddress.Builder(subType); + for (TensorType.Dimension dim : subType.dimensions()) { + builder.add(dim.name(), address.label(origType.indexOfDimension(dim.name()). + orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index")))); + } + return builder.build(); + } + + private static Tensor sliceSubAddress(Tensor tensor, TensorAddress subAddress, TensorType subType) { + List<Slice.DimensionValue<Name>> sliceDims = new ArrayList<>(subAddress.size()); + for (int i = 0; i < subAddress.size(); ++i) { + sliceDims.add(new Slice.DimensionValue<>(subType.dimensions().get(i).name(), subAddress.label(i))); + } + return new Slice<>(new ConstantTensor<>(tensor), sliceDims).evaluate(); + } + /** Deserializes the given tensor from JSON format */ // NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module public static Tensor decode(TensorType type, byte[] jsonTensorValue) { |