diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-06-07 12:00:17 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-06-07 14:05:44 +0000 |
commit | 91f3fce64d2d1e97449b55400eee12bd92b650b6 (patch) | |
tree | f352a9a3060029729b01b478e96c89fa4dc7f727 /eval/src | |
parent | c4b7047b056dfab94c4b8a3d6575d6de8e482ffb (diff) |
serialize float cells
preserve tensor type across encode/decode
Diffstat (limited to 'eval/src')
9 files changed, 208 insertions, 181 deletions
diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp index b43a127bc60..d1491e4f758 100644 --- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp +++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp @@ -76,8 +76,8 @@ TEST("test tensor serialization for SparseTensor") { TensorSpec("tensor(x{})") .add({{"x", "1"}}, 3))); TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x00, - 0x00, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00 }, + 0x00, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00 }, TensorSpec("tensor(x{},y{})") .add({{"x", ""}, {"y", ""}}, 3))); TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x01, @@ -85,23 +85,32 @@ TEST("test tensor serialization for SparseTensor") { 0x00, 0x00 }, TensorSpec("tensor(x{},y{})") .add({{"x", "1"}, {"y", ""}}, 3))); - TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x00, - 0x01, 0x33, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00 }, + TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x00, + 0x01, 0x33, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00 }, TensorSpec("tensor(x{},y{})") .add({{"x", ""}, {"y", "3"}}, 3))); - TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x01, - 0x32, 0x01, 0x34, 0x40, 0x08, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00 }, + TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x01, + 0x32, 0x01, 0x34, 0x40, 0x08, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00 }, TensorSpec("tensor(x{},y{})") .add({{"x", "2"}, {"y", "4"}}, 3))); - TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, - 0x01, 0x01, 0x31, 0x00, 0x40, 0x08, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, + TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, + 0x01, 0x01, 0x31, 0x00, 0x40, 0x08, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }, TensorSpec("tensor(x{},y{})") .add({{"x", "1"}, {"y", ""}}, 3))); } +TEST("test float cells from sparse tensor") { + TEST_DO(verify_serialized({ 0x05, 0x01, + 0x02, 0x01, 0x78, 0x01, 0x79, + 0x01, 0x01, 0x31, 0x00, + 0x40, 0x40, 0x00, 0x00 }, + TensorSpec("tensor<float>(x{},y{})") + .add({{"x", "1"}, {"y", ""}}, 3))); +} + TEST("test tensor serialization for DenseTensor") { TEST_DO(verify_serialized({0x02, 0x00, 0x00, 0x00, 0x00, 0x00, @@ -187,7 +196,7 @@ TEST("test tensor serialization for DenseTensor") { .add({{"x", 2}, {"y", 4}}, 3))); } -TEST("test 'float' cells") { +TEST("test float cells for dense tensor") { TEST_DO(verify_serialized({0x06, 0x01, 0x02, 0x01, 0x78, 0x03, 0x01, 0x79, 0x05, 0x00, 0x00, 0x00, 0x00, diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 6210295ebd4..2206cde49a9 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -170,8 +170,11 @@ Value::UP DefaultTensorEngine::from_spec(const TensorSpec &spec) const { ValueType type = ValueType::from_spec(spec.type()); - if (!tensor::Tensor::supported({type})) { - return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::create(spec)); + if (type.is_error()) { + return std::make_unique<ErrorValue>(); + } else if (type.is_double()) { + double value = spec.cells().empty() ? 0.0 : spec.cells().begin()->second.value; + return std::make_unique<DoubleValue>(value); } else if (type.is_dense()) { DirectDenseTensorBuilder builder(type); for (const auto &cell: spec.cells()) { @@ -195,13 +198,8 @@ DefaultTensorEngine::from_spec(const TensorSpec &spec) const } } return builder.build(); - } else if (type.is_double()) { - double value = spec.cells().empty() ? 0.0 : spec.cells().begin()->second.value; - return std::make_unique<DoubleValue>(value); - } else { - assert(type.is_error()); - return std::make_unique<ErrorValue>(); } + return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::create(spec)); } struct CellFunctionFunAdapter : tensor::CellFunction { diff --git a/eval/src/vespa/eval/tensor/serialization/common.h b/eval/src/vespa/eval/tensor/serialization/common.h deleted file mode 100644 index 40b1840be6e..00000000000 --- a/eval/src/vespa/eval/tensor/serialization/common.h +++ /dev/null @@ -1,9 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#pragma once - -namespace vespalib::tensor { - -enum class SerializeFormat {FLOAT, DOUBLE}; - -} diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp index 4b1ccc8db5d..677fb40b0f4 100644 --- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp @@ -7,6 +7,8 @@ #include <cassert> using vespalib::nbostream; +using vespalib::eval::ValueType; +using CellType = vespalib::eval::ValueType::CellType; namespace vespalib::tensor { @@ -15,15 +17,7 @@ using Dimension = eval::ValueType::Dimension; namespace { -eval::ValueType -makeValueType(std::vector<Dimension> &&dimensions) { - return (dimensions.empty() ? - eval::ValueType::double_type() : - eval::ValueType::tensor_type(std::move(dimensions))); -} - -size_t -encodeDimensions(nbostream &stream, const eval::ValueType & type) { +size_t encodeDimensions(nbostream &stream, const eval::ValueType & type) { stream.putInt1_4Bytes(type.dimensions().size()); size_t cellsSize = 1; for (const auto &dimension : type.dimensions()) { @@ -35,15 +29,13 @@ encodeDimensions(nbostream &stream, const eval::ValueType & type) { } template<typename T> -void -encodeCells(nbostream &stream, DenseTensorView::CellsRef cells) { +void encodeCells(nbostream &stream, DenseTensorView::CellsRef cells) { for (const auto &value : cells) { stream << static_cast<T>(value); } } -size_t -decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) { +size_t decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) { vespalib::string dimensionName; size_t dimensionsSize = stream.getInt1_4Bytes(); size_t dimensionSize; @@ -58,8 +50,7 @@ decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) { } template<typename T, typename V> -void -decodeCells(nbostream &stream, size_t cellsSize, V & cells) { +void decodeCells(nbostream &stream, size_t cellsSize, V &cells) { T cellValue = 0.0; for (size_t i = 0; i < cellsSize; ++i) { stream >> cellValue; @@ -68,13 +59,12 @@ decodeCells(nbostream &stream, size_t cellsSize, V & cells) { } template <typename V> -void decodeCells(SerializeFormat format, nbostream &stream, size_t cellsSize, V & cells) -{ - switch (format) { - case SerializeFormat::DOUBLE: +void decodeCells(CellType cell_type, nbostream &stream, size_t cellsSize, V &cells) { + switch (cell_type) { + case CellType::DOUBLE: decodeCells<double>(stream, cellsSize, cells); break; - case SerializeFormat::FLOAT: + case CellType::FLOAT: decodeCells<float>(stream, cellsSize, cells); break; } @@ -86,44 +76,41 @@ void DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor) { size_t cellsSize = encodeDimensions(stream, tensor.fast_type()); - DenseTensorView::CellsRef cells = tensor.cellsRef(); assert(cells.size() == cellsSize); - switch (_format) { - case SerializeFormat::DOUBLE: - encodeCells<double>(stream, cells); - break; - case SerializeFormat::FLOAT: - encodeCells<float>(stream, cells); - break; + switch (tensor.fast_type().cell_type()) { + case CellType::DOUBLE: + encodeCells<double>(stream, cells); + break; + case CellType::FLOAT: + encodeCells<float>(stream, cells); + break; } } std::unique_ptr<DenseTensor> -DenseBinaryFormat::deserialize(nbostream &stream) +DenseBinaryFormat::deserialize(nbostream &stream, CellType cell_type) { std::vector<Dimension> dimensions; size_t cellsSize = decodeDimensions(stream,dimensions); DenseTensor::Cells cells; cells.reserve(cellsSize); - - decodeCells(_format, stream, cellsSize, cells); - - return std::make_unique<DenseTensor>(makeValueType(std::move(dimensions)), std::move(cells)); + decodeCells(cell_type, stream, cellsSize, cells); + return std::make_unique<DenseTensor>(ValueType::tensor_type(std::move(dimensions), cell_type), std::move(cells)); } template <typename T> void -DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> & cells) +DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> &cells, CellType cell_type) { std::vector<Dimension> dimensions; size_t cellsSize = decodeDimensions(stream,dimensions); cells.clear(); cells.reserve(cellsSize); - decodeCells(_format, stream, cellsSize, cells); + decodeCells(cell_type, stream, cellsSize, cells); } -template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> & cells); -template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> & cells); +template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> &cells, CellType cell_type); +template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> &cells, CellType cell_type); } diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h index f9847d37784..9e860b3c1e4 100644 --- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h @@ -2,9 +2,9 @@ #pragma once -#include "common.h" #include <memory> #include <vector> +#include <vespa/eval/eval/value_type.h> namespace vespalib { class nbostream; } @@ -19,15 +19,14 @@ class DenseTensorView; class DenseBinaryFormat { public: - DenseBinaryFormat(SerializeFormat format) : _format(format) { } - void serialize(nbostream &stream, const DenseTensorView &tensor); - std::unique_ptr<DenseTensor> deserialize(nbostream &stream); - + using CellType = eval::ValueType::CellType; + + static void serialize(nbostream &stream, const DenseTensorView &tensor); + static std::unique_ptr<DenseTensor> deserialize(nbostream &stream, CellType cell_type); + // This is a temporary method untill we get full support for typed tensors template <typename T> - void deserializeCellsOnly(nbostream &stream, std::vector<T> & cells); -private: - SerializeFormat _format; + static void deserializeCellsOnly(nbostream &stream, std::vector<T> &cells, CellType cell_type); }; } diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp index cca310176f4..06e3f63c8da 100644 --- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp @@ -13,6 +13,7 @@ using vespalib::nbostream; using vespalib::eval::ValueType; +using CellType = vespalib::eval::ValueType::CellType; namespace vespalib::tensor { @@ -20,10 +21,9 @@ namespace { vespalib::string undefinedLabel(""); -void -writeTensorAddress(nbostream &output, - const eval::ValueType &type, - const TensorAddress &value) +void writeTensorAddress(nbostream &output, + const eval::ValueType &type, + const TensorAddress &value) { auto elemItr = value.elements().cbegin(); auto elemItrEnd = value.elements().cend(); @@ -38,78 +38,71 @@ writeTensorAddress(nbostream &output, assert(elemItr == elemItrEnd); } -} - +template <typename T> class SparseBinaryFormatSerializer : public TensorVisitor { - uint32_t _numCells; - nbostream _cells; - eval::ValueType _type; - +private: + uint32_t _num_cells; + nbostream &_cells; + const ValueType &_type; public: - SparseBinaryFormatSerializer(); + SparseBinaryFormatSerializer(nbostream &cells, const ValueType &type); + size_t num_cells() const { return _num_cells; } virtual ~SparseBinaryFormatSerializer() override; virtual void visit(const TensorAddress &address, double value) override; - void serialize(nbostream &stream, const Tensor &tensor); }; -SparseBinaryFormatSerializer::SparseBinaryFormatSerializer() - : _numCells(0u), - _cells(), - _type(eval::ValueType::error_type()) +template <typename T> +SparseBinaryFormatSerializer<T>::SparseBinaryFormatSerializer(nbostream &cells, const ValueType &type) + : _num_cells(0), + _cells(cells), + _type(type) { } +template <typename T> +SparseBinaryFormatSerializer<T>::~SparseBinaryFormatSerializer() = default; -SparseBinaryFormatSerializer::~SparseBinaryFormatSerializer() = default; - +template <typename T> void -SparseBinaryFormatSerializer::visit(const TensorAddress &address, double value) +SparseBinaryFormatSerializer<T>::visit(const TensorAddress &address, double value) { - ++_numCells; + ++_num_cells; writeTensorAddress(_cells, _type, address); - _cells << value; + _cells << static_cast<T>(value); } - -void -SparseBinaryFormatSerializer::serialize(nbostream &stream, const Tensor &tensor) -{ - _type = tensor.type(); - tensor.accept(*this); - stream.putInt1_4Bytes(_type.dimensions().size()); - for (const auto &dimension : _type.dimensions()) { +void encodeDimensions(nbostream &stream, const eval::ValueType &type) { + stream.putInt1_4Bytes(type.dimensions().size()); + for (const auto &dimension : type.dimensions()) { stream.writeSmallString(dimension.name); } - stream.putInt1_4Bytes(_numCells); - stream.write(_cells.peek(), _cells.size()); } - -void -SparseBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) -{ - SparseBinaryFormatSerializer serializer; - serializer.serialize(stream, tensor); +template <typename T> +size_t encodeCells(nbostream &stream, const Tensor &tensor) { + SparseBinaryFormatSerializer<T> serializer(stream, tensor.type()); + tensor.accept(serializer); + return serializer.num_cells(); } +size_t encodeCells(nbostream &stream, const Tensor &tensor, CellType cell_type) { + switch (cell_type) { + case CellType::DOUBLE: + return encodeCells<double>(stream, tensor); + break; + case CellType::FLOAT: + return encodeCells<float>(stream, tensor); + break; + } + return 0; +} -std::unique_ptr<Tensor> -SparseBinaryFormat::deserialize(nbostream &stream) -{ +template<typename T> +void decodeCells(nbostream &stream, size_t dimensionsSize, size_t cellsSize, DirectSparseTensorBuilder &builder) { + T cellValue = 0.0; vespalib::string str; - size_t dimensionsSize = stream.getInt1_4Bytes(); - std::vector<ValueType::Dimension> dimensions; - while (dimensions.size() < dimensionsSize) { - stream.readSmallString(str); - dimensions.emplace_back(str); - } - ValueType type = ValueType::tensor_type(std::move(dimensions)); - DirectSparseTensorBuilder builder(type); SparseTensorAddressBuilder address; - - size_t cellsSize = stream.getInt1_4Bytes(); - double cellValue = 0.0; for (size_t cellIdx = 0; cellIdx < cellsSize; ++cellIdx) { address.clear(); for (size_t dimension = 0; dimension < dimensionsSize; ++dimension) { @@ -121,10 +114,49 @@ SparseBinaryFormat::deserialize(nbostream &stream) } } stream >> cellValue; - builder.insertCell(address, cellValue); + builder.insertCell(address, cellValue, [](double, double v){ return v; }); } - return builder.build(); } +void decodeCells(CellType cell_type, nbostream &stream, size_t dimensionsSize, size_t cellsSize, DirectSparseTensorBuilder &builder) { + switch (cell_type) { + case CellType::DOUBLE: + decodeCells<double>(stream, dimensionsSize, cellsSize, builder); + break; + case CellType::FLOAT: + decodeCells<float>(stream, dimensionsSize, cellsSize, builder); + break; + } +} + +} + +void +SparseBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) +{ + const auto &type = tensor.type(); + encodeDimensions(stream, type); + nbostream cells; + size_t numCells = encodeCells(cells, tensor, type.cell_type()); + stream.putInt1_4Bytes(numCells); + stream.write(cells.peek(), cells.size()); +} + +std::unique_ptr<Tensor> +SparseBinaryFormat::deserialize(nbostream &stream, CellType cell_type) +{ + vespalib::string str; + size_t dimensionsSize = stream.getInt1_4Bytes(); + std::vector<ValueType::Dimension> dimensions; + while (dimensions.size() < dimensionsSize) { + stream.readSmallString(str); + dimensions.emplace_back(str); + } + ValueType type = ValueType::tensor_type(std::move(dimensions), cell_type); + DirectSparseTensorBuilder builder(type); + size_t cellsSize = stream.getInt1_4Bytes(); + decodeCells(cell_type, stream, dimensionsSize, cellsSize, builder); + return builder.build(); +} } diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h index cd68e7eeda4..0611d7d5a23 100644 --- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h @@ -3,6 +3,7 @@ #pragma once #include <memory> +#include <vespa/eval/eval/value_type.h> namespace vespalib { class nbostream; } @@ -11,13 +12,15 @@ namespace vespalib::tensor { class Tensor; /** - * Class for serializing a tensor. + * Class for serializing a sparse tensor. */ class SparseBinaryFormat { public: + using CellType = eval::ValueType::CellType; + static void serialize(nbostream &stream, const Tensor &tensor); - static std::unique_ptr<Tensor> deserialize(nbostream &stream); + static std::unique_ptr<Tensor> deserialize(nbostream &stream, CellType cell_type); }; } diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp index 23179d4b908..8d9767374a2 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp @@ -16,6 +16,8 @@ LOG_SETUP(".eval.tensor.serialization.typed_binary_format"); using vespalib::nbostream; +using vespalib::eval::ValueType; +using CellType = vespalib::eval::ValueType::CellType; namespace vespalib::tensor { @@ -31,47 +33,52 @@ constexpr uint32_t MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7u; constexpr uint32_t DOUBLE_VALUE_TYPE = 0; constexpr uint32_t FLOAT_VALUE_TYPE = 1; -uint32_t -format2Encoding(SerializeFormat format) { - switch (format) { - case SerializeFormat::DOUBLE: - return DOUBLE_VALUE_TYPE; - case SerializeFormat::FLOAT: - return FLOAT_VALUE_TYPE; +uint32_t cell_type_to_encoding(CellType cell_type) { + switch (cell_type) { + case CellType::DOUBLE: + return DOUBLE_VALUE_TYPE; + case CellType::FLOAT: + return FLOAT_VALUE_TYPE; } abort(); } -SerializeFormat -encoding2Format(uint32_t serializedType) { - switch (serializedType) { - case DOUBLE_VALUE_TYPE: - return SerializeFormat::DOUBLE; - case FLOAT_VALUE_TYPE: - return SerializeFormat::FLOAT; - default: - throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", serializedType)); +CellType +encoding_to_cell_type(uint32_t cell_encoding) { + switch (cell_encoding) { + case DOUBLE_VALUE_TYPE: + return CellType::DOUBLE; + case FLOAT_VALUE_TYPE: + return CellType::FLOAT; + default: + throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", cell_encoding)); } } } void -TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format) +TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) { + auto cell_type = tensor.type().cell_type(); + bool default_cell_type = (cell_type == CellType::DOUBLE); if (auto denseTensor = dynamic_cast<const DenseTensorView *>(&tensor)) { - if (format != SerializeFormat::DOUBLE) { - stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE); - stream.putInt1_4Bytes(format2Encoding(format)); - DenseBinaryFormat(format).serialize(stream, *denseTensor); - } else { + if (default_cell_type) { stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE); - DenseBinaryFormat(SerializeFormat::DOUBLE).serialize(stream, *denseTensor); + } else { + stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE); + stream.putInt1_4Bytes(cell_type_to_encoding(cell_type)); } + DenseBinaryFormat::serialize(stream, *denseTensor); } else if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(&tensor)) { eval::SimpleTensor::encode(wrapped->get(), stream); } else { - stream.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); + if (default_cell_type) { + stream.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); + } else { + stream.putInt1_4Bytes(SPARSE_BINARY_FORMAT_WITH_CELLTYPE); + stream.putInt1_4Bytes(cell_type_to_encoding(cell_type)); + } SparseBinaryFormat::serialize(stream, tensor); } } @@ -80,40 +87,46 @@ TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor, SerializeF std::unique_ptr<Tensor> TypedBinaryFormat::deserialize(nbostream &stream) { + auto cell_type = CellType::DOUBLE; auto read_pos = stream.rp(); auto formatId = stream.getInt1_4Bytes(); - if (formatId == SPARSE_BINARY_FORMAT_TYPE) { - return SparseBinaryFormat::deserialize(stream); - } - if (formatId == DENSE_BINARY_FORMAT_TYPE) { - return DenseBinaryFormat(SerializeFormat::DOUBLE).deserialize(stream); - } - if ((formatId == SPARSE_BINARY_FORMAT_WITH_CELLTYPE) || - (formatId == DENSE_BINARY_FORMAT_WITH_CELLTYPE) || - (formatId == MIXED_BINARY_FORMAT_TYPE) || - (formatId == MIXED_BINARY_FORMAT_WITH_CELLTYPE)) - { + switch (formatId) { + case SPARSE_BINARY_FORMAT_WITH_CELLTYPE: + cell_type = encoding_to_cell_type(stream.getInt1_4Bytes()); + [[fallthrough]]; + case SPARSE_BINARY_FORMAT_TYPE: + return SparseBinaryFormat::deserialize(stream, cell_type); + case DENSE_BINARY_FORMAT_WITH_CELLTYPE: + cell_type = encoding_to_cell_type(stream.getInt1_4Bytes()); + [[fallthrough]]; + case DENSE_BINARY_FORMAT_TYPE: + return DenseBinaryFormat::deserialize(stream, cell_type); + case MIXED_BINARY_FORMAT_TYPE: + case MIXED_BINARY_FORMAT_WITH_CELLTYPE: stream.adjustReadPos(read_pos - stream.rp()); return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream)); + default: + throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId)); } - throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId)); } template <typename T> void -TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells) +TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> &cells) { + auto cell_type = CellType::DOUBLE; auto formatId = stream.getInt1_4Bytes(); - if (formatId == DENSE_BINARY_FORMAT_TYPE) { - return DenseBinaryFormat(SerializeFormat::DOUBLE).deserializeCellsOnly(stream, cells); - } - if (formatId == DENSE_BINARY_FORMAT_WITH_CELLTYPE) { - return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserializeCellsOnly(stream, cells); + switch (formatId) { + case DENSE_BINARY_FORMAT_WITH_CELLTYPE: + cell_type = encoding_to_cell_type(stream.getInt1_4Bytes()); + [[fallthrough]]; + case DENSE_BINARY_FORMAT_TYPE: + return DenseBinaryFormat::deserializeCellsOnly(stream, cells, cell_type); } abort(); } -template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<double> & cells); -template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<float> & cells); +template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<double> &cells); +template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<float> &cells); } diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h index 717d51effef..198b09ae336 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h @@ -2,7 +2,6 @@ #pragma once -#include "common.h" #include <memory> #include <vector> @@ -18,16 +17,12 @@ class Tensor; class TypedBinaryFormat { public: - static void serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format); - static void serialize(nbostream &stream, const Tensor &tensor) { - serialize(stream, tensor, SerializeFormat::DOUBLE); - } - + static void serialize(nbostream &stream, const Tensor &tensor); static std::unique_ptr<Tensor> deserialize(nbostream &stream); - + // This is a temporary method until we get full support for typed tensors template <typename T> - static void deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells); + static void deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> &cells); }; } |