From ce4339ab5224bac5df5fb721cb7d2668ae75c811 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Wed, 3 Apr 2019 16:33:28 +0000 Subject: Keep the serialzation of the cells with the serialization for now. Clean up code for better understanding and reuse. --- .../tensor_serialization_test.cpp | 11 ++--- .../vespa/eval/tensor/dense/dense_tensor_view.h | 12 ++--- eval/src/vespa/eval/tensor/serialization/common.h | 9 ++++ .../tensor/serialization/dense_binary_format.cpp | 56 ++++------------------ .../tensor/serialization/dense_binary_format.h | 7 +-- .../tensor/serialization/typed_binary_format.cpp | 53 ++++++++++++++++---- .../tensor/serialization/typed_binary_format.h | 13 +++-- 7 files changed, 79 insertions(+), 82 deletions(-) create mode 100644 eval/src/vespa/eval/tensor/serialization/common.h (limited to 'eval/src') diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp index d32fecc5cba..0237f6cc769 100644 --- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp +++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp @@ -16,7 +16,6 @@ using namespace vespalib::tensor; using vespalib::nbostream; using ExpBuffer = std::vector; -using SerializeFormat = vespalib::tensor::DenseTensorView::SerializeFormat; namespace std { @@ -145,10 +144,8 @@ TEST_F("test tensor serialization for SparseTensor", SparseFixture) struct DenseFixture { - Tensor::UP createTensor(SerializeFormat format, const DenseTensorCells &cells) { - auto tensor = TensorFactory::createDense(cells); - dynamic_cast(*tensor).serializeAs(format); - return tensor; + Tensor::UP createTensor(const DenseTensorCells &cells) { + return TensorFactory::createDense(cells); } void serialize(nbostream &stream, const Tensor &tensor) { @@ -166,9 +163,9 @@ struct DenseFixture assertSerialized(exp, SerializeFormat::DOUBLE, rhs); } void assertSerialized(const ExpBuffer &exp, SerializeFormat cellType, const DenseTensorCells &rhs) { - Tensor::UP rhsTensor(createTensor(cellType, rhs)); + Tensor::UP rhsTensor(createTensor(rhs)); nbostream rhsStream; - serialize(rhsStream, *rhsTensor); + TypedBinaryFormat::serialize(rhsStream, *rhsTensor, cellType); EXPECT_EQUAL(exp, rhsStream); auto rhs2 = deserialize(rhsStream); EXPECT_EQUAL(*rhs2, *rhsTensor); diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index 19a8a66bcf7..09b6b72375e 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -14,7 +14,6 @@ namespace vespalib::tensor { class DenseTensorView : public Tensor { public: - enum class SerializeFormat {FLOAT, DOUBLE}; using Cells = std::vector; using CellsRef = ConstArrayRef; using CellsIterator = DenseTensorCellsIterator; @@ -22,16 +21,13 @@ public: DenseTensorView(const eval::ValueType &type_in, CellsRef cells_in) : _typeRef(type_in), - _cellsRef(cells_in), - _serializeFormat(SerializeFormat::DOUBLE) + _cellsRef(cells_in) {} explicit DenseTensorView(const eval::ValueType &type_in) : _typeRef(type_in), - _cellsRef(), - _serializeFormat(SerializeFormat::DOUBLE) + _cellsRef() {} - SerializeFormat serializeAs() const { return _serializeFormat; } - void serializeAs(SerializeFormat format) { _serializeFormat = format; } + const eval::ValueType &fast_type() const { return _typeRef; } const CellsRef &cellsRef() const { return _cellsRef; } bool operator==(const DenseTensorView &rhs) const; @@ -58,8 +54,6 @@ private: const eval::ValueType &_typeRef; CellsRef _cellsRef; - //TODO This is a temporary workaround until proper type support for tensors is in place. - SerializeFormat _serializeFormat; }; } diff --git a/eval/src/vespa/eval/tensor/serialization/common.h b/eval/src/vespa/eval/tensor/serialization/common.h new file mode 100644 index 00000000000..9c45bc42136 --- /dev/null +++ b/eval/src/vespa/eval/tensor/serialization/common.h @@ -0,0 +1,9 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +namespace vespalib::tensor { + +enum class SerializeFormat {FLOAT, DOUBLE}; + +} diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp index 2a939963e16..6043153adc3 100644 --- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp @@ -10,17 +10,11 @@ using vespalib::nbostream; namespace vespalib::tensor { -using SerializationFormat = DenseTensorView::SerializeFormat; using Dimension = eval::ValueType::Dimension; namespace { -using EncodeType = DenseBinaryFormat::EncodeType; - -constexpr int DOUBLE_VALUE_TYPE = 0; -constexpr int FLOAT_VALUE_TYPE = 1; - eval::ValueType makeValueType(std::vector &&dimensions) { return (dimensions.empty() ? @@ -48,20 +42,6 @@ encodeCells(nbostream &stream, DenseTensorView::CellsRef cells) { } } -void -encodeValueType(nbostream & stream, SerializationFormat valueType, EncodeType encodeType) { - switch (valueType) { - case SerializationFormat::DOUBLE: - if (encodeType != EncodeType::DOUBLE_IS_DEFAULT) { - stream.putInt1_4Bytes(DOUBLE_VALUE_TYPE); - } - break; - case SerializationFormat::FLOAT: - stream.putInt1_4Bytes(FLOAT_VALUE_TYPE); - break; - } -} - size_t decodeDimensions(nbostream & stream, std::vector & dimensions) { vespalib::string dimensionName; @@ -77,23 +57,6 @@ decodeDimensions(nbostream & stream, std::vector & dimensions) { return cellsSize; } -SerializationFormat -decodeCellType(nbostream & stream, EncodeType encodeType) { - if (encodeType != EncodeType::DOUBLE_IS_DEFAULT) { - uint32_t serializedType = stream.getInt1_4Bytes(); - switch (serializedType) { - case DOUBLE_VALUE_TYPE: - return SerializationFormat::DOUBLE; - case FLOAT_VALUE_TYPE: - return SerializationFormat::FLOAT; - default: - throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", serializedType)); - } - } else { - return SerializationFormat::DOUBLE; - } -} - template void decodeCells(nbostream &stream, size_t cellsSize, DenseTensor::Cells & cells) { @@ -109,17 +72,15 @@ decodeCells(nbostream &stream, size_t cellsSize, DenseTensor::Cells & cells) { void DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor) { - const eval::ValueType & type = tensor.fast_type(); - encodeValueType(stream, tensor.serializeAs(), _encodeType); - size_t cellsSize = encodeDimensions(stream, type); + size_t cellsSize = encodeDimensions(stream, tensor.fast_type()); DenseTensorView::CellsRef cells = tensor.cellsRef(); assert(cells.size() == cellsSize); - switch (tensor.serializeAs()) { - case SerializationFormat::DOUBLE: + switch (_format) { + case SerializeFormat::DOUBLE: encodeCells(stream, cells); break; - case SerializationFormat::FLOAT: + case SerializeFormat::FLOAT: encodeCells(stream, cells); break; } @@ -128,17 +89,16 @@ DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor) std::unique_ptr DenseBinaryFormat::deserialize(nbostream &stream) { - SerializationFormat cellType = decodeCellType(stream, _encodeType); std::vector dimensions; size_t cellsSize = decodeDimensions(stream,dimensions); DenseTensor::Cells cells; cells.reserve(cellsSize); - switch (cellType) { - case SerializationFormat::DOUBLE: + switch (_format) { + case SerializeFormat::DOUBLE: decodeCells(stream, cellsSize,cells); break; - case SerializationFormat::FLOAT: - decodeCells(stream, cellsSize,cells); + case SerializeFormat::FLOAT: + decodeCells(stream, cellsSize, cells); break; } diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h index 3f9236ce222..22c1663719e 100644 --- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h @@ -2,7 +2,9 @@ #pragma once +#include "common.h" #include + namespace vespalib { class nbostream; } namespace vespalib::tensor { @@ -16,12 +18,11 @@ class DenseTensorView; class DenseBinaryFormat { public: - enum class EncodeType { NO_DEFAULT, DOUBLE_IS_DEFAULT}; - DenseBinaryFormat(EncodeType encodeType) : _encodeType(encodeType) { } + DenseBinaryFormat(SerializeFormat format) : _format(format) { } void serialize(nbostream &stream, const DenseTensorView &tensor); std::unique_ptr deserialize(nbostream &stream); private: - EncodeType _encodeType; + SerializeFormat _format; }; } diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp index e98b106d764..d1aa09b6ce3 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp @@ -11,23 +11,61 @@ #include #include +#include +#include + LOG_SETUP(".eval.tensor.serialization.typed_binary_format"); using vespalib::nbostream; namespace vespalib::tensor { +namespace { + +constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u; +constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u; +constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u; +constexpr uint32_t TYPED_DENSE_BINARY_FORMAT_TYPE = 4u; + +constexpr uint32_t DOUBLE_VALUE_TYPE = 0; +constexpr uint32_t FLOAT_VALUE_TYPE = 1; + +uint32_t +format2Encoding(SerializeFormat format) { + switch (format) { + case SerializeFormat::DOUBLE: + return DOUBLE_VALUE_TYPE; + case SerializeFormat::FLOAT: + return FLOAT_VALUE_TYPE; + } + abort(); +} + +SerializeFormat +encoding2Format(uint32_t serializedType) { + switch (serializedType) { + case DOUBLE_VALUE_TYPE: + return SerializeFormat::DOUBLE; + case FLOAT_VALUE_TYPE: + return SerializeFormat::FLOAT; + default: + throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", serializedType)); + } +} + +} void -TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) +TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format) { if (auto denseTensor = dynamic_cast(&tensor)) { - if (denseTensor->serializeAs() != DenseTensorView::SerializeFormat::DOUBLE) { + if (format != SerializeFormat::DOUBLE) { stream.putInt1_4Bytes(TYPED_DENSE_BINARY_FORMAT_TYPE); - DenseBinaryFormat(DenseBinaryFormat::EncodeType::NO_DEFAULT).serialize(stream, *denseTensor); + stream.putInt1_4Bytes(format2Encoding(format)); + DenseBinaryFormat(format).serialize(stream, *denseTensor); } else { stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE); - DenseBinaryFormat(DenseBinaryFormat::EncodeType::DOUBLE_IS_DEFAULT).serialize(stream, *denseTensor); + DenseBinaryFormat(SerializeFormat::DOUBLE).serialize(stream, *denseTensor); } } else if (auto wrapped = dynamic_cast(&tensor)) { eval::SimpleTensor::encode(wrapped->get(), stream); @@ -49,17 +87,16 @@ TypedBinaryFormat::deserialize(nbostream &stream) return builder.build(); } if (formatId == DENSE_BINARY_FORMAT_TYPE) { - return DenseBinaryFormat(DenseBinaryFormat::EncodeType::DOUBLE_IS_DEFAULT).deserialize(stream); + return DenseBinaryFormat(SerializeFormat::DOUBLE).deserialize(stream); } if (formatId == TYPED_DENSE_BINARY_FORMAT_TYPE) { - return DenseBinaryFormat(DenseBinaryFormat::EncodeType::NO_DEFAULT).deserialize(stream); + return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserialize(stream); } if (formatId == MIXED_BINARY_FORMAT_TYPE) { stream.adjustReadPos(read_pos - stream.rp()); return std::make_unique(eval::SimpleTensor::decode(stream)); } - LOG_ABORT("should not be reached"); + abort(); } - } diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h index e40ac7bda43..95d9a75488c 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h @@ -2,27 +2,26 @@ #pragma once +#include "common.h" #include -#include namespace vespalib { class nbostream; } namespace vespalib::tensor { class Tensor; -class TensorBuilder; /** * Class for serializing a tensor. */ class TypedBinaryFormat { - static constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u; - static constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u; - static constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u; - static constexpr uint32_t TYPED_DENSE_BINARY_FORMAT_TYPE = 4u; public: - static void serialize(nbostream &stream, const Tensor &tensor); + static void serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format); + static void serialize(nbostream &stream, const Tensor &tensor) { + serialize(stream, tensor, SerializeFormat::DOUBLE); + } + static std::unique_ptr deserialize(nbostream &stream); }; -- cgit v1.2.3