diff options
Diffstat (limited to 'eval')
6 files changed, 72 insertions, 11 deletions
diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp index 0237f6cc769..b4e0db4fa8f 100644 --- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp +++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp @@ -162,6 +162,16 @@ struct DenseFixture void assertSerialized(const ExpBuffer &exp, const DenseTensorCells &rhs) { assertSerialized(exp, SerializeFormat::DOUBLE, rhs); } + template <typename T> + void assertCellsOnly(const ExpBuffer &exp, const DenseTensorView & rhs) { + nbostream a(&exp[0], exp.size()); + std::vector<T> v; + TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(a, v); + EXPECT_EQUAL(v.size(), rhs.cellsRef().size()); + for (size_t i(0); i < v.size(); i++) { + EXPECT_EQUAL(v[i], rhs.cellsRef()[i]); + } + } void assertSerialized(const ExpBuffer &exp, SerializeFormat cellType, const DenseTensorCells &rhs) { Tensor::UP rhsTensor(createTensor(rhs)); nbostream rhsStream; @@ -169,6 +179,9 @@ struct DenseFixture EXPECT_EQUAL(exp, rhsStream); auto rhs2 = deserialize(rhsStream); EXPECT_EQUAL(*rhs2, *rhsTensor); + + assertCellsOnly<float>(exp, dynamic_cast<const DenseTensorView &>(*rhs2)); + assertCellsOnly<double>(exp, dynamic_cast<const DenseTensorView &>(*rhs2)); } }; diff --git a/eval/src/vespa/eval/tensor/serialization/common.h b/eval/src/vespa/eval/tensor/serialization/common.h index 9c45bc42136..40b1840be6e 100644 --- a/eval/src/vespa/eval/tensor/serialization/common.h +++ b/eval/src/vespa/eval/tensor/serialization/common.h @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp index 6043153adc3..4b1ccc8db5d 100644 --- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp @@ -57,9 +57,9 @@ decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) { return cellsSize; } -template<typename T> +template<typename T, typename V> void -decodeCells(nbostream &stream, size_t cellsSize, DenseTensor::Cells & cells) { +decodeCells(nbostream &stream, size_t cellsSize, V & cells) { T cellValue = 0.0; for (size_t i = 0; i < cellsSize; ++i) { stream >> cellValue; @@ -67,6 +67,19 @@ decodeCells(nbostream &stream, size_t cellsSize, DenseTensor::Cells & cells) { } } +template <typename V> +void decodeCells(SerializeFormat format, nbostream &stream, size_t cellsSize, V & cells) +{ + switch (format) { + case SerializeFormat::DOUBLE: + decodeCells<double>(stream, cellsSize, cells); + break; + case SerializeFormat::FLOAT: + decodeCells<float>(stream, cellsSize, cells); + break; + } +} + } void @@ -93,16 +106,24 @@ DenseBinaryFormat::deserialize(nbostream &stream) size_t cellsSize = decodeDimensions(stream,dimensions); DenseTensor::Cells cells; cells.reserve(cellsSize); - switch (_format) { - case SerializeFormat::DOUBLE: - decodeCells<double>(stream, cellsSize,cells); - break; - case SerializeFormat::FLOAT: - decodeCells<float>(stream, cellsSize, cells); - break; - } + + decodeCells(_format, stream, cellsSize, cells); return std::make_unique<DenseTensor>(makeValueType(std::move(dimensions)), std::move(cells)); } +template <typename T> +void +DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> & cells) +{ + std::vector<Dimension> dimensions; + size_t cellsSize = decodeDimensions(stream,dimensions); + cells.clear(); + cells.reserve(cellsSize); + decodeCells(_format, stream, cellsSize, cells); +} + +template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> & cells); +template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> & cells); + } diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h index 22c1663719e..f9847d37784 100644 --- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h @@ -4,6 +4,7 @@ #include "common.h" #include <memory> +#include <vector> namespace vespalib { class nbostream; } @@ -21,6 +22,10 @@ public: DenseBinaryFormat(SerializeFormat format) : _format(format) { } void serialize(nbostream &stream, const DenseTensorView &tensor); std::unique_ptr<DenseTensor> deserialize(nbostream &stream); + + // This is a temporary method untill we get full support for typed tensors + template <typename T> + void deserializeCellsOnly(nbostream &stream, std::vector<T> & cells); private: SerializeFormat _format; }; diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp index d1aa09b6ce3..813763ba268 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp @@ -99,4 +99,21 @@ TypedBinaryFormat::deserialize(nbostream &stream) abort(); } +template <typename T> +void +TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells) +{ + auto formatId = stream.getInt1_4Bytes(); + if (formatId == DENSE_BINARY_FORMAT_TYPE) { + return DenseBinaryFormat(SerializeFormat::DOUBLE).deserializeCellsOnly(stream, cells); + } + if (formatId == TYPED_DENSE_BINARY_FORMAT_TYPE) { + return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserializeCellsOnly(stream, cells); + } + abort(); +} + +template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<double> & cells); +template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<float> & cells); + } diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h index 95d9a75488c..717d51effef 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h @@ -4,6 +4,7 @@ #include "common.h" #include <memory> +#include <vector> namespace vespalib { class nbostream; } @@ -23,6 +24,10 @@ public: } static std::unique_ptr<Tensor> deserialize(nbostream &stream); + + // This is a temporary method until we get full support for typed tensors + template <typename T> + static void deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells); }; } |