From f4fd8f8e29fa87ec79afce4172ce3d72ab6693f0 Mon Sep 17 00:00:00 2001 From: HÃ¥vard Pettersen Date: Thu, 10 Feb 2022 11:44:36 +0000 Subject: add code to simplify onnx model testing test model that combines probing and inference of dimension sizes --- .../tensor/onnx_wrapper/onnx_wrapper_test.cpp | 22 ++++++++++++++ .../src/tests/tensor/onnx_wrapper/probe_model.onnx | 30 +++++++++++++++++++ eval/src/tests/tensor/onnx_wrapper/probe_model.py | 35 ++++++++++++++++++++++ 3 files changed, 87 insertions(+) create mode 100644 eval/src/tests/tensor/onnx_wrapper/probe_model.onnx create mode 100755 eval/src/tests/tensor/onnx_wrapper/probe_model.py (limited to 'eval/src/tests/tensor') diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp index da957673f95..e50c41e2e09 100644 --- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp +++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp @@ -2,6 +2,7 @@ #include #include +#include #include #include #include @@ -28,6 +29,7 @@ std::string int_types_model = source_dir + "/int_types.onnx"; std::string guess_batch_model = source_dir + "/guess_batch.onnx"; std::string unstable_types_model = source_dir + "/unstable_types.onnx"; std::string float_to_int8_model = source_dir + "/float_to_int8.onnx"; +std::string probe_model = source_dir + "/probe_model.onnx"; void dump_info(const char *ctx, const std::vector &info) { fprintf(stderr, "%s:\n", ctx); @@ -504,4 +506,24 @@ TEST(OnnxModelCacheTest, share_and_evict_onnx_models) { EXPECT_EQ(OnnxModelCache::count_refs(), 0); } +TensorSpec val(const vespalib::string &expr) { + auto result = TensorSpec::from_expr(expr); + EXPECT_FALSE(ValueType::from_spec(result.type()).is_error()); + return result; +} + +TEST(OnnxTest, eval_onnx_with_probe_model) { + Onnx model(probe_model, Onnx::Optimize::ENABLE); + auto in1 = val("tensor( x[2], y[3]):[[ 1, 2, 3],[ 4, 5, 6]]"); + auto in2 = val("tensor( x[2], y[3]):[[ 7, 8, 9],[ 4, 5, 6]]"); + auto out1 = val("tensor(d0[2],d1[3]):[[ 8,10,12],[ 8,10,12]]"); + auto out2 = val("tensor(d0[2],d1[3]):[[-6,-6,-6],[ 0, 0, 0]]"); + auto out3 = val("tensor(d0[2],d1[3]):[[ 7,16,27],[16,25,36]]"); + auto result = test::eval_onnx(model, {in1, in2}); + ASSERT_EQ(result.size(), 3); + EXPECT_EQ(result[0], out1); + EXPECT_EQ(result[1], out2); + EXPECT_EQ(result[2], out3); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/tensor/onnx_wrapper/probe_model.onnx b/eval/src/tests/tensor/onnx_wrapper/probe_model.onnx new file mode 100644 index 00000000000..89dab2e7c4c --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/probe_model.onnx @@ -0,0 +1,30 @@ +probe_model.py:’ + +in1 +in2out1"Add + +in1 +in2out2"Sub + +in1 +in2out3"Mul probe_modelZ# +in1 + + ÿÿÿÿÿÿÿÿÿ +innerZ# +in2 + +outer + ÿÿÿÿÿÿÿÿÿb$ +out1 + + ÿÿÿÿÿÿÿÿÿ +innerb$ +out2 + +outer + ÿÿÿÿÿÿÿÿÿb( +out3 + + ÿÿÿÿÿÿÿÿÿ + ÿÿÿÿÿÿÿÿÿB \ No newline at end of file diff --git a/eval/src/tests/tensor/onnx_wrapper/probe_model.py b/eval/src/tests/tensor/onnx_wrapper/probe_model.py new file mode 100755 index 00000000000..529fa23b2b1 --- /dev/null +++ b/eval/src/tests/tensor/onnx_wrapper/probe_model.py @@ -0,0 +1,35 @@ +# Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +from onnx import helper, TensorProto + +IN1 = helper.make_tensor_value_info('in1', TensorProto.FLOAT, [-1, 'inner']) +IN2 = helper.make_tensor_value_info('in2', TensorProto.FLOAT, ['outer', -1]) +OUT1 = helper.make_tensor_value_info('out1', TensorProto.FLOAT, [-1, 'inner']) +OUT2 = helper.make_tensor_value_info('out2', TensorProto.FLOAT, ['outer', -1]) +OUT3 = helper.make_tensor_value_info('out3', TensorProto.FLOAT, [-1, -1]) + +nodes = [ + helper.make_node( + 'Add', + ['in1', 'in2'], + ['out1'], + ), + helper.make_node( + 'Sub', + ['in1', 'in2'], + ['out2'], + ), + helper.make_node( + 'Mul', + ['in1', 'in2'], + ['out3'], + ), +] +graph_def = helper.make_graph( + nodes, + 'probe_model', + [IN1, IN2], + [OUT1, OUT2, OUT3], +) +model_def = helper.make_model(graph_def, producer_name='probe_model.py', opset_imports=[onnx.OperatorSetIdProto(version=12)]) +onnx.save(model_def, 'probe_model.onnx') -- cgit v1.2.3