diff options
author | HÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com> | 2021-04-06 14:21:45 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-04-06 14:21:45 +0200 |
commit | 2dc0512d8eda7e851f5649ce0821edef17fc39af (patch) | |
tree | 19726d5bd087caf627d66cb9ff46da3163076330 | |
parent | 5df00bb90a04082847440716bcb6146bdda0ca06 (diff) | |
parent | 936ee5af91cf9b873ee35abf8bedf923019fc15a (diff) |
Merge pull request #17136 from vespa-engine/arnej/add-more-cell-types-32
Arnej/add more cell types 32
14 files changed, 91 insertions, 23 deletions
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index 120ab2d93e9..a2b25a12b4b 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -572,12 +572,22 @@ TEST("require that cell array size can be calculated") { } TEST("require that all cell types can be listed") { - std::vector<CellType> expect = {CellType::FLOAT, CellType::DOUBLE}; + std::vector<CellType> expect = { CellType::DOUBLE, CellType::FLOAT, CellType::BFLOAT16, CellType::INT8 }; + std::vector<CellType> expect_stable; + std::vector<CellType> expect_unstable; auto list = CellTypeUtils::list_types(); ASSERT_EQUAL(list.size(), expect.size()); for (size_t i = 0; i < list.size(); ++i) { EXPECT_TRUE(list[i] == expect[i]); + CellMeta cm(expect[i], false); + if (cm.decay().eq(cm)) { + expect_stable.push_back(cm.cell_type); + } else { + expect_unstable.push_back(cm.cell_type); + } } + EXPECT_TRUE(expect_stable == CellTypeUtils::list_stable_types()); + EXPECT_TRUE(expect_unstable == CellTypeUtils::list_unstable_types()); } TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp index 9fa25466f4a..9da7f379246 100644 --- a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp +++ b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp @@ -39,10 +39,11 @@ struct FunInfo { void verify_optimized(const vespalib::string &expr, const FunInfo &details) { TEST_STATE(expr.c_str()); - auto same_types = CellTypeSpace(CellTypeUtils::list_types(), 2).same(); - EvalFixture::verify<FunInfo>(expr, {details}, same_types); - auto diff_types = CellTypeSpace(CellTypeUtils::list_types(), 2).different(); - EvalFixture::verify<FunInfo>(expr, {}, diff_types); + CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 2); + CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 2); + EvalFixture::verify<FunInfo>(expr, {details}, CellTypeSpace(stable_types).same()); + EvalFixture::verify<FunInfo>(expr, {}, CellTypeSpace(stable_types).different()); + EvalFixture::verify<FunInfo>(expr, {}, unstable_types); } void verify_not_optimized(const vespalib::string &expr) { diff --git a/eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp b/eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp index 0983f84a4af..1193060b05a 100644 --- a/eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp +++ b/eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp @@ -27,8 +27,10 @@ struct FunInfo { void verify_optimized(const vespalib::string &expr) { SCOPED_TRACE(expr.c_str()); - CellTypeSpace all_types(CellTypeUtils::list_types(), 1); - EvalFixture::verify<FunInfo>(expr, {FunInfo{false}}, all_types); + CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 1); + CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 1); + EvalFixture::verify<FunInfo>(expr, {FunInfo{false}}, stable_types); + EvalFixture::verify<FunInfo>(expr, {}, unstable_types); } void verify_not_optimized(const vespalib::string &expr) { diff --git a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp index a2f18d7f7f7..a6486de6858 100644 --- a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp +++ b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp @@ -48,9 +48,11 @@ struct FunInfo { void verify_optimized(const vespalib::string &expr, Primary primary, bool inplace) { // fprintf(stderr, "%s\n", expr.c_str()); - const auto stable_types = CellTypeSpace({CellType::FLOAT, CellType::DOUBLE}, 2); + const CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 2); FunInfo stable_details{primary, inplace}; TEST_DO(EvalFixture::verify<FunInfo>(expr, {stable_details}, stable_types)); + const CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 2); + TEST_DO(EvalFixture::verify<FunInfo>(expr, {}, unstable_types)); } void verify_not_optimized(const vespalib::string &expr) { diff --git a/eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp b/eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp index e5a8cd9e92c..d123f9c89a6 100644 --- a/eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp +++ b/eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp @@ -31,15 +31,17 @@ struct InplaceInfo { void verify_optimized(const vespalib::string &expr, op1_t op1, bool inplace = false) { SCOPED_TRACE(expr.c_str()); + CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 1); if (inplace) { InplaceInfo details{op1}; - auto all_types = CellTypeSpace(CellTypeUtils::list_types(), 1); - EvalFixture::verify<InplaceInfo>(expr, {details}, all_types); + EvalFixture::verify<InplaceInfo>(expr, {details}, stable_types); } else { MapInfo details{op1}; - auto all_types = CellTypeSpace(CellTypeUtils::list_types(), 1); - EvalFixture::verify<MapInfo>(expr, {details}, all_types); + EvalFixture::verify<MapInfo>(expr, {details}, stable_types); } + MapInfo details{op1}; + CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 1); + EvalFixture::verify<MapInfo>(expr, {details}, unstable_types); } void verify_not_optimized(const vespalib::string &expr) { diff --git a/eval/src/vespa/eval/eval/cell_type.cpp b/eval/src/vespa/eval/eval/cell_type.cpp index 753d888c24c..94bd0b14573 100644 --- a/eval/src/vespa/eval/eval/cell_type.cpp +++ b/eval/src/vespa/eval/eval/cell_type.cpp @@ -33,7 +33,19 @@ CellTypeUtils::mem_size(CellType cell_type, size_t sz) std::vector<CellType> CellTypeUtils::list_types() { - return {CellType::FLOAT, CellType::DOUBLE}; + return {CellType::DOUBLE, CellType::FLOAT, CellType::BFLOAT16, CellType::INT8 }; +} + +std::vector<CellType> +CellTypeUtils::list_stable_types() +{ + return {CellType::DOUBLE, CellType::FLOAT}; +} + +std::vector<CellType> +CellTypeUtils::list_unstable_types() +{ + return {CellType::BFLOAT16, CellType::INT8 }; } } diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h index 57f707c2aa8..79750c5c875 100644 --- a/eval/src/vespa/eval/eval/cell_type.h +++ b/eval/src/vespa/eval/eval/cell_type.h @@ -6,15 +6,19 @@ #include <vector> #include <cstdint> #include <cassert> +#include "int8float.h" +#include <vespa/vespalib/util/bfloat16.h> namespace vespalib::eval { -enum class CellType : char { FLOAT, DOUBLE }; +enum class CellType : char { DOUBLE, FLOAT, BFLOAT16, INT8 }; // converts actual cell type to CellType enum value template <typename CT> constexpr CellType get_cell_type(); template <> constexpr CellType get_cell_type<double>() { return CellType::DOUBLE; } template <> constexpr CellType get_cell_type<float>() { return CellType::FLOAT; } +template <> constexpr CellType get_cell_type<BFloat16>() { return CellType::BFLOAT16; } +template <> constexpr CellType get_cell_type<Int8Float>() { return CellType::INT8; } // check if the given CellType enum value and actual cell type match template <typename CT> constexpr bool check_cell_type(CellType type) { @@ -29,6 +33,10 @@ template <CellType cell_type> constexpr auto get_cell_value() { return double(); } else if constexpr (cell_type == CellType::FLOAT) { return float(); + } else if constexpr (cell_type == CellType::BFLOAT16) { + return BFloat16(); + } else if constexpr (cell_type == CellType::INT8) { + return Int8Float(); } else { static_assert((cell_type == CellType::DOUBLE), "unknown cell type"); } @@ -136,9 +144,11 @@ struct CellMeta { struct TypifyCellType { template <typename T> using Result = TypifyResultType<T>; template <typename F> static decltype(auto) resolve(CellType value, F &&f) { - switch (value) { - case CellType::DOUBLE: return f(Result<double>()); - case CellType::FLOAT: return f(Result<float>()); + switch(value) { + case CellType::DOUBLE: return f(Result<double>()); + case CellType::FLOAT: return f(Result<float>()); + case CellType::BFLOAT16: return f(Result<BFloat16>()); + case CellType::INT8: return f(Result<Int8Float>()); } abort(); } @@ -158,8 +168,10 @@ struct TypifyCellMeta { } template <typename F> static decltype(auto) resolve(CellMetaNotScalar value, F &&f) { switch (value.cell_type) { - case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); - case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); + case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + case CellType::BFLOAT16: return f(Result<CellMeta(CellType::BFLOAT16, false)>()); + case CellType::INT8: return f(Result<CellMeta(CellType::INT8, false)>()); } abort(); } @@ -175,8 +187,8 @@ struct TypifyCellMeta { } template <typename F> static decltype(auto) resolve(LimitedCellMetaNotScalar value, F &&f) { switch (value.cell_type) { - case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); - case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); + case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>()); + case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>()); default: break; } abort(); @@ -187,6 +199,8 @@ struct CellTypeUtils { static uint32_t alignment(CellType cell_type); static size_t mem_size(CellType cell_type, size_t sz); static std::vector<CellType> list_types(); + static std::vector<CellType> list_stable_types(); + static std::vector<CellType> list_unstable_types(); }; } // namespace diff --git a/eval/src/vespa/eval/eval/dense_cells_value.cpp b/eval/src/vespa/eval/eval/dense_cells_value.cpp index 126ef806668..a699153242d 100644 --- a/eval/src/vespa/eval/eval/dense_cells_value.cpp +++ b/eval/src/vespa/eval/eval/dense_cells_value.cpp @@ -15,5 +15,7 @@ DenseCellsValue<T>::get_memory_usage() const { template class DenseCellsValue<double>; template class DenseCellsValue<float>; +template class DenseCellsValue<BFloat16>; +template class DenseCellsValue<Int8Float>; } diff --git a/eval/src/vespa/eval/eval/typed_cells.h b/eval/src/vespa/eval/eval/typed_cells.h index b65fa2b40e4..581de355694 100644 --- a/eval/src/vespa/eval/eval/typed_cells.h +++ b/eval/src/vespa/eval/eval/typed_cells.h @@ -17,6 +17,8 @@ struct TypedCells { explicit TypedCells(ConstArrayRef<double> cells) : data(cells.begin()), type(CellType::DOUBLE), size(cells.size()) {} explicit TypedCells(ConstArrayRef<float> cells) : data(cells.begin()), type(CellType::FLOAT), size(cells.size()) {} + explicit TypedCells(ConstArrayRef<BFloat16> cells) : data(cells.begin()), type(CellType::BFLOAT16), size(cells.size()) {} + explicit TypedCells(ConstArrayRef<Int8Float> cells) : data(cells.begin()), type(CellType::INT8), size(cells.size()) {} TypedCells() : data(nullptr), type(CellType::DOUBLE), size(0) {} TypedCells(const void *dp, CellType ct, size_t sz) : data(dp), type(ct), size(sz) {} diff --git a/eval/src/vespa/eval/eval/value_codec.cpp b/eval/src/vespa/eval/eval/value_codec.cpp index 85feadca85e..bd9d36bed2f 100644 --- a/eval/src/vespa/eval/eval/value_codec.cpp +++ b/eval/src/vespa/eval/eval/value_codec.cpp @@ -20,11 +20,15 @@ namespace { constexpr uint32_t DOUBLE_CELL_TYPE = 0; constexpr uint32_t FLOAT_CELL_TYPE = 1; +constexpr uint32_t BFLOAT16_CELL_TYPE = 2; +constexpr uint32_t INT8_CELL_TYPE = 3; inline uint32_t cell_type_to_id(CellType cell_type) { switch (cell_type) { case CellType::DOUBLE: return DOUBLE_CELL_TYPE; case CellType::FLOAT: return FLOAT_CELL_TYPE; + case CellType::BFLOAT16: return BFLOAT16_CELL_TYPE; + case CellType::INT8: return INT8_CELL_TYPE; } throw IllegalArgumentException(fmt("Unknown CellType=%u", (uint32_t)cell_type)); } @@ -33,6 +37,8 @@ inline CellType id_to_cell_type(uint32_t id) { switch (id) { case DOUBLE_CELL_TYPE: return CellType::DOUBLE; case FLOAT_CELL_TYPE: return CellType::FLOAT; + case BFLOAT16_CELL_TYPE: return CellType::BFLOAT16; + case INT8_CELL_TYPE: return CellType::INT8; } throw IllegalArgumentException(fmt("Unknown CellType id=%u", id)); } diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp index b518ccd1b30..92646ed1f67 100644 --- a/eval/src/vespa/eval/eval/value_type_spec.cpp +++ b/eval/src/vespa/eval/eval/value_type_spec.cpp @@ -10,8 +10,10 @@ namespace vespalib::eval::value_type { vespalib::string cell_type_to_name(CellType cell_type) { switch (cell_type) { - case CellType::DOUBLE: return "double"; - case CellType::FLOAT: return "float"; + case CellType::DOUBLE: return "double"; + case CellType::FLOAT: return "float"; + case CellType::BFLOAT16: return "bfloat16"; + case CellType::INT8: return "int8"; } abort(); } diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp index 2891b37ebe8..e9758f2ddc8 100644 --- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "onnx_wrapper.h" +#include <vespa/eval/eval/cell_type.h> #include <vespa/eval/eval/dense_cells_value.h> #include <vespa/eval/eval/value_type.h> #include <vespa/vespalib/util/arrayref.h> @@ -21,6 +22,14 @@ using vespalib::ConstArrayRef; using vespalib::make_string_short::fmt; +// as documented in onnxruntime_cxx_api.h : +namespace Ort { +template <> +struct TypeToTensorType<vespalib::BFloat16> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; }; +template <> +struct TypeToTensorType<vespalib::eval::Int8Float> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; }; +} + namespace vespalib::eval { namespace { diff --git a/eval/src/vespa/eval/streamed/streamed_value.cpp b/eval/src/vespa/eval/streamed/streamed_value.cpp index c09e433b9b9..63765d9b1da 100644 --- a/eval/src/vespa/eval/streamed/streamed_value.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value.cpp @@ -22,6 +22,8 @@ StreamedValue<T>::get_memory_usage() const template class StreamedValue<double>; template class StreamedValue<float>; +template class StreamedValue<BFloat16>; +template class StreamedValue<Int8Float>; } // namespace diff --git a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp index 957121c42b7..ba5d92bf5d2 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp @@ -9,5 +9,7 @@ StreamedValueBuilder<T>::~StreamedValueBuilder() = default; template class StreamedValueBuilder<double>; template class StreamedValueBuilder<float>; +template class StreamedValueBuilder<BFloat16>; +template class StreamedValueBuilder<Int8Float>; } // namespace |