summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2019-12-12 09:42:54 -0800
committerLester Solbakken <lesters@oath.com>2019-12-12 09:42:54 -0800
commitf99ef6d4d400be906d26fbf59762bc27553ed32b (patch)
treead25e844c0237673b4549b7d30fd1e420aebb7d3 /model-integration/src/main
parent14b0a54720077edf95d270741d207f9015a1c7aa (diff)
Initial conversion of TF to ONNX for testing
Diffstat (limited to 'model-integration/src/main')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java43
1 files changed, 43 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index 2a406f92756..96ea58edc61 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -4,10 +4,17 @@ package ai.vespa.rankingexpression.importer.tensorflow;
import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.ModelImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import com.yahoo.collections.Pair;
+import com.yahoo.io.IOUtils;
+import com.yahoo.system.ProcessExecuter;
import org.tensorflow.SavedModelBundle;
import java.io.File;
import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.logging.Logger;
/**
* Converts a saved TensorFlow model into a ranking expression and set of constants.
@@ -17,6 +24,12 @@ import java.io.IOException;
*/
public class TensorFlowImporter extends ModelImporter {
+ private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
+
+ private final static int defaultOnnxOpset = 8;
+
+ private final OnnxImporter onnxImporter = new OnnxImporter();
+
@Override
public boolean canImport(String modelPath) {
File modelDir = new File(modelPath);
@@ -39,6 +52,10 @@ public class TensorFlowImporter extends ModelImporter {
*/
@Override
public ImportedModel importModel(String modelName, String modelDir) {
+ // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model.
+ if (modelDir.contains("tf_2_onnx")) {
+ return convertToOnnxAndImport(modelName, modelDir);
+ }
try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
return importModel(modelName, modelDir, model);
}
@@ -58,5 +75,31 @@ public class TensorFlowImporter extends ModelImporter {
}
}
+ private ImportedModel convertToOnnxAndImport(String modelName, String modelDir) {
+ Path tempDir = null;
+ try {
+ log.info("Converting TensorFlow model '" + modelDir + "' to ONNX...");
+ tempDir = Files.createTempDirectory("tf2onnx");
+ String convertedPath = tempDir.toString() + File.separatorChar + "converted.onnx";
+ Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, defaultOnnxOpset);
+ if (res.getFirst() != 0) {
+ throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'. " +
+ "Reason: " + res.getSecond());
+ }
+ return onnxImporter.importModel(modelName, convertedPath);
+ } catch (IOException e) {
+ throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'");
+ } finally {
+ if (tempDir != null) {
+ IOUtils.recursiveDeleteDir(tempDir.toFile());
+ }
+ }
+ }
+
+ private Pair<Integer, String> convertToOnnx(String savedModel, String output, int opset) throws IOException {
+ ProcessExecuter executer = new ProcessExecuter();
+ String job = "python3 -m tf2onnx.convert --saved-model " + savedModel + " --output " + output + " --opset " + opset;
+ return executer.exec(job);
+ }
}