diff options
2 files changed, 143 insertions, 7 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java index 470b04fb44a..3fe92440cae 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java @@ -37,24 +37,41 @@ public class GraphImporter { String modelName = graph.name(); switch (node.getOpType().toLowerCase()) { + case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); + case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); + case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new ConcatV2(modelName, nodeName, inputs); + case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); + case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); case "identity": return new Identity(modelName, nodeName, inputs); - case "reshape": return new Reshape(modelName, nodeName, inputs); - case "shape": return new Shape(modelName, nodeName, inputs); + case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); + case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); case "matmul": return new MatMul(modelName, nodeName, inputs); case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); + case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean()); case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); + case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "shape": return new Shape(modelName, nodeName, inputs); + case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); - case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); - case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); + case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); } IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 944755c9db2..3a66eef258d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -22,22 +22,37 @@ public class ScalarFunctions { public static DoubleBinaryOperator add() { return new Add(); } public static DoubleBinaryOperator divide() { return new Divide(); } public static DoubleBinaryOperator equal() { return new Equal(); } + public static DoubleBinaryOperator greater() { return new Greater(); } + public static DoubleBinaryOperator less() { return new Less(); } public static DoubleBinaryOperator max() { return new Max(); } public static DoubleBinaryOperator min() { return new Min(); } + public static DoubleBinaryOperator mean() { return new Mean(); } public static DoubleBinaryOperator multiply() { return new Multiply(); } + public static DoubleBinaryOperator pow() { return new Pow(); } public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); } public static DoubleBinaryOperator subtract() { return new Subtract(); } + public static DoubleUnaryOperator abs() { return new Abs(); } public static DoubleUnaryOperator acos() { return new Acos(); } + public static DoubleUnaryOperator asin() { return new Asin(); } + public static DoubleUnaryOperator atan() { return new Atan(); } + public static DoubleUnaryOperator ceil() { return new Ceil(); } + public static DoubleUnaryOperator cos() { return new Cos(); } public static DoubleUnaryOperator elu() { return new Elu(); } public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator floor() { return new Floor(); } + public static DoubleUnaryOperator log() { return new Log(); } + public static DoubleUnaryOperator neg() { return new Neg(); } + public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); } public static DoubleUnaryOperator relu() { return new Relu(); } public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); } public static DoubleUnaryOperator selu() { return new Selu(); } + public static DoubleUnaryOperator sin() { return new Sin(); } public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } + public static DoubleUnaryOperator tan() { return new Tan(); } + public static DoubleUnaryOperator tanh() { return new Tanh(); } public static Function<List<Long>, Double> random() { return new Random(); } public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } @@ -59,6 +74,20 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a==b)"; } } + public static class Greater implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; } + @Override + public String toString() { return "f(a,b)(a > b)"; } + } + + public static class Less implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; } + @Override + public String toString() { return "f(a,b)(a < b)"; } + } + public static class Max implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return Math.max(left, right); } @@ -73,6 +102,13 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(min(a, b))"; } } + public static class Mean implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return (left + right) / 2; } + @Override + public String toString() { return "f(a,b)((a + b) / 2)"; } + } + public static class Multiply implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left * right; } @@ -80,6 +116,13 @@ public class ScalarFunctions { public String toString() { return "f(a,b)(a * b)"; } } + public static class Pow implements DoubleBinaryOperator { + @Override + public double applyAsDouble(double left, double right) { return Math.pow(left, right); } + @Override + public String toString() { return "f(a,b)(pow(a, b))"; } + } + public static class Divide implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left / right; } @@ -104,6 +147,13 @@ public class ScalarFunctions { // Unary operators ------------------------------------------------------------------------------ + public static class Abs implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.abs(operand); } + @Override + public String toString() { return "f(a)(fabs(a))"; } + } + public static class Acos implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return Math.acos(operand); } @@ -111,6 +161,34 @@ public class ScalarFunctions { public String toString() { return "f(a)(acos(a))"; } } + public static class Asin implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.asin(operand); } + @Override + public String toString() { return "f(a)(asin(a))"; } + } + + public static class Atan implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.atan(operand); } + @Override + public String toString() { return "f(a)(atan(a))"; } + } + + public static class Ceil implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.ceil(operand); } + @Override + public String toString() { return "f(a)(ceil(a))"; } + } + + public static class Cos implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.cos(operand); } + @Override + public String toString() { return "f(a)(cos(a))"; } + } + public static class Elu implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; } @@ -132,6 +210,26 @@ public class ScalarFunctions { public String toString() { return "f(a)(floor(a))"; } } + public static class Log implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.log(operand); } + @Override + public String toString() { return "f(a)(log(a))"; } + } + + public static class Neg implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return -operand; } + @Override + public String toString() { return "f(a)(-a)"; } + } + + public static class Reciprocal implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return 1.0 / operand; } + @Override + public String toString() { return "f(a)(1 / a)"; } + } public static class Relu implements DoubleUnaryOperator { @Override @@ -150,6 +248,13 @@ public class ScalarFunctions { public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); } } + public static class Sin implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.sin(operand); } + @Override + public String toString() { return "f(a)(sin(a))"; } + } + public static class Rsqrt implements DoubleUnaryOperator { @Override public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); } @@ -172,15 +277,29 @@ public class ScalarFunctions { } public static class Square implements DoubleUnaryOperator { - @Override public double applyAsDouble(double operand) { return operand * operand; } - @Override public String toString() { return "f(a)(a * a)"; } + } + public static class Tan implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.tan(operand); } + @Override + public String toString() { return "f(a)(tan(a))"; } } + public static class Tanh implements DoubleUnaryOperator { + @Override + public double applyAsDouble(double operand) { return Math.tanh(operand); } + @Override + public String toString() { return "f(a)(tanh(a))"; } + } + + + + // Variable-length operators ----------------------------------------------------------------------------- public static class EqualElements implements Function<List<Long>, Double> { |