From 11ac39ad8b6ee2c5b9fc122d29f754152b80a85a Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 24 Jan 2018 08:54:41 +0000 Subject: Support binary tensor files with ending .tbf --- eval/src/tests/eval/value_cache/dense.tbf | Bin 0 -> 40 bytes .../tests/eval/value_cache/tensor_loader_test.cpp | 4 ++ .../eval/value_cache/constant_tensor_loader.cpp | 47 +++++++++++++-------- 3 files changed, 34 insertions(+), 17 deletions(-) create mode 100644 eval/src/tests/eval/value_cache/dense.tbf (limited to 'eval') diff --git a/eval/src/tests/eval/value_cache/dense.tbf b/eval/src/tests/eval/value_cache/dense.tbf new file mode 100644 index 00000000000..61012a35ff9 Binary files /dev/null and b/eval/src/tests/eval/value_cache/dense.tbf differ diff --git a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp index 5dd8caa6e27..8180a7daef8 100644 --- a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp +++ b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp @@ -67,6 +67,10 @@ TEST_F("require that lz4 compressed dense tensor can be loaded", ConstantTensorL TEST_DO(verify_tensor(make_dense_tensor(), f1.create(TEST_PATH("dense.json.lz4"), "tensor(x[2],y[2])"))); } +TEST_F("require that a binary tensor can be loaded", ConstantTensorLoader(SimpleTensorEngine::ref())) { + TEST_DO(verify_tensor(make_dense_tensor(), f1.create(TEST_PATH("dense.tbf"), "tensor(x[2],y[2])"))); +} + TEST_F("require that lz4 compressed sparse tensor can be loaded", ConstantTensorLoader(SimpleTensorEngine::ref())) { TEST_DO(verify_tensor(make_sparse_tensor(), f1.create(TEST_PATH("sparse.json.lz4"), "tensor(x{},y{})"))); } diff --git a/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp b/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp index 38d5bbc643b..208b37b455d 100644 --- a/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp +++ b/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp @@ -4,10 +4,13 @@ #include #include #include +#include +#include #include #include #include #include +#include #include LOG_SETUP(".vespalib.eval.value_cache.constant_tensor_loader"); @@ -54,14 +57,14 @@ void decode_json(const vespalib::string &path, Slime &slime) { } else { if (ends_with(path, ".lz4")) { size_t buffer_size = 64 * 1024; - Lz4InputDecoder lz4_decoder(file, buffer_size); + Lz4InputDecoder lz4_decoder(file, buffer_size); decode_json(path, lz4_decoder, slime); if (lz4_decoder.failed()) { LOG(warning, "file contains lz4 errors (%s): %s", lz4_decoder.reason().c_str(), path.c_str()); } } else { - decode_json(path, file, slime); + decode_json(path, file, slime); } } } @@ -76,23 +79,33 @@ ConstantTensorLoader::create(const vespalib::string &path, const vespalib::strin LOG(warning, "invalid type specification: %s", type.c_str()); return std::make_unique(_engine.from_spec(TensorSpec("double"))); } - Slime slime; - decode_json(path, slime); - std::set indexed; - for (const auto &dimension: value_type.dimensions()) { - if (dimension.is_indexed()) { - indexed.insert(dimension.name); - } + if (ends_with(path, ".tbf")) { + vespalib::File file(path); + file.open(File::READONLY); + std::vector content(file.stat()._size); + file.read(&content[0], content.size(), 0); + vespalib::nbostream_longlivedbuf stream(&content[0], content.size()); + return std::make_unique(_engine.decode(stream)); } - TensorSpec spec(type); - const Inspector &cells = slime.get()["cells"]; - for (size_t i = 0; i < cells.entries(); ++i) { - TensorSpec::Address address; - AddressExtractor extractor(indexed, address); - cells[i]["address"].traverse(extractor); - spec.add(address, cells[i]["value"].asDouble()); + else { + Slime slime; + decode_json(path, slime); + std::set indexed; + for (const auto &dimension: value_type.dimensions()) { + if (dimension.is_indexed()) { + indexed.insert(dimension.name); + } + } + TensorSpec spec(type); + const Inspector &cells = slime.get()["cells"]; + for (size_t i = 0; i < cells.entries(); ++i) { + TensorSpec::Address address; + AddressExtractor extractor(indexed, address); + cells[i]["address"].traverse(extractor); + spec.add(address, cells[i]["value"].asDouble()); + } + return std::make_unique(_engine.from_spec(spec)); } - return std::make_unique(_engine.from_spec(spec)); } } // namespace vespalib::eval -- cgit v1.2.3