diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 16 |
1 files changed, 8 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 9125b35ea5d..2635cbecb94 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -355,16 +355,16 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) { - int[] labels = new int[plan.resultType.rank()]; + long[] labels = new long[plan.resultType.rank()]; int out = 0; int m = 0; int a = 0; int b = 0; for (var how : plan.combineHow) { switch (how) { - case left -> labels[out++] = (int) leftOnly.numericLabel(a++); - case right -> labels[out++] = (int) rightOnly.numericLabel(b++); - case both -> labels[out++] = (int) match.numericLabel(m++); + case left -> labels[out++] = leftOnly.numericLabel(a++); + case right -> labels[out++] = rightOnly.numericLabel(b++); + case both -> labels[out++] = match.numericLabel(m++); case concat -> labels[out++] = concatDimIdx; default -> throw new IllegalArgumentException("cannot handle: " + how); } @@ -399,8 +399,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET CellVectorMapMap decompose(Tensor input, SplitHow how) { var iter = input.cellIterator(); - int[] commonLabels = new int[(int)how.numCommon()]; - int[] separateLabels = new int[(int)how.numSeparate()]; + long[] commonLabels = new long[(int)how.numCommon()]; + long[] separateLabels = new long[(int)how.numSeparate()]; CellVectorMapMap result = new CellVectorMapMap(); while (iter.hasNext()) { var cell = iter.next(); @@ -410,8 +410,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int separateIdx = 0; for (int i = 0; i < how.handleDims.size(); i++) { switch (how.handleDims.get(i)) { - case common -> commonLabels[commonIdx++] = (int) addr.numericLabel(i); - case separate -> separateLabels[separateIdx++] = (int) addr.numericLabel(i); + case common -> commonLabels[commonIdx++] = addr.numericLabel(i); + case separate -> separateLabels[separateIdx++] = addr.numericLabel(i); case concat -> ccDimIndex = addr.numericLabel(i); default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i)); } |