From a8e4dec815f4df43e8f473906f88fadfaedc622c Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Sun, 21 Jan 2024 16:41:36 +0100 Subject: - Avoid expensive iteration with hash lookups just for sanity checking in private MixedTensor constructor. - Refactor so there is less need for sanity checks, and do them in the buiulder instead. --- .../main/java/com/yahoo/tensor/MixedTensor.java | 126 ++++++++------------- .../java/com/yahoo/tensor/functions/Reduce.java | 8 +- .../main/java/com/yahoo/tensor/impl/Convert.java | 16 +++ 3 files changed, 71 insertions(+), 79 deletions(-) create mode 100644 vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 55fb4af551e..09ed0f11b07 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -2,6 +2,7 @@ package com.yahoo.tensor; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.tensor.impl.NumericTensorAddress; import com.yahoo.tensor.impl.StringTensorAddress; @@ -31,7 +32,6 @@ public class MixedTensor implements Tensor { /** The dimension specification for this tensor */ private final TensorType type; - private final int denseSubspaceSize; // XXX consider using "record" instead /** only exposed for internal use; subject to change without notice */ @@ -53,45 +53,15 @@ public class MixedTensor implements Tensor { } } - /** The cells in the tensor */ - private final List denseSubspaces; - /** only exposed for internal use; subject to change without notice */ - public List getInternalDenseSubspaces() { return denseSubspaces; } + public List getInternalDenseSubspaces() { return index.denseSubspaces; } /** An index structure over the cell list */ private final Index index; - private MixedTensor(TensorType type, List denseSubspaces, Index index) { + private MixedTensor(TensorType type, Index index) { this.type = type; - this.denseSubspaceSize = index.denseSubspaceSize(); - this.denseSubspaces = List.copyOf(denseSubspaces); this.index = index; - if (this.denseSubspaceSize < 1) { - throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); - } - long count = 0; - for (var block : this.denseSubspaces) { - if (index.sparseMap.get(block.sparseAddress) != count) { - throw new IllegalStateException("map vs list mismatch: block #" - + count - + " address maps to #" - + index.sparseMap.get(block.sparseAddress)); - } - if (block.cells.length != denseSubspaceSize) { - throw new IllegalStateException("dense subspace size mismatch, expected " - + denseSubspaceSize - + " cells, but got: " - + block.cells.length); - } - ++count; - } - if (count != index.sparseMap.size()) { - throw new IllegalStateException("mismatch: list size is " - + count - + " but map size is " - + index.sparseMap.size()); - } } /** Returns the tensor type */ @@ -100,18 +70,14 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public long size() { return denseSubspaces.size() * denseSubspaceSize; } + public long size() { return index.denseSubspaces.size() * index.denseSubspaceSize; } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { - return 0.0; - } + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - if (denseOffset < 0 || denseOffset >= block.cells.length) { + if (block == null || denseOffset < 0 || denseOffset >= block.cells.length) { return 0.0; } return block.cells[denseOffset]; @@ -119,13 +85,9 @@ public class MixedTensor implements Tensor { @Override public boolean has(TensorAddress address) { - int blockNum = index.blockIndexOf(address); - if (blockNum < 0 || blockNum > denseSubspaces.size()) { - return false; - } + var block = index.blockOf(address); int denseOffset = index.denseOffsetOf(address); - var block = denseSubspaces.get(blockNum); - return (denseOffset >= 0 && denseOffset < block.cells.length); + return (block != null && denseOffset >= 0 && denseOffset < block.cells.length); } /** @@ -138,16 +100,16 @@ public class MixedTensor implements Tensor { @Override public Iterator cellIterator() { return new Iterator<>() { - final Iterator blockIterator = denseSubspaces.iterator(); + final Iterator blockIterator = index.denseSubspaces.iterator(); DenseSubspace currBlock = null; - int currOffset = denseSubspaceSize; + int currOffset = index.denseSubspaceSize; @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } @Override public Cell next() { - if (currOffset == denseSubspaceSize) { + if (currOffset == index.denseSubspaceSize) { currBlock = blockIterator.next(); currOffset = 0; } @@ -165,16 +127,16 @@ public class MixedTensor implements Tensor { @Override public Iterator valueIterator() { return new Iterator<>() { - final Iterator blockIterator = denseSubspaces.iterator(); + final Iterator blockIterator = index.denseSubspaces.iterator(); double[] currBlock = null; - int currOffset = denseSubspaceSize; + int currOffset = index.denseSubspaceSize; @Override public boolean hasNext() { - return (currOffset < denseSubspaceSize || blockIterator.hasNext()); + return (currOffset < index.denseSubspaceSize || blockIterator.hasNext()); } @Override public Double next() { - if (currOffset == denseSubspaceSize) { + if (currOffset == index.denseSubspaceSize) { currBlock = blockIterator.next().cells; currOffset = 0; } @@ -200,24 +162,22 @@ public class MixedTensor implements Tensor { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + this.type + "', requested type: '" + type + "'"); } - return new MixedTensor(other, denseSubspaces, index); + return new MixedTensor(other, index); } @Override public Tensor remove(Set addresses) { var indexBuilder = new Index.Builder(type); - List list = new ArrayList<>(); - for (var block : denseSubspaces) { + for (var block : index.denseSubspaces) { if ( ! addresses.contains(block.sparseAddress)) { // assumption: addresses only contain the sparse part - indexBuilder.addBlock(block.sparseAddress, list.size()); - list.add(block); + indexBuilder.addBlock(block); } } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } @Override - public int hashCode() { return Objects.hash(type, denseSubspaces); } + public int hashCode() { return Objects.hash(type, index.denseSubspaces); } @Override public String toString() { @@ -252,7 +212,7 @@ public class MixedTensor implements Tensor { /** Returns the size of dense subspaces */ public long denseSubspaceSize() { - return denseSubspaceSize; + return index.denseSubspaceSize; } /** @@ -367,15 +327,14 @@ public class MixedTensor implements Tensor { @Override public MixedTensor build() { - List list = new ArrayList<>(); - for (Map.Entry entry : denseSubspaceMap.entrySet()) { + Set> entrySet = denseSubspaceMap.entrySet(); + for (Map.Entry entry : entrySet) { TensorAddress sparsePart = entry.getKey(); double[] denseSubspace = entry.getValue(); var block = new DenseSubspace(sparsePart, denseSubspace); - indexBuilder.addBlock(sparsePart, list.size()); - list.add(block); + indexBuilder.addBlock(block); } - return new MixedTensor(type, list, indexBuilder.build()); + return new MixedTensor(type, indexBuilder.build()); } public static BoundBuilder of(TensorType type) { @@ -467,6 +426,7 @@ public class MixedTensor implements Tensor { private final List indexedDimensions; private ImmutableMap sparseMap; + private List denseSubspaces; private final int denseSubspaceSize; static private int computeDSS(List dimensions) { @@ -485,14 +445,21 @@ public class MixedTensor implements Tensor { this.sparseType = createPartialType(type.valueType(), mappedDimensions); this.denseType = createPartialType(type.valueType(), indexedDimensions); this.denseSubspaceSize = computeDSS(this.indexedDimensions); + if (this.denseSubspaceSize < 1) { + throw new IllegalStateException("invalid dense subspace size: " + denseSubspaceSize); + } } - int blockIndexOf(TensorAddress address) { + private DenseSubspace blockOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); - return sparseMap.getOrDefault(sparsePart, -1); + Integer blockNum = sparseMap.get(sparsePart); + if (blockNum == null || blockNum >= denseSubspaces.size()) { + return null; + } + return denseSubspaces.get(blockNum); } - int denseOffsetOf(TensorAddress address) { + private int denseOffsetOf(TensorAddress address) { long innerSize = 1; long offset = 0; for (int i = type.dimensions().size(); --i >= 0; ) { @@ -629,25 +596,32 @@ public class MixedTensor implements Tensor { } private double getDouble(int subspaceIndex, int denseOffset, MixedTensor tensor) { - return tensor.denseSubspaces.get(subspaceIndex).cells[denseOffset]; + return tensor.index.denseSubspaces.get(subspaceIndex).cells[denseOffset]; } - static class Builder { + private static class Builder { private final Index index; - private final ImmutableMap.Builder builder; + private final ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); + private final ImmutableList.Builder listBuilder = new ImmutableList.Builder<>(); + private int count = 0; Builder(TensorType type) { index = new Index(type); - builder = new ImmutableMap.Builder<>(); } - void addBlock(TensorAddress address, int sz) { - builder.put(address, sz); + void addBlock(DenseSubspace block) { + if (block.cells.length != index.denseSubspaceSize) { + throw new IllegalStateException("dense subspace size mismatch, expected " + index.denseSubspaceSize + + " cells, but got: " + block.cells.length); + } + builder.put(block.sparseAddress, count++); + listBuilder.add(block); } Index build() { index.sparseMap = builder.build(); + index.denseSubspaces = listBuilder.build(); return index; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index fe20c41174a..3ba57b29ebc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.Convert; import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; @@ -172,14 +173,15 @@ public class Reduce extends PrimitiveTensorFunction i = argument.valueIterator(); i.hasNext(); ) valueAggregator.aggregate(i.next()); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } private static Tensor reduceIndexedVector(IndexedTensor argument, Aggregator aggregator) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - for (int i = 0; i < argument.dimensionSizes().size(0); i++) + int dimensionSize = Convert.safe2Int(argument.dimensionSizes().size(0)); + for (int i = 0; i < dimensionSize ; i++) valueAggregator.aggregate(argument.get(i)); - return Tensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); + return Tensor.Builder.of(TensorType.empty).cell(valueAggregator.aggregatedValue()).build(); } static abstract class ValueAggregator { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java new file mode 100644 index 00000000000..e2cb64fdd1f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/Convert.java @@ -0,0 +1,16 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.impl; + +/** + * Utility to make common conversions safe + * + * @author baldersheim + */ +public class Convert { + public static int safe2Int(long value) { + if (value > Integer.MAX_VALUE || value < Integer.MIN_VALUE) { + throw new IndexOutOfBoundsException("value = " + value + ", which is too large to fit in an int"); + } + return (int) value; + } +} -- cgit v1.2.3