diff options
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java | 21 |
1 files changed, 21 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java index 174b218a2d3..2416d8697c1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java @@ -6,11 +6,16 @@ import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Argument; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.ConcatV2; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Constant; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Identity; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.IntermediateOperation; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Map; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.MatMul; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Reshape; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations.Shape; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; @@ -27,7 +32,23 @@ public class GraphImporter { switch (node.getOpType().toLowerCase()) { case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "concat": return new ConcatV2(modelName, nodeName, inputs); + case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "identity": return new Identity(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "shape": return new Shape(modelName, nodeName, inputs); case "matmul": return new MatMul(modelName, nodeName, inputs); + case "max": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); + case "min": return new Join(modelName, nodeName, inputs, ScalarFunctions.min()); + case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); + case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); } IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); |