summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-11-09 11:03:09 +0000
committerArne Juul <arnej@yahooinc.com>2023-11-09 11:15:22 +0000
commit44ffed14e6ebbdf14fc457f57868597ec0664527 (patch)
tree0032161bbfca3dceb5001016246683f899847c9b /vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
parent18237f7a83591a8f86ed45448f9114a54201311b (diff)
expose dense subspace blocks for serializing
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java83
1 files changed, 44 insertions, 39 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 1d55fa78b83..8b8ad5c2dcf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -28,20 +28,22 @@ public class MixedTensor implements Tensor {
/** The dimension specification for this tensor */
private final TensorType type;
- private final int blockSize; // aka dense subspace size
-
- static final class DenseBlock {
- final TensorAddress sparseAddr;
- final double[] cells;
- DenseBlock(TensorAddress sparseAddr, double[] cells) {
+ private final int denseSubspaceSize;
+
+ // XXX consider using "record" instead
+ /** only exposed for internal use; subject to change without notice */
+ public static final class DenseSubspace {
+ public final TensorAddress sparseAddr;
+ public final double[] cells;
+ DenseSubspace(TensorAddress sparseAddr, double[] cells) {
this.sparseAddr = sparseAddr;
this.cells = cells;
}
@Override public int hashCode() {
- return Objects.hash(sparseAddr, cells);
+ return Objects.hash(sparseAddr, cells[0]);
}
@Override public boolean equals(Object other) {
- if (other instanceof DenseBlock o) {
+ if (other instanceof DenseSubspace o) {
return sparseAddr.equals(o.sparseAddr) && Arrays.equals(cells, o.cells);
}
return false;
@@ -49,30 +51,33 @@ public class MixedTensor implements Tensor {
}
/** The cells in the tensor */
- private final List<DenseBlock> cellBlocks;
+ private final List<DenseSubspace> denseSubspaces;
+
+ /** only exposed for internal use; subject to change without notice */
+ public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; }
/** An index structure over the cell list */
private final Index index;
- private MixedTensor(TensorType type, List<DenseBlock> cellBlocks, Index index) {
+ private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) {
this.type = type;
- this.blockSize = index.denseSubspaceSize();
- this.cellBlocks = List.copyOf(cellBlocks);
+ this.denseSubspaceSize = index.denseSubspaceSize();
+ this.denseSubspaces = List.copyOf(denseSubspaces);
this.index = index;
- if (this.blockSize < 1) {
- throw new IllegalStateException("invalid dense subspace size: " + blockSize);
+ if (this.denseSubspaceSize < 1) {
+ throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize);
}
long count = 0;
- for (var block : this.cellBlocks) {
+ for (var block : this.denseSubspaces) {
if (index.sparseMap.get(block.sparseAddr) != count) {
throw new IllegalStateException("map vs list mismatch: block #"
+ count
+ " address maps to #"
+ index.sparseMap.get(block.sparseAddr));
}
- if (block.cells.length != blockSize) {
+ if (block.cells.length != denseSubspaceSize) {
throw new IllegalStateException("dense subspace size mismatch, expected "
- + blockSize
+ + denseSubspaceSize
+ " cells, but got: "
+ block.cells.length);
}
@@ -92,17 +97,17 @@ public class MixedTensor implements Tensor {
/** Returns the size of the tensor measured in number of cells */
@Override
- public long size() { return cellBlocks.size() * blockSize; }
+ public long size() { return denseSubspaces.size() * denseSubspaceSize; }
/** Returns the value at the given address */
@Override
public double get(TensorAddress address) {
int blockNum = index.blockIndexOf(address);
- if (blockNum < 0 || blockNum > cellBlocks.size()) {
+ if (blockNum < 0 || blockNum > denseSubspaces.size()) {
return 0.0;
}
int denseOffset = index.denseOffsetOf(address);
- var block = cellBlocks.get(blockNum);
+ var block = denseSubspaces.get(blockNum);
if (denseOffset < 0 || denseOffset >= block.cells.length) {
return 0.0;
}
@@ -112,11 +117,11 @@ public class MixedTensor implements Tensor {
@Override
public boolean has(TensorAddress address) {
int blockNum = index.blockIndexOf(address);
- if (blockNum < 0 || blockNum > cellBlocks.size()) {
+ if (blockNum < 0 || blockNum > denseSubspaces.size()) {
return false;
}
int denseOffset = index.denseOffsetOf(address);
- var block = cellBlocks.get(blockNum);
+ var block = denseSubspaces.get(blockNum);
return (denseOffset >= 0 && denseOffset < block.cells.length);
}
@@ -130,16 +135,16 @@ public class MixedTensor implements Tensor {
@Override
public Iterator<Cell> cellIterator() {
return new Iterator<>() {
- final Iterator<DenseBlock> blockIterator = cellBlocks.iterator();
- DenseBlock currBlock = null;
- int currOffset = blockSize;
+ final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator();
+ DenseSubspace currBlock = null;
+ int currOffset = denseSubspaceSize;
@Override
public boolean hasNext() {
- return (currOffset < blockSize || blockIterator.hasNext());
+ return (currOffset < denseSubspaceSize || blockIterator.hasNext());
}
@Override
public Cell next() {
- if (currOffset == blockSize) {
+ if (currOffset == denseSubspaceSize) {
currBlock = blockIterator.next();
currOffset = 0;
}
@@ -157,16 +162,16 @@ public class MixedTensor implements Tensor {
@Override
public Iterator<Double> valueIterator() {
return new Iterator<>() {
- final Iterator<DenseBlock> blockIterator = cellBlocks.iterator();
+ final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator();
double[] currBlock = null;
- int currOffset = blockSize;
+ int currOffset = denseSubspaceSize;
@Override
public boolean hasNext() {
- return (currOffset < blockSize || blockIterator.hasNext());
+ return (currOffset < denseSubspaceSize || blockIterator.hasNext());
}
@Override
public Double next() {
- if (currOffset == blockSize) {
+ if (currOffset == denseSubspaceSize) {
currBlock = blockIterator.next().cells;
currOffset = 0;
}
@@ -192,14 +197,14 @@ public class MixedTensor implements Tensor {
throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" +
this.type + "', requested type: '" + type + "'");
}
- return new MixedTensor(other, cellBlocks, index);
+ return new MixedTensor(other, denseSubspaces, index);
}
@Override
public Tensor remove(Set<TensorAddress> addresses) {
var indexBuilder = new Index.Builder(type);
- List<DenseBlock> list = new ArrayList<>();
- for (var block : cellBlocks) {
+ List<DenseSubspace> list = new ArrayList<>();
+ for (var block : denseSubspaces) {
if ( ! addresses.contains(block.sparseAddr)) { // assumption: addresses only contain the sparse part
indexBuilder.addBlock(block.sparseAddr, list.size());
list.add(block);
@@ -209,7 +214,7 @@ public class MixedTensor implements Tensor {
}
@Override
- public int hashCode() { return Objects.hash(type, cellBlocks); }
+ public int hashCode() { return Objects.hash(type, denseSubspaces); }
@Override
public String toString() {
@@ -244,7 +249,7 @@ public class MixedTensor implements Tensor {
/** Returns the size of dense subspaces */
public long denseSubspaceSize() {
- return blockSize;
+ return denseSubspaceSize;
}
/**
@@ -357,11 +362,11 @@ public class MixedTensor implements Tensor {
@Override
public MixedTensor build() {
- List<DenseBlock> list = new ArrayList<>();
+ List<DenseSubspace> list = new ArrayList<>();
for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) {
TensorAddress sparsePart = entry.getKey();
double[] denseSubspace = entry.getValue();
- var block = new DenseBlock(sparsePart, denseSubspace);
+ var block = new DenseSubspace(sparsePart, denseSubspace);
indexBuilder.addBlock(sparsePart, list.size());
list.add(block);
}
@@ -621,7 +626,7 @@ public class MixedTensor implements Tensor {
}
private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) {
- return tensor.cellBlocks.get(subspaceIndex).cells[denseOffset];
+ return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset];
}
static class Builder {