summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-07-12 12:49:22 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-07-15 15:15:33 +0000
commit1dd2c0e3329f199ea7279e948a790dc1af75ed61 (patch)
tree15a88dc3cfefd34617b3c316a87ce9ea54575313
parentc3e4fdeb920dabf96b1c007d7d7972faaf997d14 (diff)
also optimize for float cells
-rw-r--r--eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp16
-rw-r--r--eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp4
-rw-r--r--eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp13
-rw-r--r--eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp4
-rw-r--r--eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp4
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp15
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp40
-rw-r--r--eval/src/vespa/eval/eval/value_type.h16
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp36
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp28
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp10
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp14
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_generic_join.hpp2
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp71
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp25
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp12
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp5
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h7
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp116
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h3
-rw-r--r--eval/src/vespa/eval/tensor/dense/typed_cells.h29
21 files changed, 262 insertions, 208 deletions
diff --git a/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp b/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
index eaf4623afea..274117ea693 100644
--- a/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
@@ -25,6 +25,7 @@ const TensorEngine &prod_engine = DefaultTensorEngine::ref();
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("x5", spec({x(5)}, N()))
+ .add("x5f", spec(float_cells({x(5)}), N()))
.add("x5y1", spec({x(5),y(1)}, N()))
.add("y1z1", spec({y(1),z(1)}, N()))
.add("x_m", spec({x({"a"})}, N()));
@@ -78,9 +79,9 @@ TEST("require that non-canonical dimension addition is not optimized") {
TEST_DO(verify_not_optimized("tensor(y[1])(1)/x5"));
}
-TEST("require that dimension addition with overlapping dimensions is not optimized") {
- TEST_DO(verify_not_optimized("x5y1*tensor(y[1],z[1])(1)"));
- TEST_DO(verify_not_optimized("tensor(y[1],z[1])(1)*x5y1"));
+TEST("require that dimension addition with overlapping dimensions is optimized") {
+ TEST_DO(verify_optimized("x5y1*tensor(y[1],z[1])(1)"));
+ TEST_DO(verify_optimized("tensor(y[1],z[1])(1)*x5y1"));
}
TEST("require that dimension addition with inappropriate dimensions is not optimized") {
@@ -99,8 +100,13 @@ TEST("require that dimension addition optimization requires unit constant tensor
TEST_DO(verify_not_optimized("tensor(x[2])(1)*tensor(y[2])(1)"));
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(verify_not_optimized("x5*tensor<float>(a[1],b[1],c[1])(1)"));
+TEST("require that optimization also works for float cells") {
+ TEST_DO(verify_optimized("x5*tensor<float>(a[1],b[1],c[1])(1)"));
+ TEST_DO(verify_optimized("x5f*tensor<float>(a[1],b[1],c[1])(1)"));
+}
+
+TEST("require that optimization is disabled if unit vector would promote tensor cell types") {
+ TEST_DO(verify_not_optimized("x5f*tensor(a[1],b[1],c[1])(1)"));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp b/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
index 4995ea89735..55a9414f82b 100644
--- a/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
@@ -72,8 +72,8 @@ TEST("require that chained optimized renames are compacted into a single operati
TEST_DO(verify_optimized("rename(rename(x5,x,y),y,z)"));
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(verify_not_optimized("rename(x5f,x,y)"));
+TEST("require that optimization works for float cells") {
+ TEST_DO(verify_optimized("rename(x5f,x,y)"));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
index 083ed1c7071..80321ac3d22 100644
--- a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
+++ b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
@@ -144,10 +144,15 @@ TEST("require that inplace join can be debug dumped") {
fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(verify_not_optimized("mut_x5_A-mut_x5f_D"));
- TEST_DO(verify_not_optimized("mut_x5f_D-mut_x5_A"));
- TEST_DO(verify_not_optimized("mut_x5f_D-mut_x5f_E"));
+TEST("require that optimization works with float cells") {
+ TEST_DO(verify_p0_optimized("mut_x5f_D-mut_x5f_E", 1));
+}
+
+TEST("require that overwritten value must have same cell type as result") {
+ TEST_DO(verify_p0_optimized("mut_x5_A-mut_x5f_D", 1));
+ TEST_DO(verify_p1_optimized("mut_x5f_D-mut_x5_A", 1));
+ TEST_DO(verify_not_optimized("con_x5_A-mut_x5f_D"));
+ TEST_DO(verify_not_optimized("mut_x5f_D-con_x5_A"));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp b/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
index 314d3a6186c..f85742b4e0f 100644
--- a/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
+++ b/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
@@ -72,8 +72,8 @@ TEST("require that mapped tensors are not optimized") {
TEST_DO(verify_not_optimized("map(_x_m,f(x)(x+10))"));
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(verify_not_optimized("map(_x5f,f(x)(x+10))"));
+TEST("require that optimization works for float cells") {
+ TEST_DO(verify_optimized("map(_x5f,f(x)(x+10))", 1));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp b/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
index 7856775ae30..179fdd3eff4 100644
--- a/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
@@ -78,8 +78,8 @@ TEST("require that inappropriate tensor types cannot be optimized") {
TEST_DO(verify_not_optimized("reduce(x1y5z_m,sum,z)"));
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(verify_not_optimized("reduce(x1y5z1f,avg,x)"));
+TEST("require that optimization works for float cells") {
+ TEST_DO(verify_optimized("reduce(x1y5z1f,avg,x)"));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index 335aa4791a4..426281686d7 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -45,6 +45,7 @@ EvalFixture::ParamRepo make_params() {
.add("y1z1", spec({y(1),z(1)}, MyMatSeq()))
.add("x2y3", spec({x(2),y(3)}, MyMatSeq()))
.add("x2y3f", spec(float_cells({x(2),y(3)}), MyMatSeq()))
+ .add("y3z2f", spec(float_cells({y(3),z(2)}), MyMatSeq()))
.add("x2z3", spec({x(2),z(3)}, MyMatSeq()))
.add("y3z2", spec({y(3),z(2)}, MyMatSeq()))
.add("x8y5", spec({x(8),y(5)}, MyMatSeq()))
@@ -118,10 +119,16 @@ TEST("require that xw product can be debug dumped") {
fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}
-TEST("require that optimization is disabled for tensors with non-double cells") {
- TEST_DO(verify_not_optimized("reduce(y3f*x2y3,sum,y)"));
- TEST_DO(verify_not_optimized("reduce(y3*x2y3f,sum,y)"));
- TEST_DO(verify_not_optimized("reduce(y3f*x2y3f,sum,y)"));
+TEST("require that optimization works for float cells") {
+ TEST_DO(verify_optimized("reduce(y3f*x2y3,sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(y3*x2y3f,sum,y)", 3, 2, true));
+ TEST_DO(verify_optimized("reduce(y3f*x2y3f,sum,y)", 3, 2, true));
+}
+
+TEST("require that optimization works for float cells with inconvenient dimension nesting") {
+ TEST_DO(verify_optimized("reduce(y3f*y3z2,sum,y)", 3, 2, false));
+ TEST_DO(verify_optimized("reduce(y3*y3z2f,sum,y)", 3, 2, false));
+ TEST_DO(verify_optimized("reduce(y3f*y3z2f,sum,y)", 3, 2, false));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index fc0f3cc5414..d6ba8e83855 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -12,21 +12,27 @@ using CellType = ValueType::CellType;
using Dimension = ValueType::Dimension;
using DimensionList = std::vector<Dimension>;
-CellType unify(CellType a, CellType b) {
- if (a == b) {
- return a;
- } else {
- return CellType::DOUBLE;
+template <typename A, typename B>
+CellType unify() {
+ using type = typename UnifyCellTypes<A,B>::type;
+ return get_cell_type<type>();
+}
+
+template <typename A>
+CellType unify(CellType b) {
+ switch (b) {
+ case CellType::DOUBLE: return unify<A,double>();
+ case CellType::FLOAT: return unify<A,float>();
}
+ abort();
}
-CellType unify_cell_type(const ValueType &a, const ValueType &b) {
- if (a.is_double()) {
- return b.cell_type();
- } else if (b.is_double()) {
- return a.cell_type();
+CellType unify(CellType a, CellType b) {
+ switch (a) {
+ case CellType::DOUBLE: return unify<double>(b);
+ case CellType::FLOAT: return unify<float>(b);
}
- return unify(a.cell_type(), b.cell_type());
+ abort();
}
size_t my_dimension_index(const std::vector<Dimension> &list, const vespalib::string &name) {
@@ -265,6 +271,16 @@ ValueType::join(const ValueType &lhs, const ValueType &rhs)
return tensor_type(std::move(result.dimensions), unify(lhs._cell_type, rhs._cell_type));
}
+CellType
+ValueType::unify_cell_types(const ValueType &a, const ValueType &b) {
+ if (a.is_double()) {
+ return b.cell_type();
+ } else if (b.is_double()) {
+ return a.cell_type();
+ }
+ return unify(a.cell_type(), b.cell_type());
+}
+
ValueType
ValueType::concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension)
{
@@ -278,7 +294,7 @@ ValueType::concat(const ValueType &lhs, const ValueType &rhs, const vespalib::st
if (!find_dimension(result.dimensions, dimension)) {
result.dimensions.emplace_back(dimension, 2);
}
- return tensor_type(std::move(result.dimensions), unify_cell_type(lhs, rhs));
+ return tensor_type(std::move(result.dimensions), unify_cell_types(lhs, rhs));
}
ValueType
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index 0eb3e1ca28e..64003e2636e 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -78,15 +78,27 @@ public:
static ValueType from_spec(const vespalib::string &spec);
vespalib::string to_spec() const;
static ValueType join(const ValueType &lhs, const ValueType &rhs);
+ static CellType unify_cell_types(const ValueType &a, const ValueType &b);
static ValueType concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension);
static ValueType either(const ValueType &one, const ValueType &other);
};
std::ostream &operator<<(std::ostream &os, const ValueType &type);
-// utility template
-template <typename T> inline bool check_cell_type(ValueType::CellType type);
+// utility templates
+
+template <typename CT> inline bool check_cell_type(ValueType::CellType type);
template <> inline bool check_cell_type<double>(ValueType::CellType type) { return (type == ValueType::CellType::DOUBLE); }
template <> inline bool check_cell_type<float>(ValueType::CellType type) { return (type == ValueType::CellType::FLOAT); }
+template <typename LCT, typename RCT> struct UnifyCellTypes{};
+template <> struct UnifyCellTypes<double, double> { using type = double; };
+template <> struct UnifyCellTypes<double, float> { using type = double; };
+template <> struct UnifyCellTypes<float, double> { using type = double; };
+template <> struct UnifyCellTypes<float, float> { using type = float; };
+
+template <typename CT> inline ValueType::CellType get_cell_type();
+template <> inline ValueType::CellType get_cell_type<double>() { return ValueType::CellType::DOUBLE; }
+template <> inline ValueType::CellType get_cell_type<float>() { return ValueType::CellType::FLOAT; }
+
} // namespace
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 58db90f5557..f1eb9ff1523 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -37,6 +37,7 @@ using eval::TensorFunction;
using eval::TensorSpec;
using eval::Value;
using eval::ValueType;
+using CellType = eval::ValueType::CellType;
using vespalib::IllegalArgumentException;
using vespalib::make_string;
@@ -355,8 +356,7 @@ DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespali
size_t vector_size(const ValueType &type, const vespalib::string &dimension) {
if (type.is_double()) {
return 1;
- } else if ((type.cell_type() == ValueType::CellType::DOUBLE) &&
- (type.dimensions().size() == 1) &&
+ } else if ((type.dimensions().size() == 1) &&
(type.dimensions()[0].is_indexed()) &&
(type.dimensions()[0].name == dimension))
{
@@ -366,40 +366,50 @@ size_t vector_size(const ValueType &type, const vespalib::string &dimension) {
}
}
+template <typename OCT>
struct CallAppendVector {
template <typename CT>
- static void call(const ConstArrayRef<CT> &arr, double *&pos) {
- for (CT cell : arr) { *pos++ = cell; }
+ static void call(const ConstArrayRef<CT> &arr, OCT *&pos) {
+ for (CT cell: arr) { *pos++ = cell; }
}
};
-void append_vector(double *&pos, const Value &value) {
+template <typename OCT>
+void append_vector(OCT *&pos, const Value &value) {
if (auto tensor = value.as_tensor()) {
const DenseTensorView *view = static_cast<const DenseTensorView *>(tensor);
- TypedCells cellsRef = view->cellsRef();
- dispatch_1<CallAppendVector>(cellsRef, pos);
+ dispatch_1<CallAppendVector<OCT> >(view->cellsRef(), pos);
} else {
*pos++ = value.as_double();
}
}
+template <typename OCT>
const Value &concat_vectors(const Value &a, const Value &b, const vespalib::string &dimension, size_t vector_size, Stash &stash) {
- ArrayRef<double> cells = stash.create_array<double>(vector_size);
- double *pos = cells.begin();
- append_vector(pos, a);
- append_vector(pos, b);
+ ArrayRef<OCT> cells = stash.create_array<OCT>(vector_size);
+ OCT *pos = cells.begin();
+ append_vector<OCT>(pos, a);
+ append_vector<OCT>(pos, b);
assert(pos == cells.end());
- const ValueType &type = stash.create<ValueType>(ValueType::tensor_type({ValueType::Dimension(dimension, vector_size)}));
+ const ValueType &type = stash.create<ValueType>(ValueType::tensor_type({ValueType::Dimension(dimension, vector_size)}, ValueType::unify_cell_types(a.type(), b.type())));
return stash.create<DenseTensorView>(type, TypedCells(cells));
}
+struct CallConcatVectors {
+ template <typename OCT>
+ static const Value &call(const Value &a, const Value &b, const vespalib::string &dimension, size_t vector_size, Stash &stash) {
+ return concat_vectors<OCT>(a, b, dimension, vector_size, stash);
+ }
+};
+
const Value &
DefaultTensorEngine::concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const
{
size_t a_size = vector_size(a.type(), dimension);
size_t b_size = vector_size(b.type(), dimension);
if ((a_size > 0) && (b_size > 0)) {
- return concat_vectors(a, b, dimension, a_size + b_size, stash);
+ CellType result_cell_type = ValueType::unify_cell_types(a.type(), b.type());
+ return dispatch_0<CallConcatVectors>(result_cell_type, a, b, dimension, (a_size + b_size), stash);
}
return to_default(simple_engine().concat(to_simple(a, stash), to_simple(b, stash), dimension, stash), stash);
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp
index 842e064de43..a4331b6b251 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp
@@ -19,21 +19,8 @@ using namespace eval::operation;
namespace {
-bool is_concrete_dense_tensor(const ValueType &type) {
- if (type.cell_type() != ValueType::CellType::DOUBLE) {
- return false; // non-double cell types not supported
- }
- return type.is_dense();
-}
-
-bool not_overlapping(const ValueType &a, const ValueType &b) {
- size_t npos = ValueType::Dimension::npos;
- for (const auto &dim: b.dimensions()) {
- if (a.dimension_index(dim.name) != npos) {
- return false;
- }
- }
- return true;
+bool same_cell_type(const TensorFunction &a, const TensorFunction &b) {
+ return (a.result_type().cell_type() == b.result_type().cell_type());
}
bool is_unit_constant(const TensorFunction &node) {
@@ -57,15 +44,14 @@ DenseAddDimensionOptimizer::optimize(const eval::TensorFunction &expr, Stash &st
const TensorFunction &lhs = join->lhs();
const TensorFunction &rhs = join->rhs();
if ((join->function() == Mul::f) &&
- is_concrete_dense_tensor(lhs.result_type()) &&
- is_concrete_dense_tensor(rhs.result_type()) &&
- not_overlapping(lhs.result_type(), rhs.result_type()))
+ lhs.result_type().is_dense() &&
+ rhs.result_type().is_dense())
{
- if (is_unit_constant(lhs)) {
+ if (is_unit_constant(lhs) && same_cell_type(rhs, expr)) {
return DenseReplaceTypeFunction::create_compact(expr.result_type(), rhs, stash);
}
- if (is_unit_constant(rhs)) {
- return DenseReplaceTypeFunction::create_compact(expr.result_type(), lhs, stash);
+ if (is_unit_constant(rhs) && same_cell_type(lhs, expr)) {
+ return DenseReplaceTypeFunction::create_compact(expr.result_type(), lhs, stash);
}
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
index 9b839e1b12f..8bcaddba3b4 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
@@ -18,12 +18,6 @@ using namespace eval::operation;
namespace {
-template <typename T>
-ConstArrayRef<T> getCellsRef(const eval::Value &value) {
- const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value);
- return denseTensor.cellsRef().typify<T>();
-}
-
template <typename LCT, typename RCT>
struct HWSupport {
static double call(hwaccelrated::IAccelrated *, const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs) {
@@ -48,8 +42,8 @@ template <> struct HWSupport<double, double> {
template <typename LCT, typename RCT>
void my_dot_product_op(eval::InterpretedFunction::State &state, uint64_t param) {
auto *hw = (hwaccelrated::IAccelrated *)(param);
- auto lhs = getCellsRef<LCT>(state.peek(1));
- auto rhs = getCellsRef<RCT>(state.peek(0));
+ auto lhs = DenseTensorView::typify_cells<LCT>(state.peek(1));
+ auto rhs = DenseTensorView::typify_cells<RCT>(state.peek(0));
double result = HWSupport<LCT,RCT>::call(hw, lhs, rhs);
state.pop_pop_push(state.stash.create<eval::DoubleValue>(result));
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp
index d8e1876ac64..ac8442477e4 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp
@@ -17,15 +17,10 @@ using namespace eval::tensor_function;
namespace {
-bool is_concrete_dense_stable_rename(const ValueType &from_type, const ValueType &to_type,
- const std::vector<vespalib::string> &from,
- const std::vector<vespalib::string> &to)
+bool is_dense_stable_rename(const ValueType &from_type, const ValueType &to_type,
+ const std::vector<vespalib::string> &from,
+ const std::vector<vespalib::string> &to)
{
- if (from_type.cell_type() != ValueType::CellType::DOUBLE ||
- to_type.cell_type() != ValueType::CellType::DOUBLE)
- {
- return false; // non-double cell types not supported
- }
if (!from_type.is_dense() ||
!to_type.is_dense() ||
(from.size() != to.size()))
@@ -51,7 +46,8 @@ DenseFastRenameOptimizer::optimize(const eval::TensorFunction &expr, Stash &stas
if (auto rename = as<Rename>(expr)) {
const ValueType &from_type = rename->child().result_type();
const ValueType &to_type = expr.result_type();
- if (is_concrete_dense_stable_rename(from_type, to_type, rename->from(), rename->to())) {
+ if (is_dense_stable_rename(from_type, to_type, rename->from(), rename->to())) {
+ assert(to_type.cell_type() == from_type.cell_type());
return DenseReplaceTypeFunction::create_compact(to_type, rename->child(), stash);
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_generic_join.hpp b/eval/src/vespa/eval/tensor/dense/dense_generic_join.hpp
index aa08e6982bb..cdc89b30fff 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_generic_join.hpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_generic_join.hpp
@@ -43,7 +43,7 @@ struct CallGenericJoin {
DenseDimensionCombiner & combiner,
Function &&func)
{
- using OCT = typename OutputCellType<LCT, RCT>::output_type;
+ using OCT = typename eval::UnifyCellTypes<LCT, RCT>::type;
TypedDenseTensorBuilder<OCT> builder(combiner.result_type);
return generic_join(combiner, builder, lhsArr, rhsArr, std::move(func));
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
index 5fdfdbc4e9f..0b5bba88d37 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
@@ -17,35 +17,45 @@ using namespace eval::tensor_function;
namespace {
-TypedCells getCellsRef(const eval::Value &value) {
- const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value);
- return denseTensor.cellsRef();
+template <typename LCT, typename RCT>
+void my_inplace_join_left_op(eval::InterpretedFunction::State &state, uint64_t param) {
+ join_fun_t function = (join_fun_t)param;
+ auto lhs_cells = unconstify(DenseTensorView::typify_cells<LCT>(state.peek(1)));
+ auto rhs_cells = DenseTensorView::typify_cells<RCT>(state.peek(0));
+ for (size_t i = 0; i < lhs_cells.size(); ++i) {
+ lhs_cells[i] = function(lhs_cells[i], rhs_cells[i]);
+ }
+ state.stack.pop_back();
}
-template <bool write_left>
-void my_inplace_join_op(eval::InterpretedFunction::State &state, uint64_t param) {
+template <typename LCT, typename RCT>
+void my_inplace_join_right_op(eval::InterpretedFunction::State &state, uint64_t param) {
join_fun_t function = (join_fun_t)param;
- ConstArrayRef<double> lhs_cells = getCellsRef(state.peek(1)).typify<double>();
- ConstArrayRef<double> rhs_cells = getCellsRef(state.peek(0)).typify<double>();
- auto dst_cells = unconstify(write_left ? lhs_cells : rhs_cells);
- for (size_t i = 0; i < dst_cells.size(); ++i) {
- dst_cells[i] = function(lhs_cells[i], rhs_cells[i]);
- }
- if (write_left) {
- state.stack.pop_back();
- } else {
- const Value &result = state.stack.back();
- state.pop_pop_push(result);
+ auto lhs_cells = DenseTensorView::typify_cells<LCT>(state.peek(1));
+ auto rhs_cells = unconstify(DenseTensorView::typify_cells<RCT>(state.peek(0)));
+ for (size_t i = 0; i < rhs_cells.size(); ++i) {
+ rhs_cells[i] = function(lhs_cells[i], rhs_cells[i]);
}
+ const Value &result = state.stack.back();
+ state.pop_pop_push(result);
}
-bool sameShapeConcreteDenseTensors(const ValueType &a, const ValueType &b) {
- if (a.cell_type() != ValueType::CellType::DOUBLE ||
- b.cell_type() != ValueType::CellType::DOUBLE)
- {
- return false; // non-double cell types not supported
+struct MyInplaceJoinLeftOp {
+ template <typename LCT, typename RCT>
+ static auto get_fun() { return my_inplace_join_left_op<LCT,RCT>; }
+};
+
+struct MyInplaceJoinRightOp {
+ template <typename LCT, typename RCT>
+ static auto get_fun() { return my_inplace_join_right_op<LCT,RCT>; }
+};
+
+eval::InterpretedFunction::op_function my_select(CellType lct, CellType rct, bool write_left) {
+ if (write_left) {
+ return select_2<MyInplaceJoinLeftOp>(lct, rct);
+ } else {
+ return select_2<MyInplaceJoinRightOp>(lct, rct);
}
- return (a.is_dense() && (a == b));
}
} // namespace vespalib::tensor::<unnamed>
@@ -68,7 +78,8 @@ DenseInplaceJoinFunction::~DenseInplaceJoinFunction()
eval::InterpretedFunction::Instruction
DenseInplaceJoinFunction::compile_self(Stash &) const
{
- auto op = _write_left ? my_inplace_join_op<true> : my_inplace_join_op<false>;
+ auto op = my_select(lhs().result_type().cell_type(),
+ rhs().result_type().cell_type(), _write_left);
return eval::InterpretedFunction::Instruction(op, (uint64_t)function());
}
@@ -85,11 +96,17 @@ DenseInplaceJoinFunction::optimize(const eval::TensorFunction &expr, Stash &stas
if (auto join = as<Join>(expr)) {
const TensorFunction &lhs = join->lhs();
const TensorFunction &rhs = join->rhs();
- if ((lhs.result_is_mutable() || rhs.result_is_mutable()) &&
- sameShapeConcreteDenseTensors(lhs.result_type(), rhs.result_type()))
+ if (lhs.result_type().is_dense() &&
+ (lhs.result_type().dimensions() == rhs.result_type().dimensions()))
{
- return stash.create<DenseInplaceJoinFunction>(join->result_type(), lhs, rhs,
- join->function(), lhs.result_is_mutable());
+ if (lhs.result_is_mutable() && (lhs.result_type() == expr.result_type())) {
+ return stash.create<DenseInplaceJoinFunction>(join->result_type(), lhs, rhs,
+ join->function(), /* write left: */ true);
+ }
+ if (rhs.result_is_mutable() && (rhs.result_type() == expr.result_type())) {
+ return stash.create<DenseInplaceJoinFunction>(join->result_type(), lhs, rhs,
+ join->function(), /* write left: */ false);
+ }
}
}
return expr;
diff --git a/eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp
index b38a6b175dc..c82cda34a28 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp
@@ -16,24 +16,19 @@ using namespace eval::tensor_function;
namespace {
-ArrayRef<double> getMutableCells(const eval::Value &value) {
- const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value);
- return unconstify(denseTensor.cellsRef().typify<double>());
-}
-
+template <typename CT>
void my_inplace_map_op(eval::InterpretedFunction::State &state, uint64_t param) {
map_fun_t function = (map_fun_t)param;
- for (double &cell: getMutableCells(state.peek(0))) {
+ ArrayRef<CT> cells = unconstify(DenseTensorView::typify_cells<CT>(state.peek(0)));
+ for (CT &cell: cells) {
cell = function(cell);
}
}
-bool isConcreteDenseTensor(const ValueType &type) {
- if (type.cell_type() != ValueType::CellType::DOUBLE) {
- return false; // non-double cell types not supported
- }
- return type.is_dense();
-}
+struct MyInplaceMapOp {
+ template <typename CT>
+ static auto get_fun() { return my_inplace_map_op<CT>; }
+};
} // namespace vespalib::tensor::<unnamed>
@@ -51,14 +46,16 @@ DenseInplaceMapFunction::~DenseInplaceMapFunction()
eval::InterpretedFunction::Instruction
DenseInplaceMapFunction::compile_self(Stash &) const
{
- return eval::InterpretedFunction::Instruction(my_inplace_map_op, (uint64_t)function());
+ auto op = select_1<MyInplaceMapOp>(result_type().cell_type());
+ return eval::InterpretedFunction::Instruction(op, (uint64_t)function());
}
const TensorFunction &
DenseInplaceMapFunction::optimize(const eval::TensorFunction &expr, Stash &stash)
{
if (auto map = as<Map>(expr)) {
- if (map->child().result_is_mutable() && isConcreteDenseTensor(map->result_type())) {
+ if (map->child().result_is_mutable() && map->result_type().is_dense()) {
+ assert(map->result_type().cell_type() == map->child().result_type().cell_type());
return stash.create<DenseInplaceMapFunction>(map->result_type(), map->child(), map->function());
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp
index 3c58320a6e6..a64d5edbb37 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp
@@ -14,13 +14,6 @@ using namespace eval::tensor_function;
namespace {
-bool is_concrete_dense_tensor(const ValueType &type) {
- if (type.cell_type() != ValueType::CellType::DOUBLE) {
- return false; // non-double cell types not supported
- }
- return type.is_dense();
-}
-
bool is_ident_aggr(Aggr aggr) {
return ((aggr == Aggr::AVG) ||
(aggr == Aggr::PROD) ||
@@ -47,11 +40,12 @@ DenseRemoveDimensionOptimizer::optimize(const eval::TensorFunction &expr, Stash
{
if (auto reduce = as<Reduce>(expr)) {
const TensorFunction &child = reduce->child();
- if (is_concrete_dense_tensor(expr.result_type()) &&
- is_concrete_dense_tensor(child.result_type()) &&
+ if (expr.result_type().is_dense() &&
+ child.result_type().is_dense() &&
is_ident_aggr(reduce->aggr()) &&
is_trivial_dim_list(child.result_type(), reduce->dimensions()))
{
+ assert(expr.result_type().cell_type() == child.result_type().cell_type());
return DenseReplaceTypeFunction::create_compact(expr.result_type(), child, stash);
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
index d98cf52d279..3fed84323ca 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
@@ -95,8 +95,7 @@ sameShapeJoin(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs,
{
size_t sz = lhs.size();
assert(sz == rhs.size());
- using OutputSelector = OutputCellType<LCT, RCT>;
- using OCT = typename OutputSelector::output_type;
+ using OCT = typename eval::UnifyCellTypes<LCT,RCT>::type;
std::vector<OCT> newCells;
newCells.reserve(sz);
auto rhsCellItr = rhs.cbegin();
@@ -107,7 +106,7 @@ sameShapeJoin(const ConstArrayRef<LCT> &lhs, const ConstArrayRef<RCT> &rhs,
}
assert(rhsCellItr == rhs.cend());
assert(newCells.size() == sz);
- auto newType = eval::ValueType::tensor_type(lhs_type.dimensions(), OutputSelector::output_cell_type());
+ auto newType = eval::ValueType::tensor_type(lhs_type.dimensions(), eval::get_cell_type<OCT>());
return std::make_unique<DenseTensor<OCT>>(std::move(newType), std::move(newCells));
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index 1ec4daf40fd..778f2aa2871 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -42,6 +42,13 @@ public:
Tensor::UP clone() const override;
eval::TensorSpec toSpec() const override;
void accept(TensorVisitor &visitor) const override;
+
+ template <typename T> static ConstArrayRef<T> typify_cells(const eval::Value &self) {
+ return static_cast<const DenseTensorView &>(self).cellsRef().typify<T>();
+ }
+ template <typename T> static ConstArrayRef<T> unsafe_typify_cells(const eval::Value &self) {
+ return static_cast<const DenseTensorView &>(self).cellsRef().unsafe_typify<T>();
+ }
protected:
explicit DenseTensorView(const eval::ValueType &type_in)
: _typeRef(type_in),
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
index b6ac87ce012..2db5b4e8f92 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
@@ -21,21 +21,36 @@ using namespace eval::operation;
namespace {
-XWInput getCellsRef(const eval::Value &value) {
- const DenseTensorView &denseTensor = static_cast<const DenseTensorView &>(value);
- TypedCells ref = denseTensor.cellsRef();
- assert(ref.type == CellType::DOUBLE);
- return ref.typify<double>();
-}
+template <typename LCT, typename RCT>
+struct HWSupport {
+ static double call(hwaccelrated::IAccelrated *, const LCT *lhs, const RCT *rhs, size_t len) {
+ double result = 0.0;
+ for (size_t i = 0; i < len; ++i) {
+ result += (lhs[i] * rhs[i]);
+ }
+ return result;
+ }
+};
+template <> struct HWSupport<float, float> {
+ static double call(hwaccelrated::IAccelrated *hw, const float *lhs, const float *rhs, size_t len) {
+ return hw->dotProduct(lhs, rhs, len);
+ }
+};
+template <> struct HWSupport<double, double> {
+ static double call(hwaccelrated::IAccelrated *hw, const double *lhs, const double *rhs, size_t len) {
+ return hw->dotProduct(lhs, rhs, len);
+ }
+};
+template <typename LCT, typename RCT, typename OCT>
void multiDotProduct(const DenseXWProductFunction::Self &self,
- const XWInput &vectorCells, const XWInput &matrixCells, XWOutput &result)
+ const ConstArrayRef<LCT> &vectorCells, const ConstArrayRef<RCT> &matrixCells, ArrayRef<OCT> &result)
{
- double *out = result.begin();
- const double *matrixP = matrixCells.cbegin();
- const double * const vectorP = vectorCells.cbegin();
+ OCT *out = result.begin();
+ const RCT *matrixP = matrixCells.cbegin();
+ const LCT * const vectorP = vectorCells.cbegin();
for (size_t row = 0; row < self._resultSize; ++row) {
- double cell = self._hwAccelerator->dotProduct(vectorP, matrixP, self._vectorSize);
+ double cell = HWSupport<LCT,RCT>::call(self._hwAccelerator.get(), vectorP, matrixP, self._vectorSize);
*out++ = cell;
matrixP += self._vectorSize;
}
@@ -43,12 +58,13 @@ void multiDotProduct(const DenseXWProductFunction::Self &self,
assert(matrixP == matrixCells.cend());
}
+template <typename LCT, typename RCT, typename OCT>
void transposedProduct(const DenseXWProductFunction::Self &self,
- const XWInput &vectorCells, const XWInput &matrixCells, XWOutput &result)
+ const ConstArrayRef<LCT> &vectorCells, const ConstArrayRef<RCT> &matrixCells, ArrayRef<OCT> &result)
{
- double *out = result.begin();
- const double * const matrixP = matrixCells.cbegin();
- const double * const vectorP = vectorCells.cbegin();
+ OCT *out = result.begin();
+ const RCT * const matrixP = matrixCells.cbegin();
+ const LCT * const vectorP = vectorCells.cbegin();
for (size_t row = 0; row < self._resultSize; ++row) {
double cell = 0;
for (size_t col = 0; col < self._vectorSize; ++col) {
@@ -59,41 +75,54 @@ void transposedProduct(const DenseXWProductFunction::Self &self,
assert(out == result.end());
}
-template <bool commonDimensionInnermost>
+template <typename LCT, typename RCT, bool commonDimensionInnermost>
void my_xw_product_op(eval::InterpretedFunction::State &state, uint64_t param) {
DenseXWProductFunction::Self *self = (DenseXWProductFunction::Self *)(param);
- XWInput vectorCells = getCellsRef(state.peek(1));
- XWInput matrixCells = getCellsRef(state.peek(0));
-
- ArrayRef<double> outputCells = state.stash.create_array<double>(self->_resultSize);
+ using OCT = typename eval::UnifyCellTypes<LCT,RCT>::type;
+ auto vectorCells = DenseTensorView::typify_cells<LCT>(state.peek(1));
+ auto matrixCells = DenseTensorView::typify_cells<RCT>(state.peek(0));
+ auto outputCells = state.stash.create_array<OCT>(self->_resultSize);
if (commonDimensionInnermost) {
multiDotProduct(*self, vectorCells, matrixCells, outputCells);
} else {
transposedProduct(*self, vectorCells, matrixCells, outputCells);
}
+
state.pop_pop_push(state.stash.create<DenseTensorView>(self->_resultType, TypedCells(outputCells)));
}
-bool isConcreteDenseTensor(const ValueType &type, size_t d) {
- if (type.cell_type() != ValueType::CellType::DOUBLE) {
- return false; // non-double cell types not supported
+template <bool common_inner>
+struct MyXWProductOp {
+ template <typename LCT, typename RCT>
+ static auto get_fun() { return my_xw_product_op<LCT,RCT,common_inner>; }
+};
+
+eval::InterpretedFunction::op_function my_select(CellType lct, CellType rct, bool common_innermost) {
+ if (common_innermost) {
+ return select_2<MyXWProductOp<true> >(lct, rct);
+ } else {
+ return select_2<MyXWProductOp<false> >(lct, rct);
}
+}
+
+bool isDenseTensor(const ValueType &type, size_t d) {
return (type.is_dense() && (type.dimensions().size() == d));
}
bool isDenseXWProduct(const ValueType &res, const ValueType &vec, const ValueType &mat) {
- if (isConcreteDenseTensor(res, 1) &&
- isConcreteDenseTensor(vec, 1) &&
- isConcreteDenseTensor(mat, 2))
+ if (isDenseTensor(res, 1) &&
+ isDenseTensor(vec, 1) &&
+ isDenseTensor(mat, 2))
{
size_t res_idx = mat.dimension_index(res.dimensions()[0].name);
size_t vec_idx = mat.dimension_index(vec.dimensions()[0].name);
size_t npos = ValueType::Dimension::npos;
if ((res_idx != npos) && (vec_idx != npos) && (res_idx != vec_idx)) {
- return ((mat.dimensions()[res_idx].size == res.dimensions()[0].size) &&
- (mat.dimensions()[vec_idx].size == vec.dimensions()[0].size));
+ assert(mat.dimensions()[res_idx].size == res.dimensions()[0].size);
+ assert(mat.dimensions()[vec_idx].size == vec.dimensions()[0].size);
+ return true;
}
}
return false;
@@ -134,7 +163,8 @@ eval::InterpretedFunction::Instruction
DenseXWProductFunction::compile_self(Stash &stash) const
{
Self &self = stash.create<Self>(result_type(), _vectorSize, _resultSize);
- auto op = _commonDimensionInnermost ? my_xw_product_op<true> : my_xw_product_op<false>;
+ auto op = my_select(lhs().result_type().cell_type(),
+ rhs().result_type().cell_type(), _commonDimensionInnermost);
return eval::InterpretedFunction::Instruction(op, (uint64_t)(&self));
}
@@ -150,22 +180,22 @@ DenseXWProductFunction::visit_self(vespalib::ObjectVisitor &visitor) const
const TensorFunction &
DenseXWProductFunction::optimize(const eval::TensorFunction &expr, Stash &stash)
{
- const Reduce *reduce = as<Reduce>(expr);
- if (reduce && (reduce->aggr() == Aggr::SUM)) {
- const ValueType &result_type = reduce->result_type();
- const Join *join = as<Join>(reduce->child());
- if (join && (join->function() == Mul::f)) {
- const TensorFunction &lhs = join->lhs();
- const TensorFunction &rhs = join->rhs();
- if (isDenseXWProduct(result_type, lhs.result_type(), rhs.result_type())) {
- return createDenseXWProduct(result_type, lhs, rhs, stash);
- }
- if (isDenseXWProduct(result_type, rhs.result_type(), lhs.result_type())) {
- return createDenseXWProduct(result_type, rhs, lhs, stash);
- }
+ const Reduce *reduce = as<Reduce>(expr);
+ if (reduce && (reduce->aggr() == Aggr::SUM)) {
+ const ValueType &result_type = reduce->result_type();
+ const Join *join = as<Join>(reduce->child());
+ if (join && (join->function() == Mul::f)) {
+ const TensorFunction &lhs = join->lhs();
+ const TensorFunction &rhs = join->rhs();
+ if (isDenseXWProduct(result_type, lhs.result_type(), rhs.result_type())) {
+ return createDenseXWProduct(result_type, lhs, rhs, stash);
+ }
+ if (isDenseXWProduct(result_type, rhs.result_type(), lhs.result_type())) {
+ return createDenseXWProduct(result_type, rhs, lhs, stash);
}
}
- return expr;
+ }
+ return expr;
}
} // namespace vespalib::tensor
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
index 9f1bc12b110..f2f4d67c0f0 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
@@ -8,9 +8,6 @@
namespace vespalib::tensor {
-using XWInput = ConstArrayRef<double>;
-using XWOutput = ArrayRef<double>;
-
/**
* Tensor function for product of one 1-dimensional and one 2-dimensional dense tensor.
*/
diff --git a/eval/src/vespa/eval/tensor/dense/typed_cells.h b/eval/src/vespa/eval/tensor/dense/typed_cells.h
index 98f95d54d9b..0f22c85735e 100644
--- a/eval/src/vespa/eval/tensor/dense/typed_cells.h
+++ b/eval/src/vespa/eval/tensor/dense/typed_cells.h
@@ -12,25 +12,6 @@ namespace vespalib::tensor {
using CellType = vespalib::eval::ValueType::CellType;
-
-template<typename LCT, typename RCT> struct OutputCellType;
-template<> struct OutputCellType<double, double> {
- typedef double output_type;
- static constexpr CellType output_cell_type() { return CellType::DOUBLE; };
-};
-template<> struct OutputCellType<float, double> {
- typedef double output_type;
- static constexpr CellType output_cell_type() { return CellType::DOUBLE; };
-};
-template<> struct OutputCellType<double, float> {
- typedef double output_type;
- static constexpr CellType output_cell_type() { return CellType::DOUBLE; };
-};
-template<> struct OutputCellType<float, float> {
- typedef float output_type;
- static constexpr CellType output_cell_type() { return CellType::FLOAT; };
-};
-
struct TypedCells {
const void *data;
CellType type;
@@ -67,7 +48,7 @@ struct TypedCells {
};
template <typename TGT, typename... Args>
-auto dispatch_0(CellType ct, Args &&...args) {
+decltype(auto) dispatch_0(CellType ct, Args &&...args) {
switch (ct) {
case CellType::DOUBLE: return TGT::template call<double>(std::forward<Args>(args)...);
case CellType::FLOAT: return TGT::template call<float>(std::forward<Args>(args)...);
@@ -76,7 +57,7 @@ auto dispatch_0(CellType ct, Args &&...args) {
}
template <typename TGT, typename... Args>
-auto dispatch_1(const TypedCells &a, Args &&...args) {
+decltype(auto) dispatch_1(const TypedCells &a, Args &&...args) {
switch (a.type) {
case CellType::DOUBLE: return TGT::call(a.unsafe_typify<double>(), std::forward<Args>(args)...);
case CellType::FLOAT: return TGT::call(a.unsafe_typify<float>(), std::forward<Args>(args)...);
@@ -85,7 +66,7 @@ auto dispatch_1(const TypedCells &a, Args &&...args) {
}
template <typename TGT, typename A1, typename... Args>
-auto dispatch_2(A1 &&a, const TypedCells &b, Args &&...args) {
+decltype(auto) dispatch_2(A1 &&a, const TypedCells &b, Args &&...args) {
switch (b.type) {
case CellType::DOUBLE: return dispatch_1<TGT>(std::forward<A1>(a), b.unsafe_typify<double>(), std::forward<Args>(args)...);
case CellType::FLOAT: return dispatch_1<TGT>(std::forward<A1>(a), b.unsafe_typify<float>(), std::forward<Args>(args)...);
@@ -94,7 +75,7 @@ auto dispatch_2(A1 &&a, const TypedCells &b, Args &&...args) {
}
template <typename T, typename... Args>
-auto select_1(CellType a_type) {
+decltype(auto) select_1(CellType a_type) {
switch(a_type) {
case CellType::DOUBLE: return T::template get_fun<double, Args...>();
case CellType::FLOAT: return T::template get_fun<float, Args...>();
@@ -103,7 +84,7 @@ auto select_1(CellType a_type) {
}
template <typename T>
-auto select_2(CellType a_type, CellType b_type) {
+decltype(auto) select_2(CellType a_type, CellType b_type) {
switch(b_type) {
case CellType::DOUBLE: return select_1<T, double>(a_type);
case CellType::FLOAT: return select_1<T, float>(a_type);