diff options
Diffstat (limited to 'eval')
-rw-r--r-- | eval/CMakeLists.txt | 1 | ||||
-rw-r--r-- | eval/src/tests/eval/value_codec/value_codec_test.cpp | 12 | ||||
-rw-r--r-- | eval/src/tests/tensor/binary_format/.gitignore | 1 | ||||
-rw-r--r-- | eval/src/tests/tensor/binary_format/CMakeLists.txt | 9 | ||||
-rw-r--r-- | eval/src/tests/tensor/binary_format/binary_format_test.cpp | 142 | ||||
-rw-r--r-- | eval/src/tests/tensor/tensor_conformance/tensor_conformance_test.cpp | 7 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/gen_spec.cpp | 7 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/gen_spec.h | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/tensor_conformance.cpp | 108 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/test/test_io.cpp | 16 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value_codec.cpp | 22 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value_codec.h | 5 |
12 files changed, 199 insertions, 133 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index 9820163725d..302b6768cea 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -75,6 +75,7 @@ vespa_define_module( src/tests/instruction/sum_max_dot_product_function src/tests/instruction/vector_from_doubles_function src/tests/streamed/value + src/tests/tensor/binary_format src/tests/tensor/instruction_benchmark src/tests/tensor/onnx_wrapper src/tests/tensor/tensor_conformance diff --git a/eval/src/tests/eval/value_codec/value_codec_test.cpp b/eval/src/tests/eval/value_codec/value_codec_test.cpp index 0bb1bcfb337..99afba4aed9 100644 --- a/eval/src/tests/eval/value_codec/value_codec_test.cpp +++ b/eval/src/tests/eval/value_codec/value_codec_test.cpp @@ -335,11 +335,11 @@ TEST(ValueCodecTest, bad_sparse_tensors_are_caught) { bad.encode_default(data_default); bad.encode_with_double(data_double); bad.encode_with_float(data_float); - VESPA_EXPECT_EXCEPTION(decode_value(data_default, factory), vespalib::IllegalStateException, + VESPA_EXPECT_EXCEPTION(decode_value(data_default, factory), vespalib::eval::DecodeValueException, "serialized input claims 12345678 blocks of size 1*8, but only"); - VESPA_EXPECT_EXCEPTION(decode_value(data_double, factory), vespalib::IllegalStateException, + VESPA_EXPECT_EXCEPTION(decode_value(data_double, factory), vespalib::eval::DecodeValueException, "serialized input claims 12345678 blocks of size 1*8, but only"); - VESPA_EXPECT_EXCEPTION(decode_value(data_float, factory), vespalib::IllegalStateException, + VESPA_EXPECT_EXCEPTION(decode_value(data_float, factory), vespalib::eval::DecodeValueException, "serialized input claims 12345678 blocks of size 1*4, but only"); } @@ -388,11 +388,11 @@ TEST(ValueCodecTest, bad_dense_tensors_are_caught) { bad.encode_default(data_default); bad.encode_with_double(data_double); bad.encode_with_float(data_float); - VESPA_EXPECT_EXCEPTION(decode_value(data_default, factory), vespalib::IllegalStateException, + VESPA_EXPECT_EXCEPTION(decode_value(data_default, factory), vespalib::eval::DecodeValueException, "serialized input claims 1 blocks of size 60000*8, but only"); - VESPA_EXPECT_EXCEPTION(decode_value(data_double, factory), vespalib::IllegalStateException, + VESPA_EXPECT_EXCEPTION(decode_value(data_double, factory), vespalib::eval::DecodeValueException, "serialized input claims 1 blocks of size 60000*8, but only"); - VESPA_EXPECT_EXCEPTION(decode_value(data_float, factory), vespalib::IllegalStateException, + VESPA_EXPECT_EXCEPTION(decode_value(data_float, factory), vespalib::eval::DecodeValueException, "serialized input claims 1 blocks of size 60000*4, but only"); } diff --git a/eval/src/tests/tensor/binary_format/.gitignore b/eval/src/tests/tensor/binary_format/.gitignore new file mode 100644 index 00000000000..4f0fdc51492 --- /dev/null +++ b/eval/src/tests/tensor/binary_format/.gitignore @@ -0,0 +1 @@ +/binary_test_spec.json diff --git a/eval/src/tests/tensor/binary_format/CMakeLists.txt b/eval/src/tests/tensor/binary_format/CMakeLists.txt new file mode 100644 index 00000000000..ac52c2b0365 --- /dev/null +++ b/eval/src/tests/tensor/binary_format/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_tensor_binary_format_test_app TEST + SOURCES + binary_format_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_tensor_binary_format_test_app COMMAND eval_tensor_binary_format_test_app) diff --git a/eval/src/tests/tensor/binary_format/binary_format_test.cpp b/eval/src/tests/tensor/binary_format/binary_format_test.cpp new file mode 100644 index 00000000000..671765d4050 --- /dev/null +++ b/eval/src/tests/tensor/binary_format/binary_format_test.cpp @@ -0,0 +1,142 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/test/test_io.h> +#include <vespa/eval/eval/test/gen_spec.h> +#include <vespa/eval/eval/cell_type.h> +#include <vespa/eval/eval/tensor_spec.h> +#include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/streamed/streamed_value_builder_factory.h> +#include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/value_codec.h> +#include <vespa/vespalib/io/mapped_file_input.h> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib; +using namespace vespalib::eval; +using namespace vespalib::eval::test; +using namespace vespalib::slime::convenience; + +using vespalib::make_string_short::fmt; + +vespalib::string get_source_dir() { + const char *dir = getenv("SOURCE_DIRECTORY"); + return (dir ? dir : "."); +} +vespalib::string source_dir = get_source_dir(); +vespalib::string module_src_path = source_dir + "/../../../../"; +vespalib::string module_build_path = "../../../../"; + +const ValueBuilderFactory &simple = SimpleValueBuilderFactory::get(); +const ValueBuilderFactory &streamed = StreamedValueBuilderFactory::get(); +const ValueBuilderFactory &fast = FastValueBuilderFactory::get(); + +TEST(TensorBinaryFormatTest, tensor_binary_format_test_spec_can_be_generated) { + vespalib::string spec = module_src_path + "src/apps/make_tensor_binary_format_test_spec/test_spec.json"; + vespalib::string binary = module_build_path + "src/apps/make_tensor_binary_format_test_spec/eval_make_tensor_binary_format_test_spec_app"; + EXPECT_EQ(system(fmt("%s > binary_test_spec.json", binary.c_str()).c_str()), 0); + EXPECT_EQ(system(fmt("diff -u %s binary_test_spec.json", spec.c_str()).c_str()), 0); +} + +void verify_encode_decode(const TensorSpec &spec, + const ValueBuilderFactory &encode_factory, + const ValueBuilderFactory &decode_factory) +{ + nbostream data; + auto value = value_from_spec(spec, encode_factory); + encode_value(*value, data); + auto value2 = decode_value(data, decode_factory); + TensorSpec spec2 = spec_from_value(*value2); + EXPECT_EQ(spec2, spec); +} + +void verify_encode_decode(const GenSpec &spec) { + for (CellType ct : CellTypeUtils::list_types()) { + auto my_spec = spec.cpy().cells(ct); + if (my_spec.bad_scalar()) continue; + auto my_tspec = my_spec.gen(); + verify_encode_decode(my_tspec, simple, fast); + verify_encode_decode(my_tspec, fast, simple); + verify_encode_decode(my_tspec, simple, streamed); + verify_encode_decode(my_tspec, streamed, simple); + } +} + +TEST(TensorBinaryFormatTest, encode_decode) { + verify_encode_decode(GenSpec(42)); + verify_encode_decode(GenSpec().idx("x", 3)); + verify_encode_decode(GenSpec().idx("x", 3).idx("y", 5)); + verify_encode_decode(GenSpec().idx("x", 3).idx("y", 5).idx("z", 7)); + verify_encode_decode(GenSpec().map("x", 3)); + verify_encode_decode(GenSpec().map("x", 3).map("y", 2)); + verify_encode_decode(GenSpec().map("x", 3).map("y", 2).map("z", 4)); + verify_encode_decode(GenSpec().idx("x", 3).map("y", 2).idx("z", 7)); + verify_encode_decode(GenSpec().map("x", 3).idx("y", 5).map("z", 4)); +} + +uint8_t unhex(char c) { + if (c >= '0' && c <= '9') { + return (c - '0'); + } + if (c >= 'A' && c <= 'F') { + return ((c - 'A') + 10); + } + EXPECT_TRUE(false) << "bad hex char"; + return 0; +} + +nbostream extract_data(const Memory &hex_dump) { + nbostream data; + if ((hex_dump.size > 2) && (hex_dump.data[0] == '0') && (hex_dump.data[1] == 'x')) { + for (size_t i = 2; i < (hex_dump.size - 1); i += 2) { + data << uint8_t((unhex(hex_dump.data[i]) << 4) | unhex(hex_dump.data[i + 1])); + } + } + return data; +} + +bool is_same(const nbostream &a, const nbostream &b) { + return (Memory(a.peek(), a.size()) == Memory(b.peek(), b.size())); +} + +void test_binary_format_spec(const Inspector &test, const ValueBuilderFactory &factory) { + Stash stash; + TensorSpec spec = TensorSpec::from_slime(test["tensor"]); + const Inspector &binary = test["binary"]; + EXPECT_GT(binary.entries(), 0u); + nbostream encoded; + encode_value(*value_from_spec(spec, factory), encoded); + bool matched_encode = false; + for (size_t i = 0; i < binary.entries(); ++i) { + nbostream data = extract_data(binary[i].asString()); + matched_encode = (matched_encode || is_same(encoded, data)); + EXPECT_EQ(spec_from_value(*decode_value(data, factory)), spec); + EXPECT_EQ(data.size(), 0u); + } + EXPECT_TRUE(matched_encode); +} + +void test_binary_format_spec(Cursor &test) { + test_binary_format_spec(test, simple); + test_binary_format_spec(test, streamed); + test_binary_format_spec(test, fast); +} + +TEST(TensorBinaryFormatTest, tensor_binary_format_test_spec) { + vespalib::string path = module_src_path; + path.append("src/apps/make_tensor_binary_format_test_spec/test_spec.json"); + MappedFileInput file(path); + EXPECT_TRUE(file.valid()); + auto handle_test = [this](Slime &slime) + { + test_binary_format_spec(slime.get()); + }; + auto handle_summary = [](Slime &slime) + { + EXPECT_GT(slime["num_tests"].asLong(), 0); + }; + for_each_test(file, handle_test, handle_summary); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/tensor/tensor_conformance/tensor_conformance_test.cpp b/eval/src/tests/tensor/tensor_conformance/tensor_conformance_test.cpp index 15aa7212a94..3d21f9b4113 100644 --- a/eval/src/tests/tensor/tensor_conformance/tensor_conformance_test.cpp +++ b/eval/src/tests/tensor/tensor_conformance/tensor_conformance_test.cpp @@ -32,13 +32,6 @@ TEST("require that FastValue implementation passes all conformance tests") { TEST_DO(TensorConformance::run_tests(module_src_path, FastValueBuilderFactory::get())); } -TEST("require that tensor serialization test spec can be generated") { - vespalib::string spec = module_src_path + "src/apps/make_tensor_binary_format_test_spec/test_spec.json"; - vespalib::string binary = module_build_path + "src/apps/make_tensor_binary_format_test_spec/eval_make_tensor_binary_format_test_spec_app"; - EXPECT_EQUAL(system(fmt("%s > binary_test_spec.json", binary.c_str()).c_str()), 0); - EXPECT_EQUAL(system(fmt("diff -u %s binary_test_spec.json", spec.c_str()).c_str()), 0); -} - TEST("require that cross-language tensor conformance tests pass with C++ expression evaluation") { vespalib::string result_file = "conformance_result.json"; vespalib::string binary = module_build_path + "src/apps/tensor_conformance/vespa-tensor-conformance"; diff --git a/eval/src/vespa/eval/eval/test/gen_spec.cpp b/eval/src/vespa/eval/eval/test/gen_spec.cpp index 0b624a457d7..fd2c1f39382 100644 --- a/eval/src/vespa/eval/eval/test/gen_spec.cpp +++ b/eval/src/vespa/eval/eval/test/gen_spec.cpp @@ -4,6 +4,7 @@ #include <vespa/eval/eval/string_stuff.h> #include <vespa/vespalib/util/require.h> #include <vespa/vespalib/util/stringfmt.h> +#include <ostream> using vespalib::make_string_short::fmt; @@ -162,4 +163,10 @@ GenSpec::gen() const return result.normalize(); } +std::ostream &operator<<(std::ostream &out, const GenSpec &spec) +{ + out << spec.gen(); + return out; +} + } // namespace diff --git a/eval/src/vespa/eval/eval/test/gen_spec.h b/eval/src/vespa/eval/eval/test/gen_spec.h index f0eca6074dc..3f7550ba644 100644 --- a/eval/src/vespa/eval/eval/test/gen_spec.h +++ b/eval/src/vespa/eval/eval/test/gen_spec.h @@ -153,4 +153,6 @@ public: operator TensorSpec() const { return gen(); } }; +std::ostream &operator<<(std::ostream &out, const GenSpec &spec); + } // namespace diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 17ad75ae455..c58f8312cbf 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -9,12 +9,12 @@ #include <vespa/eval/eval/simple_value.h> #include <vespa/eval/eval/value_type_spec.h> #include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/util/require.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/io/mapped_file_input.h> #include "tensor_model.h" -#include "test_io.h" #include "reference_evaluation.h" using vespalib::make_string_short::fmt; @@ -47,7 +47,7 @@ TensorSpec eval(const ValueBuilderFactory &factory, const vespalib::string &expr } NodeTypes types(*fun, param_types); const auto &expect_type = types.get_type(fun->root()); - ASSERT_FALSE(expect_type.is_error()); + REQUIRE(!expect_type.is_error()); InterpretedFunction ifun(factory, *fun, types); InterpretedFunction::Context ctx(ifun); const Value &result = ifun.eval(ctx, SimpleObjectParams{param_refs}); @@ -69,31 +69,6 @@ void verify_result(const ValueBuilderFactory &factory, const vespalib::string &e // NaN value const double my_nan = std::numeric_limits<double>::quiet_NaN(); -uint8_t unhex(char c) { - if (c >= '0' && c <= '9') { - return (c - '0'); - } - if (c >= 'A' && c <= 'F') { - return ((c - 'A') + 10); - } - TEST_ERROR("bad hex char"); - return 0; -} - -nbostream extract_data(const Memory &hex_dump) { - nbostream data; - if ((hex_dump.size > 2) && (hex_dump.data[0] == '0') && (hex_dump.data[1] == 'x')) { - for (size_t i = 2; i < (hex_dump.size - 1); i += 2) { - data << uint8_t((unhex(hex_dump.data[i]) << 4) | unhex(hex_dump.data[i + 1])); - } - } - return data; -} - -bool is_same(const nbostream &a, const nbostream &b) { - return (Memory(a.peek(), a.size()) == Memory(b.peek(), b.size())); -} - // Test wrapper to avoid passing global test parameters around struct TestContext { @@ -451,7 +426,7 @@ struct TestContext { {x({"a","b","c"}),y(5)}, float_cells({y(5),z({"i","j","k","l"})}), float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})}) }; - ASSERT_TRUE((layouts.size() % 2) == 0); + REQUIRE((layouts.size() % 2) == 0); for (size_t i = 0; i < layouts.size(); i += 2) { TensorSpec lhs_input = spec(layouts[i], seq); TensorSpec rhs_input = spec(layouts[i + 1], seq); @@ -677,82 +652,6 @@ struct TestContext { //------------------------------------------------------------------------- - void verify_encode_decode(const TensorSpec &spec, - const ValueBuilderFactory &encode_factory, - const ValueBuilderFactory &decode_factory) - { - nbostream data; - auto value = value_from_spec(spec, encode_factory); - encode_value(*value, data); - auto value2 = decode_value(data, decode_factory); - TensorSpec spec2 = spec_from_value(*value2); - EXPECT_EQUAL(spec2, spec); - } - - void verify_encode_decode(const TensorSpec &spec) { - const ValueBuilderFactory &simple = SimpleValueBuilderFactory::get(); - TEST_DO(verify_encode_decode(spec, factory, simple)); - if (&factory != &simple) { - TEST_DO(verify_encode_decode(spec, simple, factory)); - } - } - - void test_binary_format_spec(Cursor &test) { - Stash stash; - TensorSpec spec = TensorSpec::from_slime(test["tensor"]); - const Inspector &binary = test["binary"]; - EXPECT_GREATER(binary.entries(), 0u); - nbostream encoded; - encode_value(*value_from_spec(spec, factory), encoded); - test.setData("encoded", Memory(encoded.peek(), encoded.size())); - bool matched_encode = false; - for (size_t i = 0; i < binary.entries(); ++i) { - nbostream data = extract_data(binary[i].asString()); - matched_encode = (matched_encode || is_same(encoded, data)); - EXPECT_EQUAL(spec_from_value(*decode_value(data, factory)), spec); - EXPECT_EQUAL(data.size(), 0u); - } - EXPECT_TRUE(matched_encode); - } - - void test_binary_format_spec() { - vespalib::string path = module_path; - path.append("src/apps/make_tensor_binary_format_test_spec/test_spec.json"); - MappedFileInput file(path); - EXPECT_TRUE(file.valid()); - auto handle_test = [this](Slime &slime) - { - size_t fail_cnt = TEST_MASTER.getProgress().failCnt; - TEST_DO(test_binary_format_spec(slime.get())); - if (TEST_MASTER.getProgress().failCnt > fail_cnt) { - fprintf(stderr, "failed:\n%s", slime.get().toString().c_str()); - } - }; - auto handle_summary = [](Slime &slime) - { - EXPECT_GREATER(slime["num_tests"].asLong(), 0); - }; - for_each_test(file, handle_test, handle_summary); - } - - void test_binary_format() { - TEST_DO(test_binary_format_spec()); - TEST_DO(verify_encode_decode(spec(42))); - TEST_DO(verify_encode_decode(spec({x(3)}, N()))); - TEST_DO(verify_encode_decode(spec({x(3),y(5)}, N()))); - TEST_DO(verify_encode_decode(spec({x(3),y(5),z(7)}, N()))); - TEST_DO(verify_encode_decode(spec(float_cells({x(3),y(5),z(7)}), N()))); - TEST_DO(verify_encode_decode(spec({x({"a","b","c"})}, N()))); - TEST_DO(verify_encode_decode(spec({x({"a","b","c"}),y({"foo","bar"})}, N()))); - TEST_DO(verify_encode_decode(spec({x({"a","b","c"}),y({"foo","bar"}),z({"i","j","k","l"})}, N()))); - TEST_DO(verify_encode_decode(spec(float_cells({x({"a","b","c"}),y({"foo","bar"}),z({"i","j","k","l"})}), N()))); - TEST_DO(verify_encode_decode(spec({x(3),y({"foo", "bar"}),z(7)}, N()))); - TEST_DO(verify_encode_decode(spec({x({"a","b","c"}),y(5),z({"i","j","k","l"})}, N()))); - TEST_DO(verify_encode_decode(spec(float_cells({x({"a","b","c"}),y(5),z({"i","j","k","l"})}), N()))); - } - - //------------------------------------------------------------------------- - void run_tests() { TEST_DO(test_tensor_create_type()); TEST_DO(test_tensor_reduce()); @@ -766,7 +665,6 @@ struct TestContext { TEST_DO(test_tensor_create()); TEST_DO(test_tensor_peek()); TEST_DO(test_tensor_merge()); - TEST_DO(test_binary_format()); } }; diff --git a/eval/src/vespa/eval/eval/test/test_io.cpp b/eval/src/vespa/eval/eval/test/test_io.cpp index 4ecfb788f28..b53ee864cbe 100644 --- a/eval/src/vespa/eval/eval/test/test_io.cpp +++ b/eval/src/vespa/eval/eval/test/test_io.cpp @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "test_io.h" -#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/vespalib/util/require.h> #include <vespa/vespalib/data/slime/slime.h> #include <vespa/vespalib/data/slime/json_format.h> #include <vespa/vespalib/util/size_literals.h> @@ -75,8 +75,8 @@ void TestWriter::maybe_write_test() { if (_test.get().type().getId() != slime::NIX::ID) { - ASSERT_GREATER(_test.get().fields(), 0u); - ASSERT_FALSE(_test[num_tests_str].valid()); + REQUIRE(_test.get().fields() > 0u); + REQUIRE(!_test[num_tests_str].valid()); write_compact(_test, _out); ++_num_tests; } @@ -116,21 +116,21 @@ void for_each_test(Input &in, if (JsonFormat::decode(in, slime)) { bool is_summary = slime[num_tests_str].valid(); bool is_test = (!is_summary && (slime.get().fields() > 0)); - ASSERT_TRUE(is_test != is_summary); + REQUIRE(is_test != is_summary); if (is_test) { ++num_tests; - ASSERT_TRUE(!got_summary); + REQUIRE(!got_summary); handle_test(slime); } else { got_summary = true; - ASSERT_EQUAL(slime[num_tests_str].asLong(), int64_t(num_tests)); + REQUIRE_EQ(slime[num_tests_str].asLong(), int64_t(num_tests)); handle_summary(slime); } } else { - ASSERT_EQUAL(in.obtain().size, 0u); + REQUIRE_EQ(in.obtain().size, 0u); } } - ASSERT_TRUE(got_summary); + REQUIRE(got_summary); } //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/value_codec.cpp b/eval/src/vespa/eval/eval/value_codec.cpp index 50c77eca7e9..bd9d36bed2f 100644 --- a/eval/src/vespa/eval/eval/value_codec.cpp +++ b/eval/src/vespa/eval/eval/value_codec.cpp @@ -14,6 +14,8 @@ using vespalib::make_string_short::fmt; namespace vespalib::eval { +VESPA_IMPLEMENT_EXCEPTION(DecodeValueException, Exception); + namespace { constexpr uint32_t DOUBLE_CELL_TYPE = 0; @@ -314,13 +316,19 @@ void encode_value(const Value &value, nbostream &output) { } std::unique_ptr<Value> decode_value(nbostream &input, const ValueBuilderFactory &factory) { - Format format(input.getInt1_4Bytes()); - ValueType type = decode_type(input, format); - size_t num_mapped_dims = type.count_mapped_dimensions(); - size_t dense_subspace_size = type.dense_subspace_size(); - const size_t num_blocks = maybe_decode_num_blocks(input, (num_mapped_dims > 0), format); - DecodeState state{type, dense_subspace_size, num_blocks, num_mapped_dims}; - return typify_invoke<1,TypifyCellType,ContentDecoder>(type.cell_type(), input, state, factory); + try { + Format format(input.getInt1_4Bytes()); + ValueType type = decode_type(input, format); + size_t num_mapped_dims = type.count_mapped_dimensions(); + size_t dense_subspace_size = type.dense_subspace_size(); + const size_t num_blocks = maybe_decode_num_blocks(input, (num_mapped_dims > 0), format); + DecodeState state{type, dense_subspace_size, num_blocks, num_mapped_dims}; + return typify_invoke<1,TypifyCellType,ContentDecoder>(type.cell_type(), input, state, factory); + } catch (const OOMException &) { + throw; + } catch (const Exception &e) { + throw DecodeValueException("failed to decode value", e); + } } //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/value_codec.h b/eval/src/vespa/eval/eval/value_codec.h index 058b2d7bf4f..23eb2de8e41 100644 --- a/eval/src/vespa/eval/eval/value_codec.h +++ b/eval/src/vespa/eval/eval/value_codec.h @@ -5,11 +5,14 @@ #include "value.h" #include "tensor_spec.h" #include <vespa/vespalib/stllike/string.h> +#include <vespa/vespalib/util/exception.h> namespace vespalib { class nbostream; } namespace vespalib::eval { +VESPA_DEFINE_EXCEPTION(DecodeValueException, Exception); + /** * encode a value (which must support the new APIs) to binary format **/ @@ -17,6 +20,8 @@ void encode_value(const Value &value, nbostream &output); /** * decode a value from binary format + * + * will throw DecodeValueException if input contains malformed data **/ std::unique_ptr<Value> decode_value(nbostream &input, const ValueBuilderFactory &factory); |