diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-12 19:33:11 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-12 19:33:11 +0200 |
commit | 3905dbf4455c4426f86f08ec925d7f66a06e85b8 (patch) | |
tree | 621bbb8832ec4a0c8e674c709bcc99aecdfbc528 /model-integration | |
parent | 599ad95a4e5003b903e464f91210892c1bee44ce (diff) |
Revert "Import Tensorflow models vis ONNX conversion"
This reverts commit 0a886d74d4c9ffde41eef1f7e3c186b60b9f3726.
Diffstat (limited to 'model-integration')
5 files changed, 67 insertions, 18 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 8f73cd02184..a9d71b7d9d5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -74,7 +74,7 @@ public abstract class ModelImporter implements MlModelImporter { signature.input(input.getKey(), input.getValue()); } for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) { - signature.output(IntermediateOperation.vespaName(output.getKey()), output.getValue()); + signature.output(output.getKey(), output.getValue()); } } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index c8d7392bb8d..f8c7dc15857 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -66,13 +66,13 @@ class TensorConverter { } private static class RawBoolValues extends RawValues { - private final ByteString values; + private final IntBuffer values; private final int size; RawBoolValues(Onnx.TensorProto tensorProto) { - values = tensorProto.getRawData(); - size = values.size(); + values = bytes(tensorProto).asIntBuffer(); + size = values.remaining(); } - @Override double get(int i) { return values.byteAt(i) == 0 ? 0.0 : 1.0; } + @Override double get(int i) { return values.get(i); } @Override int size() { return size; } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 7647161db16..6e637c72d0f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -166,7 +166,7 @@ public abstract class IntermediateOperation { return vespaName(name); } - public static String vespaName(String name) { + public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index 5bf11ed8cf6..96ea58edc61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -26,7 +26,7 @@ public class TensorFlowImporter extends ModelImporter { private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); - private final static int[] onnxOpsetsToTry = {8, 10, 12}; + private final static int defaultOnnxOpset = 8; private final OnnxImporter onnxImporter = new OnnxImporter(); @@ -52,10 +52,19 @@ public class TensorFlowImporter extends ModelImporter { */ @Override public ImportedModel importModel(String modelName, String modelDir) { - return convertToOnnxAndImport(modelName, modelDir); + // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model. + if (modelDir.contains("tf_2_onnx")) { + return convertToOnnxAndImport(modelName, modelDir); + } + try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) { + return importModel(modelName, modelDir, model); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); + } } - /** Imports a TensorFlow model - DEPRECATED */ + /** Imports a TensorFlow model */ public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { try { IntermediateGraph graph = GraphImporter.importGraph(modelName, model); @@ -69,18 +78,15 @@ public class TensorFlowImporter extends ModelImporter { private ImportedModel convertToOnnxAndImport(String modelName, String modelDir) { Path tempDir = null; try { + log.info("Converting TensorFlow model '" + modelDir + "' to ONNX..."); tempDir = Files.createTempDirectory("tf2onnx"); String convertedPath = tempDir.toString() + File.separatorChar + "converted.onnx"; - for (int opset : onnxOpsetsToTry) { - log.info("Converting TensorFlow model '" + modelDir + "' to ONNX with opset " + opset + "..."); - Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, opset); - if (res.getFirst() == 0) { - log.info("Conversion to ONNX with opset " + opset + " successful."); - return onnxImporter.importModel(modelName, convertedPath); - } - log.info("Conversion to ONNX with opset " + opset + " failed. Reason: " + res.getSecond()); + Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, defaultOnnxOpset); + if (res.getFirst() != 0) { + throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'. " + + "Reason: " + res.getSecond()); } - throw new IllegalArgumentException("Unable to convert TensorFlow model in '" + modelDir + "' to ONNX."); + return onnxImporter.importModel(modelName, convertedPath); } catch (IOException e) { throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'"); } finally { diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 09455abc380..35c853bd746 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -55,4 +55,47 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals("{Placeholder=tensor<float>(d0[],d1[784])}", output.argumentTypes().toString()); } + @Test + public void testComparisonBetweenOnnxAndTensorflow() { + String tfModelPath = "src/test/models/tensorflow/mnist_softmax/saved"; + String onnxModelPath = "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx"; + + Tensor argument = placeholderArgument(); + Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add"); + Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add"); + + assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult); + } + + private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) { + ImportedModel model = new TensorFlowImporter().importModel("test", path); + return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); + } + + private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) { + ImportedModel model = new OnnxImporter().importModel("test", path); + return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input); + } + + private Tensor evaluateExpression(RankingExpression expression, Context context, Tensor argument, String input) { + context.put(input, new TensorValue(argument)); + return expression.evaluate(context).asTensor(); + } + + private Context contextFrom(ImportedModel result) { + MapContext context = new MapContext(); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + return context; + } + + private Tensor placeholderArgument() { + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 784).build()); + for (int d0 = 0; d0 < 1; d0++) + for (int d1 = 0; d1 < 784; d1++) + b.cell(d1 * 1.0 / 784, d0, d1); + return b.build(); + } + + } |