diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 21 |
1 files changed, 11 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index f212e66fc86..05999ff1240 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -62,10 +63,10 @@ public class Concat extends PrimitiveTensorFunction { IndexedTensor bIndexed = (IndexedTensor) b; TensorType concatType = concatType(a, b); - int[] concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); + DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); - int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new); + int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); int[] aToIndexes = mapIndexes(a.type(), concatType); int[] bToIndexes = mapIndexes(b.type(), concatType); concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); @@ -123,22 +124,22 @@ public class Concat extends PrimitiveTensorFunction { } /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ - private int[] concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { - int[] joinedSizes = new int[concatType.dimensions().size()]; - for (int i = 0; i < joinedSizes.length; i++) { + private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { + DimensionSizes.Builder joinedSizes = new DimensionSizes.Builder(concatType.dimensions().size()); + for (int i = 0; i < joinedSizes.dimensions(); i++) { String currentDimension = concatType.dimensions().get(i).name(); - int aSize = a.type().indexOfDimension(currentDimension).map(a::size).orElse(0); - int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0); + int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0); + int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0); if (currentDimension.equals(concatDimension)) - joinedSizes[i] = aSize + bSize; + joinedSizes.set(i, aSize + bSize); else if (aSize != 0 && bSize != 0 && aSize!=bSize ) throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " + "concatenating " + a.type() + " and " + b.type() + " along dimension " + concatDimension + ", but was " + aSize + " and " + bSize); else - joinedSizes[i] = Math.max(aSize, bSize); + joinedSizes.set(i, Math.max(aSize, bSize)); } - return joinedSizes; + return joinedSizes.build(); } /** |