summaryrefslogtreecommitdiffstats
path: root/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
blob: 493e4af3cafb89246ccddde9f8e803d021ce673f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "dense_binary_format.h"
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/exceptions.h>
#include <cassert>

using vespalib::nbostream;
using vespalib::eval::ValueType;
using CellType = vespalib::eval::ValueType::CellType;

namespace vespalib::tensor {

using Dimension = eval::ValueType::Dimension;


namespace {

size_t encodeDimensions(nbostream &stream, const eval::ValueType & type) {
    stream.putInt1_4Bytes(type.dimensions().size());
    size_t cellsSize = 1;
    for (const auto &dimension : type.dimensions()) {
        stream.writeSmallString(dimension.name);
        stream.putInt1_4Bytes(dimension.size);
        cellsSize *= dimension.size;
    }
    return cellsSize;
}

template<typename T>
void encodeCells(nbostream &stream, TypedCells cells) {
    auto arr = cells.typify<T>();
    for (const auto &value : arr) {
        stream << value;
    }
}

size_t decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) {
    vespalib::string dimensionName;
    size_t dimensionsSize = stream.getInt1_4Bytes();
    size_t dimensionSize;
    size_t cellsSize = 1;
    while (dimensions.size() < dimensionsSize) {
        stream.readSmallString(dimensionName);
        dimensionSize = stream.getInt1_4Bytes();
        dimensions.emplace_back(dimensionName, dimensionSize);
        cellsSize *= dimensionSize;
    }
    return cellsSize;
}

template<typename T, typename V>
void decodeCells(nbostream &stream, size_t cellsSize, V &cells) {
    T cellValue = 0.0;
    for (size_t i = 0; i < cellsSize; ++i) {
        stream >> cellValue;
        cells.emplace_back(cellValue);
    }
}

template <typename V>
void decodeCells(CellType cell_type, nbostream &stream, size_t cellsSize, V &cells) {
    switch (cell_type) {
    case CellType::DOUBLE:
        decodeCells<double>(stream, cellsSize, cells);
        break;
    case CellType::FLOAT:
        decodeCells<float>(stream, cellsSize, cells);
        break;
    }
}

}

void
DenseBinaryFormat::serialize(nbostream &stream, const DenseTensorView &tensor)
{
    size_t cellsSize = encodeDimensions(stream, tensor.fast_type());
    TypedCells cells = tensor.cellsRef();
    assert(cells.size == cellsSize);
    switch (tensor.fast_type().cell_type()) {
    case CellType::DOUBLE:
        encodeCells<double>(stream, cells);
        break;
    case CellType::FLOAT:
        encodeCells<float>(stream, cells);
        break;
    }
}

struct CallDecodeCells {
    template <typename CT>
    static std::unique_ptr<DenseTensorView>
    call(nbostream &stream, size_t numCells, ValueType &&newType) {
        std::vector<CT> newCells;
        newCells.reserve(numCells);
        decodeCells<CT>(stream, numCells, newCells);
        return std::make_unique<DenseTensor<CT>>(std::move(newType), std::move(newCells));
    }
};

std::unique_ptr<DenseTensorView>
DenseBinaryFormat::deserialize(nbostream &stream, CellType cell_type)
{
    std::vector<Dimension> dimensions;
    size_t numCells = decodeDimensions(stream, dimensions);
    ValueType newType = ValueType::tensor_type(std::move(dimensions), cell_type);
    return dispatch_0<CallDecodeCells>(cell_type, stream, numCells, std::move(newType));
}

template <typename T>
void
DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> &cells, CellType cell_type)
{
    std::vector<Dimension> dimensions;
    size_t cellsSize = decodeDimensions(stream, dimensions);
    cells.clear();
    cells.reserve(cellsSize);
    decodeCells(cell_type, stream, cellsSize, cells);
}

template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> &cells, CellType cell_type);
template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> &cells, CellType cell_type);

}