From 351e0f62fe4f9b1b1015c0c1289a3f519fa9f868 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Sun, 21 Jan 2024 11:20:59 +0100 Subject: - Extract dimension names in a set to avoid recomputing it in dimensionNames() mathod. - Also extract some information about what kind of dimensions a tensor has construcion time. This avoids streamin through all diemsion later on, and keeps this local in the TensorType. --- .../main/java/com/yahoo/tensor/IndexedTensor.java | 15 +++++------ .../main/java/com/yahoo/tensor/MixedTensor.java | 2 +- .../src/main/java/com/yahoo/tensor/Tensor.java | 10 ++++---- .../src/main/java/com/yahoo/tensor/TensorType.java | 29 +++++++++++++++++++--- 4 files changed, 40 insertions(+), 16 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 93cdc3f630f..5d384e0329b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -164,9 +164,10 @@ public abstract class IndexedTensor implements Tensor { long valueIndex = 0; for (int i = 0; i < address.size(); i++) { - if (address.numericLabel(i) >= sizes.size(i)) + long label = address.numericLabel(i); + if (label >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); - valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i); + valueIndex += sizes.productOfDimensionsAfter(i) * label; } return valueIndex; } @@ -281,7 +282,7 @@ public abstract class IndexedTensor implements Tensor { } public static Builder of(TensorType type) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type)); else return new UnboundBuilder(type); @@ -295,7 +296,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, float[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -309,7 +310,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, double[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -615,11 +616,11 @@ public abstract class IndexedTensor implements Tensor { private final class ValueIterator implements Iterator { - private long count = 0; + private int count = 0; @Override public boolean hasNext() { - return count < size(); + return count < sizeAsInt(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 74b338fb503..d755d6d6337 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -266,7 +266,7 @@ public class MixedTensor implements Tensor { * a temporary structure while finding dimension bounds. */ public static Builder of(TensorType type) { - if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) { + if (type.hasIndexedUnboundDimensions()) { return new UnboundBuilder(type); } else { return new BoundBuilder(type); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index cc8e1602adb..cff17fdfd7c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -113,7 +113,7 @@ public interface Tensor { * @throws IllegalStateException if this does not have zero dimensions and one value */ default double asDouble() { - if (type().dimensions().size() > 0) + if (!type().dimensions().isEmpty()) throw new IllegalStateException("Require a dimensionless tensor but has " + type()); if (size() == 0) return Double.NaN; return valueIterator().next(); @@ -553,8 +553,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) @@ -565,8 +565,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type, DimensionSizes dimensionSizes) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b30b664a5f7..b7346348672 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.google.common.collect.ImmutableSet; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -86,16 +87,20 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final List dimensions; + private final Set dimensionNames; private final TensorType mappedSubtype; private final TensorType indexedSubtype; + private final int indexedUnBoundCount; // only used to initialize the "empty" instance private TensorType() { this.valueType = Value.DOUBLE; this.dimensions = List.of(); + this.dimensionNames = Set.of(); this.mappedSubtype = this; this.indexedSubtype = this; + indexedUnBoundCount = 0; } public TensorType(Value valueType, Collection dimensions) { @@ -103,12 +108,25 @@ public class TensorType { List dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); + ImmutableSet.Builder namesbuilder = new ImmutableSet.Builder<>(); + int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0; + for (Dimension dimension : dimensionList) { + namesbuilder.add(dimension.name()); + Dimension.Type type = dimension.type(); + switch (type) { + case indexedUnbound -> indexedUnBoundCount++; + case indexedBound -> indexedBoundCount++; + case mapped -> mappedCount++; + } + } + this.indexedUnBoundCount = indexedUnBoundCount; + dimensionNames = namesbuilder.build(); - if (dimensionList.stream().allMatch(Dimension::isIndexed)) { + if (mappedCount == 0) { mappedSubtype = empty; indexedSubtype = this; } - else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) { + else if ((indexedBoundCount + indexedUnBoundCount) == 0) { mappedSubtype = this; indexedSubtype = empty; } @@ -118,6 +136,11 @@ public class TensorType { } } + boolean hasIndexedDimensions() { return indexedSubtype != empty; } + boolean hasMappedDimensions() { return mappedSubtype != empty; } + boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); } + boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; } + static public Value combinedValueType(TensorType ... types) { List valueTypes = new ArrayList<>(); for (TensorType type : types) { @@ -161,7 +184,7 @@ public class TensorType { /** Returns an immutable set of the names of the dimensions of this */ public Set dimensionNames() { - return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); + return dimensionNames; } /** Returns the dimension with this name, or empty if not present */ -- cgit v1.2.3