aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-19 23:02:04 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-19 23:02:04 +0100
commit35d59981840614bf4b877714ee88e273816c46d2 (patch)
treefba37b2e8bc9fcee46821821ab2886d371fcd696 /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parent067eb48b7d2fc062a74392b1c16f5538b5031d5b (diff)
Use longs for dimensions lengths in all API's
This is to be able to support tensor dimensions with more than 2B elements in the future without API change.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java18
1 files changed, 9 insertions, 9 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index ff887e3e9a6..174a8e4c435 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -56,8 +56,8 @@ public class Join extends PrimitiveTensorFunction {
if (aDim.name().equals(bDim.name())) { // include
if (aDim.isIndexed() && bDim.isIndexed()) {
if (aDim.size().isPresent() || bDim.size().isPresent())
- typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
- bDim.size().orElse(Integer.MAX_VALUE)));
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Long.MAX_VALUE),
+ bDim.size().orElse(Long.MAX_VALUE)));
else
typeBuilder.indexed(aDim.name());
}
@@ -118,11 +118,11 @@ public class Join extends PrimitiveTensorFunction {
}
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+ long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
- IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build());
- for (int i = 0; i < joinedLength; i++)
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
+ for (int i = 0; i < joinedRank; i++)
builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
return builder.build();
}
@@ -169,10 +169,10 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private void joinSubspaces(Iterator<Double> subspace, int subspaceSize,
- Iterator<Tensor.Cell> superspace, int superspaceSize,
+ private void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
+ Iterator<Tensor.Cell> superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder) {
- int joinedLength = Math.min(subspaceSize, superspaceSize);
+ long joinedLength = Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
@@ -281,7 +281,7 @@ public class Join extends PrimitiveTensorFunction {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
+ builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
return builder.build();
}