diff options
Diffstat (limited to 'eval/src/apps')
-rw-r--r-- | eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp | 31 |
1 files changed, 21 insertions, 10 deletions
diff --git a/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp b/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp index 6e882fc3d9d..974f95a2add 100644 --- a/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp +++ b/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp @@ -3,6 +3,8 @@ #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/bfloat16.h> +#include <vespa/eval/eval/int8float.h> #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/eval/value_type.h> #include <vespa/eval/eval/test/test_io.h> @@ -20,14 +22,20 @@ using Dict = std::vector<vespalib::string>; template <typename T> std::vector<bool> with_cell_type_opts(); template <> std::vector<bool> with_cell_type_opts<double>() { return {false, true}; } template <> std::vector<bool> with_cell_type_opts<float>() { return {true}; } +template <> std::vector<bool> with_cell_type_opts<BFloat16>() { return {true}; } +template <> std::vector<bool> with_cell_type_opts<Int8Float>() { return {true}; } template <typename T> uint8_t cell_type_id(); template <> uint8_t cell_type_id<double>() { return 0; } template <> uint8_t cell_type_id<float>() { return 1; } +template <> uint8_t cell_type_id<BFloat16>() { return 2; } +template <> uint8_t cell_type_id<Int8Float>() { return 3; } template <typename T> const char *cell_type_str(); template <> const char *cell_type_str<double>() { return ""; } template <> const char *cell_type_str<float>() { return "<float>"; } +template <> const char *cell_type_str<BFloat16>() { return "<bfloat16>"; } +template <> const char *cell_type_str<Int8Float>() { return "<int8>"; } template <typename T> nbostream make_sparse(bool with_cell_type) { nbostream data; @@ -62,7 +70,8 @@ template <typename T> nbostream make_mixed(bool with_cell_type) { return data; } -void set_tensor(Cursor &test, const TensorSpec &spec) { +void set_tensor(Cursor &test, const TensorSpec &spec_in) { + auto spec = spec_in.normalize(); const Inspector &old_tensor = test["tensor"]; if (old_tensor.valid()) { TensorSpec old_spec = TensorSpec::from_slime(old_tensor); @@ -183,8 +192,8 @@ void make_vector_test(Cursor &test, size_t x_size) { for (size_t x = 0; x < x_size; ++x) { double value = val(x); spec.add({{"x", x}}, value); - dense << static_cast<T>(value); - mixed << static_cast<T>(value); + dense << T(value); + mixed << T(value); } set_tensor(test, spec); add_binary(test, {dense, mixed}); @@ -212,8 +221,8 @@ void make_matrix_test(Cursor &test, size_t x_size, size_t y_size) { for (size_t y = 0; y < y_size; ++y) { double value = mix({val(x), val(y)}); spec.add({{"x", x}, {"y", y}}, value); - dense << static_cast<T>(value); - mixed << static_cast<T>(value); + dense << T(value); + mixed << T(value); } } set_tensor(test, spec); @@ -245,8 +254,8 @@ void make_map_test(Cursor &test, const Dict &x_dict_in) { spec.add({{"x", x}}, value); sparse.writeSmallString(x); mixed.writeSmallString(x); - sparse << static_cast<T>(value); - mixed << static_cast<T>(value); + sparse << T(value); + mixed << T(value); } set_tensor(test, spec); add_binary(test, {sparse, mixed}); @@ -285,8 +294,8 @@ void make_mesh_test(Cursor &test, const Dict &x_dict_in, const vespalib::string sparse.writeSmallString(y); mixed.writeSmallString(x); mixed.writeSmallString(y); - sparse << static_cast<T>(value); - mixed << static_cast<T>(value); + sparse << T(value); + mixed << T(value); } set_tensor(test, spec); add_binary(test, {sparse, mixed}); @@ -326,7 +335,7 @@ void make_vector_map_test(Cursor &test, for (size_t idx = 0; idx < indexed_size; ++idx) { double value = mix({val(label), val(idx)}); spec.add({{mapped_name, label}, {indexed_name, idx}}, value); - mixed << static_cast<T>(value); + mixed << T(value); } } set_tensor(test, spec); @@ -360,6 +369,8 @@ void make_tests(test::TestWriter &writer) { make_number_test(writer.create(), 42.0); make_typed_tests<double>(writer); make_typed_tests<float>(writer); + make_typed_tests<BFloat16>(writer); + make_typed_tests<Int8Float>(writer); } int main(int, char **) { |