summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-02 13:52:56 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-02 14:20:19 +0000
commite3dc8d27a19e75323ac16322dc0e43a58c49dbae (patch)
tree3dc242b846b4ee5fd2129d12395e19c0d6a7d3f1
parent5638fb88e06e835a03cc0d9c70725a37dacd2974 (diff)
less casting to DenseTensorView
-rw-r--r--eval/src/tests/tensor/dense_replace_type_function/dense_replace_type_function_test.cpp2
-rw-r--r--eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp5
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp6
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp3
4 files changed, 7 insertions, 9 deletions
diff --git a/eval/src/tests/tensor/dense_replace_type_function/dense_replace_type_function_test.cpp b/eval/src/tests/tensor/dense_replace_type_function/dense_replace_type_function_test.cpp
index 732fc9c3e69..9ebcb0ec77c 100644
--- a/eval/src/tests/tensor/dense_replace_type_function/dense_replace_type_function_test.cpp
+++ b/eval/src/tests/tensor/dense_replace_type_function/dense_replace_type_function_test.cpp
@@ -16,7 +16,7 @@ using namespace vespalib;
const TensorEngine &engine = DefaultTensorEngine::ref();
TypedCells getCellsRef(const eval::Value &value) {
- return static_cast<const DenseTensorView &>(value).cellsRef();
+ return value.cells();
}
struct ChildMock : Leaf {
diff --git a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
index 75e9c7868ce..4c96145862c 100644
--- a/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
+++ b/eval/src/tests/tensor/direct_dense_tensor_builder/direct_dense_tensor_builder_test.cpp
@@ -29,9 +29,8 @@ void assertTensor(const vespalib::string &type_spec,
const std::vector<double> &expCells,
const Tensor &tensor)
{
- const DenseTensorView &realTensor = dynamic_cast<const DenseTensorView &>(tensor);
- EXPECT_EQUAL(ValueType::from_spec(type_spec), realTensor.type());
- EXPECT_EQUAL(expCells, dispatch_1<CallMakeVector>(realTensor.cellsRef()));
+ EXPECT_EQUAL(ValueType::from_spec(type_spec), tensor.type());
+ EXPECT_EQUAL(expCells, dispatch_1<CallMakeVector>(tensor.cells()));
}
void assertTensorSpec(const TensorSpec &expSpec, const Tensor &tensor) {
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index d1d8bc796ba..fce7ccc6411 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -162,7 +162,7 @@ TEST(OnnxTest, simple_onnx_model_can_be_evaluated)
ctx.bind_param(1, attribute);
ctx.bind_param(2, bias);
ctx.eval();
- auto cells = static_cast<const DenseTensorView&>(output).cellsRef();
+ auto cells = output.cells();
EXPECT_EQ(cells.type, ValueType::CellType::FLOAT);
EXPECT_EQ(cells.size, 1);
EXPECT_EQ(GetCell::from(cells, 0), 79.0);
@@ -208,7 +208,7 @@ TEST(OnnxTest, dynamic_onnx_model_can_be_evaluated)
ctx.bind_param(1, attribute);
ctx.bind_param(2, bias);
ctx.eval();
- auto cells = static_cast<const DenseTensorView&>(output).cellsRef();
+ auto cells = output.cells();
EXPECT_EQ(cells.type, ValueType::CellType::FLOAT);
EXPECT_EQ(cells.size, 1);
EXPECT_EQ(GetCell::from(cells, 0), 79.0);
@@ -254,7 +254,7 @@ TEST(OnnxTest, int_types_onnx_model_can_be_evaluated)
ctx.bind_param(1, attribute);
ctx.bind_param(2, bias);
ctx.eval();
- auto cells = static_cast<const DenseTensorView&>(output).cellsRef();
+ auto cells = output.cells();
EXPECT_EQ(cells.type, ValueType::CellType::DOUBLE);
EXPECT_EQ(cells.size, 1);
EXPECT_EQ(GetCell::from(cells, 0), 79.0);
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 7d4bff21380..11c1ce74eca 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -445,8 +445,7 @@ struct CallAppendVector {
template <typename OCT>
void append_vector(OCT *&pos, const Value &value) {
if (auto tensor = value.as_tensor()) {
- const DenseTensorView *view = static_cast<const DenseTensorView *>(tensor);
- dispatch_1<CallAppendVector<OCT> >(view->cellsRef(), pos);
+ dispatch_1<CallAppendVector<OCT> >(tensor->cells(), pos);
} else {
*pos++ = value.as_double();
}