diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2018-06-07 08:21:03 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-06-07 08:21:03 +0200 |
commit | 84bac783dd9a5ee5a76db09ab4e7070d90364656 (patch) | |
tree | 3f192fabc3eb2e19d27f17793fadcf22146fc685 /vespajlib/src | |
parent | 0ef3ee3405bcda518cd0155b8ab80e1f7a4b2407 (diff) | |
parent | 0bf235c481d24d627c82901a84bef585fe84bbb2 (diff) |
Merge pull request #6111 from vespa-engine/lesters/revert-revert-onnx
Refactor ONNX and TF import to use same code base
Diffstat (limited to 'vespajlib/src')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java | 123 |
1 files changed, 121 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 944755c9db2..3a66eef258d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -22,22 +22,37 @@ public class ScalarFunctions { public static DoubleBinaryOperator add() { return new Add(); } public static DoubleBinaryOperator divide() { return new Divide(); } public static DoubleBinaryOperator equal() { return new Equal(); } + public static DoubleBinaryOperator greater() { return new Greater(); } + public static DoubleBinaryOperator less() { return new Less(); } public static DoubleBinaryOperator max() { return new Max(); } public static DoubleBinaryOperator min() { return new Min(); } + public static DoubleBinaryOperator mean() { return new Mean(); } public static DoubleBinaryOperator multiply() { return new Multiply(); } + public static DoubleBinaryOperator pow() { return new Pow(); } public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); } public static DoubleBinaryOperator subtract() { return new Subtract(); } + public static DoubleUnaryOperator abs() { return new Abs(); } public static DoubleUnaryOperator acos() { return new Acos(); } + public static DoubleUnaryOperator asin() { return new Asin(); } + public static DoubleUnaryOperator atan() { return new Atan(); } + public static DoubleUnaryOperator ceil() { return new Ceil(); } + public static DoubleUnaryOperator cos() { return new Cos(); } public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator floor() { return new Floor(); } + public static DoubleUnaryOperator log() { return new Log(); } + public static DoubleUnaryOperator neg() { return new Neg(); } + public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); } public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } public static DoubleUnaryOperator selu() { return new Selu(); } + public static DoubleUnaryOperator sin() { return new Sin(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } + public static DoubleUnaryOperator tan() { return new Tan(); } + public static DoubleUnaryOperator tanh() { return new Tanh(); } public static Function<List<Long>, Double> random() { return new Random(); } public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } @@ -59,6 +74,20 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a==b)"; } } + public static class Greater implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; } + @Override + public String toString() { return "f(a,b)(a > b)"; } + } + + public static class Less implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; } + @Override + public String toString() { return "f(a,b)(a < b)"; } + } + public static class Max implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return Math.max(left, right); } @@ -73,6 +102,13 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(min(a, b))"; } } + public static class Mean implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return (left + right) / 2; } + @Override + public String toString() { return "f(a,b)((a + b) / 2)"; } + } + public static class Multiply implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left * right; } @@ -80,6 +116,13 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a * b)"; } } + public static class Pow implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return Math.pow(left, right); } + @Override + public String toString() { return "f(a,b)(pow(a, b))"; } + } + public static class Divide implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left / right; } @@ -104,6 +147,13 @@ public class ScalarFunctions { // Unary operators ------------------------------------------------------------------------------ + public static class Abs implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.abs(operand); } + @Override + public String toString() { return "f(a)(fabs(a))"; } + } + public static class Acos implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return Math.acos(operand); } @@ -111,6 +161,34 @@ public class ScalarFunctions { public String toString() { return "f(a)(acos(a))"; } } + public static class Asin implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.asin(operand); } + @Override + public String toString() { return "f(a)(asin(a))"; } + } + + public static class Atan implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.atan(operand); } + @Override + public String toString() { return "f(a)(atan(a))"; } + } + + public static class Ceil implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.ceil(operand); } + @Override + public String toString() { return "f(a)(ceil(a))"; } + } + + public static class Cos implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.cos(operand); } + @Override + public String toString() { return "f(a)(cos(a))"; } + } + public static class Elu implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; } @@ -132,6 +210,26 @@ public class ScalarFunctions { public String toString() { return "f(a)(floor(a))"; } } + public static class Log implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.log(operand); } + @Override + public String toString() { return "f(a)(log(a))"; } + } + + public static class Neg implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return -operand; } + @Override + public String toString() { return "f(a)(-a)"; } + } + + public static class Reciprocal implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return 1.0 / operand; } + @Override + public String toString() { return "f(a)(1 / a)"; } + } public static class Relu implements DoubleUnaryOperator { @Override @@ -150,6 +248,13 @@ public class ScalarFunctions { public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); } } + public static class Sin implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.sin(operand); } + @Override + public String toString() { return "f(a)(sin(a))"; } + } + public static class Rsqrt implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); } @@ -172,15 +277,29 @@ public class ScalarFunctions { } public static class Square implements DoubleUnaryOperator { - @Override public double applyAsDouble(double operand) { return operand * operand; } - @Override public String toString() { return "f(a)(a * a)"; } + } + public static class Tan implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.tan(operand); } + @Override + public String toString() { return "f(a)(tan(a))"; } } + public static class Tanh implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.tanh(operand); } + @Override + public String toString() { return "f(a)(tanh(a))"; } + } + + + + // Variable-length operators ----------------------------------------------------------------------------- public static class EqualElements implements Function<List<Long>, Double> { |