diff options
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java | 78 |
1 files changed, 39 insertions, 39 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java index 4249e2285b1..dcea8f1a230 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java @@ -50,58 +50,58 @@ public class GraphImporter { switch (node.getOp().toLowerCase()) { // array ops - case "concatv2": return new ConcatV2(nodeName, inputs); - case "const": return new Const(nodeName, inputs, attributes, nodeType, modelName); // todo: test this - case "expanddims": return new ExpandDims(nodeName, inputs); - case "identity": return new Identity(nodeName, inputs, modelName); - case "placeholder": return new Argument(nodeName, nodeType); - case "placeholderwithdefault": return new PlaceholderWithDefault(nodeName, inputs); - case "reshape": return new Reshape(nodeName, inputs); - case "shape": return new Shape(nodeName, inputs); - case "squeeze": return new Squeeze(nodeName, inputs, attributes); // todo: test this + case "concatv2": return new ConcatV2(modelName, nodeName, inputs); + case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType); + case "expanddims": return new ExpandDims(modelName, nodeName, inputs); + case "identity": return new Identity(modelName, nodeName, inputs); + case "placeholder": return new Argument(modelName, nodeName, nodeType); + case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "shape": return new Shape(modelName, nodeName, inputs); + case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); // control flow - case "merge": return new Merge(nodeName, inputs); - case "switch": return new Switch(nodeName, inputs, nodePort); // todo: test this + case "merge": return new Merge(modelName, nodeName, inputs); + case "switch": return new Switch(modelName, nodeName, inputs, nodePort); // math ops - case "add": return new Join(nodeName, inputs, ScalarFunctions.add()); - case "add_n": return new Join(nodeName, inputs, ScalarFunctions.add()); - case "acos": return new Map(nodeName, inputs, ScalarFunctions.acos()); - case "div": return new Join(nodeName, inputs, ScalarFunctions.divide()); - case "realdiv": return new Join(nodeName, inputs, ScalarFunctions.divide()); - case "floor": return new Map(nodeName, inputs, ScalarFunctions.floor()); - case "matmul": return new MatMul(nodeName, inputs); - case "maximum": return new Join(nodeName, inputs, ScalarFunctions.max()); - case "mean": return new Mean(nodeName, inputs, attributes); // todo: test this - case "reducemean": return new Mean(nodeName, inputs, attributes); - case "mul": return new Join(nodeName, inputs, ScalarFunctions.multiply()); - case "multiply": return new Join(nodeName, inputs, ScalarFunctions.multiply()); - case "rsqrt": return new Map(nodeName, inputs, ScalarFunctions.rsqrt()); - case "select": return new Select(nodeName, inputs); - case "where3": return new Select(nodeName, inputs); - case "sigmoid": return new Map(nodeName, inputs, ScalarFunctions.sigmoid()); - case "squareddifference": return new Join(nodeName, inputs, ScalarFunctions.squareddifference()); - case "sub": return new Join(nodeName, inputs, ScalarFunctions.subtract()); - case "subtract": return new Join(nodeName, inputs, ScalarFunctions.subtract()); + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "matmul": return new MatMul(modelName, nodeName, inputs); + case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); + case "mean": return new Mean(modelName, nodeName, inputs, attributes); + case "reducemean": return new Mean(modelName, nodeName, inputs, attributes); + case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt()); + case "select": return new Select(modelName, nodeName, inputs); + case "where3": return new Select(modelName, nodeName, inputs); + case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference()); + case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); // nn ops - case "biasadd": return new Join(nodeName, inputs, ScalarFunctions.add()); - case "elu": return new Map(nodeName, inputs, ScalarFunctions.elu()); - case "relu": return new Map(nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(nodeName, inputs, ScalarFunctions.selu()); + case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); // state ops - case "variable": return new Constant(nodeName, nodeType, modelName); - case "variablev2": return new Constant(nodeName, nodeType, modelName); + case "variable": return new Constant(modelName, nodeName, nodeType); + case "variablev2": return new Constant(modelName, nodeName, nodeType); // evaluation no-ops - case "stopgradient":return new Identity(nodeName, inputs, modelName); - case "noop": return new NoOp(nodeName, inputs); + case "stopgradient":return new Identity(modelName, nodeName, inputs); + case "noop": return new NoOp(modelName, nodeName, inputs); } - IntermediateOperation op = new NoOp(node.getName(), inputs); + IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); op.warning("Operation '" + node.getOp() + "' is currently not implemented"); return op; } |