aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2019-06-12 13:39:50 +0200
committerGitHub <noreply@github.com>2019-06-12 13:39:50 +0200
commit2b9aa7f40171f918fe1a906de8334d637aa8810b (patch)
treec4a030494b4ea4b30ce45982a664cc4f49c80c3d /eval
parenta7fe82abe533cf8633af17a0e914fdc89d6db231 (diff)
parent91f3fce64d2d1e97449b55400eee12bd92b650b6 (diff)
Merge pull request #9729 from vespa-engine/havardpe/float-cell-serialization
serialize float cells
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp33
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp14
-rw-r--r--eval/src/vespa/eval/tensor/serialization/common.h9
-rw-r--r--eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp61
-rw-r--r--eval/src/vespa/eval/tensor/serialization/dense_binary_format.h15
-rw-r--r--eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp138
-rw-r--r--eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h7
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp101
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.h11
9 files changed, 208 insertions, 181 deletions
diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
index b43a127bc60..d1491e4f758 100644
--- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
+++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
@@ -76,8 +76,8 @@ TEST("test tensor serialization for SparseTensor") {
TensorSpec("tensor(x{})")
.add({{"x", "1"}}, 3)));
TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x00,
- 0x00, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
- 0x00 },
+ 0x00, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00,
+ 0x00 },
TensorSpec("tensor(x{},y{})")
.add({{"x", ""}, {"y", ""}}, 3)));
TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x01,
@@ -85,23 +85,32 @@ TEST("test tensor serialization for SparseTensor") {
0x00, 0x00 },
TensorSpec("tensor(x{},y{})")
.add({{"x", "1"}, {"y", ""}}, 3)));
- TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x00,
- 0x01, 0x33, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00,
- 0x00, 0x00 },
+ TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x00,
+ 0x01, 0x33, 0x40, 0x08, 0x00, 0x00, 0x00, 0x00,
+ 0x00, 0x00 },
TensorSpec("tensor(x{},y{})")
.add({{"x", ""}, {"y", "3"}}, 3)));
- TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x01,
- 0x32, 0x01, 0x34, 0x40, 0x08, 0x00, 0x00, 0x00,
- 0x00, 0x00, 0x00 },
+ TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79, 0x01, 0x01,
+ 0x32, 0x01, 0x34, 0x40, 0x08, 0x00, 0x00, 0x00,
+ 0x00, 0x00, 0x00 },
TensorSpec("tensor(x{},y{})")
.add({{"x", "2"}, {"y", "4"}}, 3)));
- TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79,
- 0x01, 0x01, 0x31, 0x00, 0x40, 0x08,
- 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
+ TEST_DO(verify_serialized({ 0x01, 0x02, 0x01, 0x78, 0x01, 0x79,
+ 0x01, 0x01, 0x31, 0x00, 0x40, 0x08,
+ 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 },
TensorSpec("tensor(x{},y{})")
.add({{"x", "1"}, {"y", ""}}, 3)));
}
+TEST("test float cells from sparse tensor") {
+ TEST_DO(verify_serialized({ 0x05, 0x01,
+ 0x02, 0x01, 0x78, 0x01, 0x79,
+ 0x01, 0x01, 0x31, 0x00,
+ 0x40, 0x40, 0x00, 0x00 },
+ TensorSpec("tensor<float>(x{},y{})")
+ .add({{"x", "1"}, {"y", ""}}, 3)));
+}
+
TEST("test tensor serialization for DenseTensor") {
TEST_DO(verify_serialized({0x02, 0x00,
0x00, 0x00, 0x00, 0x00,
@@ -187,7 +196,7 @@ TEST("test tensor serialization for DenseTensor") {
.add({{"x", 2}, {"y", 4}}, 3)));
}
-TEST("test 'float' cells") {
+TEST("test float cells for dense tensor") {
TEST_DO(verify_serialized({0x06, 0x01, 0x02, 0x01, 0x78, 0x03,
0x01, 0x79, 0x05,
0x00, 0x00, 0x00, 0x00,
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 6210295ebd4..2206cde49a9 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -170,8 +170,11 @@ Value::UP
DefaultTensorEngine::from_spec(const TensorSpec &spec) const
{
ValueType type = ValueType::from_spec(spec.type());
- if (!tensor::Tensor::supported({type})) {
- return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::create(spec));
+ if (type.is_error()) {
+ return std::make_unique<ErrorValue>();
+ } else if (type.is_double()) {
+ double value = spec.cells().empty() ? 0.0 : spec.cells().begin()->second.value;
+ return std::make_unique<DoubleValue>(value);
} else if (type.is_dense()) {
DirectDenseTensorBuilder builder(type);
for (const auto &cell: spec.cells()) {
@@ -195,13 +198,8 @@ DefaultTensorEngine::from_spec(const TensorSpec &spec) const
}
}
return builder.build();
- } else if (type.is_double()) {
- double value = spec.cells().empty() ? 0.0 : spec.cells().begin()->second.value;
- return std::make_unique<DoubleValue>(value);
- } else {
- assert(type.is_error());
- return std::make_unique<ErrorValue>();
}
+ return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::create(spec));
}
struct CellFunctionFunAdapter : tensor::CellFunction {
diff --git a/eval/src/vespa/eval/tensor/serialization/common.h b/eval/src/vespa/eval/tensor/serialization/common.h
deleted file mode 100644
index 40b1840be6e..00000000000
--- a/eval/src/vespa/eval/tensor/serialization/common.h
+++ /dev/null
@@ -1,9 +0,0 @@
-// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-#pragma once
-
-namespace vespalib::tensor {
-
-enum class SerializeFormat {FLOAT, DOUBLE};
-
-}
diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
index 4b1ccc8db5d..677fb40b0f4 100644
--- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
@@ -7,6 +7,8 @@
#include <cassert>
using vespalib::nbostream;
+using vespalib::eval::ValueType;
+using CellType = vespalib::eval::ValueType::CellType;
namespace vespalib::tensor {
@@ -15,15 +17,7 @@ using Dimension = eval::ValueType::Dimension;
namespace {
-eval::ValueType
-makeValueType(std::vector<Dimension> &&dimensions) {
- return (dimensions.empty() ?
- eval::ValueType::double_type() :
- eval::ValueType::tensor_type(std::move(dimensions)));
-}
-
-size_t
-encodeDimensions(nbostream &stream, const eval::ValueType & type) {
+size_t encodeDimensions(nbostream &stream, const eval::ValueType & type) {
stream.putInt1_4Bytes(type.dimensions().size());
size_t cellsSize = 1;
for (const auto &dimension : type.dimensions()) {
@@ -35,15 +29,13 @@ encodeDimensions(nbostream &stream, const eval::ValueType & type) {
}
template<typename T>
-void
-encodeCells(nbostream &stream, DenseTensorView::CellsRef cells) {
+void encodeCells(nbostream &stream, DenseTensorView::CellsRef cells) {
for (const auto &value : cells) {
stream << static_cast<T>(value);
}
}
-size_t
-decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) {
+size_t decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) {
vespalib::string dimensionName;
size_t dimensionsSize = stream.getInt1_4Bytes();
size_t dimensionSize;
@@ -58,8 +50,7 @@ decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) {
}
template<typename T, typename V>
-void
-decodeCells(nbostream &stream, size_t cellsSize, V & cells) {
+void decodeCells(nbostream &stream, size_t cellsSize, V &cells) {
T cellValue = 0.0;
for (size_t i = 0; i < cellsSize; ++i) {
stream >> cellValue;
@@ -68,13 +59,12 @@ decodeCells(nbostream &stream, size_t cellsSize, V & cells) {
}
template <typename V>
-void decodeCells(SerializeFormat format, nbostream &stream, size_t cellsSize, V & cells)
-{
- switch (format) {
- case SerializeFormat::DOUBLE:
+void decodeCells(CellType cell_type, nbostream &stream, size_t cellsSize, V &cells) {
+ switch (cell_type) {
+ case CellType::DOUBLE:
decodeCells<double>(stream, cellsSize, cells);
break;
- case SerializeFormat::FLOAT:
+ case CellType::FLOAT:
decodeCells<float>(stream, cellsSize, cells);
break;
}
@@ -86,44 +76,41 @@ void
DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor)
{
size_t cellsSize = encodeDimensions(stream, tensor.fast_type());
-
DenseTensorView::CellsRef cells = tensor.cellsRef();
assert(cells.size() == cellsSize);
- switch (_format) {
- case SerializeFormat::DOUBLE:
- encodeCells<double>(stream, cells);
- break;
- case SerializeFormat::FLOAT:
- encodeCells<float>(stream, cells);
- break;
+ switch (tensor.fast_type().cell_type()) {
+ case CellType::DOUBLE:
+ encodeCells<double>(stream, cells);
+ break;
+ case CellType::FLOAT:
+ encodeCells<float>(stream, cells);
+ break;
}
}
std::unique_ptr<DenseTensor>
-DenseBinaryFormat::deserialize(nbostream &stream)
+DenseBinaryFormat::deserialize(nbostream &stream, CellType cell_type)
{
std::vector<Dimension> dimensions;
size_t cellsSize = decodeDimensions(stream,dimensions);
DenseTensor::Cells cells;
cells.reserve(cellsSize);
-
- decodeCells(_format, stream, cellsSize, cells);
-
- return std::make_unique<DenseTensor>(makeValueType(std::move(dimensions)), std::move(cells));
+ decodeCells(cell_type, stream, cellsSize, cells);
+ return std::make_unique<DenseTensor>(ValueType::tensor_type(std::move(dimensions), cell_type), std::move(cells));
}
template <typename T>
void
-DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> & cells)
+DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> &cells, CellType cell_type)
{
std::vector<Dimension> dimensions;
size_t cellsSize = decodeDimensions(stream,dimensions);
cells.clear();
cells.reserve(cellsSize);
- decodeCells(_format, stream, cellsSize, cells);
+ decodeCells(cell_type, stream, cellsSize, cells);
}
-template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> & cells);
-template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> & cells);
+template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> &cells, CellType cell_type);
+template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> &cells, CellType cell_type);
}
diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
index f9847d37784..9e860b3c1e4 100644
--- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
@@ -2,9 +2,9 @@
#pragma once
-#include "common.h"
#include <memory>
#include <vector>
+#include <vespa/eval/eval/value_type.h>
namespace vespalib { class nbostream; }
@@ -19,15 +19,14 @@ class DenseTensorView;
class DenseBinaryFormat
{
public:
- DenseBinaryFormat(SerializeFormat format) : _format(format) { }
- void serialize(nbostream &stream, const DenseTensorView &tensor);
- std::unique_ptr<DenseTensor> deserialize(nbostream &stream);
-
+ using CellType = eval::ValueType::CellType;
+
+ static void serialize(nbostream &stream, const DenseTensorView &tensor);
+ static std::unique_ptr<DenseTensor> deserialize(nbostream &stream, CellType cell_type);
+
// This is a temporary method untill we get full support for typed tensors
template <typename T>
- void deserializeCellsOnly(nbostream &stream, std::vector<T> & cells);
-private:
- SerializeFormat _format;
+ static void deserializeCellsOnly(nbostream &stream, std::vector<T> &cells, CellType cell_type);
};
}
diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
index cca310176f4..06e3f63c8da 100644
--- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.cpp
@@ -13,6 +13,7 @@
using vespalib::nbostream;
using vespalib::eval::ValueType;
+using CellType = vespalib::eval::ValueType::CellType;
namespace vespalib::tensor {
@@ -20,10 +21,9 @@ namespace {
vespalib::string undefinedLabel("");
-void
-writeTensorAddress(nbostream &output,
- const eval::ValueType &type,
- const TensorAddress &value)
+void writeTensorAddress(nbostream &output,
+ const eval::ValueType &type,
+ const TensorAddress &value)
{
auto elemItr = value.elements().cbegin();
auto elemItrEnd = value.elements().cend();
@@ -38,78 +38,71 @@ writeTensorAddress(nbostream &output,
assert(elemItr == elemItrEnd);
}
-}
-
+template <typename T>
class SparseBinaryFormatSerializer : public TensorVisitor
{
- uint32_t _numCells;
- nbostream _cells;
- eval::ValueType _type;
-
+private:
+ uint32_t _num_cells;
+ nbostream &_cells;
+ const ValueType &_type;
public:
- SparseBinaryFormatSerializer();
+ SparseBinaryFormatSerializer(nbostream &cells, const ValueType &type);
+ size_t num_cells() const { return _num_cells; }
virtual ~SparseBinaryFormatSerializer() override;
virtual void visit(const TensorAddress &address, double value) override;
- void serialize(nbostream &stream, const Tensor &tensor);
};
-SparseBinaryFormatSerializer::SparseBinaryFormatSerializer()
- : _numCells(0u),
- _cells(),
- _type(eval::ValueType::error_type())
+template <typename T>
+SparseBinaryFormatSerializer<T>::SparseBinaryFormatSerializer(nbostream &cells, const ValueType &type)
+ : _num_cells(0),
+ _cells(cells),
+ _type(type)
{
}
+template <typename T>
+SparseBinaryFormatSerializer<T>::~SparseBinaryFormatSerializer() = default;
-SparseBinaryFormatSerializer::~SparseBinaryFormatSerializer() = default;
-
+template <typename T>
void
-SparseBinaryFormatSerializer::visit(const TensorAddress &address, double value)
+SparseBinaryFormatSerializer<T>::visit(const TensorAddress &address, double value)
{
- ++_numCells;
+ ++_num_cells;
writeTensorAddress(_cells, _type, address);
- _cells << value;
+ _cells << static_cast<T>(value);
}
-
-void
-SparseBinaryFormatSerializer::serialize(nbostream &stream, const Tensor &tensor)
-{
- _type = tensor.type();
- tensor.accept(*this);
- stream.putInt1_4Bytes(_type.dimensions().size());
- for (const auto &dimension : _type.dimensions()) {
+void encodeDimensions(nbostream &stream, const eval::ValueType &type) {
+ stream.putInt1_4Bytes(type.dimensions().size());
+ for (const auto &dimension : type.dimensions()) {
stream.writeSmallString(dimension.name);
}
- stream.putInt1_4Bytes(_numCells);
- stream.write(_cells.peek(), _cells.size());
}
-
-void
-SparseBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
-{
- SparseBinaryFormatSerializer serializer;
- serializer.serialize(stream, tensor);
+template <typename T>
+size_t encodeCells(nbostream &stream, const Tensor &tensor) {
+ SparseBinaryFormatSerializer<T> serializer(stream, tensor.type());
+ tensor.accept(serializer);
+ return serializer.num_cells();
}
+size_t encodeCells(nbostream &stream, const Tensor &tensor, CellType cell_type) {
+ switch (cell_type) {
+ case CellType::DOUBLE:
+ return encodeCells<double>(stream, tensor);
+ break;
+ case CellType::FLOAT:
+ return encodeCells<float>(stream, tensor);
+ break;
+ }
+ return 0;
+}
-std::unique_ptr<Tensor>
-SparseBinaryFormat::deserialize(nbostream &stream)
-{
+template<typename T>
+void decodeCells(nbostream &stream, size_t dimensionsSize, size_t cellsSize, DirectSparseTensorBuilder &builder) {
+ T cellValue = 0.0;
vespalib::string str;
- size_t dimensionsSize = stream.getInt1_4Bytes();
- std::vector<ValueType::Dimension> dimensions;
- while (dimensions.size() < dimensionsSize) {
- stream.readSmallString(str);
- dimensions.emplace_back(str);
- }
- ValueType type = ValueType::tensor_type(std::move(dimensions));
- DirectSparseTensorBuilder builder(type);
SparseTensorAddressBuilder address;
-
- size_t cellsSize = stream.getInt1_4Bytes();
- double cellValue = 0.0;
for (size_t cellIdx = 0; cellIdx < cellsSize; ++cellIdx) {
address.clear();
for (size_t dimension = 0; dimension < dimensionsSize; ++dimension) {
@@ -121,10 +114,49 @@ SparseBinaryFormat::deserialize(nbostream &stream)
}
}
stream >> cellValue;
- builder.insertCell(address, cellValue);
+ builder.insertCell(address, cellValue, [](double, double v){ return v; });
}
- return builder.build();
}
+void decodeCells(CellType cell_type, nbostream &stream, size_t dimensionsSize, size_t cellsSize, DirectSparseTensorBuilder &builder) {
+ switch (cell_type) {
+ case CellType::DOUBLE:
+ decodeCells<double>(stream, dimensionsSize, cellsSize, builder);
+ break;
+ case CellType::FLOAT:
+ decodeCells<float>(stream, dimensionsSize, cellsSize, builder);
+ break;
+ }
+}
+
+}
+
+void
+SparseBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
+{
+ const auto &type = tensor.type();
+ encodeDimensions(stream, type);
+ nbostream cells;
+ size_t numCells = encodeCells(cells, tensor, type.cell_type());
+ stream.putInt1_4Bytes(numCells);
+ stream.write(cells.peek(), cells.size());
+}
+
+std::unique_ptr<Tensor>
+SparseBinaryFormat::deserialize(nbostream &stream, CellType cell_type)
+{
+ vespalib::string str;
+ size_t dimensionsSize = stream.getInt1_4Bytes();
+ std::vector<ValueType::Dimension> dimensions;
+ while (dimensions.size() < dimensionsSize) {
+ stream.readSmallString(str);
+ dimensions.emplace_back(str);
+ }
+ ValueType type = ValueType::tensor_type(std::move(dimensions), cell_type);
+ DirectSparseTensorBuilder builder(type);
+ size_t cellsSize = stream.getInt1_4Bytes();
+ decodeCells(cell_type, stream, dimensionsSize, cellsSize, builder);
+ return builder.build();
+}
}
diff --git a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h
index cd68e7eeda4..0611d7d5a23 100644
--- a/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/sparse_binary_format.h
@@ -3,6 +3,7 @@
#pragma once
#include <memory>
+#include <vespa/eval/eval/value_type.h>
namespace vespalib { class nbostream; }
@@ -11,13 +12,15 @@ namespace vespalib::tensor {
class Tensor;
/**
- * Class for serializing a tensor.
+ * Class for serializing a sparse tensor.
*/
class SparseBinaryFormat
{
public:
+ using CellType = eval::ValueType::CellType;
+
static void serialize(nbostream &stream, const Tensor &tensor);
- static std::unique_ptr<Tensor> deserialize(nbostream &stream);
+ static std::unique_ptr<Tensor> deserialize(nbostream &stream, CellType cell_type);
};
}
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
index 23179d4b908..8d9767374a2 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
@@ -16,6 +16,8 @@
LOG_SETUP(".eval.tensor.serialization.typed_binary_format");
using vespalib::nbostream;
+using vespalib::eval::ValueType;
+using CellType = vespalib::eval::ValueType::CellType;
namespace vespalib::tensor {
@@ -31,47 +33,52 @@ constexpr uint32_t MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7u;
constexpr uint32_t DOUBLE_VALUE_TYPE = 0;
constexpr uint32_t FLOAT_VALUE_TYPE = 1;
-uint32_t
-format2Encoding(SerializeFormat format) {
- switch (format) {
- case SerializeFormat::DOUBLE:
- return DOUBLE_VALUE_TYPE;
- case SerializeFormat::FLOAT:
- return FLOAT_VALUE_TYPE;
+uint32_t cell_type_to_encoding(CellType cell_type) {
+ switch (cell_type) {
+ case CellType::DOUBLE:
+ return DOUBLE_VALUE_TYPE;
+ case CellType::FLOAT:
+ return FLOAT_VALUE_TYPE;
}
abort();
}
-SerializeFormat
-encoding2Format(uint32_t serializedType) {
- switch (serializedType) {
- case DOUBLE_VALUE_TYPE:
- return SerializeFormat::DOUBLE;
- case FLOAT_VALUE_TYPE:
- return SerializeFormat::FLOAT;
- default:
- throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", serializedType));
+CellType
+encoding_to_cell_type(uint32_t cell_encoding) {
+ switch (cell_encoding) {
+ case DOUBLE_VALUE_TYPE:
+ return CellType::DOUBLE;
+ case FLOAT_VALUE_TYPE:
+ return CellType::FLOAT;
+ default:
+ throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", cell_encoding));
}
}
}
void
-TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format)
+TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
{
+ auto cell_type = tensor.type().cell_type();
+ bool default_cell_type = (cell_type == CellType::DOUBLE);
if (auto denseTensor = dynamic_cast<const DenseTensorView *>(&tensor)) {
- if (format != SerializeFormat::DOUBLE) {
- stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE);
- stream.putInt1_4Bytes(format2Encoding(format));
- DenseBinaryFormat(format).serialize(stream, *denseTensor);
- } else {
+ if (default_cell_type) {
stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
- DenseBinaryFormat(SerializeFormat::DOUBLE).serialize(stream, *denseTensor);
+ } else {
+ stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_WITH_CELLTYPE);
+ stream.putInt1_4Bytes(cell_type_to_encoding(cell_type));
}
+ DenseBinaryFormat::serialize(stream, *denseTensor);
} else if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(&tensor)) {
eval::SimpleTensor::encode(wrapped->get(), stream);
} else {
- stream.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
+ if (default_cell_type) {
+ stream.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
+ } else {
+ stream.putInt1_4Bytes(SPARSE_BINARY_FORMAT_WITH_CELLTYPE);
+ stream.putInt1_4Bytes(cell_type_to_encoding(cell_type));
+ }
SparseBinaryFormat::serialize(stream, tensor);
}
}
@@ -80,40 +87,46 @@ TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor, SerializeF
std::unique_ptr<Tensor>
TypedBinaryFormat::deserialize(nbostream &stream)
{
+ auto cell_type = CellType::DOUBLE;
auto read_pos = stream.rp();
auto formatId = stream.getInt1_4Bytes();
- if (formatId == SPARSE_BINARY_FORMAT_TYPE) {
- return SparseBinaryFormat::deserialize(stream);
- }
- if (formatId == DENSE_BINARY_FORMAT_TYPE) {
- return DenseBinaryFormat(SerializeFormat::DOUBLE).deserialize(stream);
- }
- if ((formatId == SPARSE_BINARY_FORMAT_WITH_CELLTYPE) ||
- (formatId == DENSE_BINARY_FORMAT_WITH_CELLTYPE) ||
- (formatId == MIXED_BINARY_FORMAT_TYPE) ||
- (formatId == MIXED_BINARY_FORMAT_WITH_CELLTYPE))
- {
+ switch (formatId) {
+ case SPARSE_BINARY_FORMAT_WITH_CELLTYPE:
+ cell_type = encoding_to_cell_type(stream.getInt1_4Bytes());
+ [[fallthrough]];
+ case SPARSE_BINARY_FORMAT_TYPE:
+ return SparseBinaryFormat::deserialize(stream, cell_type);
+ case DENSE_BINARY_FORMAT_WITH_CELLTYPE:
+ cell_type = encoding_to_cell_type(stream.getInt1_4Bytes());
+ [[fallthrough]];
+ case DENSE_BINARY_FORMAT_TYPE:
+ return DenseBinaryFormat::deserialize(stream, cell_type);
+ case MIXED_BINARY_FORMAT_TYPE:
+ case MIXED_BINARY_FORMAT_WITH_CELLTYPE:
stream.adjustReadPos(read_pos - stream.rp());
return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream));
+ default:
+ throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId));
}
- throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId));
}
template <typename T>
void
-TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells)
+TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> &cells)
{
+ auto cell_type = CellType::DOUBLE;
auto formatId = stream.getInt1_4Bytes();
- if (formatId == DENSE_BINARY_FORMAT_TYPE) {
- return DenseBinaryFormat(SerializeFormat::DOUBLE).deserializeCellsOnly(stream, cells);
- }
- if (formatId == DENSE_BINARY_FORMAT_WITH_CELLTYPE) {
- return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserializeCellsOnly(stream, cells);
+ switch (formatId) {
+ case DENSE_BINARY_FORMAT_WITH_CELLTYPE:
+ cell_type = encoding_to_cell_type(stream.getInt1_4Bytes());
+ [[fallthrough]];
+ case DENSE_BINARY_FORMAT_TYPE:
+ return DenseBinaryFormat::deserializeCellsOnly(stream, cells, cell_type);
}
abort();
}
-template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<double> & cells);
-template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<float> & cells);
+template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<double> &cells);
+template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<float> &cells);
}
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
index 717d51effef..198b09ae336 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
@@ -2,7 +2,6 @@
#pragma once
-#include "common.h"
#include <memory>
#include <vector>
@@ -18,16 +17,12 @@ class Tensor;
class TypedBinaryFormat
{
public:
- static void serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format);
- static void serialize(nbostream &stream, const Tensor &tensor) {
- serialize(stream, tensor, SerializeFormat::DOUBLE);
- }
-
+ static void serialize(nbostream &stream, const Tensor &tensor);
static std::unique_ptr<Tensor> deserialize(nbostream &stream);
-
+
// This is a temporary method until we get full support for typed tensors
template <typename T>
- static void deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells);
+ static void deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> &cells);
};
}