aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java16
1 files changed, 8 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 9125b35ea5d..2635cbecb94 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -355,16 +355,16 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) {
- int[] labels = new int[plan.resultType.rank()];
+ long[] labels = new long[plan.resultType.rank()];
int out = 0;
int m = 0;
int a = 0;
int b = 0;
for (var how : plan.combineHow) {
switch (how) {
- case left -> labels[out++] = (int) leftOnly.numericLabel(a++);
- case right -> labels[out++] = (int) rightOnly.numericLabel(b++);
- case both -> labels[out++] = (int) match.numericLabel(m++);
+ case left -> labels[out++] = leftOnly.numericLabel(a++);
+ case right -> labels[out++] = rightOnly.numericLabel(b++);
+ case both -> labels[out++] = match.numericLabel(m++);
case concat -> labels[out++] = concatDimIdx;
default -> throw new IllegalArgumentException("cannot handle: " + how);
}
@@ -399,8 +399,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
CellVectorMapMap decompose(Tensor input, SplitHow how) {
var iter = input.cellIterator();
- int[] commonLabels = new int[(int)how.numCommon()];
- int[] separateLabels = new int[(int)how.numSeparate()];
+ long[] commonLabels = new long[(int)how.numCommon()];
+ long[] separateLabels = new long[(int)how.numSeparate()];
CellVectorMapMap result = new CellVectorMapMap();
while (iter.hasNext()) {
var cell = iter.next();
@@ -410,8 +410,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int separateIdx = 0;
for (int i = 0; i < how.handleDims.size(); i++) {
switch (how.handleDims.get(i)) {
- case common -> commonLabels[commonIdx++] = (int) addr.numericLabel(i);
- case separate -> separateLabels[separateIdx++] = (int) addr.numericLabel(i);
+ case common -> commonLabels[commonIdx++] = addr.numericLabel(i);
+ case separate -> separateLabels[separateIdx++] = addr.numericLabel(i);
case concat -> ccDimIndex = addr.numericLabel(i);
default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i));
}