aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-19 23:02:04 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-19 23:02:04 +0100
commit35d59981840614bf4b877714ee88e273816c46d2 (patch)
treefba37b2e8bc9fcee46821821ab2886d371fcd696 /vespajlib/src/main/java/com/yahoo/tensor/functions
parent067eb48b7d2fc062a74392b1c16f5538b5031d5b (diff)
Use longs for dimensions lengths in all API's
This is to be able to support tensor dimensions with more than 2B elements in the future without API change.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java36
6 files changed, 42 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index faa0ca36cb6..d4affe0ef9b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -67,7 +67,7 @@ public class Concat extends PrimitiveTensorFunction {
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
+ long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
@@ -75,7 +75,7 @@ public class Concat extends PrimitiveTensorFunction {
return builder.build();
}
- private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
@@ -129,8 +129,8 @@ public class Concat extends PrimitiveTensorFunction {
DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
for (int i = 0; i < concatSizes.dimensions(); i++) {
String currentDimension = concatType.dimensions().get(i).name();
- int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0);
- int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0);
+ long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
+ long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
if (currentDimension.equals(concatDimension))
concatSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
@@ -148,8 +148,8 @@ public class Concat extends PrimitiveTensorFunction {
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType concatType, int concatOffset, String concatDimension) {
- int[] combinedLabels = new int[concatType.dimensions().size()];
+ TensorType concatType, long concatOffset, String concatDimension) {
+ long[] combinedLabels = new long[concatType.dimensions().size()];
Arrays.fill(combinedLabels, -1);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
@@ -179,15 +179,15 @@ public class Concat extends PrimitiveTensorFunction {
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) {
+ private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
if (concatDimension == toIndex) {
- to[toIndex] = from.intLabel(i) + concatOffset;
+ to[toIndex] = from.numericLabel(i) + concatOffset;
}
else {
- if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) return false;
- to[toIndex] = from.intLabel(i);
+ if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ to[toIndex] = from.numericLabel(i);
}
}
return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index c75d8ee4753..653be8dacf0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -17,7 +17,7 @@ import java.util.stream.Stream;
public class Diag extends CompositeTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> diagFunction;
+ private final Function<List<Long>, Double> diagFunction;
public Diag(TensorType type) {
this.type = type;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index e42d25197e2..ef2770c04f5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -22,17 +22,17 @@ import java.util.function.Function;
public class Generate extends PrimitiveTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> generator;
+ private final Function<List<Long>, Double> generator;
/**
* Creates a generated tensor
*
* @param type the type of the tensor
- * @param generator the function generating values from a list of ints specifying the indexes of the
+ * @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, Function<List<Integer>, Double> generator) {
+ public Generate(TensorType type, Function<List<Long>, Double> generator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
Objects.requireNonNull(generator, "The argument function cannot be null");
validateType(type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index ff887e3e9a6..174a8e4c435 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -56,8 +56,8 @@ public class Join extends PrimitiveTensorFunction {
if (aDim.name().equals(bDim.name())) { // include
if (aDim.isIndexed() && bDim.isIndexed()) {
if (aDim.size().isPresent() || bDim.size().isPresent())
- typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
- bDim.size().orElse(Integer.MAX_VALUE)));
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Long.MAX_VALUE),
+ bDim.size().orElse(Long.MAX_VALUE)));
else
typeBuilder.indexed(aDim.name());
}
@@ -118,11 +118,11 @@ public class Join extends PrimitiveTensorFunction {
}
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+ long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
- IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build());
- for (int i = 0; i < joinedLength; i++)
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
+ for (int i = 0; i < joinedRank; i++)
builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
return builder.build();
}
@@ -169,10 +169,10 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private void joinSubspaces(Iterator<Double> subspace, int subspaceSize,
- Iterator<Tensor.Cell> superspace, int superspaceSize,
+ private void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
+ Iterator<Tensor.Cell> superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder) {
- int joinedLength = Math.min(subspaceSize, superspaceSize);
+ long joinedLength = Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
@@ -281,7 +281,7 @@ public class Join extends PrimitiveTensorFunction {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
+ builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
return builder.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index a56f82b026a..8e7f4e4c773 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -18,7 +18,7 @@ import java.util.stream.Stream;
public class Range extends CompositeTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> rangeFunction;
+ private final Function<List<Long>, Double> rangeFunction;
public Range(TensorType type) {
this.type = type;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index fb5029fbfd6..f1dadba2a29 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -14,8 +14,8 @@ import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
* The purpose of this is to embellish anonymous functions with a runtime type
- * such that they can be inspected and will return a parseable toString.
- *
+ * such that they can be inspected and will return a parsable toString.
+ *
* @author bratseth
*/
@Beta
@@ -31,9 +31,9 @@ public class ScalarFunctions {
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
- public static Function<List<Integer>, Double> random() { return new Random(); }
- public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
- public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
+ public static Function<List<Long>, Double> random() { return new Random(); }
+ public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
+ public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
// Binary operators -----------------------------------------------------------------------------
@@ -60,7 +60,7 @@ public class ScalarFunctions {
public static class Multiply implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left * right; }
+ public double applyAsDouble(double left, double right) { return left * right; }
@Override
public String toString() { return "f(a,b)(a * b)"; }
}
@@ -100,26 +100,26 @@ public class ScalarFunctions {
// Variable-length operators -----------------------------------------------------------------------------
- public static class EqualElements implements Function<List<Integer>, Double> {
- private final ImmutableList<String> argumentNames;
+ public static class EqualElements implements Function<List<Long>, Double> {
+ private final ImmutableList<String> argumentNames;
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
if (values.isEmpty()) return 1.0;
- for (Integer value : values)
+ for (Long value : values)
if ( ! value.equals(values.get(0)))
return 0.0;
return 1.0;
}
@Override
- public String toString() {
+ public String toString() {
if (argumentNames.size() == 0) return "1";
if (argumentNames.size() == 1) return "1";
if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1);
-
+
StringBuilder b = new StringBuilder();
for (int i = 0; i < argumentNames.size() -1; i++) {
b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
@@ -130,25 +130,25 @@ public class ScalarFunctions {
}
}
- public static class Random implements Function<List<Integer>, Double> {
+ public static class Random implements Function<List<Long>, Double> {
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
return ThreadLocalRandom.current().nextDouble();
}
@Override
public String toString() { return "random"; }
}
- public static class SumElements implements Function<List<Integer>, Double> {
+ public static class SumElements implements Function<List<Long>, Double> {
private final ImmutableList<String> argumentNames;
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
- int sum = 0;
- for (Integer value : values)
+ public Double apply(List<Long> values) {
+ long sum = 0;
+ for (Long value : values)
sum += value;
return (double)sum;
}