summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@oath.com>2019-04-03 17:35:32 +0000
committerHenning Baldersheim <balder@oath.com>2019-04-03 17:35:32 +0000
commit0b132f8a14a13caaa38fe8e4769d60f1ae501eee (patch)
tree6d9bb3ef84d0cdafe775efdb9fac2a5a797e74b5 /eval
parentce4339ab5224bac5df5fb721cb7d2668ae75c811 (diff)
Add a method that will extract cells only to a favoured cell type. This is temporary until full typed tensor support is in place.
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp13
-rw-r--r--eval/src/vespa/eval/tensor/serialization/common.h2
-rw-r--r--eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp41
-rw-r--r--eval/src/vespa/eval/tensor/serialization/dense_binary_format.h5
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp17
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.h5
6 files changed, 72 insertions, 11 deletions
diff --git a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
index 0237f6cc769..b4e0db4fa8f 100644
--- a/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
+++ b/eval/src/tests/tensor/tensor_serialization/tensor_serialization_test.cpp
@@ -162,6 +162,16 @@ struct DenseFixture
void assertSerialized(const ExpBuffer &exp, const DenseTensorCells &rhs) {
assertSerialized(exp, SerializeFormat::DOUBLE, rhs);
}
+ template <typename T>
+ void assertCellsOnly(const ExpBuffer &exp, const DenseTensorView & rhs) {
+ nbostream a(&exp[0], exp.size());
+ std::vector<T> v;
+ TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(a, v);
+ EXPECT_EQUAL(v.size(), rhs.cellsRef().size());
+ for (size_t i(0); i < v.size(); i++) {
+ EXPECT_EQUAL(v[i], rhs.cellsRef()[i]);
+ }
+ }
void assertSerialized(const ExpBuffer &exp, SerializeFormat cellType, const DenseTensorCells &rhs) {
Tensor::UP rhsTensor(createTensor(rhs));
nbostream rhsStream;
@@ -169,6 +179,9 @@ struct DenseFixture
EXPECT_EQUAL(exp, rhsStream);
auto rhs2 = deserialize(rhsStream);
EXPECT_EQUAL(*rhs2, *rhsTensor);
+
+ assertCellsOnly<float>(exp, dynamic_cast<const DenseTensorView &>(*rhs2));
+ assertCellsOnly<double>(exp, dynamic_cast<const DenseTensorView &>(*rhs2));
}
};
diff --git a/eval/src/vespa/eval/tensor/serialization/common.h b/eval/src/vespa/eval/tensor/serialization/common.h
index 9c45bc42136..40b1840be6e 100644
--- a/eval/src/vespa/eval/tensor/serialization/common.h
+++ b/eval/src/vespa/eval/tensor/serialization/common.h
@@ -1,4 +1,4 @@
-// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
index 6043153adc3..4b1ccc8db5d 100644
--- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.cpp
@@ -57,9 +57,9 @@ decodeDimensions(nbostream & stream, std::vector<Dimension> & dimensions) {
return cellsSize;
}
-template<typename T>
+template<typename T, typename V>
void
-decodeCells(nbostream &stream, size_t cellsSize, DenseTensor::Cells & cells) {
+decodeCells(nbostream &stream, size_t cellsSize, V & cells) {
T cellValue = 0.0;
for (size_t i = 0; i < cellsSize; ++i) {
stream >> cellValue;
@@ -67,6 +67,19 @@ decodeCells(nbostream &stream, size_t cellsSize, DenseTensor::Cells & cells) {
}
}
+template <typename V>
+void decodeCells(SerializeFormat format, nbostream &stream, size_t cellsSize, V & cells)
+{
+ switch (format) {
+ case SerializeFormat::DOUBLE:
+ decodeCells<double>(stream, cellsSize, cells);
+ break;
+ case SerializeFormat::FLOAT:
+ decodeCells<float>(stream, cellsSize, cells);
+ break;
+ }
+}
+
}
void
@@ -93,16 +106,24 @@ DenseBinaryFormat::deserialize(nbostream &stream)
size_t cellsSize = decodeDimensions(stream,dimensions);
DenseTensor::Cells cells;
cells.reserve(cellsSize);
- switch (_format) {
- case SerializeFormat::DOUBLE:
- decodeCells<double>(stream, cellsSize,cells);
- break;
- case SerializeFormat::FLOAT:
- decodeCells<float>(stream, cellsSize, cells);
- break;
- }
+
+ decodeCells(_format, stream, cellsSize, cells);
return std::make_unique<DenseTensor>(makeValueType(std::move(dimensions)), std::move(cells));
}
+template <typename T>
+void
+DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<T> & cells)
+{
+ std::vector<Dimension> dimensions;
+ size_t cellsSize = decodeDimensions(stream,dimensions);
+ cells.clear();
+ cells.reserve(cellsSize);
+ decodeCells(_format, stream, cellsSize, cells);
+}
+
+template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<double> & cells);
+template void DenseBinaryFormat::deserializeCellsOnly(nbostream &stream, std::vector<float> & cells);
+
}
diff --git a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
index 22c1663719e..f9847d37784 100644
--- a/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/dense_binary_format.h
@@ -4,6 +4,7 @@
#include "common.h"
#include <memory>
+#include <vector>
namespace vespalib { class nbostream; }
@@ -21,6 +22,10 @@ public:
DenseBinaryFormat(SerializeFormat format) : _format(format) { }
void serialize(nbostream &stream, const DenseTensorView &tensor);
std::unique_ptr<DenseTensor> deserialize(nbostream &stream);
+
+ // This is a temporary method untill we get full support for typed tensors
+ template <typename T>
+ void deserializeCellsOnly(nbostream &stream, std::vector<T> & cells);
private:
SerializeFormat _format;
};
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
index d1aa09b6ce3..813763ba268 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
@@ -99,4 +99,21 @@ TypedBinaryFormat::deserialize(nbostream &stream)
abort();
}
+template <typename T>
+void
+TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells)
+{
+ auto formatId = stream.getInt1_4Bytes();
+ if (formatId == DENSE_BINARY_FORMAT_TYPE) {
+ return DenseBinaryFormat(SerializeFormat::DOUBLE).deserializeCellsOnly(stream, cells);
+ }
+ if (formatId == TYPED_DENSE_BINARY_FORMAT_TYPE) {
+ return DenseBinaryFormat(encoding2Format(stream.getInt1_4Bytes())).deserializeCellsOnly(stream, cells);
+ }
+ abort();
+}
+
+template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<double> & cells);
+template void TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<float> & cells);
+
}
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
index 95d9a75488c..717d51effef 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.h
@@ -4,6 +4,7 @@
#include "common.h"
#include <memory>
+#include <vector>
namespace vespalib { class nbostream; }
@@ -23,6 +24,10 @@ public:
}
static std::unique_ptr<Tensor> deserialize(nbostream &stream);
+
+ // This is a temporary method until we get full support for typed tensors
+ template <typename T>
+ static void deserializeCellsOnlyFromDenseTensors(nbostream &stream, std::vector<T> & cells);
};
}