aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-11-27 15:52:51 +0000
committerArne Juul <arnej@verizonmedia.com>2020-11-27 15:52:51 +0000
commit4ac10dad24980734a161533c36b232cfc5d3a2f9 (patch)
tree6ce947c9d2c9d8adede8a6334ecb55dbd5c51270 /searchlib
parentda1bc43297be17d4f7ea18026c9984f07098af73 (diff)
check tensor type in attribute, just assert in store
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp5
-rw-r--r--searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp46
2 files changed, 22 insertions, 29 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp
index 27acb4062ed..6e1fb1a0a2f 100644
--- a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp
@@ -155,11 +155,10 @@ SerializedFastValueAttribute::~SerializedFastValueAttribute()
void
SerializedFastValueAttribute::setTensor(DocId docId, const vespalib::eval::Value &tensor)
{
+ checkTensorType(tensor);
EntryRef ref = _streamedValueStore.store_tensor(tensor);
+ assert(ref.valid());
setTensorRef(docId, ref);
- if (!ref.valid()) {
- checkTensorType(tensor);
- }
}
std::unique_ptr<Value>
diff --git a/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp b/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp
index bec5a2799ac..ae2e0e7ed10 100644
--- a/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp
@@ -195,32 +195,26 @@ StreamedValueStore::serialize_labels(const Value::Index &index,
TensorStore::EntryRef
StreamedValueStore::store_tensor(const Value &tensor)
{
- if (tensor.type() == _tensor_type) {
- CellsMemBlock cells_mem(tensor.cells());
- size_t alignment = CellTypeUtils::alignment(_data_from_type.cell_type);
- size_t padding = alignment - 1;
- vespalib::nbostream stream;
- stream << uint32_t(cells_mem.num);
- serialize_labels(tensor.index(), stream);
- size_t mem_size = stream.size() + cells_mem.total_sz + padding;
- auto raw = allocRawBuffer(mem_size);
- char *target = raw.data;
- memcpy(target, stream.peek(), sizeof(uint32_t));
- stream.adjustReadPos(sizeof(uint32_t));
- target += sizeof(uint32_t);
- target = fix_alignment(target, alignment);
- memcpy(target, cells_mem.ptr, cells_mem.total_sz);
- target += cells_mem.total_sz;
- memcpy(target, stream.peek(), stream.size());
- target += stream.size();
- assert(target <= raw.data + mem_size);
- return raw.ref;
- } else {
- LOG(error, "trying to store tensor of type %s in store only allowing %s",
- tensor.type().to_spec().c_str(), _tensor_type.to_spec().c_str());
- TensorStore::EntryRef invalid;
- return invalid;
- }
+ assert(tensor.type() == _tensor_type);
+ CellsMemBlock cells_mem(tensor.cells());
+ size_t alignment = CellTypeUtils::alignment(_data_from_type.cell_type);
+ size_t padding = alignment - 1;
+ vespalib::nbostream stream;
+ stream << uint32_t(cells_mem.num);
+ serialize_labels(tensor.index(), stream);
+ size_t mem_size = stream.size() + cells_mem.total_sz + padding;
+ auto raw = allocRawBuffer(mem_size);
+ char *target = raw.data;
+ memcpy(target, stream.peek(), sizeof(uint32_t));
+ stream.adjustReadPos(sizeof(uint32_t));
+ target += sizeof(uint32_t);
+ target = fix_alignment(target, alignment);
+ memcpy(target, cells_mem.ptr, cells_mem.total_sz);
+ target += cells_mem.total_sz;
+ memcpy(target, stream.peek(), stream.size());
+ target += stream.size();
+ assert(target <= raw.data + mem_size);
+ return raw.ref;
}
TensorStore::EntryRef