diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-10 11:39:39 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-12-10 11:39:39 -0800 |
commit | 4c46e1816d2cdfacd8435ad4d55e831929fc99ba (patch) | |
tree | d55a90aeeddcf9265a74e7f16129517e36f45375 /vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java | |
parent | b8d2859a9fece15dac2b9260d71dea39f8ce19b3 (diff) |
Tensor parsing improvements
- Mixed tensor format parsing (outside expressions)
- Validate structure of dense tensor strings
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java | 49 |
1 files changed, 41 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1cde1fcdbb7..0c4efe78113 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -217,25 +217,34 @@ public class MixedTensor implements Tensor { public static class BoundBuilder extends Builder { /** For each sparse partial address, hold a dense subspace */ - final private Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); - final private Index.Builder indexBuilder; - final private Index index; + private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); + private final Index.Builder indexBuilder; + private final Index index; + private final TensorType denseSubtype; private BoundBuilder(TensorType type) { super(type); indexBuilder = new Index.Builder(type); index = indexBuilder.index(); + denseSubtype = new TensorType(type.valueType(), + type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList())); } public long denseSubspaceSize() { return index.denseSubspaceSize(); } - private double[] denseSubspace(TensorAddress sparsePartial) { - if (!denseSubspaceMap.containsKey(sparsePartial)) { - denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]); + private double[] denseSubspace(TensorAddress sparseAddress) { + if (!denseSubspaceMap.containsKey(sparseAddress)) { + denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]); } - return denseSubspaceMap.get(sparsePartial); + return denseSubspaceMap.get(sparseAddress); + } + + public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) { + double[] values = new double[(int)denseSubspaceSize()]; + denseSubspaceMap.put(sparseAddress, values); + return new DenseSubspaceBuilder(denseSubtype, values); } @Override @@ -280,7 +289,6 @@ public class MixedTensor implements Tensor { } - /** * Temporarily stores all cells to find bounds of indexed dimensions, * then creates a tensor using BoundBuilder. This is due to the @@ -491,6 +499,31 @@ public class MixedTensor implements Tensor { } + private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder { + + private final TensorType type; + private final double[] values; + + public DenseSubspaceBuilder(TensorType type, double[] values) { + this.type = type; + this.values = values; + } + + @Override + public TensorType type() { return type; } + + @Override + public void cellByDirectIndex(long index, double value) { + values[(int)index] = value; + } + + @Override + public void cellByDirectIndex(long index, float value) { + values[(int)index] = value; + } + + } + public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { TensorType.Builder builder = new TensorType.Builder(valueType); for (TensorType.Dimension dimension : dimensions) { |