summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-11-30 15:43:44 -0800
committerJon Bratseth <bratseth@yahoo-inc.com>2017-11-30 15:43:44 -0800
commitfc3a9e518eb0e5904609028e9e388d35ddc61db0 (patch)
tree7de72421e07b88399a257ffa538ee2991030bab3
parent146aff973397215f1f5ab4a9d0e6e1c32a2a2c61 (diff)
Don't write to System.out/err
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java31
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java27
3 files changed, 44 insertions, 19 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
index cee501841b4..61cab2f6ce7 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java
@@ -4,10 +4,9 @@ package com.yahoo.config.application.api;
import java.util.logging.Level;
/**
- * Used during application deployment to persist and propagate messages to end user
+ * Used during application deployment to propagate messages to the end user
*
- * @author lulf
- * @since 5.1
+ * @author Ulf Lillengen
*/
public interface DeployLogger {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index d8c1b3f6bfc..38d9483a162 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -21,6 +21,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.logging.Level;
import java.util.stream.Collectors;
/**
@@ -37,11 +38,11 @@ public class TensorFlowImporter {
* The model should be saved as a pbtxt file.
* The name of the model is taken at the pbtxt file name (not including the .pbtxt ending).
*/
- public List<RankingExpression> importModel(String modelDir) {
+ public List<RankingExpression> importModel(String modelDir, MessageLogger logger) {
try {
SavedModel.Builder builder = SavedModel.newBuilder();
TextFormat.getParser().merge(IOUtils.createReader(modelDir + "/saved_model.pbtxt"), builder);
- return importModel(builder.build());
+ return importModel(builder.build(), logger);
// TODO: Support binary reading:
//SavedModel.parseFrom(new FileInputStream(modelDir + "/saved_model.pbtxt"));
@@ -53,19 +54,16 @@ public class TensorFlowImporter {
}
/** Import all declared inputs in all the graphs in the given model */
- private List<RankingExpression> importModel(SavedModel model) {
+ private List<RankingExpression> importModel(SavedModel model, MessageLogger logger) {
// TODO: Handle name conflicts between output keys in different graphs?
return model.getMetaGraphsList().stream()
- .flatMap(graph -> importGraph(graph).stream())
+ .flatMap(graph -> importGraph(graph, logger).stream())
.collect(Collectors.toList());
}
- private List<RankingExpression> importGraph(MetaGraphDef graph) {
- System.out.println("Importing graph");
+ private List<RankingExpression> importGraph(MetaGraphDef graph, MessageLogger logger) {
List<RankingExpression> expressions = new ArrayList<>();
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- System.out.println(" Importing signature def " + signatureEntry.getKey() +
- " with method name " + signatureEntry.getValue().getMethodName());
Map<String, TensorType> inputs = importInputs(signatureEntry.getValue().getInputsMap());
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
try {
@@ -75,9 +73,9 @@ public class TensorFlowImporter {
expressions.add(new RankingExpression(output.getKey(), result));
}
catch (IllegalArgumentException e) {
- System.err.println("Skipping output '" + output.getValue().getName() + "' of signature '" + // TODO: Log, or ...
- signatureEntry.getValue().getMethodName() +
- "': " + Exceptions.toMessageString(e));
+ logger.log(Level.INFO, "Skipping output '" + output.getValue().getName() + "' of signature '" +
+ signatureEntry.getValue().getMethodName() +
+ "': " + Exceptions.toMessageString(e));
}
}
}
@@ -104,14 +102,12 @@ public class TensorFlowImporter {
}
private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph) {
- System.out.println(" Importing output " + output.getName());
NodeDef node = getNode(nameOf(output.getName()), graph);
return new TensorFunctionNode(importNode(node, inputs, graph, "").function());
}
/** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) {
- System.out.println(" " + indent + "Importing node " + tfNode.getName() + " with operation " + tfNode.getOp());
return tensorFunctionOf(tfNode, inputs, graph, indent);
}
@@ -151,5 +147,12 @@ public class TensorFlowImporter {
private String nameOf(String name) {
return name.split(":")[0];
}
-
+
+ /** An interface which can be implemented to receive messages emitted during import */
+ public interface MessageLogger {
+
+ void log(Level level, String message);
+
+ }
+
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
index f2164a1b177..c780b3d0c7d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
@@ -4,7 +4,9 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import org.junit.Test;
+import java.util.ArrayList;
import java.util.List;
+import java.util.logging.Level;
import static org.junit.Assert.assertEquals;
@@ -15,8 +17,18 @@ public class TensorFlowImporterTestCase {
@Test
public void testModel1() {
- List<RankingExpression> expressions =
- new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/");
+ TestLogger logger = new TestLogger();
+ List<RankingExpression> expressions =
+ new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", logger);
+
+ // Check logged messages
+ assertEquals(2, logger.messages.size());
+ assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported",
+ logger.messages.get(0));
+ assertEquals("Skipping output 'TopKV2:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'TopKV2' is not supported",
+ logger.messages.get(1));
+
+ // Check resulting Vespa expression
assertEquals(1, expressions.size());
assertEquals("scores", expressions.get(0).getName());
assertEquals("" +
@@ -32,4 +44,15 @@ public class TensorFlowImporterTestCase {
return ((TensorFunctionNode)expression.getRoot()).function().toString();
}
+ private class TestLogger implements TensorFlowImporter.MessageLogger {
+
+ List<String> messages = new ArrayList<>();
+
+ @Override
+ public void log(Level level, String message) {
+ messages.add(message);
+ }
+
+ }
+
}