diff options
author | Lester Solbakken <lesters@oath.com> | 2020-08-31 09:08:26 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-09-04 08:30:25 +0200 |
commit | 67caf3b6eee690bcd0c7fc7a7666bd2cf41b8816 (patch) | |
tree | f0b424c80c357091d738831bfcce16d071f006a1 /model-integration/src/main | |
parent | 9a735030cd58e7f7ce7c2cd9bcaae121089e6ee7 (diff) |
Import TensorFlow models via ONNX conversion
Diffstat (limited to 'model-integration/src/main')
4 files changed, 18 insertions, 24 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index a9d71b7d9d5..8f73cd02184 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -74,7 +74,7 @@ public abstract class ModelImporter implements MlModelImporter { signature.input(input.getKey(), input.getValue()); } for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) { - signature.output(output.getKey(), output.getValue()); + signature.output(IntermediateOperation.vespaName(output.getKey()), output.getValue()); } } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index f8c7dc15857..c8d7392bb8d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -66,13 +66,13 @@ class TensorConverter { } private static class RawBoolValues extends RawValues { - private final IntBuffer values; + private final ByteString values; private final int size; RawBoolValues(Onnx.TensorProto tensorProto) { - values = bytes(tensorProto).asIntBuffer(); - size = values.remaining(); + values = tensorProto.getRawData(); + size = values.size(); } - @Override double get(int i) { return values.get(i); } + @Override double get(int i) { return values.byteAt(i) == 0 ? 0.0 : 1.0; } @Override int size() { return size; } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 6711b999940..a5244125d5a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -166,7 +166,7 @@ public abstract class IntermediateOperation { return vespaName(name); } - public String vespaName(String name) { + public static String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index 34b9c847a12..71b9c66a5c0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -26,7 +26,7 @@ public class TensorFlowImporter extends ModelImporter { private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); - private final static int defaultOnnxOpset = 8; + private final static int[] onnxOpsetsToTry = {8, 10, 12}; private final OnnxImporter onnxImporter = new OnnxImporter(); @@ -52,19 +52,10 @@ public class TensorFlowImporter extends ModelImporter { */ @Override public ImportedModel importModel(String modelName, String modelDir) { - // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model. - if (modelDir.contains("tf_2_onnx")) { - return convertToOnnxAndImport(modelName, modelDir); - } - try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { - return importModel(modelName, modelDir, model); - } - catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); - } + return convertToOnnxAndImport(modelName, modelDir); } - /** Imports a TensorFlow model */ + /** Imports a TensorFlow model - DEPRECATED */ public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); @@ -78,15 +69,18 @@ public class TensorFlowImporter extends ModelImporter { private ImportedModel convertToOnnxAndImport(String modelName, String modelDir) { Path tempDir = null; try { - log.info("Converting TensorFlow model '" + modelDir + "' to ONNX..."); tempDir = Files.createTempDirectory("tf2onnx"); String convertedPath = tempDir.toString() + File.separatorChar + "converted.onnx"; - Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, defaultOnnxOpset); - if (res.getFirst() != 0) { - throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'. " + - "Reason: " + res.getSecond()); + for (int opset : onnxOpsetsToTry) { + log.info("Converting TensorFlow model '" + modelDir + "' to ONNX with opset " + opset + "..."); + Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, opset); + if (res.getFirst() == 0) { + log.info("Conversion to ONNX with opset " + opset + " successful."); + return onnxImporter.importModel(modelName, convertedPath); + } + log.info("Conversion to ONNX with opset " + opset + " failed. Reason: " + res.getSecond()); } - return onnxImporter.importModel(modelName, convertedPath); + throw new IllegalArgumentException("Unable to convert TensorFlow model in '" + modelDir + "' to ONNX."); } catch (IOException e) { throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'"); } finally { |