aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-12 12:16:56 +0200
committerLester Solbakken <lesters@oath.com>2020-06-12 12:16:56 +0200
commit0a886d74d4c9ffde41eef1f7e3c186b60b9f3726 (patch)
treee142b94341563b28a2b4a0e26fe77458749d2ed9 /model-integration
parent8de8ff4f87295d812d4e660f0216953726200c92 (diff)
Import Tensorflow models vis ONNX conversion
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java30
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java43
5 files changed, 18 insertions, 67 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index a9d71b7d9d5..8f73cd02184 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -74,7 +74,7 @@ public abstract class ModelImporter implements MlModelImporter {
signature.input(input.getKey(), input.getValue());
}
for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
- signature.output(output.getKey(), output.getValue());
+ signature.output(IntermediateOperation.vespaName(output.getKey()), output.getValue());
}
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
index f8c7dc15857..c8d7392bb8d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java
@@ -66,13 +66,13 @@ class TensorConverter {
}
private static class RawBoolValues extends RawValues {
- private final IntBuffer values;
+ private final ByteString values;
private final int size;
RawBoolValues(Onnx.TensorProto tensorProto) {
- values = bytes(tensorProto).asIntBuffer();
- size = values.remaining();
+ values = tensorProto.getRawData();
+ size = values.size();
}
- @Override double get(int i) { return values.get(i); }
+ @Override double get(int i) { return values.byteAt(i) == 0 ? 0.0 : 1.0; }
@Override int size() { return size; }
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 6e637c72d0f..7647161db16 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -166,7 +166,7 @@ public abstract class IntermediateOperation {
return vespaName(name);
}
- public String vespaName(String name) {
+ public static String vespaName(String name) {
return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index 96ea58edc61..5bf11ed8cf6 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -26,7 +26,7 @@ public class TensorFlowImporter extends ModelImporter {
private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
- private final static int defaultOnnxOpset = 8;
+ private final static int[] onnxOpsetsToTry = {8, 10, 12};
private final OnnxImporter onnxImporter = new OnnxImporter();
@@ -52,19 +52,10 @@ public class TensorFlowImporter extends ModelImporter {
*/
@Override
public ImportedModel importModel(String modelName, String modelDir) {
- // Temporary (for testing): if path contains "tf_2_onnx", convert to ONNX then import that model.
- if (modelDir.contains("tf_2_onnx")) {
- return convertToOnnxAndImport(modelName, modelDir);
- }
- try (SavedModelBundle model = SavedModelBundle.load(modelDir, "serve")) {
- return importModel(modelName, modelDir, model);
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
- }
+ return convertToOnnxAndImport(modelName, modelDir);
}
- /** Imports a TensorFlow model */
+ /** Imports a TensorFlow model - DEPRECATED */
public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
try {
IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
@@ -78,15 +69,18 @@ public class TensorFlowImporter extends ModelImporter {
private ImportedModel convertToOnnxAndImport(String modelName, String modelDir) {
Path tempDir = null;
try {
- log.info("Converting TensorFlow model '" + modelDir + "' to ONNX...");
tempDir = Files.createTempDirectory("tf2onnx");
String convertedPath = tempDir.toString() + File.separatorChar + "converted.onnx";
- Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, defaultOnnxOpset);
- if (res.getFirst() != 0) {
- throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'. " +
- "Reason: " + res.getSecond());
+ for (int opset : onnxOpsetsToTry) {
+ log.info("Converting TensorFlow model '" + modelDir + "' to ONNX with opset " + opset + "...");
+ Pair<Integer, String> res = convertToOnnx(modelDir, convertedPath, opset);
+ if (res.getFirst() == 0) {
+ log.info("Conversion to ONNX with opset " + opset + " successful.");
+ return onnxImporter.importModel(modelName, convertedPath);
+ }
+ log.info("Conversion to ONNX with opset " + opset + " failed. Reason: " + res.getSecond());
}
- return onnxImporter.importModel(modelName, convertedPath);
+ throw new IllegalArgumentException("Unable to convert TensorFlow model in '" + modelDir + "' to ONNX.");
} catch (IOException e) {
throw new IllegalArgumentException("Conversion from TensorFlow to ONNX failed for '" + modelDir + "'");
} finally {
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 35c853bd746..09455abc380 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -55,47 +55,4 @@ public class OnnxMnistSoftmaxImportTestCase {
assertEquals("{Placeholder=tensor<float>(d0[],d1[784])}", output.argumentTypes().toString());
}
- @Test
- public void testComparisonBetweenOnnxAndTensorflow() {
- String tfModelPath = "src/test/models/tensorflow/mnist_softmax/saved";
- String onnxModelPath = "src/test/models/onnx/mnist_softmax/mnist_softmax.onnx";
-
- Tensor argument = placeholderArgument();
- Tensor tensorFlowResult = evaluateTensorFlowModel(tfModelPath, argument, "Placeholder", "add");
- Tensor onnxResult = evaluateOnnxModel(onnxModelPath, argument, "Placeholder", "add");
-
- assertEquals("Operation 'add' produces equal results", tensorFlowResult, onnxResult);
- }
-
- private Tensor evaluateTensorFlowModel(String path, Tensor argument, String input, String output) {
- ImportedModel model = new TensorFlowImporter().importModel("test", path);
- return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
- }
-
- private Tensor evaluateOnnxModel(String path, Tensor argument, String input, String output) {
- ImportedModel model = new OnnxImporter().importModel("test", path);
- return evaluateExpression(model.expressions().get(output), contextFrom(model), argument, input);
- }
-
- private Tensor evaluateExpression(RankingExpression expression, Context context, Tensor argument, String input) {
- context.put(input, new TensorValue(argument));
- return expression.evaluate(context).asTensor();
- }
-
- private Context contextFrom(ImportedModel result) {
- MapContext context = new MapContext();
- result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
- result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
- return context;
- }
-
- private Tensor placeholderArgument() {
- Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", 784).build());
- for (int d0 = 0; d0 < 1; d0++)
- for (int d1 = 0; d1 < 784; d1++)
- b.cell(d1 * 1.0 / 784, d0, d1);
- return b.build();
- }
-
-
}