summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-11-09 14:18:11 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-11-09 14:18:11 +0000
commit56c26cfe5d3fd41169d928180f91d029b1295adf (patch)
tree95a77a19c31764e3984877eb516eae8444fc5713 /eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp
parent9d47aad83ad4750fbfeb934c4a85abf353988974 (diff)
use const references and stash instead of UP
Diffstat (limited to 'eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp')
-rw-r--r--eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp27
1 files changed, 16 insertions, 11 deletions
diff --git a/eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp b/eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp
index 6dcfc0791e7..63829650cc5 100644
--- a/eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp
+++ b/eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp
@@ -3,32 +3,36 @@
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/tensor/dense/dense_dot_product_function.h>
#include <vespa/eval/tensor/dense/dense_tensor_function_compiler.h>
+#include <vespa/eval/eval/operation.h>
using namespace vespalib::eval;
using namespace vespalib::eval::operation;
using namespace vespalib::eval::tensor_function;
using namespace vespalib::tensor;
+using vespalib::Stash;
template <typename T>
const T *as(const TensorFunction &function) { return dynamic_cast<const T *>(&function); }
-TensorFunction::UP
+const TensorFunction &
compileDotProduct(const vespalib::string &lhsType,
- const vespalib::string &rhsType)
+ const vespalib::string &rhsType,
+ Stash &stash)
{
- Node_UP reduceNode = reduce(join(inject(ValueType::from_spec(lhsType), 1),
- inject(ValueType::from_spec(rhsType), 3),
- Mul::f),
- Aggr::SUM, {});
- return DenseTensorFunctionCompiler::compile(std::move(reduceNode));
+ const Node &reduceNode = reduce(join(inject(ValueType::from_spec(lhsType), 1, stash),
+ inject(ValueType::from_spec(rhsType), 3, stash),
+ Mul::f, stash),
+ Aggr::SUM, {}, stash);
+ return DenseTensorFunctionCompiler::compile(reduceNode, stash);
}
void
assertCompiledDotProduct(const vespalib::string &lhsType,
const vespalib::string &rhsType)
{
- TensorFunction::UP func = compileDotProduct(lhsType, rhsType);
- const DenseDotProductFunction *dotProduct = as<DenseDotProductFunction>(*func);
+ Stash stash;
+ const TensorFunction &func = compileDotProduct(lhsType, rhsType, stash);
+ const DenseDotProductFunction *dotProduct = as<DenseDotProductFunction>(func);
ASSERT_TRUE(dotProduct);
EXPECT_EQUAL(1u, dotProduct->lhsTensorId());
EXPECT_EQUAL(3u, dotProduct->rhsTensorId());
@@ -38,8 +42,9 @@ void
assertNotCompiledDotProduct(const vespalib::string &lhsType,
const vespalib::string &rhsType)
{
- TensorFunction::UP func = compileDotProduct(lhsType, rhsType);
- const Reduce *reduce = as<Reduce>(*func);
+ Stash stash;
+ const TensorFunction &func = compileDotProduct(lhsType, rhsType, stash);
+ const Reduce *reduce = as<Reduce>(func);
EXPECT_TRUE(reduce);
}