diff options
author | Lester Solbakken <lesters@oath.com> | 2018-06-01 15:02:35 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-06-01 15:02:35 +0200 |
commit | c5089d4c8f5c6259190ecbf80fbe0c96f391c218 (patch) | |
tree | 1d414cf348b7b347686f7df8bda1afe070441fa8 /searchlib | |
parent | d3bdbedb5aeba5c36e932b77bca57b582971ad21 (diff) |
Put model names back in generated macros
Diffstat (limited to 'searchlib')
22 files changed, 92 insertions, 97 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java index 47623daf022..b7620b6b5c7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java @@ -19,15 +19,18 @@ import java.util.stream.Collectors; public class GraphImporter { - public static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs) { + public static IntermediateOperation mapOperation(Onnx.NodeProto node, + List<IntermediateOperation> inputs, + IntermediateGraph graph) { String nodeName = node.getName(); + String modelName = graph.name(); switch (node.getOpType().toLowerCase()) { - case "add": return new Join(nodeName, inputs, ScalarFunctions.add()); - case "matmul": return new MatMul(nodeName, inputs); + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "matmul": return new MatMul(modelName, nodeName, inputs); } - IntermediateOperation op = new NoOp(node.getName(), inputs); + IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); op.warning("Operation '" + node.getOpType() + "' is currently not implemented"); return op; } @@ -62,7 +65,7 @@ public class GraphImporter { if (valueInfoProto == null) throw new IllegalArgumentException("Could not find argument tensor: " + name); OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType()); - operation = new Argument(valueInfoProto.getName(), type); + operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type); intermediateGraph.inputs(intermediateGraph.defaultSignature()) .put(IntermediateOperation.namePartOf(name), operation.vespaName()); @@ -70,14 +73,13 @@ public class GraphImporter { } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList()); - operation = new Constant(name, defaultType, intermediateGraph.name()); + operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); } else { - Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph); - operation = mapOperation(node, inputs); + operation = mapOperation(node, inputs, intermediateGraph); if (isOutputNode(name, onnxGraph)) { intermediateGraph.outputs(intermediateGraph.defaultSignature()) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java index 8783152841d..7fc2aae87d1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java @@ -15,8 +15,8 @@ public class Argument extends IntermediateOperation { private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... - public Argument(String name, OrderedTensorType type) { - super(name, Collections.emptyList()); + public Argument(String modelName, String nodeName, OrderedTensorType type) { + super(modelName, nodeName, Collections.emptyList()); this.type = type.rename(vespaName() + "_"); standardNamingType = OrderedTensorType.standardType(type); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java index 759ccd32657..1b8c62fe0e9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java @@ -14,8 +14,8 @@ public class ConcatV2 extends IntermediateOperation { private String concatDimensionName; - public ConcatV2(String name, List<IntermediateOperation> inputs) { - super(name, inputs); + public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java index 45103f35402..3c0f8569c47 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java @@ -17,16 +17,14 @@ import java.util.Optional; public class Const extends IntermediateOperation { - private final String modelName; private final AttributeMap attributeMap; - public Const(String name, + public Const(String modelName, + String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap, - OrderedTensorType type, - String modelName) { - super(name, inputs); - this.modelName = modelName; + OrderedTensorType type) { + super(modelName, nodeName, inputs); this.attributeMap = attributeMap; this.type = type.rename(vespaName() + "_"); setConstantValue(value()); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java index 323d1327d2e..5e4abeaa234 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java @@ -14,8 +14,8 @@ public class Constant extends IntermediateOperation { private final String modelName; - public Constant(String name, OrderedTensorType type, String modelName) { - super(name, Collections.emptyList()); + public Constant(String modelName, String nodeName, OrderedTensorType type) { + super(modelName, nodeName, Collections.emptyList()); this.modelName = modelName; this.type = type.rename(vespaName() + "_"); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java index ec290078f12..742ed8b89ab 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java @@ -21,8 +21,8 @@ public class ExpandDims extends IntermediateOperation { private List<String> expandDimensions; - public ExpandDims(String name, List<IntermediateOperation> inputs) { - super(name, inputs); + public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java index 9b7700aa01f..d29bd4b7a9e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java @@ -8,11 +8,8 @@ import java.util.List; public class Identity extends IntermediateOperation { - private final String modelName; - - public Identity(String name, List<IntermediateOperation> inputs, String modelName) { - super(name, inputs); - this.modelName = modelName; + public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java index 3115c79e81a..e24b2a828b5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java @@ -23,6 +23,7 @@ public abstract class IntermediateOperation { protected final static String MACRO_PREFIX = "imported_ml_macro_"; protected final String name; + protected final String modelName; protected final List<IntermediateOperation> inputs; protected final List<IntermediateOperation> outputs = new ArrayList<>(); protected final List<String> importWarnings = new ArrayList<>(); @@ -34,8 +35,9 @@ public abstract class IntermediateOperation { protected Function<OrderedTensorType, Value> constantValueFunction = null; protected List<IntermediateOperation> controlInputs = Collections.emptyList(); - IntermediateOperation(String name, List<IntermediateOperation> inputs) { + IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) { this.name = name; + this.modelName = modelName; this.inputs = Collections.unmodifiableList(inputs); this.inputs.forEach(i -> i.outputs.add(this)); } @@ -119,11 +121,7 @@ public abstract class IntermediateOperation { public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } /** Retrieve the valid Vespa name of this node if it is a macro */ - public String macroName() { - return vespaName() != null ? MACRO_PREFIX + "_" + vespaName() : null; -// return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; - // todo: add model name - } + public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; } /** Retrieve the list of warnings produced during its lifetime */ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java index e3e2c94d04e..8413ed74118 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java @@ -15,8 +15,8 @@ public class Join extends IntermediateOperation { private final DoubleBinaryOperator operator; - public Join(String name, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) { - super(name, inputs); + public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) { + super(modelName, nodeName, inputs); this.operator = operator; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java index 2d0086e7db3..f54ae83052f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java @@ -12,8 +12,8 @@ public class Map extends IntermediateOperation { private final DoubleUnaryOperator operator; - public Map(String name, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) { - super(name, inputs); + public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) { + super(modelName, nodeName, inputs); this.operator = operator; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java index 1873ec80375..52e223f9518 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java @@ -11,8 +11,8 @@ import java.util.Optional; public class MatMul extends IntermediateOperation { - public MatMul(String name, List<IntermediateOperation> inputs) { - super(name, inputs); + public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java index 9b566c8b8ad..822656916f8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java @@ -25,8 +25,8 @@ public class Mean extends IntermediateOperation { private final AttributeMap attributeMap; private List<String> reduceDimensions; - public Mean(String name, List<IntermediateOperation> inputs, AttributeMap attributeMap) { - super(name, inputs); + public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); this.attributeMap = attributeMap; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java index ea2b89b1fc6..9d9eca47b1c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java @@ -8,8 +8,8 @@ import java.util.List; public class Merge extends IntermediateOperation { - public Merge(String name, List<IntermediateOperation> inputs) { - super(name, inputs); + public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java index 12941568d8d..19ba146492c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java @@ -9,8 +9,8 @@ import java.util.List; public class NoOp extends IntermediateOperation { - public NoOp(String name, List<IntermediateOperation> inputs) { - super(name, Collections.emptyList()); // don't propagate inputs + public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java index 4ef1b8b0672..9299ae9be12 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java @@ -9,8 +9,8 @@ import java.util.Optional; public class PlaceholderWithDefault extends IntermediateOperation { - public PlaceholderWithDefault(String name, List<IntermediateOperation> inputs) { - super(name, inputs); + public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java index 8164632b35b..e91c2305f7d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java @@ -28,8 +28,8 @@ import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.Orde public class Reshape extends IntermediateOperation { - public Reshape(String name, List<IntermediateOperation> inputs) { - super(name, inputs); + public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java index cffb705c774..927a4a368f9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java @@ -16,8 +16,8 @@ import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.Orde public class Select extends IntermediateOperation { - public Select(String name, List<IntermediateOperation> inputs) { - super(name, inputs); + public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java index d0ab692f8bd..da566909adc 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java @@ -12,8 +12,8 @@ import java.util.List; public class Shape extends IntermediateOperation { - public Shape(String name, List<IntermediateOperation> inputs) { - super(name, inputs); + public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); createConstantValue(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java index 16a7f666fc9..c750c47e27e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java @@ -18,8 +18,8 @@ public class Squeeze extends IntermediateOperation { private final AttributeMap attributeMap; private List<String> squeezeDimensions; - public Squeeze(String name, List<IntermediateOperation> inputs, AttributeMap attributeMap) { - super(name, inputs); + public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); this.attributeMap = attributeMap; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java index 0e726b55711..0171d1ea171 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java @@ -11,8 +11,8 @@ public class Switch extends IntermediateOperation { private final int port; - public Switch(String name, List<IntermediateOperation> inputs, int port) { - super(name, inputs); + public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) { + super(modelName, nodeName, inputs); this.port = port; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java index 4249e2285b1..dcea8f1a230 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java @@ -50,58 +50,58 @@ public class GraphImporter { switch (node.getOp().toLowerCase()) { // array ops - case "concatv2": return new ConcatV2(nodeName, inputs); - case "const": return new Const(nodeName, inputs, attributes, nodeType, modelName); // todo: test this - case "expanddims": return new ExpandDims(nodeName, inputs); - case "identity": return new Identity(nodeName, inputs, modelName); - case "placeholder": return new Argument(nodeName, nodeType); - case "placeholderwithdefault": return new PlaceholderWithDefault(nodeName, inputs); - case "reshape": return new Reshape(nodeName, inputs); - case "shape": return new Shape(nodeName, inputs); - case "squeeze": return new Squeeze(nodeName, inputs, attributes); // todo: test this + case "concatv2": return new ConcatV2(modelName, nodeName, inputs); + case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType); + case "expanddims": return new ExpandDims(modelName, nodeName, inputs); + case "identity": return new Identity(modelName, nodeName, inputs); + case "placeholder": return new Argument(modelName, nodeName, nodeType); + case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); + case "reshape": return new Reshape(modelName, nodeName, inputs); + case "shape": return new Shape(modelName, nodeName, inputs); + case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); // control flow - case "merge": return new Merge(nodeName, inputs); - case "switch": return new Switch(nodeName, inputs, nodePort); // todo: test this + case "merge": return new Merge(modelName, nodeName, inputs); + case "switch": return new Switch(modelName, nodeName, inputs, nodePort); // math ops - case "add": return new Join(nodeName, inputs, ScalarFunctions.add()); - case "add_n": return new Join(nodeName, inputs, ScalarFunctions.add()); - case "acos": return new Map(nodeName, inputs, ScalarFunctions.acos()); - case "div": return new Join(nodeName, inputs, ScalarFunctions.divide()); - case "realdiv": return new Join(nodeName, inputs, ScalarFunctions.divide()); - case "floor": return new Map(nodeName, inputs, ScalarFunctions.floor()); - case "matmul": return new MatMul(nodeName, inputs); - case "maximum": return new Join(nodeName, inputs, ScalarFunctions.max()); - case "mean": return new Mean(nodeName, inputs, attributes); // todo: test this - case "reducemean": return new Mean(nodeName, inputs, attributes); - case "mul": return new Join(nodeName, inputs, ScalarFunctions.multiply()); - case "multiply": return new Join(nodeName, inputs, ScalarFunctions.multiply()); - case "rsqrt": return new Map(nodeName, inputs, ScalarFunctions.rsqrt()); - case "select": return new Select(nodeName, inputs); - case "where3": return new Select(nodeName, inputs); - case "sigmoid": return new Map(nodeName, inputs, ScalarFunctions.sigmoid()); - case "squareddifference": return new Join(nodeName, inputs, ScalarFunctions.squareddifference()); - case "sub": return new Join(nodeName, inputs, ScalarFunctions.subtract()); - case "subtract": return new Join(nodeName, inputs, ScalarFunctions.subtract()); + case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); + case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); + case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); + case "matmul": return new MatMul(modelName, nodeName, inputs); + case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); + case "mean": return new Mean(modelName, nodeName, inputs, attributes); + case "reducemean": return new Mean(modelName, nodeName, inputs, attributes); + case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); + case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt()); + case "select": return new Select(modelName, nodeName, inputs); + case "where3": return new Select(modelName, nodeName, inputs); + case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); + case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference()); + case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); // nn ops - case "biasadd": return new Join(nodeName, inputs, ScalarFunctions.add()); - case "elu": return new Map(nodeName, inputs, ScalarFunctions.elu()); - case "relu": return new Map(nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(nodeName, inputs, ScalarFunctions.selu()); + case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); + case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); + case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); + case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); // state ops - case "variable": return new Constant(nodeName, nodeType, modelName); - case "variablev2": return new Constant(nodeName, nodeType, modelName); + case "variable": return new Constant(modelName, nodeName, nodeType); + case "variablev2": return new Constant(modelName, nodeName, nodeType); // evaluation no-ops - case "stopgradient":return new Identity(nodeName, inputs, modelName); - case "noop": return new NoOp(nodeName, inputs); + case "stopgradient":return new Identity(modelName, nodeName, inputs); + case "noop": return new NoOp(modelName, nodeName, inputs); } - IntermediateOperation op = new NoOp(node.getName(), inputs); + IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); op.warning("Operation '" + node.getOp() + "' is currently not implemented"); return op; } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 00d3517f90d..a63c7346335 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/Maximum", output.getName()); - assertEquals("join(join(imported_ml_macro__outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro__outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } |