diff options
Diffstat (limited to 'searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp')
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp | 50 |
1 files changed, 43 insertions, 7 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp index fba1d494690..1184cca37e7 100644 --- a/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/direct_tensor_store.cpp @@ -1,7 +1,10 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "direct_tensor_store.h" +#include "tensor_deserialize.h" +#include <vespa/eval/eval/fast_value.h> #include <vespa/eval/eval/value.h> +#include <vespa/eval/eval/value_codec.h> #include <vespa/vespalib/datastore/compacting_buffers.h> #include <vespa/vespalib/datastore/compaction_context.h> #include <vespa/vespalib/datastore/compaction_strategy.h> @@ -14,6 +17,8 @@ using vespalib::datastore::CompactionSpec; using vespalib::datastore::CompactionStrategy; using vespalib::datastore::EntryRef; using vespalib::datastore::ICompactionContext; +using vespalib::eval::FastValueBuilderFactory; +using vespalib::eval::Value; namespace search::tensor { @@ -54,13 +59,6 @@ DirectTensorStore::DirectTensorStore() DirectTensorStore::~DirectTensorStore() = default; -EntryRef -DirectTensorStore::store_tensor(std::unique_ptr<vespalib::eval::Value> tensor) -{ - assert(tensor); - return add_entry(TensorSP(std::move(tensor))); -} - void DirectTensorStore::holdTensor(EntryRef ref) { @@ -100,4 +98,42 @@ DirectTensorStore::start_compact(const CompactionStrategy& compaction_strategy) return std::make_unique<CompactionContext>(*this, std::move(compacting_buffers)); } +EntryRef +DirectTensorStore::store_tensor(std::unique_ptr<Value> tensor) +{ + assert(tensor); + return add_entry(std::move(tensor)); +} + +EntryRef +DirectTensorStore::store_tensor(const Value& tensor) +{ + return add_entry(FastValueBuilderFactory::get().copy(tensor)); +} + +EntryRef +DirectTensorStore::store_encoded_tensor(vespalib::nbostream& encoded) +{ + return add_entry(deserialize_tensor(encoded)); +} + +std::unique_ptr<Value> +DirectTensorStore::get_tensor(EntryRef ref) const +{ + if (!ref.valid()) { + return {}; + } + return FastValueBuilderFactory::get().copy(*_tensor_store.getEntry(ref)); +} + +bool +DirectTensorStore::encode_stored_tensor(EntryRef ref, vespalib::nbostream& target) const +{ + if (!ref.valid()) { + return false; + } + vespalib::eval::encode_value(*_tensor_store.getEntry(ref), target); + return true; +} + } |