summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-04-08 11:14:55 +0200
committerGitHub <noreply@github.com>2021-04-08 11:14:55 +0200
commit907638cc59eba8f6e80efcc980ca3f35f30c8f20 (patch)
tree8a2482f5fd71072a4a017fd11137cc5ff686d96e /eval
parent1edf0b9cba9d0f36dc7f4a615c486f3b5ca735d6 (diff)
parent71dc1922dd23599b56cab8fafc10c75f00138e0c (diff)
Merge pull request #17282 from vespa-engine/arnej/optimize-unstable-join-with-number
add optimization for join between bfloat16 tensor and number
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp16
-rw-r--r--eval/src/vespa/eval/instruction/join_with_number_function.cpp70
-rw-r--r--eval/src/vespa/eval/instruction/join_with_number_function.h2
3 files changed, 52 insertions, 36 deletions
diff --git a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
index a6486de6858..4a0eb00e7e7 100644
--- a/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
+++ b/eval/src/tests/instruction/join_with_number/join_with_number_function_test.cpp
@@ -5,8 +5,8 @@
#include <vespa/eval/eval/test/eval_fixture.h>
#include <vespa/eval/eval/test/gen_spec.h>
#include <vespa/eval/instruction/join_with_number_function.h>
-
#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/unwind_message.h>
using namespace vespalib;
using namespace vespalib::eval;
@@ -33,11 +33,12 @@ std::ostream &operator<<(std::ostream &os, Primary primary)
struct FunInfo {
using LookFor = JoinWithNumberFunction;
Primary primary;
+ bool pri_mut;
bool inplace;
void verify(const EvalFixture &fixture, const LookFor &fun) const {
EXPECT_TRUE(fun.result_is_mutable());
EXPECT_EQUAL(fun.primary(), primary);
- EXPECT_EQUAL(fun.inplace(), inplace);
+ EXPECT_EQUAL(fun.primary_is_mutable(), pri_mut);
if (inplace) {
size_t idx = (fun.primary() == Primary::LHS) ? 0 : 1;
EXPECT_EQUAL(fixture.result_value().cells().data,
@@ -46,17 +47,18 @@ struct FunInfo {
}
};
-void verify_optimized(const vespalib::string &expr, Primary primary, bool inplace) {
- // fprintf(stderr, "%s\n", expr.c_str());
+void verify_optimized(const vespalib::string &expr, Primary primary, bool pri_mut) {
+ UNWIND_MSG("optimize %s", expr.c_str());
const CellTypeSpace stable_types(CellTypeUtils::list_stable_types(), 2);
- FunInfo stable_details{primary, inplace};
+ FunInfo stable_details{primary, pri_mut, pri_mut};
TEST_DO(EvalFixture::verify<FunInfo>(expr, {stable_details}, stable_types));
const CellTypeSpace unstable_types(CellTypeUtils::list_unstable_types(), 2);
- TEST_DO(EvalFixture::verify<FunInfo>(expr, {}, unstable_types));
+ FunInfo unstable_details{primary, pri_mut, false};
+ TEST_DO(EvalFixture::verify<FunInfo>(expr, {unstable_details}, unstable_types));
}
void verify_not_optimized(const vespalib::string &expr) {
- // fprintf(stderr, "%s\n", expr.c_str());
+ UNWIND_MSG("not: %s", expr.c_str());
CellTypeSpace all_types(CellTypeUtils::list_types(), 2);
TEST_DO(EvalFixture::verify<FunInfo>(expr, {}, all_types));
}
diff --git a/eval/src/vespa/eval/instruction/join_with_number_function.cpp b/eval/src/vespa/eval/instruction/join_with_number_function.cpp
index c574e3f8ad9..592076f23ce 100644
--- a/eval/src/vespa/eval/instruction/join_with_number_function.cpp
+++ b/eval/src/vespa/eval/instruction/join_with_number_function.cpp
@@ -14,39 +14,55 @@ namespace vespalib::eval {
using Instruction = InterpretedFunction::Instruction;
using State = InterpretedFunction::State;
+using vespalib::eval::tensor_function::unwrap_param;
+using vespalib::eval::tensor_function::wrap_param;
namespace {
-template <typename CT, bool inplace>
-ArrayRef<CT> make_dst_cells(ConstArrayRef<CT> src_cells, Stash &stash) {
- if (inplace) {
+struct JoinWithNumberParam {
+ const ValueType res_type;
+ const join_fun_t function;
+ JoinWithNumberParam(const ValueType &r, join_fun_t f) : res_type(r), function(f) {}
+};
+
+template <typename ICT, typename OCT, bool inplace>
+ArrayRef<OCT> make_dst_cells(ConstArrayRef<ICT> src_cells, Stash &stash) {
+ if constexpr (inplace) {
+ static_assert(std::is_same_v<ICT,OCT>);
return unconstify(src_cells);
} else {
- return stash.create_uninitialized_array<CT>(src_cells.size());
+ return stash.create_uninitialized_array<OCT>(src_cells.size());
}
}
-template <typename CT, typename Fun, bool inplace, bool swap>
-void my_number_join_op(State &state, uint64_t param) {
+template <typename ICT, typename OCT, typename Fun, bool inplace, bool swap>
+void my_number_join_op(State &state, uint64_t param_in) {
+ const auto &param = unwrap_param<JoinWithNumberParam>(param_in);
using OP = typename std::conditional<swap,SwapArgs2<Fun>,Fun>::type;
- OP my_op((join_fun_t)param);
+ OP my_op(param.function);
const Value &tensor = state.peek(swap ? 0 : 1);
- CT number = state.peek(swap ? 1 : 0).as_double();
- auto src_cells = tensor.cells().typify<CT>();
- auto dst_cells = make_dst_cells<CT, inplace>(src_cells, state.stash);
+ OCT number = state.peek(swap ? 1 : 0).as_double();
+ auto src_cells = tensor.cells().typify<ICT>();
+ auto dst_cells = make_dst_cells<ICT, OCT, inplace>(src_cells, state.stash);
apply_op2_vec_num(dst_cells.begin(), src_cells.begin(), number, dst_cells.size(), my_op);
if (inplace) {
state.pop_pop_push(tensor);
} else {
- state.pop_pop_push(state.stash.create<ValueView>(tensor.type(), tensor.index(), TypedCells(dst_cells)));
+ state.pop_pop_push(state.stash.create<ValueView>(param.res_type, tensor.index(), TypedCells(dst_cells)));
}
}
struct SelectJoinWithNumberOp {
- template<typename CT, typename Fun,
- typename InputIsMutable, typename NumberWasLeft>
+ template<typename CM, typename Fun,
+ typename PrimaryMutable, typename NumberWasLeft>
static auto invoke() {
- return my_number_join_op<CT, Fun, InputIsMutable::value, NumberWasLeft::value>;
+ constexpr CellMeta icm = CM::value;
+ constexpr CellMeta num(CellType::DOUBLE, true);
+ constexpr CellMeta ocm = CellMeta::join(icm, num);
+ using ICT = CellValueType<icm.cell_type>;
+ using OCT = CellValueType<ocm.cell_type>;
+ constexpr bool inplace = (PrimaryMutable::value && std::is_same_v<ICT,OCT>);
+ return my_number_join_op<ICT, OCT, Fun, inplace, NumberWasLeft::value>;
}
};
@@ -62,7 +78,7 @@ JoinWithNumberFunction::JoinWithNumberFunction(const Join &original, bool tensor
JoinWithNumberFunction::~JoinWithNumberFunction() = default;
bool
-JoinWithNumberFunction::inplace() const {
+JoinWithNumberFunction::primary_is_mutable() const {
if (_primary == Primary::LHS) {
return lhs().result_is_mutable();
} else {
@@ -70,16 +86,19 @@ JoinWithNumberFunction::inplace() const {
}
}
-using MyTypify = TypifyValue<TypifyCellType,vespalib::TypifyBool,operation::TypifyOp2>;
+using MyTypify = TypifyValue<TypifyCellMeta,vespalib::TypifyBool,operation::TypifyOp2>;
InterpretedFunction::Instruction
-JoinWithNumberFunction::compile_self(const ValueBuilderFactory &, Stash &) const
+JoinWithNumberFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const
{
- auto op = typify_invoke<4,MyTypify,SelectJoinWithNumberOp>(result_type().cell_type(),
+ const auto &param = stash.create<JoinWithNumberParam>(result_type(), _function);
+ auto input_type = (_primary == Primary::LHS) ? lhs().result_type() : rhs().result_type();
+ assert(result_type() == ValueType::join(input_type, ValueType::double_type()));
+ auto op = typify_invoke<4,MyTypify,SelectJoinWithNumberOp>(input_type.cell_meta(),
_function,
- inplace(),
+ primary_is_mutable(),
(_primary == Primary::RHS));
- return Instruction(op, (uint64_t)(_function));
+ return Instruction(op, wrap_param<JoinWithNumberParam>(param));
}
void
@@ -87,7 +106,7 @@ JoinWithNumberFunction::visit_self(vespalib::ObjectVisitor &visitor) const
{
Super::visit_self(visitor);
visitor.visitBool("tensor_was_right", (_primary == Primary::RHS));
- visitor.visitBool("is_inplace", inplace());
+ visitor.visitBool("primary_is_mutable", primary_is_mutable());
}
const TensorFunction &
@@ -95,17 +114,12 @@ JoinWithNumberFunction::optimize(const TensorFunction &expr, Stash &stash)
{
if (! expr.result_type().is_double()) {
if (const auto *join = as<Join>(expr)) {
- const ValueType &result_type = join->result_type();
const TensorFunction &lhs = join->lhs();
const TensorFunction &rhs = join->rhs();
- if (lhs.result_type().is_double() &&
- (result_type == rhs.result_type()))
- {
+ if (lhs.result_type().is_double()) {
return stash.create<JoinWithNumberFunction>(*join, true);
}
- if (rhs.result_type().is_double() &&
- (result_type == lhs.result_type()))
- {
+ if (rhs.result_type().is_double()) {
return stash.create<JoinWithNumberFunction>(*join, false);
}
}
diff --git a/eval/src/vespa/eval/instruction/join_with_number_function.h b/eval/src/vespa/eval/instruction/join_with_number_function.h
index 546ff75b175..9d29ad5eb5d 100644
--- a/eval/src/vespa/eval/instruction/join_with_number_function.h
+++ b/eval/src/vespa/eval/instruction/join_with_number_function.h
@@ -23,7 +23,7 @@ public:
JoinWithNumberFunction(const tensor_function::Join &original_join, bool number_on_left);
~JoinWithNumberFunction();
Primary primary() const { return _primary; }
- bool inplace() const;
+ bool primary_is_mutable() const;
bool result_is_mutable() const override { return true; }
InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;