diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-03-16 12:01:44 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-03-23 16:57:37 +0000 |
commit | 19711977575c14de515ff1dca768ba6eb11c6be6 (patch) | |
tree | 82180192df25cb191cd840f72d5820c38febb5af /eval/src | |
parent | f8cf0739f35de7b8da799e5206421ec5dc66df49 (diff) |
add cell types int8 and bfloat16
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/tests/eval/value_type/value_type_test.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/cell_type.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/cell_type.h | 28 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/dense_cells_value.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/typed_cells.h | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value_codec.cpp | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value_type_spec.cpp | 6 | ||||
-rw-r--r-- | eval/src/vespa/eval/streamed/streamed_value.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/streamed/streamed_value_builder.cpp | 2 |
9 files changed, 40 insertions, 12 deletions
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index 120ab2d93e9..554cfba8c94 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -572,7 +572,7 @@ TEST("require that cell array size can be calculated") { } TEST("require that all cell types can be listed") { - std::vector<CellType> expect = {CellType::FLOAT, CellType::DOUBLE}; + std::vector<CellType> expect = { CellType::INT8, CellType::BFLOAT16, CellType::FLOAT, CellType::DOUBLE }; auto list = CellTypeUtils::list_types(); ASSERT_EQUAL(list.size(), expect.size()); for (size_t i = 0; i < list.size(); ++i) { diff --git a/eval/src/vespa/eval/eval/cell_type.cpp b/eval/src/vespa/eval/eval/cell_type.cpp index 753d888c24c..95bbecaa6ee 100644 --- a/eval/src/vespa/eval/eval/cell_type.cpp +++ b/eval/src/vespa/eval/eval/cell_type.cpp @@ -33,7 +33,7 @@ CellTypeUtils::mem_size(CellType cell_type, size_t sz) std::vector<CellType> CellTypeUtils::list_types() { - return {CellType::FLOAT, CellType::DOUBLE}; + return {CellType::INT8, CellType::BFLOAT16, CellType::FLOAT, CellType::DOUBLE}; } } diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h index 57f707c2aa8..02e321ac744 100644 --- a/eval/src/vespa/eval/eval/cell_type.h +++ b/eval/src/vespa/eval/eval/cell_type.h @@ -6,15 +6,19 @@ #include <vector> #include <cstdint> #include <cassert> +#include "int8float.h" +#include <vespa/vespalib/util/bfloat16.h> namespace vespalib::eval { -enum class CellType : char { FLOAT, DOUBLE }; +enum class CellType : char { FLOAT, DOUBLE, BFLOAT16, INT8 }; // converts actual cell type to CellType enum value template <typename CT> constexpr CellType get_cell_type(); template <> constexpr CellType get_cell_type<double>() { return CellType::DOUBLE; } template <> constexpr CellType get_cell_type<float>() { return CellType::FLOAT; } +template <> constexpr CellType get_cell_type<BFloat16>() { return CellType::BFLOAT16; } +template <> constexpr CellType get_cell_type<Int8Float>() { return CellType::INT8; } // check if the given CellType enum value and actual cell type match template <typename CT> constexpr bool check_cell_type(CellType type) { @@ -29,6 +33,10 @@ template <CellType cell_type> constexpr auto get_cell_value() { return double(); } else if constexpr (cell_type == CellType::FLOAT) { return float(); + } else if constexpr (cell_type == CellType::BFLOAT16) { + return BFloat16(); + } else if constexpr (cell_type == CellType::INT8) { + return Int8Float(); } else { static_assert((cell_type == CellType::DOUBLE), "unknown cell type"); } @@ -136,9 +144,11 @@ struct CellMeta { struct TypifyCellType { template <typename T> using Result = TypifyResultType<T>; template <typename F> static decltype(auto) resolve(CellType value, F &&f) { - switch (value) { - case CellType::DOUBLE: return f(Result<double>()); - case CellType::FLOAT: return f(Result<float>()); + switch(value) { + case CellType::DOUBLE: return f(Result<double>()); + case CellType::FLOAT: return f(Result<float>()); + case CellType::BFLOAT16: return f(Result<BFloat16>()); + case CellType::INT8: return f(Result<Int8Float>()); } abort(); } @@ -158,8 +168,10 @@ struct TypifyCellMeta { } template <typename F> static decltype(auto) resolve(CellMetaNotScalar value, F &&f) { switch (value.cell_type) { - case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); - case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); + case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + case CellType::BFLOAT16: return f(Result<CellMeta(CellType::BFLOAT16, false)>()); + case CellType::INT8: return f(Result<CellMeta(CellType::INT8, false)>()); } abort(); } @@ -175,8 +187,8 @@ struct TypifyCellMeta { } template <typename F> static decltype(auto) resolve(LimitedCellMetaNotScalar value, F &&f) { switch (value.cell_type) { - case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); - case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); + case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); default: break; } abort(); diff --git a/eval/src/vespa/eval/eval/dense_cells_value.cpp b/eval/src/vespa/eval/eval/dense_cells_value.cpp index 126ef806668..a699153242d 100644 --- a/eval/src/vespa/eval/eval/dense_cells_value.cpp +++ b/eval/src/vespa/eval/eval/dense_cells_value.cpp @@ -15,5 +15,7 @@ DenseCellsValue<T>::get_memory_usage() const { template class DenseCellsValue<double>; template class DenseCellsValue<float>; +template class DenseCellsValue<BFloat16>; +template class DenseCellsValue<Int8Float>; } diff --git a/eval/src/vespa/eval/eval/typed_cells.h b/eval/src/vespa/eval/eval/typed_cells.h index b65fa2b40e4..581de355694 100644 --- a/eval/src/vespa/eval/eval/typed_cells.h +++ b/eval/src/vespa/eval/eval/typed_cells.h @@ -17,6 +17,8 @@ struct TypedCells { explicit TypedCells(ConstArrayRef<double> cells) : data(cells.begin()), type(CellType::DOUBLE), size(cells.size()) {} explicit TypedCells(ConstArrayRef<float> cells) : data(cells.begin()), type(CellType::FLOAT), size(cells.size()) {} + explicit TypedCells(ConstArrayRef<BFloat16> cells) : data(cells.begin()), type(CellType::BFLOAT16), size(cells.size()) {} + explicit TypedCells(ConstArrayRef<Int8Float> cells) : data(cells.begin()), type(CellType::INT8), size(cells.size()) {} TypedCells() : data(nullptr), type(CellType::DOUBLE), size(0) {} TypedCells(const void *dp, CellType ct, size_t sz) : data(dp), type(ct), size(sz) {} diff --git a/eval/src/vespa/eval/eval/value_codec.cpp b/eval/src/vespa/eval/eval/value_codec.cpp index bf45f34fd64..50c77eca7e9 100644 --- a/eval/src/vespa/eval/eval/value_codec.cpp +++ b/eval/src/vespa/eval/eval/value_codec.cpp @@ -18,11 +18,15 @@ namespace { constexpr uint32_t DOUBLE_CELL_TYPE = 0; constexpr uint32_t FLOAT_CELL_TYPE = 1; +constexpr uint32_t BFLOAT16_CELL_TYPE = 2; +constexpr uint32_t INT8_CELL_TYPE = 3; inline uint32_t cell_type_to_id(CellType cell_type) { switch (cell_type) { case CellType::DOUBLE: return DOUBLE_CELL_TYPE; case CellType::FLOAT: return FLOAT_CELL_TYPE; + case CellType::BFLOAT16: return BFLOAT16_CELL_TYPE; + case CellType::INT8: return INT8_CELL_TYPE; } throw IllegalArgumentException(fmt("Unknown CellType=%u", (uint32_t)cell_type)); } @@ -31,6 +35,8 @@ inline CellType id_to_cell_type(uint32_t id) { switch (id) { case DOUBLE_CELL_TYPE: return CellType::DOUBLE; case FLOAT_CELL_TYPE: return CellType::FLOAT; + case BFLOAT16_CELL_TYPE: return CellType::BFLOAT16; + case INT8_CELL_TYPE: return CellType::INT8; } throw IllegalArgumentException(fmt("Unknown CellType id=%u", id)); } diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp index b518ccd1b30..92646ed1f67 100644 --- a/eval/src/vespa/eval/eval/value_type_spec.cpp +++ b/eval/src/vespa/eval/eval/value_type_spec.cpp @@ -10,8 +10,10 @@ namespace vespalib::eval::value_type { vespalib::string cell_type_to_name(CellType cell_type) { switch (cell_type) { - case CellType::DOUBLE: return "double"; - case CellType::FLOAT: return "float"; + case CellType::DOUBLE: return "double"; + case CellType::FLOAT: return "float"; + case CellType::BFLOAT16: return "bfloat16"; + case CellType::INT8: return "int8"; } abort(); } diff --git a/eval/src/vespa/eval/streamed/streamed_value.cpp b/eval/src/vespa/eval/streamed/streamed_value.cpp index c09e433b9b9..63765d9b1da 100644 --- a/eval/src/vespa/eval/streamed/streamed_value.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value.cpp @@ -22,6 +22,8 @@ StreamedValue<T>::get_memory_usage() const template class StreamedValue<double>; template class StreamedValue<float>; +template class StreamedValue<BFloat16>; +template class StreamedValue<Int8Float>; } // namespace diff --git a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp index 957121c42b7..ba5d92bf5d2 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp @@ -9,5 +9,7 @@ StreamedValueBuilder<T>::~StreamedValueBuilder() = default; template class StreamedValueBuilder<double>; template class StreamedValueBuilder<float>; +template class StreamedValueBuilder<BFloat16>; +template class StreamedValueBuilder<Int8Float>; } // namespace |