diff options
10 files changed, 57 insertions, 33 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 174ce6332db..4a65b00a6a4 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1470,6 +1470,9 @@ ], "methods" : [ "public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)", + "public boolean hasIndexedDimensions()", + "public boolean hasMappedDimensions()", + "public boolean hasOnlyIndexedBoundDimensions()", "public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", "public com.yahoo.tensor.TensorType$Value valueType()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 93cdc3f630f..5d384e0329b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -164,9 +164,10 @@ public abstract class IndexedTensor implements Tensor { long valueIndex = 0; for (int i = 0; i < address.size(); i++) { - if (address.numericLabel(i) >= sizes.size(i)) + long label = address.numericLabel(i); + if (label >= sizes.size(i)) throw new IllegalArgumentException(address + " is not within the bounds of " + type); - valueIndex += sizes.productOfDimensionsAfter(i) * address.numericLabel(i); + valueIndex += sizes.productOfDimensionsAfter(i) * label; } return valueIndex; } @@ -281,7 +282,7 @@ public abstract class IndexedTensor implements Tensor { } public static Builder of(TensorType type) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type)); else return new UnboundBuilder(type); @@ -295,7 +296,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, float[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -309,7 +310,7 @@ public abstract class IndexedTensor implements Tensor { * must not be further mutated by the caller */ public static Builder of(TensorType type, double[] values) { - if (type.dimensions().stream().allMatch(d -> d instanceof TensorType.IndexedBoundDimension)) + if (type.hasOnlyIndexedBoundDimensions()) return of(type, BoundBuilder.dimensionSizesOf(type), values); else return new UnboundBuilder(type); @@ -615,11 +616,11 @@ public abstract class IndexedTensor implements Tensor { private final class ValueIterator implements Iterator<Double> { - private long count = 0; + private int count = 0; @Override public boolean hasNext() { - return count < size(); + return count < sizeAsInt(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index e6315dbef80..30dd1d6dc29 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -267,7 +267,7 @@ public class MixedTensor implements Tensor { * a temporary structure while finding dimension bounds. */ public static Builder of(TensorType type) { - if (type.dimensions().stream().anyMatch(d -> d instanceof TensorType.IndexedUnboundDimension)) { + if (type.hasIndexedUnboundDimensions()) { return new UnboundBuilder(type); } else { return new BoundBuilder(type); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index cc8e1602adb..cff17fdfd7c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -113,7 +113,7 @@ public interface Tensor { * @throws IllegalStateException if this does not have zero dimensions and one value */ default double asDouble() { - if (type().dimensions().size() > 0) + if (!type().dimensions().isEmpty()) throw new IllegalStateException("Require a dimensionless tensor but has " + type()); if (size() == 0) return Double.NaN; return valueIterator().next(); @@ -553,8 +553,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) @@ -565,8 +565,8 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type, DimensionSizes dimensionSizes) { - boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); + boolean containsIndexed = type.hasIndexedDimensions(); + boolean containsMapped = type.hasMappedDimensions(); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); if (containsMapped) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b30b664a5f7..82968476296 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.google.common.collect.ImmutableSet; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -86,16 +87,20 @@ public class TensorType { /** Sorted list of the dimensions of this */ private final List<Dimension> dimensions; + private final Set<String> dimensionNames; private final TensorType mappedSubtype; private final TensorType indexedSubtype; + private final int indexedUnBoundCount; // only used to initialize the "empty" instance private TensorType() { this.valueType = Value.DOUBLE; this.dimensions = List.of(); + this.dimensionNames = Set.of(); this.mappedSubtype = this; this.indexedSubtype = this; + indexedUnBoundCount = 0; } public TensorType(Value valueType, Collection<Dimension> dimensions) { @@ -103,12 +108,25 @@ public class TensorType { List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); + ImmutableSet.Builder<String> namesbuilder = new ImmutableSet.Builder<>(); + int indexedBoundCount = 0, indexedUnBoundCount = 0, mappedCount = 0; + for (Dimension dimension : dimensionList) { + namesbuilder.add(dimension.name()); + Dimension.Type type = dimension.type(); + switch (type) { + case indexedUnbound -> indexedUnBoundCount++; + case indexedBound -> indexedBoundCount++; + case mapped -> mappedCount++; + } + } + this.indexedUnBoundCount = indexedUnBoundCount; + dimensionNames = namesbuilder.build(); - if (dimensionList.stream().allMatch(Dimension::isIndexed)) { + if (mappedCount == 0) { mappedSubtype = empty; indexedSubtype = this; } - else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) { + else if ((indexedBoundCount + indexedUnBoundCount) == 0) { mappedSubtype = this; indexedSubtype = empty; } @@ -118,6 +136,11 @@ public class TensorType { } } + public boolean hasIndexedDimensions() { return indexedSubtype != empty; } + public boolean hasMappedDimensions() { return mappedSubtype != empty; } + public boolean hasOnlyIndexedBoundDimensions() { return !hasMappedDimensions() && ! hasIndexedUnboundDimensions(); } + boolean hasIndexedUnboundDimensions() { return indexedUnBoundCount > 0; } + static public Value combinedValueType(TensorType ... types) { List<Value> valueTypes = new ArrayList<>(); for (TensorType type : types) { @@ -161,7 +184,7 @@ public class TensorType { /** Returns an immutable set of the names of the dimensions of this */ public Set<String> dimensionNames() { - return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); + return dimensionNames; } /** Returns the dimension with this name, or empty if not present */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 8d8fe2b356f..866b710b72e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -134,7 +134,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return tensor; } else { // extend tensor with this dimension - if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) + if (tensor.type().hasMappedDimensions()) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 3b6e03186a3..b595b1a40cd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -40,7 +40,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 0) + if (!arguments.isEmpty()) throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); return this; } @@ -79,7 +79,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells.values()) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } @@ -133,7 +133,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { super(type); - if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) + if ( ! type.hasOnlyIndexedBoundDimensions()) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + "only indexed, bound dimensions, but this has " + type); this.cells = List.copyOf(cells); @@ -142,7 +142,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens public List<TensorFunction<NAMETYPE>> cellGeneratorFunctions() { var result = new ArrayList<TensorFunction<NAMETYPE>>(); for (var fun : cells) { - fun.asTensorFunction().ifPresent(tf -> result.add(tf)); + fun.asTensorFunction().ifPresent(result::add); } return result; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index aece782d296..2d5a0518747 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -92,11 +92,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N return false; if ( ! (a instanceof IndexedTensor)) return false; - if ( ! (a.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (a.type().hasOnlyIndexedBoundDimensions())) return false; if ( ! (b instanceof IndexedTensor)) return false; - if ( ! (b.type().dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))) + if ( ! (b.type().hasOnlyIndexedBoundDimensions())) return false; TensorType commonDimensions = dimensionsInCommon((IndexedTensor)a, (IndexedTensor)b); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 444ce02b14a..771b74633d9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -21,10 +21,8 @@ import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Slice; import java.util.ArrayList; -import java.util.HashSet; import java.util.Iterator; import java.util.List; -import java.util.Set; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -60,8 +58,7 @@ public class JsonFormat { // Short form for a single mapped dimension Cursor parent = root == null ? slime.setObject() : root.setObject("cells"); encodeSingleDimensionCells((MappedTensor) tensor, parent); - } else if (tensor instanceof MixedTensor && - tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped)) { + } else if (tensor instanceof MixedTensor && tensor.type().hasMappedDimensions()) { // Short form for a mixed tensor boolean singleMapped = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() == 1; Cursor parent = root == null ? ( singleMapped ? slime.setObject() : slime.setArray() ) @@ -204,7 +201,7 @@ public class JsonFormat { if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); - else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) + else if (root.field("values").valid() && ! builder.type().hasMappedDimensions()) decodeValuesAtTop(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); @@ -298,14 +295,14 @@ public class JsonFormat { /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ private static void decodeDirectValue(Inspector root, Tensor.Builder builder) { - boolean hasIndexed = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); - boolean hasMapped = builder.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); + boolean hasIndexed = builder.type().hasIndexedDimensions(); + boolean hasMapped = builder.type().hasMappedDimensions(); if (isArrayOfObjects(root)) decodeCells(root, builder); else if ( ! hasMapped) decodeValuesAtTop(root, builder); - else if (hasMapped && hasIndexed) + else if (hasIndexed) decodeBlocks(root, builder); else decodeCells(root, builder); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index d4b18c73f11..0a5c713f3e2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -55,8 +55,8 @@ public class TypedBinaryFormat { } private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { - boolean hasMappedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isMapped); - boolean hasIndexedDimensions = tensor.type().dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); + boolean hasMappedDimensions = tensor.type().hasMappedDimensions(); + boolean hasIndexedDimensions = tensor.type().hasIndexedDimensions(); boolean isMixed = hasMappedDimensions && hasIndexedDimensions; // TODO: Encoding as indexed if the implementation is mixed is not yet supported so use mixed format instead |