aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com>2021-04-06 14:21:45 +0200
committerGitHub <noreply@github.com>2021-04-06 14:21:45 +0200
commit2dc0512d8eda7e851f5649ce0821edef17fc39af (patch)
tree19726d5bd087caf627d66cb9ff46da3163076330 /eval/src
parent5df00bb90a04082847440716bcb6146bdda0ca06 (diff)
parent936ee5af91cf9b873ee35abf8bedf923019fc15a (diff)
Merge pull request #17136 from vespa-engine/arnej/add-more-cell-types-32
Arnej/add more cell types 32
Diffstat (limited to 'eval/src')
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp12
-rw-r--r--eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp9
-rw-r--r--eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp6
-rw-r--r--eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp4
-rw-r--r--eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp10
-rw-r--r--eval/src/vespa/eval/eval/cell_type.cpp14
-rw-r--r--eval/src/vespa/eval/eval/cell_type.h30
-rw-r--r--eval/src/vespa/eval/eval/dense_cells_value.cpp2
-rw-r--r--eval/src/vespa/eval/eval/typed_cells.h2
-rw-r--r--eval/src/vespa/eval/eval/value_codec.cpp6
-rw-r--r--eval/src/vespa/eval/eval/value_type_spec.cpp6
-rw-r--r--eval/src/vespa/eval/onnx/onnx_wrapper.cpp9
-rw-r--r--eval/src/vespa/eval/streamed/streamed_value.cpp2
-rw-r--r--eval/src/vespa/eval/streamed/streamed_value_builder.cpp2
14 files changed, 91 insertions, 23 deletions
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index 120ab2d93e9..a2b25a12b4b 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -572,12 +572,22 @@ TEST("require that cell array size can be calculated") {
}
TEST("require that all cell types can be listed") {
- std::vector<CellType> expect = {CellType::FLOAT, CellType::DOUBLE};
+ std::vector<CellType> expect = { CellType::DOUBLE, CellType::FLOAT, CellType::BFLOAT16, CellType::INT8 };
+ std::vector<CellType> expect_stable;
+ std::vector<CellType> expect_unstable;
auto list = CellTypeUtils::list_types();
ASSERT_EQUAL(list.size(), expect.size());
for (size_t i = 0; i < list.size(); ++i) {
EXPECT_TRUE(list[i] == expect[i]);
+ CellMeta cm(expect[i], false);
+ if (cm.decay().eq(cm)) {
+ expect_stable.push_back(cm.cell_type);
+ } else {
+ expect_unstable.push_back(cm.cell_type);
+ }
}
+ EXPECT_TRUE(expect_stable == CellTypeUtils::list_stable_types());
+ EXPECT_TRUE(expect_unstable == CellTypeUtils::list_unstable_types());
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
index 9fa25466f4a..9da7f379246 100644
--- a/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
+++ b/eval/src/tests/instruction/dense_multi_matmul_function/dense_multi_matmul_function_test.cpp
@@ -39,10 +39,11 @@ struct FunInfo {
void verify_optimized(const vespalib::string &expr, const FunInfo &details)
{
TEST_STATE(expr.c_str());
- auto same_types = CellTypeSpace(CellTypeUtils::list_types(), 2).same();
- EvalFixture::verify<FunInfo>(expr, {details}, same_types);
- auto diff_types = CellTypeSpace(CellTypeUtils::list_types(), 2).different();
- EvalFixture::verify<FunInfo>(expr, {}, diff_types);
+ CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 2);
+ CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 2);
+ EvalFixture::verify<FunInfo>(expr, {details}, CellTypeSpace(stable_types).same());
+ EvalFixture::verify<FunInfo>(expr, {}, CellTypeSpace(stable_types).different());
+ EvalFixture::verify<FunInfo>(expr, {}, unstable_types);
}
void verify_not_optimized(const vespalib::string &expr) {
diff --git a/eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp b/eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp
index 0983f84a4af..1193060b05a 100644
--- a/eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp
+++ b/eval/src/tests/instruction/inplace_map_function/inplace_map_function_test.cpp
@@ -27,8 +27,10 @@ struct FunInfo {
void verify_optimized(const vespalib::string &expr) {
SCOPED_TRACE(expr.c_str());
- CellTypeSpace all_types(CellTypeUtils::list_types(), 1);
- EvalFixture::verify<FunInfo>(expr, {FunInfo{false}}, all_types);
+ CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 1);
+ CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 1);
+ EvalFixture::verify<FunInfo>(expr, {FunInfo{false}}, stable_types);
+ EvalFixture::verify<FunInfo>(expr, {}, unstable_types);
}
void verify_not_optimized(const vespalib::string &expr) {
diff --git a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
index a2f18d7f7f7..a6486de6858 100644
--- a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
+++ b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
@@ -48,9 +48,11 @@ struct FunInfo {
void verify_optimized(const vespalib::string &expr, Primary primary, bool inplace) {
// fprintf(stderr, "%s\n", expr.c_str());
- const auto stable_types = CellTypeSpace({CellType::FLOAT, CellType::DOUBLE}, 2);
+ const CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 2);
FunInfo stable_details{primary, inplace};
TEST_DO(EvalFixture::verify<FunInfo>(expr, {stable_details}, stable_types));
+ const CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 2);
+ TEST_DO(EvalFixture::verify<FunInfo>(expr, {}, unstable_types));
}
void verify_not_optimized(const vespalib::string &expr) {
diff --git a/eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp b/eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp
index e5a8cd9e92c..d123f9c89a6 100644
--- a/eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp
+++ b/eval/src/tests/instruction/pow_as_map_optimizer/pow_as_map_optimizer_test.cpp
@@ -31,15 +31,17 @@ struct InplaceInfo {
void verify_optimized(const vespalib::string &expr, op1_t op1, bool inplace = false) {
SCOPED_TRACE(expr.c_str());
+ CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 1);
if (inplace) {
InplaceInfo details{op1};
- auto all_types = CellTypeSpace(CellTypeUtils::list_types(), 1);
- EvalFixture::verify<InplaceInfo>(expr, {details}, all_types);
+ EvalFixture::verify<InplaceInfo>(expr, {details}, stable_types);
} else {
MapInfo details{op1};
- auto all_types = CellTypeSpace(CellTypeUtils::list_types(), 1);
- EvalFixture::verify<MapInfo>(expr, {details}, all_types);
+ EvalFixture::verify<MapInfo>(expr, {details}, stable_types);
}
+ MapInfo details{op1};
+ CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 1);
+ EvalFixture::verify<MapInfo>(expr, {details}, unstable_types);
}
void verify_not_optimized(const vespalib::string &expr) {
diff --git a/eval/src/vespa/eval/eval/cell_type.cpp b/eval/src/vespa/eval/eval/cell_type.cpp
index 753d888c24c..94bd0b14573 100644
--- a/eval/src/vespa/eval/eval/cell_type.cpp
+++ b/eval/src/vespa/eval/eval/cell_type.cpp
@@ -33,7 +33,19 @@ CellTypeUtils::mem_size(CellType cell_type, size_t sz)
std::vector<CellType>
CellTypeUtils::list_types()
{
- return {CellType::FLOAT, CellType::DOUBLE};
+ return {CellType::DOUBLE, CellType::FLOAT, CellType::BFLOAT16, CellType::INT8 };
+}
+
+std::vector<CellType>
+CellTypeUtils::list_stable_types()
+{
+ return {CellType::DOUBLE, CellType::FLOAT};
+}
+
+std::vector<CellType>
+CellTypeUtils::list_unstable_types()
+{
+ return {CellType::BFLOAT16, CellType::INT8 };
}
}
diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h
index 57f707c2aa8..79750c5c875 100644
--- a/eval/src/vespa/eval/eval/cell_type.h
+++ b/eval/src/vespa/eval/eval/cell_type.h
@@ -6,15 +6,19 @@
#include <vector>
#include <cstdint>
#include <cassert>
+#include "int8float.h"
+#include <vespa/vespalib/util/bfloat16.h>
namespace vespalib::eval {
-enum class CellType : char { FLOAT, DOUBLE };
+enum class CellType : char { DOUBLE, FLOAT, BFLOAT16, INT8 };
// converts actual cell type to CellType enum value
template <typename CT> constexpr CellType get_cell_type();
template <> constexpr CellType get_cell_type<double>() { return CellType::DOUBLE; }
template <> constexpr CellType get_cell_type<float>() { return CellType::FLOAT; }
+template <> constexpr CellType get_cell_type<BFloat16>() { return CellType::BFLOAT16; }
+template <> constexpr CellType get_cell_type<Int8Float>() { return CellType::INT8; }
// check if the given CellType enum value and actual cell type match
template <typename CT> constexpr bool check_cell_type(CellType type) {
@@ -29,6 +33,10 @@ template <CellType cell_type> constexpr auto get_cell_value() {
return double();
} else if constexpr (cell_type == CellType::FLOAT) {
return float();
+ } else if constexpr (cell_type == CellType::BFLOAT16) {
+ return BFloat16();
+ } else if constexpr (cell_type == CellType::INT8) {
+ return Int8Float();
} else {
static_assert((cell_type == CellType::DOUBLE), "unknown cell type");
}
@@ -136,9 +144,11 @@ struct CellMeta {
struct TypifyCellType {
template <typename T> using Result = TypifyResultType<T>;
template <typename F> static decltype(auto) resolve(CellType value, F &&f) {
- switch (value) {
- case CellType::DOUBLE: return f(Result<double>());
- case CellType::FLOAT: return f(Result<float>());
+ switch(value) {
+ case CellType::DOUBLE: return f(Result<double>());
+ case CellType::FLOAT: return f(Result<float>());
+ case CellType::BFLOAT16: return f(Result<BFloat16>());
+ case CellType::INT8: return f(Result<Int8Float>());
}
abort();
}
@@ -158,8 +168,10 @@ struct TypifyCellMeta {
}
template <typename F> static decltype(auto) resolve(CellMetaNotScalar value, F &&f) {
switch (value.cell_type) {
- case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>());
- case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>());
+ case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>());
+ case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>());
+ case CellType::BFLOAT16: return f(Result<CellMeta(CellType::BFLOAT16, false)>());
+ case CellType::INT8: return f(Result<CellMeta(CellType::INT8, false)>());
}
abort();
}
@@ -175,8 +187,8 @@ struct TypifyCellMeta {
}
template <typename F> static decltype(auto) resolve(LimitedCellMetaNotScalar value, F &&f) {
switch (value.cell_type) {
- case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>());
- case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>());
+ case CellType::DOUBLE: return f(Result<CellMeta(CellType::DOUBLE, false)>());
+ case CellType::FLOAT: return f(Result<CellMeta(CellType::FLOAT, false)>());
default: break;
}
abort();
@@ -187,6 +199,8 @@ struct CellTypeUtils {
static uint32_t alignment(CellType cell_type);
static size_t mem_size(CellType cell_type, size_t sz);
static std::vector<CellType> list_types();
+ static std::vector<CellType> list_stable_types();
+ static std::vector<CellType> list_unstable_types();
};
} // namespace
diff --git a/eval/src/vespa/eval/eval/dense_cells_value.cpp b/eval/src/vespa/eval/eval/dense_cells_value.cpp
index 126ef806668..a699153242d 100644
--- a/eval/src/vespa/eval/eval/dense_cells_value.cpp
+++ b/eval/src/vespa/eval/eval/dense_cells_value.cpp
@@ -15,5 +15,7 @@ DenseCellsValue<T>::get_memory_usage() const {
template class DenseCellsValue<double>;
template class DenseCellsValue<float>;
+template class DenseCellsValue<BFloat16>;
+template class DenseCellsValue<Int8Float>;
}
diff --git a/eval/src/vespa/eval/eval/typed_cells.h b/eval/src/vespa/eval/eval/typed_cells.h
index b65fa2b40e4..581de355694 100644
--- a/eval/src/vespa/eval/eval/typed_cells.h
+++ b/eval/src/vespa/eval/eval/typed_cells.h
@@ -17,6 +17,8 @@ struct TypedCells {
explicit TypedCells(ConstArrayRef<double> cells) : data(cells.begin()), type(CellType::DOUBLE), size(cells.size()) {}
explicit TypedCells(ConstArrayRef<float> cells) : data(cells.begin()), type(CellType::FLOAT), size(cells.size()) {}
+ explicit TypedCells(ConstArrayRef<BFloat16> cells) : data(cells.begin()), type(CellType::BFLOAT16), size(cells.size()) {}
+ explicit TypedCells(ConstArrayRef<Int8Float> cells) : data(cells.begin()), type(CellType::INT8), size(cells.size()) {}
TypedCells() : data(nullptr), type(CellType::DOUBLE), size(0) {}
TypedCells(const void *dp, CellType ct, size_t sz) : data(dp), type(ct), size(sz) {}
diff --git a/eval/src/vespa/eval/eval/value_codec.cpp b/eval/src/vespa/eval/eval/value_codec.cpp
index 85feadca85e..bd9d36bed2f 100644
--- a/eval/src/vespa/eval/eval/value_codec.cpp
+++ b/eval/src/vespa/eval/eval/value_codec.cpp
@@ -20,11 +20,15 @@ namespace {
constexpr uint32_t DOUBLE_CELL_TYPE = 0;
constexpr uint32_t FLOAT_CELL_TYPE = 1;
+constexpr uint32_t BFLOAT16_CELL_TYPE = 2;
+constexpr uint32_t INT8_CELL_TYPE = 3;
inline uint32_t cell_type_to_id(CellType cell_type) {
switch (cell_type) {
case CellType::DOUBLE: return DOUBLE_CELL_TYPE;
case CellType::FLOAT: return FLOAT_CELL_TYPE;
+ case CellType::BFLOAT16: return BFLOAT16_CELL_TYPE;
+ case CellType::INT8: return INT8_CELL_TYPE;
}
throw IllegalArgumentException(fmt("Unknown CellType=%u", (uint32_t)cell_type));
}
@@ -33,6 +37,8 @@ inline CellType id_to_cell_type(uint32_t id) {
switch (id) {
case DOUBLE_CELL_TYPE: return CellType::DOUBLE;
case FLOAT_CELL_TYPE: return CellType::FLOAT;
+ case BFLOAT16_CELL_TYPE: return CellType::BFLOAT16;
+ case INT8_CELL_TYPE: return CellType::INT8;
}
throw IllegalArgumentException(fmt("Unknown CellType id=%u", id));
}
diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp
index b518ccd1b30..92646ed1f67 100644
--- a/eval/src/vespa/eval/eval/value_type_spec.cpp
+++ b/eval/src/vespa/eval/eval/value_type_spec.cpp
@@ -10,8 +10,10 @@ namespace vespalib::eval::value_type {
vespalib::string cell_type_to_name(CellType cell_type) {
switch (cell_type) {
- case CellType::DOUBLE: return "double";
- case CellType::FLOAT: return "float";
+ case CellType::DOUBLE: return "double";
+ case CellType::FLOAT: return "float";
+ case CellType::BFLOAT16: return "bfloat16";
+ case CellType::INT8: return "int8";
}
abort();
}
diff --git a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
index 2891b37ebe8..e9758f2ddc8 100644
--- a/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/onnx/onnx_wrapper.cpp
@@ -1,6 +1,7 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "onnx_wrapper.h"
+#include <vespa/eval/eval/cell_type.h>
#include <vespa/eval/eval/dense_cells_value.h>
#include <vespa/eval/eval/value_type.h>
#include <vespa/vespalib/util/arrayref.h>
@@ -21,6 +22,14 @@ using vespalib::ConstArrayRef;
using vespalib::make_string_short::fmt;
+// as documented in onnxruntime_cxx_api.h :
+namespace Ort {
+template <>
+struct TypeToTensorType<vespalib::BFloat16> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16; };
+template <>
+struct TypeToTensorType<vespalib::eval::Int8Float> { static constexpr ONNXTensorElementDataType type = ONNX_TENSOR_ELEMENT_DATA_TYPE_INT8; };
+}
+
namespace vespalib::eval {
namespace {
diff --git a/eval/src/vespa/eval/streamed/streamed_value.cpp b/eval/src/vespa/eval/streamed/streamed_value.cpp
index c09e433b9b9..63765d9b1da 100644
--- a/eval/src/vespa/eval/streamed/streamed_value.cpp
+++ b/eval/src/vespa/eval/streamed/streamed_value.cpp
@@ -22,6 +22,8 @@ StreamedValue<T>::get_memory_usage() const
template class StreamedValue<double>;
template class StreamedValue<float>;
+template class StreamedValue<BFloat16>;
+template class StreamedValue<Int8Float>;
} // namespace
diff --git a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp
index 957121c42b7..ba5d92bf5d2 100644
--- a/eval/src/vespa/eval/streamed/streamed_value_builder.cpp
+++ b/eval/src/vespa/eval/streamed/streamed_value_builder.cpp
@@ -9,5 +9,7 @@ StreamedValueBuilder<T>::~StreamedValueBuilder() = default;
template class StreamedValueBuilder<double>;
template class StreamedValueBuilder<float>;
+template class StreamedValueBuilder<BFloat16>;
+template class StreamedValueBuilder<Int8Float>;
} // namespace