diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-11-09 11:03:09 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-11-09 11:15:22 +0000 |
commit | 44ffed14e6ebbdf14fc457f57868597ec0664527 (patch) | |
tree | 0032161bbfca3dceb5001016246683f899847c9b /vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java | |
parent | 18237f7a83591a8f86ed45448f9114a54201311b (diff) |
expose dense subspace blocks for serializing
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java | 83 |
1 files changed, 44 insertions, 39 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1d55fa78b83..8b8ad5c2dcf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -28,20 +28,22 @@ public class MixedTensor implements Tensor { /** The dimension specification for this tensor */ private final TensorType type; - private final int blockSize; // aka dense subspace size - - static final class DenseBlock { - final TensorAddress sparseAddr; - final double[] cells; - DenseBlock(TensorAddress sparseAddr, double[] cells) { + private final int denseSubspaceSize; + + // XXX consider using "record" instead + /** only exposed for internal use; subject to change without notice */ + public static final class DenseSubspace { + public final TensorAddress sparseAddr; + public final double[] cells; + DenseSubspace(TensorAddress sparseAddr, double[] cells) { this.sparseAddr = sparseAddr; this.cells = cells; } @Override public int hashCode() { - return Objects.hash(sparseAddr, cells); + return Objects.hash(sparseAddr, cells[0]); } @Override public boolean equals(Object other) { - if (other instanceof DenseBlock o) { + if (other instanceof DenseSubspace o) { return sparseAddr.equals(o.sparseAddr) && Arrays.equals(cells, o.cells); } return false; @@ -49,30 +51,33 @@ public class MixedTensor implements Tensor { } /** The cells in the tensor */ - private final List<DenseBlock> cellBlocks; + private final List<DenseSubspace> denseSubspaces; + + /** only exposed for internal use; subject to change without notice */ + public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; } /** An index structure over the cell list */ private final Index index; - private MixedTensor(TensorType type, List<DenseBlock> cellBlocks, Index index) { + private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) { this.type = type; - this.blockSize = index.denseSubspaceSize(); - this.cellBlocks = List.copyOf(cellBlocks); + this.denseSubspaceSize = index.denseSubspaceSize(); + this.denseSubspaces = List.copyOf(denseSubspaces); this.index = index; - if (this.blockSize < 1) { - throw new IllegalStateException("invalid dense subspace size: " + blockSize); + if (this.denseSubspaceSize < 1) { + throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); } long count = 0; - for (var block : this.cellBlocks) { + for (var block : this.denseSubspaces) { if (index.sparseMap.get(block.sparseAddr) != count) { throw new IllegalStateException("map vs list mismatch: block #" + count + " address maps to #" + index.sparseMap.get(block.sparseAddr)); } - if (block.cells.length != blockSize) { + if (block.cells.length != denseSubspaceSize) { throw new IllegalStateException("dense subspace size mismatch, expected " - + blockSize + + denseSubspaceSize + " cells, but got: " + block.cells.length); } @@ -92,17 +97,17 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public long size() { return cellBlocks.size() * blockSize; } + public long size() { return denseSubspaces.size() * denseSubspaceSize; } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > cellBlocks.size()) { + if (blockNum < 0 || blockNum > denseSubspaces.size()) { return 0.0; } int denseOffset = index.denseOffsetOf(address); - var block = cellBlocks.get(blockNum); + var block = denseSubspaces.get(blockNum); if (denseOffset < 0 || denseOffset >= block.cells.length) { return 0.0; } @@ -112,11 +117,11 @@ public class MixedTensor implements Tensor { @Override public boolean has(TensorAddress address) { int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > cellBlocks.size()) { + if (blockNum < 0 || blockNum > denseSubspaces.size()) { return false; } int denseOffset = index.denseOffsetOf(address); - var block = cellBlocks.get(blockNum); + var block = denseSubspaces.get(blockNum); return (denseOffset >= 0 && denseOffset < block.cells.length); } @@ -130,16 +135,16 @@ public class MixedTensor implements Tensor { @Override public Iterator<Cell> cellIterator() { return new Iterator<>() { - final Iterator<DenseBlock> blockIterator = cellBlocks.iterator(); - DenseBlock currBlock = null; - int currOffset = blockSize; + final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); + DenseSubspace currBlock = null; + int currOffset = denseSubspaceSize; @Override public boolean hasNext() { - return (currOffset < blockSize || blockIterator.hasNext()); + return (currOffset < denseSubspaceSize || blockIterator.hasNext()); } @Override public Cell next() { - if (currOffset == blockSize) { + if (currOffset == denseSubspaceSize) { currBlock = blockIterator.next(); currOffset = 0; } @@ -157,16 +162,16 @@ public class MixedTensor implements Tensor { @Override public Iterator<Double> valueIterator() { return new Iterator<>() { - final Iterator<DenseBlock> blockIterator = cellBlocks.iterator(); + final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator(); double[] currBlock = null; - int currOffset = blockSize; + int currOffset = denseSubspaceSize; @Override public boolean hasNext() { - return (currOffset < blockSize || blockIterator.hasNext()); + return (currOffset < denseSubspaceSize || blockIterator.hasNext()); } @Override public Double next() { - if (currOffset == blockSize) { + if (currOffset == denseSubspaceSize) { currBlock = blockIterator.next().cells; currOffset = 0; } @@ -192,14 +197,14 @@ public class MixedTensor implements Tensor { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + this.type + "', requested type: '" + type + "'"); } - return new MixedTensor(other, cellBlocks, index); + return new MixedTensor(other, denseSubspaces, index); } @Override public Tensor remove(Set<TensorAddress> addresses) { var indexBuilder = new Index.Builder(type); - List<DenseBlock> list = new ArrayList<>(); - for (var block : cellBlocks) { + List<DenseSubspace> list = new ArrayList<>(); + for (var block : denseSubspaces) { if ( ! addresses.contains(block.sparseAddr)) { // assumption: addresses only contain the sparse part indexBuilder.addBlock(block.sparseAddr, list.size()); list.add(block); @@ -209,7 +214,7 @@ public class MixedTensor implements Tensor { } @Override - public int hashCode() { return Objects.hash(type, cellBlocks); } + public int hashCode() { return Objects.hash(type, denseSubspaces); } @Override public String toString() { @@ -244,7 +249,7 @@ public class MixedTensor implements Tensor { /** Returns the size of dense subspaces */ public long denseSubspaceSize() { - return blockSize; + return denseSubspaceSize; } /** @@ -357,11 +362,11 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { - List<DenseBlock> list = new ArrayList<>(); + List<DenseSubspace> list = new ArrayList<>(); for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) { TensorAddress sparsePart = entry.getKey(); double[] denseSubspace = entry.getValue(); - var block = new DenseBlock(sparsePart, denseSubspace); + var block = new DenseSubspace(sparsePart, denseSubspace); indexBuilder.addBlock(sparsePart, list.size()); list.add(block); } @@ -621,7 +626,7 @@ public class MixedTensor implements Tensor { } private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) { - return tensor.cellBlocks.get(subspaceIndex).cells[denseOffset]; + return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset]; } static class Builder { |