summaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2018-06-07 08:21:03 +0200
committerGitHub <noreply@github.com>2018-06-07 08:21:03 +0200
commit84bac783dd9a5ee5a76db09ab4e7070d90364656 (patch)
tree3f192fabc3eb2e19d27f17793fadcf22146fc685 /vespajlib/src
parent0ef3ee3405bcda518cd0155b8ab80e1f7a4b2407 (diff)
parent0bf235c481d24d627c82901a84bef585fe84bbb2 (diff)
Merge pull request #6111 from vespa-engine/lesters/revert-revert-onnx
Refactor ONNX and TF import to use same code base
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java123
1 files changed, 121 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 944755c9db2..3a66eef258d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -22,22 +22,37 @@ public class ScalarFunctions {
public static DoubleBinaryOperator add() { return new Add(); }
public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
+ public static DoubleBinaryOperator greater() { return new Greater(); }
+ public static DoubleBinaryOperator less() { return new Less(); }
public static DoubleBinaryOperator max() { return new Max(); }
public static DoubleBinaryOperator min() { return new Min(); }
+ public static DoubleBinaryOperator mean() { return new Mean(); }
public static DoubleBinaryOperator multiply() { return new Multiply(); }
+ public static DoubleBinaryOperator pow() { return new Pow(); }
public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
public static DoubleBinaryOperator subtract() { return new Subtract(); }
+ public static DoubleUnaryOperator abs() { return new Abs(); }
public static DoubleUnaryOperator acos() { return new Acos(); }
+ public static DoubleUnaryOperator asin() { return new Asin(); }
+ public static DoubleUnaryOperator atan() { return new Atan(); }
+ public static DoubleUnaryOperator ceil() { return new Ceil(); }
+ public static DoubleUnaryOperator cos() { return new Cos(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
+ public static DoubleUnaryOperator log() { return new Log(); }
+ public static DoubleUnaryOperator neg() { return new Neg(); }
+ public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
public static DoubleUnaryOperator selu() { return new Selu(); }
+ public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleUnaryOperator tan() { return new Tan(); }
+ public static DoubleUnaryOperator tanh() { return new Tanh(); }
public static Function<List<Long>, Double> random() { return new Random(); }
public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
@@ -59,6 +74,20 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a==b)"; }
}
+ public static class Greater implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a > b)"; }
+ }
+
+ public static class Less implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a < b)"; }
+ }
+
public static class Max implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@@ -73,6 +102,13 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(min(a, b))"; }
}
+ public static class Mean implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return (left + right) / 2; }
+ @Override
+ public String toString() { return "f(a,b)((a + b) / 2)"; }
+ }
+
public static class Multiply implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
@@ -80,6 +116,13 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a * b)"; }
}
+ public static class Pow implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
+ @Override
+ public String toString() { return "f(a,b)(pow(a, b))"; }
+ }
+
public static class Divide implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left / right; }
@@ -104,6 +147,13 @@ public class ScalarFunctions {
// Unary operators ------------------------------------------------------------------------------
+ public static class Abs implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.abs(operand); }
+ @Override
+ public String toString() { return "f(a)(fabs(a))"; }
+ }
+
public static class Acos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.acos(operand); }
@@ -111,6 +161,34 @@ public class ScalarFunctions {
public String toString() { return "f(a)(acos(a))"; }
}
+ public static class Asin implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.asin(operand); }
+ @Override
+ public String toString() { return "f(a)(asin(a))"; }
+ }
+
+ public static class Atan implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.atan(operand); }
+ @Override
+ public String toString() { return "f(a)(atan(a))"; }
+ }
+
+ public static class Ceil implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.ceil(operand); }
+ @Override
+ public String toString() { return "f(a)(ceil(a))"; }
+ }
+
+ public static class Cos implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.cos(operand); }
+ @Override
+ public String toString() { return "f(a)(cos(a))"; }
+ }
+
public static class Elu implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; }
@@ -132,6 +210,26 @@ public class ScalarFunctions {
public String toString() { return "f(a)(floor(a))"; }
}
+ public static class Log implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.log(operand); }
+ @Override
+ public String toString() { return "f(a)(log(a))"; }
+ }
+
+ public static class Neg implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return -operand; }
+ @Override
+ public String toString() { return "f(a)(-a)"; }
+ }
+
+ public static class Reciprocal implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return 1.0 / operand; }
+ @Override
+ public String toString() { return "f(a)(1 / a)"; }
+ }
public static class Relu implements DoubleUnaryOperator {
@Override
@@ -150,6 +248,13 @@ public class ScalarFunctions {
public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); }
}
+ public static class Sin implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.sin(operand); }
+ @Override
+ public String toString() { return "f(a)(sin(a))"; }
+ }
+
public static class Rsqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@@ -172,15 +277,29 @@ public class ScalarFunctions {
}
public static class Square implements DoubleUnaryOperator {
-
@Override
public double applyAsDouble(double operand) { return operand * operand; }
-
@Override
public String toString() { return "f(a)(a * a)"; }
+ }
+ public static class Tan implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.tan(operand); }
+ @Override
+ public String toString() { return "f(a)(tan(a))"; }
}
+ public static class Tanh implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.tanh(operand); }
+ @Override
+ public String toString() { return "f(a)(tanh(a))"; }
+ }
+
+
+
+
// Variable-length operators -----------------------------------------------------------------------------
public static class EqualElements implements Function<List<Long>, Double> {