summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java36
6 files changed, 42 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index faa0ca36cb6..d4affe0ef9b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -67,7 +67,7 @@ public class Concat extends PrimitiveTensorFunction {
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
- int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
+ long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new);
int[] aToIndexes = mapIndexes(a.type(), concatType);
int[] bToIndexes = mapIndexes(b.type(), concatType);
concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder);
@@ -75,7 +75,7 @@ public class Concat extends PrimitiveTensorFunction {
return builder.build();
}
- private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
+ private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) {
@@ -129,8 +129,8 @@ public class Concat extends PrimitiveTensorFunction {
DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
for (int i = 0; i < concatSizes.dimensions(); i++) {
String currentDimension = concatType.dimensions().get(i).name();
- int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0);
- int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0);
+ long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L);
+ long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L);
if (currentDimension.equals(concatDimension))
concatSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
@@ -148,8 +148,8 @@ public class Concat extends PrimitiveTensorFunction {
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType concatType, int concatOffset, String concatDimension) {
- int[] combinedLabels = new int[concatType.dimensions().size()];
+ TensorType concatType, long concatOffset, String concatDimension) {
+ long[] combinedLabels = new long[concatType.dimensions().size()];
Arrays.fill(combinedLabels, -1);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
@@ -179,15 +179,15 @@ public class Concat extends PrimitiveTensorFunction {
* @return true if the mapping was successful, false if one of the destination positions was
* occupied by a different value
*/
- private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) {
+ private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) {
for (int i = 0; i < from.size(); i++) {
int toIndex = indexMap[i];
if (concatDimension == toIndex) {
- to[toIndex] = from.intLabel(i) + concatOffset;
+ to[toIndex] = from.numericLabel(i) + concatOffset;
}
else {
- if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) return false;
- to[toIndex] = from.intLabel(i);
+ if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false;
+ to[toIndex] = from.numericLabel(i);
}
}
return true;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index c75d8ee4753..653be8dacf0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -17,7 +17,7 @@ import java.util.stream.Stream;
public class Diag extends CompositeTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> diagFunction;
+ private final Function<List<Long>, Double> diagFunction;
public Diag(TensorType type) {
this.type = type;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index e42d25197e2..ef2770c04f5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -22,17 +22,17 @@ import java.util.function.Function;
public class Generate extends PrimitiveTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> generator;
+ private final Function<List<Long>, Double> generator;
/**
* Creates a generated tensor
*
* @param type the type of the tensor
- * @param generator the function generating values from a list of ints specifying the indexes of the
+ * @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, Function<List<Integer>, Double> generator) {
+ public Generate(TensorType type, Function<List<Long>, Double> generator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
Objects.requireNonNull(generator, "The argument function cannot be null");
validateType(type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index ff887e3e9a6..174a8e4c435 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -56,8 +56,8 @@ public class Join extends PrimitiveTensorFunction {
if (aDim.name().equals(bDim.name())) { // include
if (aDim.isIndexed() && bDim.isIndexed()) {
if (aDim.size().isPresent() || bDim.size().isPresent())
- typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
- bDim.size().orElse(Integer.MAX_VALUE)));
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Long.MAX_VALUE),
+ bDim.size().orElse(Long.MAX_VALUE)));
else
typeBuilder.indexed(aDim.name());
}
@@ -118,11 +118,11 @@ public class Join extends PrimitiveTensorFunction {
}
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
- int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
+ long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
Iterator<Double> bIterator = b.valueIterator();
- IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build());
- for (int i = 0; i < joinedLength; i++)
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build());
+ for (int i = 0; i < joinedRank; i++)
builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i);
return builder.build();
}
@@ -169,10 +169,10 @@ public class Join extends PrimitiveTensorFunction {
return builder.build();
}
- private void joinSubspaces(Iterator<Double> subspace, int subspaceSize,
- Iterator<Tensor.Cell> superspace, int superspaceSize,
+ private void joinSubspaces(Iterator<Double> subspace, long subspaceSize,
+ Iterator<Tensor.Cell> superspace, long superspaceSize,
boolean reversedArgumentOrder, IndexedTensor.Builder builder) {
- int joinedLength = Math.min(subspaceSize, superspaceSize);
+ long joinedLength = Math.min(subspaceSize, superspaceSize);
if (reversedArgumentOrder) {
for (int i = 0; i < joinedLength; i++) {
Tensor.Cell supercell = superspace.next();
@@ -281,7 +281,7 @@ public class Join extends PrimitiveTensorFunction {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
if (retainDimensions.contains(addressType.dimensions().get(i).name()))
- builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
+ builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i));
return builder.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index a56f82b026a..8e7f4e4c773 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -18,7 +18,7 @@ import java.util.stream.Stream;
public class Range extends CompositeTensorFunction {
private final TensorType type;
- private final Function<List<Integer>, Double> rangeFunction;
+ private final Function<List<Long>, Double> rangeFunction;
public Range(TensorType type) {
this.type = type;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index fb5029fbfd6..f1dadba2a29 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -14,8 +14,8 @@ import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
* The purpose of this is to embellish anonymous functions with a runtime type
- * such that they can be inspected and will return a parseable toString.
- *
+ * such that they can be inspected and will return a parsable toString.
+ *
* @author bratseth
*/
@Beta
@@ -31,9 +31,9 @@ public class ScalarFunctions {
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
- public static Function<List<Integer>, Double> random() { return new Random(); }
- public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
- public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
+ public static Function<List<Long>, Double> random() { return new Random(); }
+ public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
+ public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
// Binary operators -----------------------------------------------------------------------------
@@ -60,7 +60,7 @@ public class ScalarFunctions {
public static class Multiply implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left * right; }
+ public double applyAsDouble(double left, double right) { return left * right; }
@Override
public String toString() { return "f(a,b)(a * b)"; }
}
@@ -100,26 +100,26 @@ public class ScalarFunctions {
// Variable-length operators -----------------------------------------------------------------------------
- public static class EqualElements implements Function<List<Integer>, Double> {
- private final ImmutableList<String> argumentNames;
+ public static class EqualElements implements Function<List<Long>, Double> {
+ private final ImmutableList<String> argumentNames;
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
if (values.isEmpty()) return 1.0;
- for (Integer value : values)
+ for (Long value : values)
if ( ! value.equals(values.get(0)))
return 0.0;
return 1.0;
}
@Override
- public String toString() {
+ public String toString() {
if (argumentNames.size() == 0) return "1";
if (argumentNames.size() == 1) return "1";
if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1);
-
+
StringBuilder b = new StringBuilder();
for (int i = 0; i < argumentNames.size() -1; i++) {
b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
@@ -130,25 +130,25 @@ public class ScalarFunctions {
}
}
- public static class Random implements Function<List<Integer>, Double> {
+ public static class Random implements Function<List<Long>, Double> {
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
return ThreadLocalRandom.current().nextDouble();
}
@Override
public String toString() { return "random"; }
}
- public static class SumElements implements Function<List<Integer>, Double> {
+ public static class SumElements implements Function<List<Long>, Double> {
private final ImmutableList<String> argumentNames;
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
- int sum = 0;
- for (Integer value : values)
+ public Double apply(List<Long> values) {
+ long sum = 0;
+ for (Long value : values)
sum += value;
return (double)sum;
}