aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-03-16 12:01:44 +0000
committerArne Juul <arnej@verizonmedia.com>2021-03-23 16:57:37 +0000
commit19711977575c14de515ff1dca768ba6eb11c6be6 (patch)
tree82180192df25cb191cd840f72d5820c38febb5af /eval
parentf8cf0739f35de7b8da799e5206421ec5dc66df49 (diff)
add cell types int8 and bfloat16
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp2
-rw-r--r--eval/src/vespa/eval/eval/cell_type.cpp2
-rw-r--r--eval/src/vespa/eval/eval/cell_type.h28
-rw-r--r--eval/src/vespa/eval/eval/dense_cells_value.cpp2
-rw-r--r--eval/src/vespa/eval/eval/typed_cells.h2
-rw-r--r--eval/src/vespa/eval/eval/value_codec.cpp6
-rw-r--r--eval/src/vespa/eval/eval/value_type_spec.cpp6
-rw-r--r--eval/src/vespa/eval/streamed/streamed_value.cpp2
-rw-r--r--eval/src/vespa/eval/streamed/streamed_value_builder.cpp2
9 files changed, 40 insertions, 12 deletions
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index 120ab2d93e9..554cfba8c94 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -572,7 +572,7 @@ TEST("require that cell array size can be calculated") {
}
TEST("require that all cell types can be listed") {
- std::vector<CellType> expect = {CellType::FLOAT, CellType::DOUBLE};
+ std::vector<CellType> expect = { CellType::INT8, CellType::BFLOAT16, CellType::FLOAT, CellType::DOUBLE };
auto list = CellTypeUtils::list_types();
ASSERT_EQUAL(list.size(), expect.size());
for (size_t i = 0; i < list.size(); ++i) {
diff --git a/eval/src/vespa/eval/eval/cell_type.cpp b/eval/src/vespa/eval/eval/cell_type.cpp
index 753d888c24c..95bbecaa6ee 100644
--- a/eval/src/vespa/eval/eval/cell_type.cpp
+++ b/eval/src/vespa/eval/eval/cell_type.cpp
@@ -33,7 +33,7 @@ CellTypeUtils::mem_size(CellType cell_type, size_t sz)
std::vector<CellType>
CellTypeUtils::list_types()
{
- return {CellType::FLOAT, CellType::DOUBLE};
+ return {CellType::INT8, CellType::BFLOAT16, CellType::FLOAT, CellType::DOUBLE};
}
}
diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h
index 57f707c2aa8..02e321ac744 100644
--- a/eval/src/vespa/eval/eval/cell_type.h
+++ b/eval/src/vespa/eval/eval/cell_type.h
@@ -6,15 +6,19 @@
#include <vector>
#include <cstdint>
#include <cassert>
+#include "int8float.h"
+#include <vespa/vespalib/util/bfloat16.h>
namespace vespalib::eval {
-enum class CellType : char { FLOAT, DOUBLE };
+enum class CellType : char { FLOAT, DOUBLE, BFLOAT16, INT8 };
// converts actual cell type to CellType enum value
template <typename CT> constexpr CellType get_cell_type();
template <> constexpr CellType get_cell_type<double>() { return CellType::DOUBLE; }
template <> constexpr CellType get_cell_type<float>() { return CellType::FLOAT; }
+template <> constexpr CellType get_cell_type<BFloat16>() { return CellType::BFLOAT16; }
+template <> constexpr CellType get_cell_type<Int8Float>() { return CellType::INT8; }
// check if the given CellType enum value and actual cell type match
template <typename CT> constexpr bool check_cell_type(CellType type) {
@@ -29,6 +33,10 @@ template <CellType cell_type> constexpr auto get_cell_value() {
return double();
} else if constexpr (cell_type == CellType::FLOAT) {
return float();
+ } else if constexpr (cell_type == CellType::BFLOAT16) {
+ return BFloat16();
+ } else if constexpr (cell_type == CellType::INT8) {
+ return Int8Float();
} else {
static_assert((cell_type == CellType::DOUBLE), "unknown cell type");
}
@@ -136,9 +144,11 @@ struct CellMeta {
struct TypifyCellType {
template <typename T> using Result = TypifyResultType<T>;
template <typename F> static decltype(auto) resolve(CellType value, F &&f) {
- switch (value) {
- case CellType::DOUBLE: return f(Result<double>());
- case CellType::FLOAT: return f(Result<float>());
+ switch(value) {
+ case CellType::DOUBLE: return f(Result<double>());
+ case CellType::FLOAT: return f(Result<float>());
+ case CellType::BFLOAT16: return f(Result<BFloat16>());
+ case CellType::INT8: return f(Result<Int8Float>());
}
abort();
}
@@ -158,8 +168,10 @@ struct TypifyCellMeta {
}
template <typename F> static decltype(auto) resolve(CellMetaNotScalar value, F &&f) {
switch (value.cell_type) {
- case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>());
- case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>());
+ case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>());
+ case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>());
+ case CellType::BFLOAT16: return f(Result<CellMeta(CellType::BFLOAT16, false)>());
+ case CellType::INT8: return f(Result<CellMeta(CellType::INT8, false)>());
}
abort();
}
@@ -175,8 +187,8 @@ struct TypifyCellMeta {
}
template <typename F> static decltype(auto) resolve(LimitedCellMetaNotScalar value, F &&f) {
switch (value.cell_type) {
- case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>());
- case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>());
+ case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>());
+ case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>());
default: break;
}
abort();
diff --git a/eval/src/vespa/eval/eval/dense_cells_value.cpp b/eval/src/vespa/eval/eval/dense_cells_value.cpp
index 126ef806668..a699153242d 100644
--- a/eval/src/vespa/eval/eval/dense_cells_value.cpp
+++ b/eval/src/vespa/eval/eval/dense_cells_value.cpp
@@ -15,5 +15,7 @@ DenseCellsValue<T>::get_memory_usage() const {
template class DenseCellsValue<double>;
template class DenseCellsValue<float>;
+template class DenseCellsValue<BFloat16>;
+template class DenseCellsValue<Int8Float>;
}
diff --git a/eval/src/vespa/eval/eval/typed_cells.h b/eval/src/vespa/eval/eval/typed_cells.h
index b65fa2b40e4..581de355694 100644
--- a/eval/src/vespa/eval/eval/typed_cells.h
+++ b/eval/src/vespa/eval/eval/typed_cells.h
@@ -17,6 +17,8 @@ struct TypedCells {
explicit TypedCells(ConstArrayRef<double> cells) : data(cells.begin()), type(CellType::DOUBLE), size(cells.size()) {}
explicit TypedCells(ConstArrayRef<float> cells) : data(cells.begin()), type(CellType::FLOAT), size(cells.size()) {}
+ explicit TypedCells(ConstArrayRef<BFloat16> cells) : data(cells.begin()), type(CellType::BFLOAT16), size(cells.size()) {}
+ explicit TypedCells(ConstArrayRef<Int8Float> cells) : data(cells.begin()), type(CellType::INT8), size(cells.size()) {}
TypedCells() : data(nullptr), type(CellType::DOUBLE), size(0) {}
TypedCells(const void *dp, CellType ct, size_t sz) : data(dp), type(ct), size(sz) {}
diff --git a/eval/src/vespa/eval/eval/value_codec.cpp b/eval/src/vespa/eval/eval/value_codec.cpp
index bf45f34fd64..50c77eca7e9 100644
--- a/eval/src/vespa/eval/eval/value_codec.cpp
+++ b/eval/src/vespa/eval/eval/value_codec.cpp
@@ -18,11 +18,15 @@ namespace {
constexpr uint32_t DOUBLE_CELL_TYPE = 0;
constexpr uint32_t FLOAT_CELL_TYPE = 1;
+constexpr uint32_t BFLOAT16_CELL_TYPE = 2;
+constexpr uint32_t INT8_CELL_TYPE = 3;
inline uint32_t cell_type_to_id(CellType cell_type) {
switch (cell_type) {
case CellType::DOUBLE: return DOUBLE_CELL_TYPE;
case CellType::FLOAT: return FLOAT_CELL_TYPE;
+ case CellType::BFLOAT16: return BFLOAT16_CELL_TYPE;
+ case CellType::INT8: return INT8_CELL_TYPE;
}
throw IllegalArgumentException(fmt("Unknown CellType=%u", (uint32_t)cell_type));
}
@@ -31,6 +35,8 @@ inline CellType id_to_cell_type(uint32_t id) {
switch (id) {
case DOUBLE_CELL_TYPE: return CellType::DOUBLE;
case FLOAT_CELL_TYPE: return CellType::FLOAT;
+ case BFLOAT16_CELL_TYPE: return CellType::BFLOAT16;
+ case INT8_CELL_TYPE: return CellType::INT8;
}
throw IllegalArgumentException(fmt("Unknown CellType id=%u", id));
}
diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp
index b518ccd1b30..92646ed1f67 100644
--- a/eval/src/vespa/eval/eval/value_type_spec.cpp
+++ b/eval/src/vespa/eval/eval/value_type_spec.cpp
@@ -10,8 +10,10 @@ namespace vespalib::eval::value_type {
vespalib::string cell_type_to_name(CellType cell_type) {
switch (cell_type) {
- case CellType::DOUBLE: return "double";
- case CellType::FLOAT: return "float";
+ case CellType::DOUBLE: return "double";
+ case CellType::FLOAT: return "float";
+ case CellType::BFLOAT16: return "bfloat16";
+ case CellType::INT8: return "int8";
}
abort();
}
diff --git a/eval/src/vespa/eval/streamed/streamed_value.cpp b/eval/src/vespa/eval/streamed/streamed_value.cpp
index c09e433b9b9..63765d9b1da 100644
--- a/eval/src/vespa/eval/streamed/streamed_value.cpp
+++ b/eval/src/vespa/eval/streamed/streamed_value.cpp
@@ -22,6 +22,8 @@ StreamedValue<T>::get_memory_usage() const
template class StreamedValue<double>;
template class StreamedValue<float>;
+template class StreamedValue<BFloat16>;
+template class StreamedValue<Int8Float>;
} // namespace
diff --git a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp
index 957121c42b7..ba5d92bf5d2 100644
--- a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp
+++ b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp
@@ -9,5 +9,7 @@ StreamedValueBuilder<T>::~StreamedValueBuilder() = default;
template class StreamedValueBuilder<double>;
template class StreamedValueBuilder<float>;
+template class StreamedValueBuilder<BFloat16>;
+template class StreamedValueBuilder<Int8Float>;
} // namespace