diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index faa0ca36cb6..d4affe0ef9b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -67,7 +67,7 @@ public class Concat extends PrimitiveTensorFunction { DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); - int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); + long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); int[] aToIndexes = mapIndexes(a.type(), concatType); int[] bToIndexes = mapIndexes(b.type(), concatType); concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); @@ -75,7 +75,7 @@ public class Concat extends PrimitiveTensorFunction { return builder.build(); } - private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, + private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { @@ -129,8 +129,8 @@ public class Concat extends PrimitiveTensorFunction { DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); for (int i = 0; i < concatSizes.dimensions(); i++) { String currentDimension = concatType.dimensions().get(i).name(); - int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0); - int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0); + long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L); + long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L); if (currentDimension.equals(concatDimension)) concatSizes.set(i, aSize + bSize); else if (aSize != 0 && bSize != 0 && aSize!=bSize ) @@ -148,8 +148,8 @@ public class Concat extends PrimitiveTensorFunction { * (in some other dimension than the concat dimension) */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType concatType, int concatOffset, String concatDimension) { - int[] combinedLabels = new int[concatType.dimensions().size()]; + TensorType concatType, long concatOffset, String concatDimension) { + long[] combinedLabels = new long[concatType.dimensions().size()]; Arrays.fill(combinedLabels, -1); int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension @@ -179,15 +179,15 @@ public class Concat extends PrimitiveTensorFunction { * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) { + private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; if (concatDimension == toIndex) { - to[toIndex] = from.intLabel(i) + concatOffset; + to[toIndex] = from.numericLabel(i) + concatOffset; } else { - if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) return false; - to[toIndex] = from.intLabel(i); + if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + to[toIndex] = from.numericLabel(i); } } return true; |