summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-04 13:10:27 +0200
committerLester Solbakken <lesters@oath.com>2018-06-04 13:10:27 +0200
commite4626398c7e9c1b4b0fa5dbd974e1696c377dd77 (patch)
tree38edfd2f056e23759917fbb6319d12c779810879 /searchlib
parent30ac849f0893b4d98e9392648a2f59e014d6f617 (diff)
Add more ONNX operations
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java27
1 files changed, 22 insertions, 5 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
index 470b04fb44a..3fe92440cae 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
@@ -37,24 +37,41 @@ public class GraphImporter {
String modelName = graph.name();
switch (node.getOpType().toLowerCase()) {
+ case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
+ case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
+ case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
case "concat": return new ConcatV2(modelName, nodeName, inputs);
+ case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal());
+ case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater());
case "identity": return new Identity(modelName, nodeName, inputs);
- case "reshape": return new Reshape(modelName, nodeName, inputs);
- case "shape": return new Shape(modelName, nodeName, inputs);
+ case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less());
+ case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
case "matmul": return new MatMul(modelName, nodeName, inputs);
case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
+ case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean());
case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
+ case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow());
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
- case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
+ case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
+ case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
}
IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);