diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-11-09 11:03:09 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-11-09 11:15:22 +0000 |
commit | 44ffed14e6ebbdf14fc457f57868597ec0664527 (patch) | |
tree | 0032161bbfca3dceb5001016246683f899847c9b /vespajlib/src/main | |
parent | 18237f7a83591a8f86ed45448f9114a54201311b (diff) |
expose dense subspace blocks for serializing
Diffstat (limited to 'vespajlib/src/main')
4 files changed, 57 insertions, 63 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1d55fa78b83..8b8ad5c2dcf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -28,20 +28,22 @@ public class MixedTensor implements Tensor { /** The dimension specification for this tensor */ private final TensorType type; - private final int blockSize; // aka dense subspace size - - static final class DenseBlock { - final TensorAddress sparseAddr; - final double[] cells; - DenseBlock(TensorAddress sparseAddr, double[] cells) { + private final int denseSubspaceSize; + + // XXX consider using "record" instead + /** only exposed for internal use; subject to change without notice */ + public static final class DenseSubspace { + public final TensorAddress sparseAddr; + public final double[] cells; + DenseSubspace(TensorAddress sparseAddr, double[] cells) { this.sparseAddr = sparseAddr; this.cells = cells; } @Override public int hashCode() { - return Objects.hash(sparseAddr, cells); + return Objects.hash(sparseAddr, cells[0]); } @Override public boolean equals(Object other) { - if (other instanceof DenseBlock o) { + if (other instanceof DenseSubspace o) { return sparseAddr.equals(o.sparseAddr) && Arrays.equals(cells, o.cells); } return false; @@ -49,30 +51,33 @@ public class MixedTensor implements Tensor { } /** The cells in the tensor */ - private final List<DenseBlock> cellBlocks; + private final List<DenseSubspace> denseSubspaces; + + /** only exposed for internal use; subject to change without notice */ + public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; } /** An index structure over the cell list */ private final Index index; - private MixedTensor(TensorType type, List<DenseBlock> cellBlocks, Index index) { + private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) { this.type = type; - this.blockSize = index.denseSubspaceSize(); - this.cellBlocks = List.copyOf(cellBlocks); + this.denseSubspaceSize = index.denseSubspaceSize(); + this.denseSubspaces = List.copyOf(denseSubspaces); this.index = index; - if (this.blockSize < 1) { - throw new IllegalStateException("invalid dense subspace size: " + blockSize); + if (this.denseSubspaceSize < 1) { + throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); } long count = 0; - for (var block : this.cellBlocks) { + for (var block : this.denseSubspaces) { if (index.sparseMap.get(block.sparseAddr) != count) { throw new IllegalStateException("map vs list mismatch: block #" + count + " address maps to #" + index.sparseMap.get(block.sparseAddr)); } - if (block.cells.length != blockSize) { + if (block.cells.length != denseSubspaceSize) { throw new IllegalStateException("dense subspace size mismatch, expected " - + blockSize + + denseSubspaceSize + " cells, but got: " + block.cells.length); } @@ -92,17 +97,17 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public long size() { return cellBlocks.size() * blockSize; } + public long size() { return denseSubspaces.size() * denseSubspaceSize; } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > cellBlocks.size()) { + if (blockNum < 0 || blockNum > denseSubspaces.size()) { return 0.0; } int denseOffset = index.denseOffsetOf(address); - var block = cellBlocks.get(blockNum); + var block = denseSubspaces.get(blockNum); if (denseOffset < 0 || denseOffset >= block.cells.length) { return 0.0; } @@ -112,11 +117,11 @@ public class MixedTensor implements Tensor { @Override public boolean has(TensorAddress address) { int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > cellBlocks.size()) { + if (blockNum < 0 || blockNum > denseSubspaces.size()) { return false; } int denseOffset = index.denseOffsetOf(address); - var block = cellBlocks.get(blockNum); + var block = denseSubspaces.get(blockNum); return (denseOffset >= 0 && denseOffset < block.cells.length); } @@ -130,16 +135,16 @@ public class MixedTensor implements Tensor { @Override public Iterator<Cell> cellIterator() { return new Iterator<>() { - final Iterator<DenseBlock> blockIterator = cellBlocks.iterator(); - DenseBlock currBlock = null; - int currOffset = blockSize; + final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); + DenseSubspace currBlock = null; + int currOffset = denseSubspaceSize; @Override public boolean hasNext() { - return (currOffset < blockSize || blockIterator.hasNext()); + return (currOffset < denseSubspaceSize || blockIterator.hasNext()); } @Override public Cell next() { - if (currOffset == blockSize) { + if (currOffset == denseSubspaceSize) { currBlock = blockIterator.next(); currOffset = 0; } @@ -157,16 +162,16 @@ public class MixedTensor implements Tensor { @Override public Iterator<Double> valueIterator() { return new Iterator<>() { - final Iterator<DenseBlock> blockIterator = cellBlocks.iterator(); + final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); double[] currBlock = null; - int currOffset = blockSize; + int currOffset = denseSubspaceSize; @Override public boolean hasNext() { - return (currOffset < blockSize || blockIterator.hasNext()); + return (currOffset < denseSubspaceSize || blockIterator.hasNext()); } @Override public Double next() { - if (currOffset == blockSize) { + if (currOffset == denseSubspaceSize) { currBlock = blockIterator.next().cells; currOffset = 0; } @@ -192,14 +197,14 @@ public class MixedTensor implements Tensor { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + this.type + "', requested type: '" + type + "'"); } - return new MixedTensor(other, cellBlocks, index); + return new MixedTensor(other, denseSubspaces, index); } @Override public Tensor remove(Set<TensorAddress> addresses) { var indexBuilder = new Index.Builder(type); - List<DenseBlock> list = new ArrayList<>(); - for (var block : cellBlocks) { + List<DenseSubspace> list = new ArrayList<>(); + for (var block : denseSubspaces) { if ( ! addresses.contains(block.sparseAddr)) { // assumption: addresses only contain the sparse part indexBuilder.addBlock(block.sparseAddr, list.size()); list.add(block); @@ -209,7 +214,7 @@ public class MixedTensor implements Tensor { } @Override - public int hashCode() { return Objects.hash(type, cellBlocks); } + public int hashCode() { return Objects.hash(type, denseSubspaces); } @Override public String toString() { @@ -244,7 +249,7 @@ public class MixedTensor implements Tensor { /** Returns the size of dense subspaces */ public long denseSubspaceSize() { - return blockSize; + return denseSubspaceSize; } /** @@ -357,11 +362,11 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { - List<DenseBlock> list = new ArrayList<>(); + List<DenseSubspace> list = new ArrayList<>(); for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) { TensorAddress sparsePart = entry.getKey(); double[] denseSubspace = entry.getValue(); - var block = new DenseBlock(sparsePart, denseSubspace); + var block = new DenseSubspace(sparsePart, denseSubspace); indexBuilder.addBlock(sparsePart, list.size()); list.add(block); } @@ -621,7 +626,7 @@ public class MixedTensor implements Tensor { } private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) { - return tensor.cellBlocks.get(subspaceIndex).cells[denseOffset]; + return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset]; } static class Builder { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 7f890a9ec51..b30b664a5f7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -90,6 +90,7 @@ public class TensorType { private final TensorType mappedSubtype; private final TensorType indexedSubtype; + // only used to initialize the "empty" instance private TensorType() { this.valueType = Value.DOUBLE; this.dimensions = List.of(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 3dc563feb40..dd433b92493 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -151,22 +151,14 @@ public class JsonFormat { // Create tensor type for mapped dimensions subtype TensorType mappedSubType = new TensorType.Builder(mappedDimensions).build(); - - // Find all unique indices for the mapped dimensions - Set<TensorAddress> denseSubSpaceAddresses = new HashSet<>(); - tensor.cellIterator().forEachRemaining((cell) -> { - denseSubSpaceAddresses.add(subAddress(cell.getKey(), mappedSubType, tensor.type())); - }); - - // Slice out dense subspace of each and encode dense subspace as a list - for (TensorAddress denseSubSpaceAddress : denseSubSpaceAddresses) { - IndexedTensor denseSubspace = (IndexedTensor) sliceSubAddress(tensor, denseSubSpaceAddress, mappedSubType); - + TensorType denseSubType = tensor.type().indexedSubtype(); + for (var subspace : tensor.getInternalDenseSubspaces()) { + IndexedTensor denseSubspace = IndexedTensor.Builder.of(denseSubType, subspace.cells).build(); if (mappedDimensions.size() == 1) { - encodeValues(denseSubspace, cursor.setArray(denseSubSpaceAddress.label(0)), new long[denseSubspace.dimensionSizes().dimensions()], 0); + encodeValues(denseSubspace, cursor.setArray(subspace.sparseAddr.label(0)), new long[denseSubspace.dimensionSizes().dimensions()], 0); } else { Cursor block = cursor.addObject(); - encodeAddress(mappedSubType, denseSubSpaceAddress, block.setObject("address")); + encodeAddress(mappedSubType, subspace.sparseAddr, block.setObject("address")); encodeValues(denseSubspace, block.setArray("values"), new long[denseSubspace.dimensionSizes().dimensions()], 0); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index b184e6e0159..edebc9acdd6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -73,20 +73,16 @@ class MixedBinaryFormat implements BinaryFormat { private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor, Consumer<Double> consumer) { List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).toList(); long denseSubspaceSize = tensor.denseSubspaceSize(); + var denseSubspaces = tensor.getInternalDenseSubspaces(); if (sparseDimensions.size() > 0) { - buffer.putInt1_4Bytes((int)(tensor.size() / denseSubspaceSize)); // XXX: Size truncation + buffer.putInt1_4Bytes(denseSubspaces.size()); } - Iterator<Tensor.Cell> cellIterator = tensor.cellIterator(); - while (cellIterator.hasNext()) { - Tensor.Cell cell = cellIterator.next(); - for (TensorType.Dimension dimension : sparseDimensions) { - int index = tensor.type().indexOfDimension(dimension.name()).orElseThrow(() -> - new IllegalStateException("Dimension not found in address.")); - buffer.putUtf8String(cell.getKey().label(index)); + for (var subspace : denseSubspaces) { + for (int index = 0; index < subspace.sparseAddr.size(); index++) { + buffer.putUtf8String(subspace.sparseAddr.label(index)); } - consumer.accept(cell.getValue()); - for (int i = 1; i < denseSubspaceSize; ++i ) { - consumer.accept(cellIterator.next().getValue()); + for (double val : subspace.cells) { + consumer.accept(val); } } } |