summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-12-19 23:02:04 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-12-19 23:02:04 +0100
commit35d59981840614bf4b877714ee88e273816c46d2 (patch)
treefba37b2e8bc9fcee46821821ab2886d371fcd696 /vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
parent067eb48b7d2fc062a74392b1c16f5538b5031d5b (diff)
Use longs for dimensions lengths in all API's
This is to be able to support tensor dimensions with more than 2B elements in the future without API change.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java36
1 files changed, 18 insertions, 18 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index fb5029fbfd6..f1dadba2a29 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -14,8 +14,8 @@ import java.util.stream.Collectors;
/**
* Factory of scalar Java functions.
* The purpose of this is to embellish anonymous functions with a runtime type
- * such that they can be inspected and will return a parseable toString.
- *
+ * such that they can be inspected and will return a parsable toString.
+ *
* @author bratseth
*/
@Beta
@@ -31,9 +31,9 @@ public class ScalarFunctions {
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
- public static Function<List<Integer>, Double> random() { return new Random(); }
- public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
- public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
+ public static Function<List<Long>, Double> random() { return new Random(); }
+ public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
+ public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
// Binary operators -----------------------------------------------------------------------------
@@ -60,7 +60,7 @@ public class ScalarFunctions {
public static class Multiply implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left * right; }
+ public double applyAsDouble(double left, double right) { return left * right; }
@Override
public String toString() { return "f(a,b)(a * b)"; }
}
@@ -100,26 +100,26 @@ public class ScalarFunctions {
// Variable-length operators -----------------------------------------------------------------------------
- public static class EqualElements implements Function<List<Integer>, Double> {
- private final ImmutableList<String> argumentNames;
+ public static class EqualElements implements Function<List<Long>, Double> {
+ private final ImmutableList<String> argumentNames;
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
if (values.isEmpty()) return 1.0;
- for (Integer value : values)
+ for (Long value : values)
if ( ! value.equals(values.get(0)))
return 0.0;
return 1.0;
}
@Override
- public String toString() {
+ public String toString() {
if (argumentNames.size() == 0) return "1";
if (argumentNames.size() == 1) return "1";
if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1);
-
+
StringBuilder b = new StringBuilder();
for (int i = 0; i < argumentNames.size() -1; i++) {
b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")");
@@ -130,25 +130,25 @@ public class ScalarFunctions {
}
}
- public static class Random implements Function<List<Integer>, Double> {
+ public static class Random implements Function<List<Long>, Double> {
@Override
- public Double apply(List<Integer> values) {
+ public Double apply(List<Long> values) {
return ThreadLocalRandom.current().nextDouble();
}
@Override
public String toString() { return "random"; }
}
- public static class SumElements implements Function<List<Integer>, Double> {
+ public static class SumElements implements Function<List<Long>, Double> {
private final ImmutableList<String> argumentNames;
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@Override
- public Double apply(List<Integer> values) {
- int sum = 0;
- for (Integer value : values)
+ public Double apply(List<Long> values) {
+ long sum = 0;
+ for (Long value : values)
sum += value;
return (double)sum;
}