summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp')
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp9
1 files changed, 8 insertions, 1 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index 21f0044faf1..da957673f95 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -281,7 +281,7 @@ TEST(OnnxTest, int_types_onnx_model_can_be_evaluated)
//-------------------------------------------------------------------------
}
-TEST(OnnxTest, we_guess_batch_dimension_size_when_inference_fails) {
+TEST(OnnxTest, we_probe_batch_dimension_size_when_inference_fails) {
Onnx model(guess_batch_model, Onnx::Optimize::ENABLE);
Onnx::WirePlanner planner_3;
Onnx::WirePlanner planner_4;
@@ -298,6 +298,13 @@ TEST(OnnxTest, we_guess_batch_dimension_size_when_inference_fails) {
EXPECT_TRUE(planner_4.bind_input_type(in_4_type, model.inputs()[0]));
EXPECT_TRUE(planner_4.bind_input_type(in_4_type, model.inputs()[1]));
+ // without model probe
+ EXPECT_TRUE(planner_3.make_output_type(model.outputs()[0]).is_error());
+ EXPECT_TRUE(planner_4.make_output_type(model.outputs()[0]).is_error());
+
+ // with model probe
+ planner_3.prepare_output_types(model);
+ planner_4.prepare_output_types(model);
EXPECT_EQ(planner_3.make_output_type(model.outputs()[0]).to_spec(), "tensor<float>(d0[3])");
EXPECT_EQ(planner_4.make_output_type(model.outputs()[0]).to_spec(), "tensor<float>(d0[4])");