diff options
author | Lester Solbakken <lesters@oath.com> | 2018-06-04 13:10:27 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-06-04 13:10:27 +0200 |
commit | e4626398c7e9c1b4b0fa5dbd974e1696c377dd77 (patch) | |
tree | 38edfd2f056e23759917fbb6319d12c779810879 /searchlib | |
parent | 30ac849f0893b4d98e9392648a2f59e014d6f617 (diff) |
Add more ONNX operations
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java | 27 |
1 files changed, 22 insertions, 5 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java index 470b04fb44a..3fe92440cae 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java @@ -37,24 +37,41 @@ public class GraphImporter { String modelName = graph.name(); switch (node.getOpType().toLowerCase()) { + case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); + case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); + case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); case "concat": return new ConcatV2(modelName, nodeName, inputs); + case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "equal": return new Join(modelName, nodeName, inputs, ScalarFunctions.equal()); + case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "greater": return new Join(modelName, nodeName, inputs, ScalarFunctions.greater()); case "identity": return new Identity(modelName, nodeName, inputs); - case "reshape": return new Reshape(modelName, nodeName, inputs); - case "shape": return new Shape(modelName, nodeName, inputs); + case "less": return new Join(modelName, nodeName, inputs, ScalarFunctions.less()); + case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); case "matmul": return new MatMul(modelName, nodeName, inputs); case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); + case "mean": return new Join(modelName, nodeName, inputs, ScalarFunctions.mean()); case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "neg": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); + case "pow": return new Join(modelName, nodeName, inputs, ScalarFunctions.pow()); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "shape": return new Shape(modelName, nodeName, inputs); + case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); - case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); - case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); + case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); + case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); } IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); |