summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-01 15:27:43 +0200
committerLester Solbakken <lesters@oath.com>2018-06-01 15:27:43 +0200
commit6d5e6caa958f3f5913922530f8656e4126d26817 (patch)
treeffcd717086b4787f808acf16c54e4a0c7515ade1 /searchlib
parenta9c6b2b50de98990b879663c0e7176e6f07c8fb7 (diff)
Add more ONNX operations
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java21
1 files changed, 21 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
index 174b218a2d3..2416d8697c1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
@@ -6,11 +6,16 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
@@ -27,7 +32,23 @@ public class GraphImporter {
switch (node.getOpType().toLowerCase()) {
case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "concat": return new ConcatV2(modelName, nodeName, inputs);
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "shape": return new Shape(modelName, nodeName, inputs);
case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min());
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
}
IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);