From f99ef6d4d400be906d26fbf59762bc27553ed32b Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 12 Dec 2019 09:42:54 -0800 Subject: Initial conversion of TF to ONNX for testing --- .../importer/tensorflow/TensorFlowImporter.java | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) (limited to 'model-integration/src/main') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index 2a406f92756..96ea58edc61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -4,10 +4,17 @@ package ai.vespa.rankingexpression.importer.tensorflow; import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.ModelImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import com.yahoo.collections.Pair; +import com.yahoo.io.IOUtils; +import com.yahoo.system.ProcessExecuter; import org.tensorflow.SavedModelBundle; import java.io.File; import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.logging.Logger; /** * Converts a saved TensorFlow model into a ranking expression and set of constants. @@ -17,6 +24,12 @@ import java.io.IOException; */ public class TensorFlowImporter extends ModelImporter { + private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); + + private final static int defaultOnnxOpset = 8; + + private final OnnxImporter onnxImporter = new OnnxImporter(); + @Override public boolean canImport(String modelPath) { File modelDir = new File(modelPath); @@ -39,6 +52,10 @@ public class TensorFlowImporter extends ModelImporter { */ @Override public ImportedModel importModel(String modelName, String modelDir) { + // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model. + if (modelDir.contains("tf_2_onnx")) { + return convertToOnnxAndImport(modelName, modelDir); + } try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { return importModel(modelName, modelDir, model); } @@ -58,5 +75,31 @@ public class TensorFlowImporter extends ModelImporter { } } + private ImportedModel convertToOnnxAndImport(String modelName, String modelDir) { + Path tempDir = null; + try { + log.info("Converting TensorFlow model '" + modelDir + "' to ONNX..."); + tempDir = Files.createTempDirectory("tf2onnx"); + String convertedPath = tempDir.toString() + File.separatorChar + "converted.onnx"; + Pair res = convertToOnnx(modelDir, convertedPath, defaultOnnxOpset); + if (res.getFirst() != 0) { + throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'. " + + "Reason: " + res.getSecond()); + } + return onnxImporter.importModel(modelName, convertedPath); + } catch (IOException e) { + throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'"); + } finally { + if (tempDir != null) { + IOUtils.recursiveDeleteDir(tempDir.toFile()); + } + } + } + + private Pair convertToOnnx(String savedModel, String output, int opset) throws IOException { + ProcessExecuter executer = new ProcessExecuter(); + String job = "python3 -m tf2onnx.convert --saved-model " + savedModel + " --output " + output + " --opset " + opset; + return executer.exec(job); + } } -- cgit v1.2.3