diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-11-27 15:52:51 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-11-27 15:52:51 +0000 |
commit | 4ac10dad24980734a161533c36b232cfc5d3a2f9 (patch) | |
tree | 6ce947c9d2c9d8adede8a6334ecb55dbd5c51270 /searchlib | |
parent | da1bc43297be17d4f7ea18026c9984f07098af73 (diff) |
check tensor type in attribute, just assert in store
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp | 5 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp | 46 |
2 files changed, 22 insertions, 29 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp index 27acb4062ed..6e1fb1a0a2f 100644 --- a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp @@ -155,11 +155,10 @@ SerializedFastValueAttribute::~SerializedFastValueAttribute() void SerializedFastValueAttribute::setTensor(DocId docId, const vespalib::eval::Value &tensor) { + checkTensorType(tensor); EntryRef ref = _streamedValueStore.store_tensor(tensor); + assert(ref.valid()); setTensorRef(docId, ref); - if (!ref.valid()) { - checkTensorType(tensor); - } } std::unique_ptr<Value> diff --git a/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp b/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp index bec5a2799ac..ae2e0e7ed10 100644 --- a/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp @@ -195,32 +195,26 @@ StreamedValueStore::serialize_labels(const Value::Index &index, TensorStore::EntryRef StreamedValueStore::store_tensor(const Value &tensor) { - if (tensor.type() == _tensor_type) { - CellsMemBlock cells_mem(tensor.cells()); - size_t alignment = CellTypeUtils::alignment(_data_from_type.cell_type); - size_t padding = alignment - 1; - vespalib::nbostream stream; - stream << uint32_t(cells_mem.num); - serialize_labels(tensor.index(), stream); - size_t mem_size = stream.size() + cells_mem.total_sz + padding; - auto raw = allocRawBuffer(mem_size); - char *target = raw.data; - memcpy(target, stream.peek(), sizeof(uint32_t)); - stream.adjustReadPos(sizeof(uint32_t)); - target += sizeof(uint32_t); - target = fix_alignment(target, alignment); - memcpy(target, cells_mem.ptr, cells_mem.total_sz); - target += cells_mem.total_sz; - memcpy(target, stream.peek(), stream.size()); - target += stream.size(); - assert(target <= raw.data + mem_size); - return raw.ref; - } else { - LOG(error, "trying to store tensor of type %s in store only allowing %s", - tensor.type().to_spec().c_str(), _tensor_type.to_spec().c_str()); - TensorStore::EntryRef invalid; - return invalid; - } + assert(tensor.type() == _tensor_type); + CellsMemBlock cells_mem(tensor.cells()); + size_t alignment = CellTypeUtils::alignment(_data_from_type.cell_type); + size_t padding = alignment - 1; + vespalib::nbostream stream; + stream << uint32_t(cells_mem.num); + serialize_labels(tensor.index(), stream); + size_t mem_size = stream.size() + cells_mem.total_sz + padding; + auto raw = allocRawBuffer(mem_size); + char *target = raw.data; + memcpy(target, stream.peek(), sizeof(uint32_t)); + stream.adjustReadPos(sizeof(uint32_t)); + target += sizeof(uint32_t); + target = fix_alignment(target, alignment); + memcpy(target, cells_mem.ptr, cells_mem.total_sz); + target += cells_mem.total_sz; + memcpy(target, stream.peek(), stream.size()); + target += stream.size(); + assert(target <= raw.data + mem_size); + return raw.ref; } TensorStore::EntryRef |