summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java21
1 files changed, 11 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index f212e66fc86..05999ff1240 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -2,6 +2,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -62,10 +63,10 @@ public class Concat extends PrimitiveTensorFunction {
IndexedTensor bIndexed = (IndexedTensor) b;
TensorType concatType = concatType(a, b);
- int[] concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
+ DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(aIndexed::size).orElseThrow(RuntimeException::new);
+ int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
@@ -123,22 +124,22 @@ public class Concat extends PrimitiveTensorFunction {
}
/** Returns the concrete (not type) dimension sizes resulting from combining a and b */
- private int[] concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
- int[] joinedSizes = new int[concatType.dimensions().size()];
- for (int i = 0; i < joinedSizes.length; i++) {
+ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
+ DimensionSizes.Builder joinedSizes = new DimensionSizes.Builder(concatType.dimensions().size());
+ for (int i = 0; i < joinedSizes.dimensions(); i++) {
String currentDimension = concatType.dimensions().get(i).name();
- int aSize = a.type().indexOfDimension(currentDimension).map(a::size).orElse(0);
- int bSize = b.type().indexOfDimension(currentDimension).map(b::size).orElse(0);
+ int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0);
+ int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0);
if (currentDimension.equals(concatDimension))
- joinedSizes[i] = aSize + bSize;
+ joinedSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " +
"concatenating " + a.type() + " and " + b.type() + " along dimension " +
concatDimension + ", but was " + aSize + " and " + bSize);
else
- joinedSizes[i] = Math.max(aSize, bSize);
+ joinedSizes.set(i, Math.max(aSize, bSize));
}
- return joinedSizes;
+ return joinedSizes.build();
}
/**