summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java78
1 files changed, 39 insertions, 39 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
index 4249e2285b1..dcea8f1a230 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
@@ -50,58 +50,58 @@ public class GraphImporter {
switch (node.getOp().toLowerCase()) {
// array ops
- case "concatv2": return new ConcatV2(nodeName, inputs);
- case "const": return new Const(nodeName, inputs, attributes, nodeType, modelName); // todo: test this
- case "expanddims": return new ExpandDims(nodeName, inputs);
- case "identity": return new Identity(nodeName, inputs, modelName);
- case "placeholder": return new Argument(nodeName, nodeType);
- case "placeholderwithdefault": return new PlaceholderWithDefault(nodeName, inputs);
- case "reshape": return new Reshape(nodeName, inputs);
- case "shape": return new Shape(nodeName, inputs);
- case "squeeze": return new Squeeze(nodeName, inputs, attributes); // todo: test this
+ case "concatv2": return new ConcatV2(modelName, nodeName, inputs);
+ case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType);
+ case "expanddims": return new ExpandDims(modelName, nodeName, inputs);
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "placeholder": return new Argument(modelName, nodeName, nodeType);
+ case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
// control flow
- case "merge": return new Merge(nodeName, inputs);
- case "switch": return new Switch(nodeName, inputs, nodePort); // todo: test this
+ case "merge": return new Merge(modelName, nodeName, inputs);
+ case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
// math ops
- case "add": return new Join(nodeName, inputs, ScalarFunctions.add());
- case "add_n": return new Join(nodeName, inputs, ScalarFunctions.add());
- case "acos": return new Map(nodeName, inputs, ScalarFunctions.acos());
- case "div": return new Join(nodeName, inputs, ScalarFunctions.divide());
- case "realdiv": return new Join(nodeName, inputs, ScalarFunctions.divide());
- case "floor": return new Map(nodeName, inputs, ScalarFunctions.floor());
- case "matmul": return new MatMul(nodeName, inputs);
- case "maximum": return new Join(nodeName, inputs, ScalarFunctions.max());
- case "mean": return new Mean(nodeName, inputs, attributes); // todo: test this
- case "reducemean": return new Mean(nodeName, inputs, attributes);
- case "mul": return new Join(nodeName, inputs, ScalarFunctions.multiply());
- case "multiply": return new Join(nodeName, inputs, ScalarFunctions.multiply());
- case "rsqrt": return new Map(nodeName, inputs, ScalarFunctions.rsqrt());
- case "select": return new Select(nodeName, inputs);
- case "where3": return new Select(nodeName, inputs);
- case "sigmoid": return new Map(nodeName, inputs, ScalarFunctions.sigmoid());
- case "squareddifference": return new Join(nodeName, inputs, ScalarFunctions.squareddifference());
- case "sub": return new Join(nodeName, inputs, ScalarFunctions.subtract());
- case "subtract": return new Join(nodeName, inputs, ScalarFunctions.subtract());
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "mean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
+ case "select": return new Select(modelName, nodeName, inputs);
+ case "where3": return new Select(modelName, nodeName, inputs);
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
// nn ops
- case "biasadd": return new Join(nodeName, inputs, ScalarFunctions.add());
- case "elu": return new Map(nodeName, inputs, ScalarFunctions.elu());
- case "relu": return new Map(nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(nodeName, inputs, ScalarFunctions.selu());
+ case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
// state ops
- case "variable": return new Constant(nodeName, nodeType, modelName);
- case "variablev2": return new Constant(nodeName, nodeType, modelName);
+ case "variable": return new Constant(modelName, nodeName, nodeType);
+ case "variablev2": return new Constant(modelName, nodeName, nodeType);
// evaluation no-ops
- case "stopgradient":return new Identity(nodeName, inputs, modelName);
- case "noop": return new NoOp(nodeName, inputs);
+ case "stopgradient":return new Identity(modelName, nodeName, inputs);
+ case "noop": return new NoOp(modelName, nodeName, inputs);
}
- IntermediateOperation op = new NoOp(node.getName(), inputs);
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
op.warning("Operation '" + node.getOp() + "' is currently not implemented");
return op;
}