summaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@oath.com>2019-04-03 16:33:28 +0000
committerHenning Baldersheim <balder@oath.com>2019-04-03 16:33:28 +0000
commitce4339ab5224bac5df5fb721cb7d2668ae75c811 (patch)
treeb0fc7edd32afed431750e65cb84635295666cddb /eval/src
parent6bb8fbba81114e3b7c902b1a9d12f3d91029011e (diff)
Keep the serialzation of the cells with the serialization for now.
Clean up code for better understanding and reuse.
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp11
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h12
-rw-r--r--eval/src/vespa/eval/tensor/serialization/common.h9
-rw-r--r--eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp56
-rw-r--r--eval/src/vespa/eval/tensor/serialization/dense_binary_format.h7
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp53
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.h13
7 files changed, 79 insertions, 82 deletions
diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
index d32fecc5cba..0237f6cc769 100644
--- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
+++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
@@ -16,7 +16,6 @@
using namespace vespalib::tensor;
using vespalib::nbostream;
using ExpBuffer = std::vector<uint8_t>;
-using SerializeFormat = vespalib::tensor::DenseTensorView::SerializeFormat;
namespace std {
@@ -145,10 +144,8 @@ TEST_F("test tensor serialization for SparseTensor", SparseFixture)
struct DenseFixture
{
- Tensor::UP createTensor(SerializeFormat format, const DenseTensorCells &cells) {
- auto tensor = TensorFactory::createDense(cells);
- dynamic_cast<DenseTensorView &>(*tensor).serializeAs(format);
- return tensor;
+ Tensor::UP createTensor(const DenseTensorCells &cells) {
+ return TensorFactory::createDense(cells);
}
void serialize(nbostream &stream, const Tensor &tensor) {
@@ -166,9 +163,9 @@ struct DenseFixture
assertSerialized(exp, SerializeFormat::DOUBLE, rhs);
}
void assertSerialized(const ExpBuffer &exp, SerializeFormat cellType, const DenseTensorCells &rhs) {
- Tensor::UP rhsTensor(createTensor(cellType, rhs));
+ Tensor::UP rhsTensor(createTensor(rhs));
nbostream rhsStream;
- serialize(rhsStream, *rhsTensor);
+ TypedBinaryFormat::serialize(rhsStream, *rhsTensor, cellType);
EXPECT_EQUAL(exp, rhsStream);
auto rhs2 = deserialize(rhsStream);
EXPECT_EQUAL(*rhs2, *rhsTensor);
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index 19a8a66bcf7..09b6b72375e 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -14,7 +14,6 @@ namespace vespalib::tensor {
class DenseTensorView : public Tensor
{
public:
- enum class SerializeFormat {FLOAT, DOUBLE};
using Cells = std::vector<double>;
using CellsRef = ConstArrayRef<double>;
using CellsIterator = DenseTensorCellsIterator;
@@ -22,16 +21,13 @@ public:
DenseTensorView(const eval::ValueType &type_in, CellsRef cells_in)
: _typeRef(type_in),
- _cellsRef(cells_in),
- _serializeFormat(SerializeFormat::DOUBLE)
+ _cellsRef(cells_in)
{}
explicit DenseTensorView(const eval::ValueType &type_in)
: _typeRef(type_in),
- _cellsRef(),
- _serializeFormat(SerializeFormat::DOUBLE)
+ _cellsRef()
{}
- SerializeFormat serializeAs() const { return _serializeFormat; }
- void serializeAs(SerializeFormat format) { _serializeFormat = format; }
+
const eval::ValueType &fast_type() const { return _typeRef; }
const CellsRef &cellsRef() const { return _cellsRef; }
bool operator==(const DenseTensorView &rhs) const;
@@ -58,8 +54,6 @@ private:
const eval::ValueType &_typeRef;
CellsRef _cellsRef;
- //TODO This is a temporary workaround until proper type support for tensors is in place.
- SerializeFormat _serializeFormat;
};
}
diff --git a/eval/src/vespa/eval/tensor/serialization/common.h b/eval/src/vespa/eval/tensor/serialization/common.h
new file mode 100644
index 00000000000..9c45bc42136
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/serialization/common.h
@@ -0,0 +1,9 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+namespace vespalib::tensor {
+
+enum class SerializeFormat {FLOAT, DOUBLE};
+
+}
diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
index 2a939963e16..6043153adc3 100644
--- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
@@ -10,17 +10,11 @@ using vespalib::nbostream;
namespace vespalib::tensor {
-using SerializationFormat = DenseTensorView::SerializeFormat;
using Dimension = eval::ValueType::Dimension;
namespace {
-using EncodeType = DenseBinaryFormat::EncodeType;
-
-constexpr int DOUBLE_VALUE_TYPE = 0;
-constexpr int FLOAT_VALUE_TYPE = 1;
-
eval::ValueType
makeValueType(std::vector<Dimension> &&dimensions) {
return (dimensions.empty() ?
@@ -48,20 +42,6 @@ encodeCells(nbostream &stream, DenseTensorView::CellsRef cells) {
}
}
-void
-encodeValueType(nbostream & stream, SerializationFormat valueType, EncodeType encodeType) {
- switch (valueType) {
- case SerializationFormat::DOUBLE:
- if (encodeType != EncodeType::DOUBLE_IS_DEFAULT) {
- stream.putInt1_4Bytes(DOUBLE_VALUE_TYPE);
- }
- break;
- case SerializationFormat::FLOAT:
- stream.putInt1_4Bytes(FLOAT_VALUE_TYPE);
- break;
- }
-}
-
size_t
decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) {
vespalib::string dimensionName;
@@ -77,23 +57,6 @@ decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) {
return cellsSize;
}
-SerializationFormat
-decodeCellType(nbostream & stream, EncodeType encodeType) {
- if (encodeType != EncodeType::DOUBLE_IS_DEFAULT) {
- uint32_t serializedType = stream.getInt1_4Bytes();
- switch (serializedType) {
- case DOUBLE_VALUE_TYPE:
- return SerializationFormat::DOUBLE;
- case FLOAT_VALUE_TYPE:
- return SerializationFormat::FLOAT;
- default:
- throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", serializedType));
- }
- } else {
- return SerializationFormat::DOUBLE;
- }
-}
-
template<typename T>
void
decodeCells(nbostream &stream, size_t cellsSize, DenseTensor::Cells & cells) {
@@ -109,17 +72,15 @@ decodeCells(nbostream &stream, size_t cellsSize, DenseTensor::Cells & cells) {
void
DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor)
{
- const eval::ValueType & type = tensor.fast_type();
- encodeValueType(stream, tensor.serializeAs(), _encodeType);
- size_t cellsSize = encodeDimensions(stream, type);
+ size_t cellsSize = encodeDimensions(stream, tensor.fast_type());
DenseTensorView::CellsRef cells = tensor.cellsRef();
assert(cells.size() == cellsSize);
- switch (tensor.serializeAs()) {
- case SerializationFormat::DOUBLE:
+ switch (_format) {
+ case SerializeFormat::DOUBLE:
encodeCells<double>(stream, cells);
break;
- case SerializationFormat::FLOAT:
+ case SerializeFormat::FLOAT:
encodeCells<float>(stream, cells);
break;
}
@@ -128,17 +89,16 @@ DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor)
std::unique_ptr<DenseTensor>
DenseBinaryFormat::deserialize(nbostream &stream)
{
- SerializationFormat cellType = decodeCellType(stream, _encodeType);
std::vector<Dimension> dimensions;
size_t cellsSize = decodeDimensions(stream,dimensions);
DenseTensor::Cells cells;
cells.reserve(cellsSize);
- switch (cellType) {
- case SerializationFormat::DOUBLE:
+ switch (_format) {
+ case SerializeFormat::DOUBLE:
decodeCells<double>(stream, cellsSize,cells);
break;
- case SerializationFormat::FLOAT:
- decodeCells<float>(stream, cellsSize,cells);
+ case SerializeFormat::FLOAT:
+ decodeCells<float>(stream, cellsSize, cells);
break;
}
diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
index 3f9236ce222..22c1663719e 100644
--- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
@@ -2,7 +2,9 @@
#pragma once
+#include "common.h"
#include <memory>
+
namespace vespalib { class nbostream; }
namespace vespalib::tensor {
@@ -16,12 +18,11 @@ class DenseTensorView;
class DenseBinaryFormat
{
public:
- enum class EncodeType { NO_DEFAULT, DOUBLE_IS_DEFAULT};
- DenseBinaryFormat(EncodeType encodeType) : _encodeType(encodeType) { }
+ DenseBinaryFormat(SerializeFormat format) : _format(format) { }
void serialize(nbostream &stream, const DenseTensorView &tensor);
std::unique_ptr<DenseTensor> deserialize(nbostream &stream);
private:
- EncodeType _encodeType;
+ SerializeFormat _format;
};
}
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
index e98b106d764..d1aa09b6ce3 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
@@ -11,23 +11,61 @@
#include <vespa/eval/tensor/wrapped_simple_tensor.h>
#include <vespa/log/log.h>
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/exceptions.h>
+
LOG_SETUP(".eval.tensor.serialization.typed_binary_format");
using vespalib::nbostream;
namespace vespalib::tensor {
+namespace {
+
+constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u;
+constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u;
+constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u;
+constexpr uint32_t TYPED_DENSE_BINARY_FORMAT_TYPE = 4u;
+
+constexpr uint32_t DOUBLE_VALUE_TYPE = 0;
+constexpr uint32_t FLOAT_VALUE_TYPE = 1;
+
+uint32_t
+format2Encoding(SerializeFormat format) {
+ switch (format) {
+ case SerializeFormat::DOUBLE:
+ return DOUBLE_VALUE_TYPE;
+ case SerializeFormat::FLOAT:
+ return FLOAT_VALUE_TYPE;
+ }
+ abort();
+}
+
+SerializeFormat
+encoding2Format(uint32_t serializedType) {
+ switch (serializedType) {
+ case DOUBLE_VALUE_TYPE:
+ return SerializeFormat::DOUBLE;
+ case FLOAT_VALUE_TYPE:
+ return SerializeFormat::FLOAT;
+ default:
+ throw IllegalArgumentException(make_string("Received unknown tensor value type = %u. Only 0(double), or 1(float) are legal.", serializedType));
+ }
+}
+
+}
void
-TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
+TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format)
{
if (auto denseTensor = dynamic_cast<const DenseTensorView *>(&tensor)) {
- if (denseTensor->serializeAs() != DenseTensorView::SerializeFormat::DOUBLE) {
+ if (format != SerializeFormat::DOUBLE) {
stream.putInt1_4Bytes(TYPED_DENSE_BINARY_FORMAT_TYPE);
- DenseBinaryFormat(DenseBinaryFormat::EncodeType::NO_DEFAULT).serialize(stream, *denseTensor);
+ stream.putInt1_4Bytes(format2Encoding(format));
+ DenseBinaryFormat(format).serialize(stream, *denseTensor);
} else {
stream.putInt1_4Bytes(DENSE_BINARY_FORMAT_TYPE);
- DenseBinaryFormat(DenseBinaryFormat::EncodeType::DOUBLE_IS_DEFAULT).serialize(stream, *denseTensor);
+ DenseBinaryFormat(SerializeFormat::DOUBLE).serialize(stream, *denseTensor);
}
} else if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(&tensor)) {
eval::SimpleTensor::encode(wrapped->get(), stream);
@@ -49,17 +87,16 @@ TypedBinaryFormat::deserialize(nbostream &stream)
return builder.build();
}
if (formatId == DENSE_BINARY_FORMAT_TYPE) {
- return DenseBinaryFormat(DenseBinaryFormat::EncodeType::DOUBLE_IS_DEFAULT).deserialize(stream);
+ return DenseBinaryFormat(SerializeFormat::DOUBLE).deserialize(stream);
}
if (formatId == TYPED_DENSE_BINARY_FORMAT_TYPE) {
- return DenseBinaryFormat(DenseBinaryFormat::EncodeType::NO_DEFAULT).deserialize(stream);
+ return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserialize(stream);
}
if (formatId == MIXED_BINARY_FORMAT_TYPE) {
stream.adjustReadPos(read_pos - stream.rp());
return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::decode(stream));
}
- LOG_ABORT("should not be reached");
+ abort();
}
-
}
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
index e40ac7bda43..95d9a75488c 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
@@ -2,27 +2,26 @@
#pragma once
+#include "common.h"
#include <memory>
-#include <cstdint>
namespace vespalib { class nbostream; }
namespace vespalib::tensor {
class Tensor;
-class TensorBuilder;
/**
* Class for serializing a tensor.
*/
class TypedBinaryFormat
{
- static constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u;
- static constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u;
- static constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u;
- static constexpr uint32_t TYPED_DENSE_BINARY_FORMAT_TYPE = 4u;
public:
- static void serialize(nbostream &stream, const Tensor &tensor);
+ static void serialize(nbostream &stream, const Tensor &tensor, SerializeFormat format);
+ static void serialize(nbostream &stream, const Tensor &tensor) {
+ serialize(stream, tensor, SerializeFormat::DOUBLE);
+ }
+
static std::unique_ptr<Tensor> deserialize(nbostream &stream);
};