summaryrefslogtreecommitdiffstats
path: root/eval/src/apps/tensor_conformance/tensor_conformance.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/apps/tensor_conformance/tensor_conformance.cpp')
-rw-r--r--eval/src/apps/tensor_conformance/tensor_conformance.cpp25
1 files changed, 17 insertions, 8 deletions
diff --git a/eval/src/apps/tensor_conformance/tensor_conformance.cpp b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
index 37ecce51714..47089e1a8ff 100644
--- a/eval/src/apps/tensor_conformance/tensor_conformance.cpp
+++ b/eval/src/apps/tensor_conformance/tensor_conformance.cpp
@@ -164,8 +164,8 @@ void print_test(const Inspector &test, OutputWriter &dst) {
auto value = extract_value(test["inputs"][input]);
dst.printf("input '%s': %s\n", input.c_str(), value.to_string().c_str());
}
- auto result = extract_value(test["result"]["expect"]);
- dst.printf("expected result: %s\n", result.to_string().c_str());
+ auto result = eval_expr(test, prod_factory);
+ dst.printf("result: %s\n", result.to_string().c_str());
}
//-----------------------------------------------------------------------------
@@ -184,20 +184,24 @@ public:
for (const auto& [name, spec]: inputs_in) {
insert_value(inputs, name, spec);
}
- insert_value(test.setObject("result"), "expect", ref_eval(test));
+ test.setObject("result");
}
- void add_failing_test() {
+ void add_failing_test(bool ignore_fail) {
Cursor &test = _writer.create();
test.setString("expression", "a");
insert_value(test.setObject("inputs"), "a", GenSpec(1).idx("x", 3));
insert_value(test.setObject("result"), "dummy", GenSpec(2).idx("x", 3));
+ if (ignore_fail) {
+ test.setBool("ignore_fail", true);
+ }
}
};
void generate(Output &out, bool full) {
MyTestBuilder my_test_builder(full, out);
Generator::generate(my_test_builder);
- // my_test_builder.add_failing_test();
+ // my_test_builder.add_failing_test(true);
+ // my_test_builder.add_failing_test(false);
}
//-----------------------------------------------------------------------------
@@ -228,9 +232,12 @@ void verify(Input &in, Output &out) {
++result_map[result];
auto actual_result = extract_value(slime["result"][result]);
if (!require_impl::eq(actual_result, reference_result)) {
- ++fail_cnt;
- fprintf(stderr, "expression failed('%s'): '%s'\n", result.c_str(),
- slime["expression"].asString().make_string().c_str());
+ bool ignore_fail = slime["ignore_fail"].asBool();
+ if (!ignore_fail) {
+ ++fail_cnt;
+ }
+ fprintf(stderr, "%sexpression failed('%s'): '%s'\n", ignore_fail ? "IGNORED: " : "",
+ result.c_str(), slime["expression"].asString().make_string().c_str());
fprintf(stderr, "%s", TensorSpec::diff(actual_result, "actual", reference_result, "expected").c_str());
dump_test(slime.get());
}
@@ -241,6 +248,8 @@ void verify(Input &in, Output &out) {
for (const auto &entry: result_map) {
stats.setLong(entry.first, entry.second);
}
+ REQUIRE(!slime["fail_cnt"].valid());
+ slime.get().setLong("fail_cnt", fail_cnt);
JsonFormat::encode(slime, out, false);
};
for_each_test(in, handle_test, handle_summary);