From 450c0f9bd7eee971c8f63824cbb4f2b7a3a4d1c4 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Mon, 30 Nov 2020 12:51:04 +0000 Subject: compare with reference operation --- eval/src/tests/streamed/value/streamed_value_test.cpp | 16 ++++------------ 1 file changed, 4 insertions(+), 12 deletions(-) (limited to 'eval/src/tests/streamed') diff --git a/eval/src/tests/streamed/value/streamed_value_test.cpp b/eval/src/tests/streamed/value/streamed_value_test.cpp index 3de6ba0fb63..05d6e20451c 100644 --- a/eval/src/tests/streamed/value/streamed_value_test.cpp +++ b/eval/src/tests/streamed/value/streamed_value_test.cpp @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include +#include #include #include #include @@ -59,16 +60,7 @@ std::vector join_layouts = { float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})}) }; -TensorSpec simple_tensor_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { - Stash stash; - const auto &engine = SimpleTensorEngine::ref(); - auto lhs = engine.from_spec(a); - auto rhs = engine.from_spec(b); - const auto &result = engine.join(*lhs, *rhs, function, stash); - return engine.to_spec(result); -} - -TensorSpec streamed_value_new_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { +TensorSpec streamed_value_join(const TensorSpec &a, const TensorSpec &b, join_fun_t function) { Stash stash; const auto &factory = StreamedValueBuilderFactory::get(); auto lhs = value_from_spec(a, factory); @@ -126,8 +118,8 @@ TEST(StreamedValueTest, new_generic_join_works_for_streamed_values) { TensorSpec rhs = spec(join_layouts[i + 1], Div16(N())); for (auto fun: {operation::Add::f, operation::Sub::f, operation::Mul::f, operation::Max::f}) { SCOPED_TRACE(fmt("\n===\nLHS: %s\nRHS: %s\n===\n", lhs.to_string().c_str(), rhs.to_string().c_str())); - auto expect = simple_tensor_join(lhs, rhs, fun); - auto actual = streamed_value_new_join(lhs, rhs, fun); + auto expect = ReferenceOperations::join(lhs, rhs, fun); + auto actual = streamed_value_join(lhs, rhs, fun); EXPECT_EQ(actual, expect); } } -- cgit v1.2.3