summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@yahoo-inc.com>2016-10-21 12:23:35 +0000
committerTor Egge <Tor.Egge@yahoo-inc.com>2016-10-21 12:23:35 +0000
commit9ad8c2909abc2fe07b14184da3fa16f71f1817c3 (patch)
tree4311aacf2fa81f1ad21c8d190fb3dfaf7b70445c /searchlib
parent0ac0e8167f8dd3753b8bf13e976541fe84dceeca (diff)
If tensor type for dense tensor store is abstract then different tensors
can have different concrete tensor types.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp35
-rw-r--r--searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h1
2 files changed, 32 insertions, 4 deletions
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
index 4f300db184e..c2f9b67dfc5 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.cpp
@@ -4,6 +4,7 @@
#include "dense_tensor_store.h"
#include <vespa/vespalib/tensor/tensor.h>
#include <vespa/vespalib/tensor/dense/dense_tensor_view.h>
+#include <vespa/vespalib/tensor/dense/mutable_dense_tensor_view.h>
#include <vespa/vespalib/tensor/dense/dense_tensor.h>
#include <vespa/vespalib/tensor/serialization/typed_binary_format.h>
#include <vespa/vespalib/objects/nbostream.h>
@@ -15,6 +16,7 @@
using vespalib::tensor::Tensor;
using vespalib::tensor::DenseTensor;
using vespalib::tensor::DenseTensorView;
+using vespalib::tensor::MutableDenseTensorView;
using vespalib::eval::ValueType;
using vespalib::ConstArrayRef;
@@ -170,6 +172,28 @@ DenseTensorStore::move(EntryRef ref) {
return newraw.second;
}
+namespace {
+
+ValueType makeBoundType(const ValueType &type,
+ const void *buffer,
+ uint32_t numUnboundDims)
+{
+ std::vector<ValueType::Dimension> dimensions(type.dimensions());
+ const uint32_t *unboundDimSizeEnd = static_cast<const uint32_t *>(buffer);
+ const uint32_t *unboundDimSize = unboundDimSizeEnd - numUnboundDims;
+ for (auto &dim : dimensions) {
+ if (!dim.is_bound()) {
+ assert(unboundDimSize != unboundDimSizeEnd);
+ dim.size = *unboundDimSize;
+ ++unboundDimSize;
+ }
+ }
+ assert(unboundDimSize == unboundDimSizeEnd);
+ return ValueType::tensor_type(std::move(dimensions));
+}
+
+}
+
std::unique_ptr<Tensor>
DenseTensorStore::getTensor(EntryRef ref) const
{
@@ -178,8 +202,13 @@ DenseTensorStore::getTensor(EntryRef ref) const
return std::unique_ptr<Tensor>();
}
size_t numCells = getNumCells(raw);
- return std::make_unique<DenseTensorView>
- (_type,
+ if (_numUnboundDims == 0) {
+ return std::make_unique<DenseTensorView>
+ (_type,
+ ConstArrayRef<double>(static_cast<const double *>(raw), numCells));
+ }
+ return std::make_unique<MutableDenseTensorView>
+ (makeBoundType(_type, raw, _numUnboundDims),
ConstArrayRef<double>(static_cast<const double *>(raw), numCells));
}
@@ -233,7 +262,7 @@ DenseTensorStore::setDenseTensor(const TensorType &tensor)
size_t numCells = tensor.cells().size();
checkMatchingType(_type, tensor.type(), numCells);
auto raw = allocRawBuffer(numCells);
- setDenseTensorUnboundDimSizes(raw.first, _type, numUnboundDims(), tensor.type());
+ setDenseTensorUnboundDimSizes(raw.first, _type, _numUnboundDims, tensor.type());
memcpy(raw.first, &tensor.cells()[0], numCells * _cellSize);
return raw.second;
}
diff --git a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
index 2d754dc48af..6c507289c03 100644
--- a/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
+++ b/searchlib/src/vespa/searchlib/tensor/dense_tensor_store.h
@@ -66,7 +66,6 @@ public:
DenseTensorStore(const ValueType &type);
virtual ~DenseTensorStore();
- uint32_t numUnboundDims() const { return _numUnboundDims; }
size_t getNumCells(const void *buffer) const;
uint32_t getCellSize() const { return _cellSize; }
const void *getRawBuffer(RefType ref) const;