summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-01 15:02:35 +0200
committerLester Solbakken <lesters@oath.com>2018-06-01 15:02:35 +0200
commitc5089d4c8f5c6259190ecbf80fbe0c96f391c218 (patch)
tree1d414cf348b7b347686f7df8bda1afe070441fa8 /searchlib
parentd3bdbedb5aeba5c36e932b77bca57b582971ad21 (diff)
Put model names back in generated macros
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java18
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java10
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java78
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java2
22 files changed, 92 insertions, 97 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
index 47623daf022..b7620b6b5c7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
@@ -19,15 +19,18 @@ import java.util.stream.Collectors;
public class GraphImporter {
- public static IntermediateOperation mapOperation(Onnx.NodeProto node, List<IntermediateOperation> inputs) {
+ public static IntermediateOperation mapOperation(Onnx.NodeProto node,
+ List<IntermediateOperation> inputs,
+ IntermediateGraph graph) {
String nodeName = node.getName();
+ String modelName = graph.name();
switch (node.getOpType().toLowerCase()) {
- case "add": return new Join(nodeName, inputs, ScalarFunctions.add());
- case "matmul": return new MatMul(nodeName, inputs);
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
}
- IntermediateOperation op = new NoOp(node.getName(), inputs);
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
op.warning("Operation '" + node.getOpType() + "' is currently not implemented");
return op;
}
@@ -62,7 +65,7 @@ public class GraphImporter {
if (valueInfoProto == null)
throw new IllegalArgumentException("Could not find argument tensor: " + name);
OrderedTensorType type = TypeConverter.fromOnnxType(valueInfoProto.getType());
- operation = new Argument(valueInfoProto.getName(), type);
+ operation = new Argument(intermediateGraph.name(), valueInfoProto.getName(), type);
intermediateGraph.inputs(intermediateGraph.defaultSignature())
.put(IntermediateOperation.namePartOf(name), operation.vespaName());
@@ -70,14 +73,13 @@ public class GraphImporter {
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
- operation = new Constant(name, defaultType, intermediateGraph.name());
+ operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
} else {
-
Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
- operation = mapOperation(node, inputs);
+ operation = mapOperation(node, inputs, intermediateGraph);
if (isOutputNode(name, onnxGraph)) {
intermediateGraph.outputs(intermediateGraph.defaultSignature())
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
index 8783152841d..7fc2aae87d1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Argument.java
@@ -15,8 +15,8 @@ public class Argument extends IntermediateOperation {
private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
- public Argument(String name, OrderedTensorType type) {
- super(name, Collections.emptyList());
+ public Argument(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
this.type = type.rename(vespaName() + "_");
standardNamingType = OrderedTensorType.standardType(type);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
index 759ccd32657..1b8c62fe0e9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ConcatV2.java
@@ -14,8 +14,8 @@ public class ConcatV2 extends IntermediateOperation {
private String concatDimensionName;
- public ConcatV2(String name, List<IntermediateOperation> inputs) {
- super(name, inputs);
+ public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
index 45103f35402..3c0f8569c47 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Const.java
@@ -17,16 +17,14 @@ import java.util.Optional;
public class Const extends IntermediateOperation {
- private final String modelName;
private final AttributeMap attributeMap;
- public Const(String name,
+ public Const(String modelName,
+ String nodeName,
List<IntermediateOperation> inputs,
AttributeMap attributeMap,
- OrderedTensorType type,
- String modelName) {
- super(name, inputs);
- this.modelName = modelName;
+ OrderedTensorType type) {
+ super(modelName, nodeName, inputs);
this.attributeMap = attributeMap;
this.type = type.rename(vespaName() + "_");
setConstantValue(value());
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
index 323d1327d2e..5e4abeaa234 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Constant.java
@@ -14,8 +14,8 @@ public class Constant extends IntermediateOperation {
private final String modelName;
- public Constant(String name, OrderedTensorType type, String modelName) {
- super(name, Collections.emptyList());
+ public Constant(String modelName, String nodeName, OrderedTensorType type) {
+ super(modelName, nodeName, Collections.emptyList());
this.modelName = modelName;
this.type = type.rename(vespaName() + "_");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
index ec290078f12..742ed8b89ab 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/ExpandDims.java
@@ -21,8 +21,8 @@ public class ExpandDims extends IntermediateOperation {
private List<String> expandDimensions;
- public ExpandDims(String name, List<IntermediateOperation> inputs) {
- super(name, inputs);
+ public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
index 9b7700aa01f..d29bd4b7a9e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Identity.java
@@ -8,11 +8,8 @@ import java.util.List;
public class Identity extends IntermediateOperation {
- private final String modelName;
-
- public Identity(String name, List<IntermediateOperation> inputs, String modelName) {
- super(name, inputs);
- this.modelName = modelName;
+ public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 3115c79e81a..e24b2a828b5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -23,6 +23,7 @@ public abstract class IntermediateOperation {
protected final static String MACRO_PREFIX = "imported_ml_macro_";
protected final String name;
+ protected final String modelName;
protected final List<IntermediateOperation> inputs;
protected final List<IntermediateOperation> outputs = new ArrayList<>();
protected final List<String> importWarnings = new ArrayList<>();
@@ -34,8 +35,9 @@ public abstract class IntermediateOperation {
protected Function<OrderedTensorType, Value> constantValueFunction = null;
protected List<IntermediateOperation> controlInputs = Collections.emptyList();
- IntermediateOperation(String name, List<IntermediateOperation> inputs) {
+ IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
this.name = name;
+ this.modelName = modelName;
this.inputs = Collections.unmodifiableList(inputs);
this.inputs.forEach(i -> i.outputs.add(this));
}
@@ -119,11 +121,7 @@ public abstract class IntermediateOperation {
public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
/** Retrieve the valid Vespa name of this node if it is a macro */
- public String macroName() {
- return vespaName() != null ? MACRO_PREFIX + "_" + vespaName() : null;
-// return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null;
- // todo: add model name
- }
+ public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; }
/** Retrieve the list of warnings produced during its lifetime */
public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
index e3e2c94d04e..8413ed74118 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Join.java
@@ -15,8 +15,8 @@ public class Join extends IntermediateOperation {
private final DoubleBinaryOperator operator;
- public Join(String name, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
- super(name, inputs);
+ public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
+ super(modelName, nodeName, inputs);
this.operator = operator;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
index 2d0086e7db3..f54ae83052f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Map.java
@@ -12,8 +12,8 @@ public class Map extends IntermediateOperation {
private final DoubleUnaryOperator operator;
- public Map(String name, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) {
- super(name, inputs);
+ public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) {
+ super(modelName, nodeName, inputs);
this.operator = operator;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
index 1873ec80375..52e223f9518 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/MatMul.java
@@ -11,8 +11,8 @@ import java.util.Optional;
public class MatMul extends IntermediateOperation {
- public MatMul(String name, List<IntermediateOperation> inputs) {
- super(name, inputs);
+ public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
index 9b566c8b8ad..822656916f8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Mean.java
@@ -25,8 +25,8 @@ public class Mean extends IntermediateOperation {
private final AttributeMap attributeMap;
private List<String> reduceDimensions;
- public Mean(String name, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
- super(name, inputs);
+ public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
this.attributeMap = attributeMap;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
index ea2b89b1fc6..9d9eca47b1c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Merge.java
@@ -8,8 +8,8 @@ import java.util.List;
public class Merge extends IntermediateOperation {
- public Merge(String name, List<IntermediateOperation> inputs) {
- super(name, inputs);
+ public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
index 12941568d8d..19ba146492c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/NoOp.java
@@ -9,8 +9,8 @@ import java.util.List;
public class NoOp extends IntermediateOperation {
- public NoOp(String name, List<IntermediateOperation> inputs) {
- super(name, Collections.emptyList()); // don't propagate inputs
+ public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
index 4ef1b8b0672..9299ae9be12 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
@@ -9,8 +9,8 @@ import java.util.Optional;
public class PlaceholderWithDefault extends IntermediateOperation {
- public PlaceholderWithDefault(String name, List<IntermediateOperation> inputs) {
- super(name, inputs);
+ public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
index 8164632b35b..e91c2305f7d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
@@ -28,8 +28,8 @@ import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.Orde
public class Reshape extends IntermediateOperation {
- public Reshape(String name, List<IntermediateOperation> inputs) {
- super(name, inputs);
+ public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
index cffb705c774..927a4a368f9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Select.java
@@ -16,8 +16,8 @@ import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.Orde
public class Select extends IntermediateOperation {
- public Select(String name, List<IntermediateOperation> inputs) {
- super(name, inputs);
+ public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
}
@Override
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
index d0ab692f8bd..da566909adc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Shape.java
@@ -12,8 +12,8 @@ import java.util.List;
public class Shape extends IntermediateOperation {
- public Shape(String name, List<IntermediateOperation> inputs) {
- super(name, inputs);
+ public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
createConstantValue();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
index 16a7f666fc9..c750c47e27e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Squeeze.java
@@ -18,8 +18,8 @@ public class Squeeze extends IntermediateOperation {
private final AttributeMap attributeMap;
private List<String> squeezeDimensions;
- public Squeeze(String name, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
- super(name, inputs);
+ public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
+ super(modelName, nodeName, inputs);
this.attributeMap = attributeMap;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
index 0e726b55711..0171d1ea171 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Switch.java
@@ -11,8 +11,8 @@ public class Switch extends IntermediateOperation {
private final int port;
- public Switch(String name, List<IntermediateOperation> inputs, int port) {
- super(name, inputs);
+ public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) {
+ super(modelName, nodeName, inputs);
this.port = port;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
index 4249e2285b1..dcea8f1a230 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
@@ -50,58 +50,58 @@ public class GraphImporter {
switch (node.getOp().toLowerCase()) {
// array ops
- case "concatv2": return new ConcatV2(nodeName, inputs);
- case "const": return new Const(nodeName, inputs, attributes, nodeType, modelName); // todo: test this
- case "expanddims": return new ExpandDims(nodeName, inputs);
- case "identity": return new Identity(nodeName, inputs, modelName);
- case "placeholder": return new Argument(nodeName, nodeType);
- case "placeholderwithdefault": return new PlaceholderWithDefault(nodeName, inputs);
- case "reshape": return new Reshape(nodeName, inputs);
- case "shape": return new Shape(nodeName, inputs);
- case "squeeze": return new Squeeze(nodeName, inputs, attributes); // todo: test this
+ case "concatv2": return new ConcatV2(modelName, nodeName, inputs);
+ case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType);
+ case "expanddims": return new ExpandDims(modelName, nodeName, inputs);
+ case "identity": return new Identity(modelName, nodeName, inputs);
+ case "placeholder": return new Argument(modelName, nodeName, nodeType);
+ case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
+ case "reshape": return new Reshape(modelName, nodeName, inputs);
+ case "shape": return new Shape(modelName, nodeName, inputs);
+ case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
// control flow
- case "merge": return new Merge(nodeName, inputs);
- case "switch": return new Switch(nodeName, inputs, nodePort); // todo: test this
+ case "merge": return new Merge(modelName, nodeName, inputs);
+ case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
// math ops
- case "add": return new Join(nodeName, inputs, ScalarFunctions.add());
- case "add_n": return new Join(nodeName, inputs, ScalarFunctions.add());
- case "acos": return new Map(nodeName, inputs, ScalarFunctions.acos());
- case "div": return new Join(nodeName, inputs, ScalarFunctions.divide());
- case "realdiv": return new Join(nodeName, inputs, ScalarFunctions.divide());
- case "floor": return new Map(nodeName, inputs, ScalarFunctions.floor());
- case "matmul": return new MatMul(nodeName, inputs);
- case "maximum": return new Join(nodeName, inputs, ScalarFunctions.max());
- case "mean": return new Mean(nodeName, inputs, attributes); // todo: test this
- case "reducemean": return new Mean(nodeName, inputs, attributes);
- case "mul": return new Join(nodeName, inputs, ScalarFunctions.multiply());
- case "multiply": return new Join(nodeName, inputs, ScalarFunctions.multiply());
- case "rsqrt": return new Map(nodeName, inputs, ScalarFunctions.rsqrt());
- case "select": return new Select(nodeName, inputs);
- case "where3": return new Select(nodeName, inputs);
- case "sigmoid": return new Map(nodeName, inputs, ScalarFunctions.sigmoid());
- case "squareddifference": return new Join(nodeName, inputs, ScalarFunctions.squareddifference());
- case "sub": return new Join(nodeName, inputs, ScalarFunctions.subtract());
- case "subtract": return new Join(nodeName, inputs, ScalarFunctions.subtract());
+ case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
+ case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
+ case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
+ case "matmul": return new MatMul(modelName, nodeName, inputs);
+ case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
+ case "mean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
+ case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
+ case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
+ case "select": return new Select(modelName, nodeName, inputs);
+ case "where3": return new Select(modelName, nodeName, inputs);
+ case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
+ case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
+ case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
// nn ops
- case "biasadd": return new Join(nodeName, inputs, ScalarFunctions.add());
- case "elu": return new Map(nodeName, inputs, ScalarFunctions.elu());
- case "relu": return new Map(nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(nodeName, inputs, ScalarFunctions.selu());
+ case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
+ case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
+ case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
+ case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
// state ops
- case "variable": return new Constant(nodeName, nodeType, modelName);
- case "variablev2": return new Constant(nodeName, nodeType, modelName);
+ case "variable": return new Constant(modelName, nodeName, nodeType);
+ case "variablev2": return new Constant(modelName, nodeName, nodeType);
// evaluation no-ops
- case "stopgradient":return new Identity(nodeName, inputs, modelName);
- case "noop": return new NoOp(nodeName, inputs);
+ case "stopgradient":return new Identity(modelName, nodeName, inputs);
+ case "noop": return new NoOp(modelName, nodeName, inputs);
}
- IntermediateOperation op = new NoOp(node.getName(), inputs);
+ IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
op.warning("Operation '" + node.getOp() + "' is currently not implemented");
return op;
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 00d3517f90d..a63c7346335 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/Maximum", output.getName());
- assertEquals("join(join(imported_ml_macro__outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro__outputs_BiasAdd, f(a,b)(max(a,b)))",
+ assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}