summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-11-09 11:03:09 +0000
committerArne Juul <arnej@yahooinc.com>2023-11-09 11:15:22 +0000
commit44ffed14e6ebbdf14fc457f57868597ec0664527 (patch)
tree0032161bbfca3dceb5001016246683f899847c9b /vespajlib/src/main/java/com/yahoo/tensor
parent18237f7a83591a8f86ed45448f9114a54201311b (diff)
expose dense subspace blocks for serializing
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java83
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java18
4 files changed, 57 insertions, 63 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 1d55fa78b83..8b8ad5c2dcf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -28,20 +28,22 @@ public class MixedTensor implements Tensor {
/** The dimension specification for this tensor */
private final TensorType type;
- private final int blockSize; // aka dense subspace size
-
- static final class DenseBlock {
- final TensorAddress sparseAddr;
- final double[] cells;
- DenseBlock(TensorAddress sparseAddr, double[] cells) {
+ private final int denseSubspaceSize;
+
+ // XXX consider using "record" instead
+ /** only exposed for internal use; subject to change without notice */
+ public static final class DenseSubspace {
+ public final TensorAddress sparseAddr;
+ public final double[] cells;
+ DenseSubspace(TensorAddress sparseAddr, double[] cells) {
this.sparseAddr = sparseAddr;
this.cells = cells;
}
@Override public int hashCode() {
- return Objects.hash(sparseAddr, cells);
+ return Objects.hash(sparseAddr, cells[0]);
}
@Override public boolean equals(Object other) {
- if (other instanceof DenseBlock o) {
+ if (other instanceof DenseSubspace o) {
return sparseAddr.equals(o.sparseAddr) && Arrays.equals(cells, o.cells);
}
return false;
@@ -49,30 +51,33 @@ public class MixedTensor implements Tensor {
}
/** The cells in the tensor */
- private final List<DenseBlock> cellBlocks;
+ private final List<DenseSubspace> denseSubspaces;
+
+ /** only exposed for internal use; subject to change without notice */
+ public List<DenseSubspace> getInternalDenseSubspaces() { return denseSubspaces; }
/** An index structure over the cell list */
private final Index index;
- private MixedTensor(TensorType type, List<DenseBlock> cellBlocks, Index index) {
+ private MixedTensor(TensorType type, List<DenseSubspace> denseSubspaces, Index index) {
this.type = type;
- this.blockSize = index.denseSubspaceSize();
- this.cellBlocks = List.copyOf(cellBlocks);
+ this.denseSubspaceSize = index.denseSubspaceSize();
+ this.denseSubspaces = List.copyOf(denseSubspaces);
this.index = index;
- if (this.blockSize < 1) {
- throw new IllegalStateException("invalid dense subspace size: " + blockSize);
+ if (this.denseSubspaceSize < 1) {
+ throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize);
}
long count = 0;
- for (var block : this.cellBlocks) {
+ for (var block : this.denseSubspaces) {
if (index.sparseMap.get(block.sparseAddr) != count) {
throw new IllegalStateException("map vs list mismatch: block #"
+ count
+ " address maps to #"
+ index.sparseMap.get(block.sparseAddr));
}
- if (block.cells.length != blockSize) {
+ if (block.cells.length != denseSubspaceSize) {
throw new IllegalStateException("dense subspace size mismatch, expected "
- + blockSize
+ + denseSubspaceSize
+ " cells, but got: "
+ block.cells.length);
}
@@ -92,17 +97,17 @@ public class MixedTensor implements Tensor {
/** Returns the size of the tensor measured in number of cells */
@Override
- public long size() { return cellBlocks.size() * blockSize; }
+ public long size() { return denseSubspaces.size() * denseSubspaceSize; }
/** Returns the value at the given address */
@Override
public double get(TensorAddress address) {
int blockNum = index.blockIndexOf(address);
- if (blockNum < 0 || blockNum > cellBlocks.size()) {
+ if (blockNum < 0 || blockNum > denseSubspaces.size()) {
return 0.0;
}
int denseOffset = index.denseOffsetOf(address);
- var block = cellBlocks.get(blockNum);
+ var block = denseSubspaces.get(blockNum);
if (denseOffset < 0 || denseOffset >= block.cells.length) {
return 0.0;
}
@@ -112,11 +117,11 @@ public class MixedTensor implements Tensor {
@Override
public boolean has(TensorAddress address) {
int blockNum = index.blockIndexOf(address);
- if (blockNum < 0 || blockNum > cellBlocks.size()) {
+ if (blockNum < 0 || blockNum > denseSubspaces.size()) {
return false;
}
int denseOffset = index.denseOffsetOf(address);
- var block = cellBlocks.get(blockNum);
+ var block = denseSubspaces.get(blockNum);
return (denseOffset >= 0 && denseOffset < block.cells.length);
}
@@ -130,16 +135,16 @@ public class MixedTensor implements Tensor {
@Override
public Iterator<Cell> cellIterator() {
return new Iterator<>() {
- final Iterator<DenseBlock> blockIterator = cellBlocks.iterator();
- DenseBlock currBlock = null;
- int currOffset = blockSize;
+ final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator();
+ DenseSubspace currBlock = null;
+ int currOffset = denseSubspaceSize;
@Override
public boolean hasNext() {
- return (currOffset < blockSize || blockIterator.hasNext());
+ return (currOffset < denseSubspaceSize || blockIterator.hasNext());
}
@Override
public Cell next() {
- if (currOffset == blockSize) {
+ if (currOffset == denseSubspaceSize) {
currBlock = blockIterator.next();
currOffset = 0;
}
@@ -157,16 +162,16 @@ public class MixedTensor implements Tensor {
@Override
public Iterator<Double> valueIterator() {
return new Iterator<>() {
- final Iterator<DenseBlock> blockIterator = cellBlocks.iterator();
+ final Iterator<DenseSubspace> blockIterator = denseSubspaces.iterator();
double[] currBlock = null;
- int currOffset = blockSize;
+ int currOffset = denseSubspaceSize;
@Override
public boolean hasNext() {
- return (currOffset < blockSize || blockIterator.hasNext());
+ return (currOffset < denseSubspaceSize || blockIterator.hasNext());
}
@Override
public Double next() {
- if (currOffset == blockSize) {
+ if (currOffset == denseSubspaceSize) {
currBlock = blockIterator.next().cells;
currOffset = 0;
}
@@ -192,14 +197,14 @@ public class MixedTensor implements Tensor {
throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" +
this.type + "', requested type: '" + type + "'");
}
- return new MixedTensor(other, cellBlocks, index);
+ return new MixedTensor(other, denseSubspaces, index);
}
@Override
public Tensor remove(Set<TensorAddress> addresses) {
var indexBuilder = new Index.Builder(type);
- List<DenseBlock> list = new ArrayList<>();
- for (var block : cellBlocks) {
+ List<DenseSubspace> list = new ArrayList<>();
+ for (var block : denseSubspaces) {
if ( ! addresses.contains(block.sparseAddr)) { // assumption: addresses only contain the sparse part
indexBuilder.addBlock(block.sparseAddr, list.size());
list.add(block);
@@ -209,7 +214,7 @@ public class MixedTensor implements Tensor {
}
@Override
- public int hashCode() { return Objects.hash(type, cellBlocks); }
+ public int hashCode() { return Objects.hash(type, denseSubspaces); }
@Override
public String toString() {
@@ -244,7 +249,7 @@ public class MixedTensor implements Tensor {
/** Returns the size of dense subspaces */
public long denseSubspaceSize() {
- return blockSize;
+ return denseSubspaceSize;
}
/**
@@ -357,11 +362,11 @@ public class MixedTensor implements Tensor {
@Override
public MixedTensor build() {
- List<DenseBlock> list = new ArrayList<>();
+ List<DenseSubspace> list = new ArrayList<>();
for (Map.Entry<TensorAddress, double[]> entry : denseSubspaceMap.entrySet()) {
TensorAddress sparsePart = entry.getKey();
double[] denseSubspace = entry.getValue();
- var block = new DenseBlock(sparsePart, denseSubspace);
+ var block = new DenseSubspace(sparsePart, denseSubspace);
indexBuilder.addBlock(sparsePart, list.size());
list.add(block);
}
@@ -621,7 +626,7 @@ public class MixedTensor implements Tensor {
}
private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) {
- return tensor.cellBlocks.get(subspaceIndex).cells[denseOffset];
+ return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset];
}
static class Builder {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 7f890a9ec51..b30b664a5f7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -90,6 +90,7 @@ public class TensorType {
private final TensorType mappedSubtype;
private final TensorType indexedSubtype;
+ // only used to initialize the "empty" instance
private TensorType() {
this.valueType = Value.DOUBLE;
this.dimensions = List.of();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 3dc563feb40..dd433b92493 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -151,22 +151,14 @@ public class JsonFormat {
// Create tensor type for mapped dimensions subtype
TensorType mappedSubType = new TensorType.Builder(mappedDimensions).build();
-
- // Find all unique indices for the mapped dimensions
- Set<TensorAddress> denseSubSpaceAddresses = new HashSet<>();
- tensor.cellIterator().forEachRemaining((cell) -> {
- denseSubSpaceAddresses.add(subAddress(cell.getKey(), mappedSubType, tensor.type()));
- });
-
- // Slice out dense subspace of each and encode dense subspace as a list
- for (TensorAddress denseSubSpaceAddress : denseSubSpaceAddresses) {
- IndexedTensor denseSubspace = (IndexedTensor) sliceSubAddress(tensor, denseSubSpaceAddress, mappedSubType);
-
+ TensorType denseSubType = tensor.type().indexedSubtype();
+ for (var subspace : tensor.getInternalDenseSubspaces()) {
+ IndexedTensor denseSubspace = IndexedTensor.Builder.of(denseSubType, subspace.cells).build();
if (mappedDimensions.size() == 1) {
- encodeValues(denseSubspace, cursor.setArray(denseSubSpaceAddress.label(0)), new long[denseSubspace.dimensionSizes().dimensions()], 0);
+ encodeValues(denseSubspace, cursor.setArray(subspace.sparseAddr.label(0)), new long[denseSubspace.dimensionSizes().dimensions()], 0);
} else {
Cursor block = cursor.addObject();
- encodeAddress(mappedSubType, denseSubSpaceAddress, block.setObject("address"));
+ encodeAddress(mappedSubType, subspace.sparseAddr, block.setObject("address"));
encodeValues(denseSubspace, block.setArray("values"), new long[denseSubspace.dimensionSizes().dimensions()], 0);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index b184e6e0159..edebc9acdd6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -73,20 +73,16 @@ class MixedBinaryFormat implements BinaryFormat {
private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor, Consumer<Double> consumer) {
List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).toList();
long denseSubspaceSize = tensor.denseSubspaceSize();
+ var denseSubspaces = tensor.getInternalDenseSubspaces();
if (sparseDimensions.size() > 0) {
- buffer.putInt1_4Bytes((int)(tensor.size() / denseSubspaceSize)); // XXX: Size truncation
+ buffer.putInt1_4Bytes(denseSubspaces.size());
}
- Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
- while (cellIterator.hasNext()) {
- Tensor.Cell cell = cellIterator.next();
- for (TensorType.Dimension dimension : sparseDimensions) {
- int index = tensor.type().indexOfDimension(dimension.name()).orElseThrow(() ->
- new IllegalStateException("Dimension not found in address."));
- buffer.putUtf8String(cell.getKey().label(index));
+ for (var subspace : denseSubspaces) {
+ for (int index = 0; index < subspace.sparseAddr.size(); index++) {
+ buffer.putUtf8String(subspace.sparseAddr.label(index));
}
- consumer.accept(cell.getValue());
- for (int i = 1; i < denseSubspaceSize; ++i ) {
- consumer.accept(cellIterator.next().getValue());
+ for (double val : subspace.cells) {
+ consumer.accept(val);
}
}
}