summaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp')
-rw-r--r--eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp13
1 files changed, 13 insertions, 0 deletions
diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
index 0237f6cc769..b4e0db4fa8f 100644
--- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
+++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
@@ -162,6 +162,16 @@ struct DenseFixture
void assertSerialized(const ExpBuffer &exp, const DenseTensorCells &rhs) {
assertSerialized(exp, SerializeFormat::DOUBLE, rhs);
}
+ template <typename T>
+ void assertCellsOnly(const ExpBuffer &exp, const DenseTensorView & rhs) {
+ nbostream a(&exp[0], exp.size());
+ std::vector<T> v;
+ TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(a, v);
+ EXPECT_EQUAL(v.size(), rhs.cellsRef().size());
+ for (size_t i(0); i < v.size(); i++) {
+ EXPECT_EQUAL(v[i], rhs.cellsRef()[i]);
+ }
+ }
void assertSerialized(const ExpBuffer &exp, SerializeFormat cellType, const DenseTensorCells &rhs) {
Tensor::UP rhsTensor(createTensor(rhs));
nbostream rhsStream;
@@ -169,6 +179,9 @@ struct DenseFixture
EXPECT_EQUAL(exp, rhsStream);
auto rhs2 = deserialize(rhsStream);
EXPECT_EQUAL(*rhs2, *rhsTensor);
+
+ assertCellsOnly<float>(exp, dynamic_cast<const DenseTensorView &>(*rhs2));
+ assertCellsOnly<double>(exp, dynamic_cast<const DenseTensorView &>(*rhs2));
}
};