From 19711977575c14de515ff1dca768ba6eb11c6be6 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Tue, 16 Mar 2021 12:01:44 +0000 Subject: add cell types int8 and bfloat16 --- eval/src/tests/eval/value_type/value_type_test.cpp | 2 +- eval/src/vespa/eval/eval/cell_type.cpp | 2 +- eval/src/vespa/eval/eval/cell_type.h | 28 +++++++++++++++------- eval/src/vespa/eval/eval/dense_cells_value.cpp | 2 ++ eval/src/vespa/eval/eval/typed_cells.h | 2 ++ eval/src/vespa/eval/eval/value_codec.cpp | 6 +++++ eval/src/vespa/eval/eval/value_type_spec.cpp | 6 +++-- eval/src/vespa/eval/streamed/streamed_value.cpp | 2 ++ .../vespa/eval/streamed/streamed_value_builder.cpp | 2 ++ 9 files changed, 40 insertions(+), 12 deletions(-) (limited to 'eval') diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index 120ab2d93e9..554cfba8c94 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -572,7 +572,7 @@ TEST("require that cell array size can be calculated") { } TEST("require that all cell types can be listed") { - std::vector expect = {CellType::FLOAT, CellType::DOUBLE}; + std::vector expect = { CellType::INT8, CellType::BFLOAT16, CellType::FLOAT, CellType::DOUBLE }; auto list = CellTypeUtils::list_types(); ASSERT_EQUAL(list.size(), expect.size()); for (size_t i = 0; i < list.size(); ++i) { diff --git a/eval/src/vespa/eval/eval/cell_type.cpp b/eval/src/vespa/eval/eval/cell_type.cpp index 753d888c24c..95bbecaa6ee 100644 --- a/eval/src/vespa/eval/eval/cell_type.cpp +++ b/eval/src/vespa/eval/eval/cell_type.cpp @@ -33,7 +33,7 @@ CellTypeUtils::mem_size(CellType cell_type, size_t sz) std::vector CellTypeUtils::list_types() { - return {CellType::FLOAT, CellType::DOUBLE}; + return {CellType::INT8, CellType::BFLOAT16, CellType::FLOAT, CellType::DOUBLE}; } } diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h index 57f707c2aa8..02e321ac744 100644 --- a/eval/src/vespa/eval/eval/cell_type.h +++ b/eval/src/vespa/eval/eval/cell_type.h @@ -6,15 +6,19 @@ #include #include #include +#include "int8float.h" +#include namespace vespalib::eval { -enum class CellType : char { FLOAT, DOUBLE }; +enum class CellType : char { FLOAT, DOUBLE, BFLOAT16, INT8 }; // converts actual cell type to CellType enum value template constexpr CellType get_cell_type(); template <> constexpr CellType get_cell_type() { return CellType::DOUBLE; } template <> constexpr CellType get_cell_type() { return CellType::FLOAT; } +template <> constexpr CellType get_cell_type() { return CellType::BFLOAT16; } +template <> constexpr CellType get_cell_type() { return CellType::INT8; } // check if the given CellType enum value and actual cell type match template constexpr bool check_cell_type(CellType type) { @@ -29,6 +33,10 @@ template constexpr auto get_cell_value() { return double(); } else if constexpr (cell_type == CellType::FLOAT) { return float(); + } else if constexpr (cell_type == CellType::BFLOAT16) { + return BFloat16(); + } else if constexpr (cell_type == CellType::INT8) { + return Int8Float(); } else { static_assert((cell_type == CellType::DOUBLE), "unknown cell type"); } @@ -136,9 +144,11 @@ struct CellMeta { struct TypifyCellType { template using Result = TypifyResultType; template static decltype(auto) resolve(CellType value, F &&f) { - switch (value) { - case CellType::DOUBLE: return f(Result()); - case CellType::FLOAT: return f(Result()); + switch(value) { + case CellType::DOUBLE: return f(Result()); + case CellType::FLOAT: return f(Result()); + case CellType::BFLOAT16: return f(Result()); + case CellType::INT8: return f(Result()); } abort(); } @@ -158,8 +168,10 @@ struct TypifyCellMeta { } template static decltype(auto) resolve(CellMetaNotScalar value, F &&f) { switch (value.cell_type) { - case CellType::DOUBLE: return f(Result()); - case CellType::FLOAT: return f(Result()); + case CellType::DOUBLE: return f(Result()); + case CellType::FLOAT: return f(Result()); + case CellType::BFLOAT16: return f(Result()); + case CellType::INT8: return f(Result()); } abort(); } @@ -175,8 +187,8 @@ struct TypifyCellMeta { } template static decltype(auto) resolve(LimitedCellMetaNotScalar value, F &&f) { switch (value.cell_type) { - case CellType::DOUBLE: return f(Result()); - case CellType::FLOAT: return f(Result()); + case CellType::DOUBLE: return f(Result()); + case CellType::FLOAT: return f(Result()); default: break; } abort(); diff --git a/eval/src/vespa/eval/eval/dense_cells_value.cpp b/eval/src/vespa/eval/eval/dense_cells_value.cpp index 126ef806668..a699153242d 100644 --- a/eval/src/vespa/eval/eval/dense_cells_value.cpp +++ b/eval/src/vespa/eval/eval/dense_cells_value.cpp @@ -15,5 +15,7 @@ DenseCellsValue::get_memory_usage() const { template class DenseCellsValue; template class DenseCellsValue; +template class DenseCellsValue; +template class DenseCellsValue; } diff --git a/eval/src/vespa/eval/eval/typed_cells.h b/eval/src/vespa/eval/eval/typed_cells.h index b65fa2b40e4..581de355694 100644 --- a/eval/src/vespa/eval/eval/typed_cells.h +++ b/eval/src/vespa/eval/eval/typed_cells.h @@ -17,6 +17,8 @@ struct TypedCells { explicit TypedCells(ConstArrayRef cells) : data(cells.begin()), type(CellType::DOUBLE), size(cells.size()) {} explicit TypedCells(ConstArrayRef cells) : data(cells.begin()), type(CellType::FLOAT), size(cells.size()) {} + explicit TypedCells(ConstArrayRef cells) : data(cells.begin()), type(CellType::BFLOAT16), size(cells.size()) {} + explicit TypedCells(ConstArrayRef cells) : data(cells.begin()), type(CellType::INT8), size(cells.size()) {} TypedCells() : data(nullptr), type(CellType::DOUBLE), size(0) {} TypedCells(const void *dp, CellType ct, size_t sz) : data(dp), type(ct), size(sz) {} diff --git a/eval/src/vespa/eval/eval/value_codec.cpp b/eval/src/vespa/eval/eval/value_codec.cpp index bf45f34fd64..50c77eca7e9 100644 --- a/eval/src/vespa/eval/eval/value_codec.cpp +++ b/eval/src/vespa/eval/eval/value_codec.cpp @@ -18,11 +18,15 @@ namespace { constexpr uint32_t DOUBLE_CELL_TYPE = 0; constexpr uint32_t FLOAT_CELL_TYPE = 1; +constexpr uint32_t BFLOAT16_CELL_TYPE = 2; +constexpr uint32_t INT8_CELL_TYPE = 3; inline uint32_t cell_type_to_id(CellType cell_type) { switch (cell_type) { case CellType::DOUBLE: return DOUBLE_CELL_TYPE; case CellType::FLOAT: return FLOAT_CELL_TYPE; + case CellType::BFLOAT16: return BFLOAT16_CELL_TYPE; + case CellType::INT8: return INT8_CELL_TYPE; } throw IllegalArgumentException(fmt("Unknown CellType=%u", (uint32_t)cell_type)); } @@ -31,6 +35,8 @@ inline CellType id_to_cell_type(uint32_t id) { switch (id) { case DOUBLE_CELL_TYPE: return CellType::DOUBLE; case FLOAT_CELL_TYPE: return CellType::FLOAT; + case BFLOAT16_CELL_TYPE: return CellType::BFLOAT16; + case INT8_CELL_TYPE: return CellType::INT8; } throw IllegalArgumentException(fmt("Unknown CellType id=%u", id)); } diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp index b518ccd1b30..92646ed1f67 100644 --- a/eval/src/vespa/eval/eval/value_type_spec.cpp +++ b/eval/src/vespa/eval/eval/value_type_spec.cpp @@ -10,8 +10,10 @@ namespace vespalib::eval::value_type { vespalib::string cell_type_to_name(CellType cell_type) { switch (cell_type) { - case CellType::DOUBLE: return "double"; - case CellType::FLOAT: return "float"; + case CellType::DOUBLE: return "double"; + case CellType::FLOAT: return "float"; + case CellType::BFLOAT16: return "bfloat16"; + case CellType::INT8: return "int8"; } abort(); } diff --git a/eval/src/vespa/eval/streamed/streamed_value.cpp b/eval/src/vespa/eval/streamed/streamed_value.cpp index c09e433b9b9..63765d9b1da 100644 --- a/eval/src/vespa/eval/streamed/streamed_value.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value.cpp @@ -22,6 +22,8 @@ StreamedValue::get_memory_usage() const template class StreamedValue; template class StreamedValue; +template class StreamedValue; +template class StreamedValue; } // namespace diff --git a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp index 957121c42b7..ba5d92bf5d2 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp @@ -9,5 +9,7 @@ StreamedValueBuilder::~StreamedValueBuilder() = default; template class StreamedValueBuilder; template class StreamedValueBuilder; +template class StreamedValueBuilder; +template class StreamedValueBuilder; } // namespace -- cgit v1.2.3