diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-30 15:43:44 -0800 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-30 15:43:44 -0800 |
commit | fc3a9e518eb0e5904609028e9e388d35ddc61db0 (patch) | |
tree | 7de72421e07b88399a257ffa538ee2991030bab3 | |
parent | 146aff973397215f1f5ab4a9d0e6e1c32a2a2c61 (diff) |
Don't write to System.out/err
3 files changed, 44 insertions, 19 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java index cee501841b4..61cab2f6ce7 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/DeployLogger.java @@ -4,10 +4,9 @@ package com.yahoo.config.application.api; import java.util.logging.Level; /** - * Used during application deployment to persist and propagate messages to end user + * Used during application deployment to propagate messages to the end user * - * @author lulf - * @since 5.1 + * @author Ulf Lillengen */ public interface DeployLogger { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index d8c1b3f6bfc..38d9483a162 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -21,6 +21,7 @@ import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.logging.Level; import java.util.stream.Collectors; /** @@ -37,11 +38,11 @@ public class TensorFlowImporter { * The model should be saved as a pbtxt file. * The name of the model is taken at the pbtxt file name (not including the .pbtxt ending). */ - public List<RankingExpression> importModel(String modelDir) { + public List<RankingExpression> importModel(String modelDir, MessageLogger logger) { try { SavedModel.Builder builder = SavedModel.newBuilder(); TextFormat.getParser().merge(IOUtils.createReader(modelDir + "/saved_model.pbtxt"), builder); - return importModel(builder.build()); + return importModel(builder.build(), logger); // TODO: Support binary reading: //SavedModel.parseFrom(new FileInputStream(modelDir + "/saved_model.pbtxt")); @@ -53,19 +54,16 @@ public class TensorFlowImporter { } /** Import all declared inputs in all the graphs in the given model */ - private List<RankingExpression> importModel(SavedModel model) { + private List<RankingExpression> importModel(SavedModel model, MessageLogger logger) { // TODO: Handle name conflicts between output keys in different graphs? return model.getMetaGraphsList().stream() - .flatMap(graph -> importGraph(graph).stream()) + .flatMap(graph -> importGraph(graph, logger).stream()) .collect(Collectors.toList()); } - private List<RankingExpression> importGraph(MetaGraphDef graph) { - System.out.println("Importing graph"); + private List<RankingExpression> importGraph(MetaGraphDef graph, MessageLogger logger) { List<RankingExpression> expressions = new ArrayList<>(); for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - System.out.println(" Importing signature def " + signatureEntry.getKey() + - " with method name " + signatureEntry.getValue().getMethodName()); Map<String, TensorType> inputs = importInputs(signatureEntry.getValue().getInputsMap()); for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { try { @@ -75,9 +73,9 @@ public class TensorFlowImporter { expressions.add(new RankingExpression(output.getKey(), result)); } catch (IllegalArgumentException e) { - System.err.println("Skipping output '" + output.getValue().getName() + "' of signature '" + // TODO: Log, or ... - signatureEntry.getValue().getMethodName() + - "': " + Exceptions.toMessageString(e)); + logger.log(Level.INFO, "Skipping output '" + output.getValue().getName() + "' of signature '" + + signatureEntry.getValue().getMethodName() + + "': " + Exceptions.toMessageString(e)); } } } @@ -104,14 +102,12 @@ public class TensorFlowImporter { } private ExpressionNode importOutput(TensorInfo output, Map<String, TensorType> inputs, GraphDef graph) { - System.out.println(" Importing output " + output.getName()); NodeDef node = getNode(nameOf(output.getName()), graph); return new TensorFunctionNode(importNode(node, inputs, graph, "").function()); } /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ private TypedTensorFunction importNode(NodeDef tfNode, Map<String, TensorType> inputs, GraphDef graph, String indent) { - System.out.println(" " + indent + "Importing node " + tfNode.getName() + " with operation " + tfNode.getOp()); return tensorFunctionOf(tfNode, inputs, graph, indent); } @@ -151,5 +147,12 @@ public class TensorFlowImporter { private String nameOf(String name) { return name.split(":")[0]; } - + + /** An interface which can be implemented to receive messages emitted during import */ + public interface MessageLogger { + + void log(Level level, String message); + + } + } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java index f2164a1b177..c780b3d0c7d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -4,7 +4,9 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import org.junit.Test; +import java.util.ArrayList; import java.util.List; +import java.util.logging.Level; import static org.junit.Assert.assertEquals; @@ -15,8 +17,18 @@ public class TensorFlowImporterTestCase { @Test public void testModel1() { - List<RankingExpression> expressions = - new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/"); + TestLogger logger = new TestLogger(); + List<RankingExpression> expressions = + new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/", logger); + + // Check logged messages + assertEquals(2, logger.messages.size()); + assertEquals("Skipping output 'index_to_string_Lookup:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'LookupTableFindV2' is not supported", + logger.messages.get(0)); + assertEquals("Skipping output 'TopKV2:0' of signature 'tensorflow/serving/classify': Conversion of TensorFlow operation 'TopKV2' is not supported", + logger.messages.get(1)); + + // Check resulting Vespa expression assertEquals(1, expressions.size()); assertEquals("scores", expressions.get(0).getName()); assertEquals("" + @@ -32,4 +44,15 @@ public class TensorFlowImporterTestCase { return ((TensorFunctionNode)expression.getRoot()).function().toString(); } + private class TestLogger implements TensorFlowImporter.MessageLogger { + + List<String> messages = new ArrayList<>(); + + @Override + public void log(Level level, String message) { + messages.add(message); + } + + } + } |