summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java123
2 files changed, 143 insertions, 7 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
index 470b04fb44a..3fe92440cae 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
@@ -37,24 +37,41 @@ public class GraphImporter {
String modelName = graph.name();
switch (node.getOpType().toLowerCase()) {
+ case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
+ case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
+ case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
case "concat": return new ConcatV2(modelName, nodeName, inputs);
+ case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
+ case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
case "identity": return new Identity(modelName, nodeName, inputs);
- case "reshape": return new Reshape(modelName, nodeName, inputs);
- case "shape": return new Shape(modelName, nodeName, inputs);
+ case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
+ case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
case "matmul": return new MatMul(modelName, nodeName, inputs);
case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
+ case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
+ case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
- case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
+ case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
}
IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 944755c9db2..3a66eef258d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -22,22 +22,37 @@ public class ScalarFunctions {
public static DoubleBinaryOperator add() { return new Add(); }
public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
+ public static DoubleBinaryOperator greater() { return new Greater(); }
+ public static DoubleBinaryOperator less() { return new Less(); }
public static DoubleBinaryOperator max() { return new Max(); }
public static DoubleBinaryOperator min() { return new Min(); }
+ public static DoubleBinaryOperator mean() { return new Mean(); }
public static DoubleBinaryOperator multiply() { return new Multiply(); }
+ public static DoubleBinaryOperator pow() { return new Pow(); }
public static DoubleBinaryOperator squareddifference() { return new SquaredDifference(); }
public static DoubleBinaryOperator subtract() { return new Subtract(); }
+ public static DoubleUnaryOperator abs() { return new Abs(); }
public static DoubleUnaryOperator acos() { return new Acos(); }
+ public static DoubleUnaryOperator asin() { return new Asin(); }
+ public static DoubleUnaryOperator atan() { return new Atan(); }
+ public static DoubleUnaryOperator ceil() { return new Ceil(); }
+ public static DoubleUnaryOperator cos() { return new Cos(); }
public static DoubleUnaryOperator elu() { return new Elu(); }
public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator floor() { return new Floor(); }
+ public static DoubleUnaryOperator log() { return new Log(); }
+ public static DoubleUnaryOperator neg() { return new Neg(); }
+ public static DoubleUnaryOperator reciprocal() { return new Reciprocal(); }
public static DoubleUnaryOperator relu() { return new Relu(); }
public static DoubleUnaryOperator rsqrt() { return new Rsqrt(); }
public static DoubleUnaryOperator selu() { return new Selu(); }
+ public static DoubleUnaryOperator sin() { return new Sin(); }
public static DoubleUnaryOperator sigmoid() { return new Sigmoid(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleUnaryOperator tan() { return new Tan(); }
+ public static DoubleUnaryOperator tanh() { return new Tanh(); }
public static Function<List<Long>, Double> random() { return new Random(); }
public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
@@ -59,6 +74,20 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a==b)"; }
}
+ public static class Greater implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left > right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a > b)"; }
+ }
+
+ public static class Less implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return left < right ? 1 : 0; }
+ @Override
+ public String toString() { return "f(a,b)(a < b)"; }
+ }
+
public static class Max implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return Math.max(left, right); }
@@ -73,6 +102,13 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(min(a, b))"; }
}
+ public static class Mean implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return (left + right) / 2; }
+ @Override
+ public String toString() { return "f(a,b)((a + b) / 2)"; }
+ }
+
public static class Multiply implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left * right; }
@@ -80,6 +116,13 @@ public class ScalarFunctions {
public String toString() { return "f(a,b)(a * b)"; }
}
+ public static class Pow implements DoubleBinaryOperator {
+ @Override
+ public double applyAsDouble(double left, double right) { return Math.pow(left, right); }
+ @Override
+ public String toString() { return "f(a,b)(pow(a, b))"; }
+ }
+
public static class Divide implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left / right; }
@@ -104,6 +147,13 @@ public class ScalarFunctions {
// Unary operators ------------------------------------------------------------------------------
+ public static class Abs implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.abs(operand); }
+ @Override
+ public String toString() { return "f(a)(fabs(a))"; }
+ }
+
public static class Acos implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return Math.acos(operand); }
@@ -111,6 +161,34 @@ public class ScalarFunctions {
public String toString() { return "f(a)(acos(a))"; }
}
+ public static class Asin implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.asin(operand); }
+ @Override
+ public String toString() { return "f(a)(asin(a))"; }
+ }
+
+ public static class Atan implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.atan(operand); }
+ @Override
+ public String toString() { return "f(a)(atan(a))"; }
+ }
+
+ public static class Ceil implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.ceil(operand); }
+ @Override
+ public String toString() { return "f(a)(ceil(a))"; }
+ }
+
+ public static class Cos implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.cos(operand); }
+ @Override
+ public String toString() { return "f(a)(cos(a))"; }
+ }
+
public static class Elu implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return operand < 0 ? Math.exp(operand) -1 : operand; }
@@ -132,6 +210,26 @@ public class ScalarFunctions {
public String toString() { return "f(a)(floor(a))"; }
}
+ public static class Log implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.log(operand); }
+ @Override
+ public String toString() { return "f(a)(log(a))"; }
+ }
+
+ public static class Neg implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return -operand; }
+ @Override
+ public String toString() { return "f(a)(-a)"; }
+ }
+
+ public static class Reciprocal implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return 1.0 / operand; }
+ @Override
+ public String toString() { return "f(a)(1 / a)"; }
+ }
public static class Relu implements DoubleUnaryOperator {
@Override
@@ -150,6 +248,13 @@ public class ScalarFunctions {
public String toString() { return String.format("f(a)(%f * if(a >= 0, a, %f*(exp(a)-1)))", scale, alpha); }
}
+ public static class Sin implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.sin(operand); }
+ @Override
+ public String toString() { return "f(a)(sin(a))"; }
+ }
+
public static class Rsqrt implements DoubleUnaryOperator {
@Override
public double applyAsDouble(double operand) { return 1.0 / Math.sqrt(operand); }
@@ -172,15 +277,29 @@ public class ScalarFunctions {
}
public static class Square implements DoubleUnaryOperator {
-
@Override
public double applyAsDouble(double operand) { return operand * operand; }
-
@Override
public String toString() { return "f(a)(a * a)"; }
+ }
+ public static class Tan implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.tan(operand); }
+ @Override
+ public String toString() { return "f(a)(tan(a))"; }
}
+ public static class Tanh implements DoubleUnaryOperator {
+ @Override
+ public double applyAsDouble(double operand) { return Math.tanh(operand); }
+ @Override
+ public String toString() { return "f(a)(tanh(a))"; }
+ }
+
+
+
+
// Variable-length operators -----------------------------------------------------------------------------
public static class EqualElements implements Function<List<Long>, Double> {