aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2018-01-16 13:43:32 +0000
committerHåvard Pettersen <havardpe@oath.com>2018-01-16 13:43:32 +0000
commit9f3409ebed82fbaf6b7b18347211e2cc2277a6d2 (patch)
treefbfeb930876e90ab18314569b7c6e04c46ec5157 /eval
parent2f55bdbcdfb09c0a1f031c93cb3152582f2d0f81 (diff)
tweak tensor IR semantics and add more encapsulation
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp10
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp74
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h187
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp6
-rw-r--r--eval/src/vespa/eval/eval/value.h1
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h4
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp26
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h2
8 files changed, 170 insertions, 140 deletions
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index 9f4889ee5be..a5dbf390562 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -113,7 +113,7 @@ TEST("require that tensor injection works") {
size_t a_id = ctx.add_tensor(ctx.make_tensor_inject());
Value::UP expect = ctx.make_tensor_inject();
const auto &fun = inject(ValueType::from_spec("tensor(x[2],y[2])"), a_id, ctx.stash);
- EXPECT_EQUAL(expect->type(), fun.result_type);
+ EXPECT_EQUAL(expect->type(), fun.result_type());
const auto &prog = ctx.compile(fun);
TEST_DO(verify_equal(*expect, ctx.eval(prog)));
}
@@ -123,7 +123,7 @@ TEST("require that partial tensor reduction works") {
size_t a_id = ctx.add_tensor(ctx.make_tensor_reduce_input());
Value::UP expect = ctx.make_tensor_reduce_y_output();
const auto &fun = reduce(inject(ValueType::from_spec("tensor(x[3],y[2])"), a_id, ctx.stash), Aggr::SUM, {"y"}, ctx.stash);
- EXPECT_EQUAL(expect->type(), fun.result_type);
+ EXPECT_EQUAL(expect->type(), fun.result_type());
const auto &prog = ctx.compile(fun);
TEST_DO(verify_equal(*expect, ctx.eval(prog)));
}
@@ -132,7 +132,7 @@ TEST("require that full tensor reduction works") {
EvalCtx ctx(SimpleTensorEngine::ref());
size_t a_id = ctx.add_tensor(ctx.make_tensor_reduce_input());
const auto &fun = reduce(inject(ValueType::from_spec("tensor(x[3],y[2])"), a_id, ctx.stash), Aggr::SUM, {}, ctx.stash);
- EXPECT_EQUAL(ValueType::from_spec("double"), fun.result_type);
+ EXPECT_EQUAL(ValueType::from_spec("double"), fun.result_type());
const auto &prog = ctx.compile(fun);
const Value &result = ctx.eval(prog);
EXPECT_TRUE(result.is_double());
@@ -144,7 +144,7 @@ TEST("require that tensor map works") {
size_t a_id = ctx.add_tensor(ctx.make_tensor_map_input());
Value::UP expect = ctx.make_tensor_map_output();
const auto &fun = map(inject(ValueType::from_spec("tensor(x{},y{})"), a_id, ctx.stash), operation::Neg::f, ctx.stash);
- EXPECT_EQUAL(expect->type(), fun.result_type);
+ EXPECT_EQUAL(expect->type(), fun.result_type());
const auto &prog = ctx.compile(fun);
TEST_DO(verify_equal(*expect, ctx.eval(prog)));
}
@@ -157,7 +157,7 @@ TEST("require that tensor join works") {
const auto &fun = join(inject(ValueType::from_spec("tensor(x{},y{})"), a_id, ctx.stash),
inject(ValueType::from_spec("tensor(y{},z{})"), b_id, ctx.stash),
operation::Mul::f, ctx.stash);
- EXPECT_EQUAL(expect->type(), fun.result_type);
+ EXPECT_EQUAL(expect->type(), fun.result_type());
const auto &prog = ctx.compile(fun);
TEST_DO(verify_equal(*expect, ctx.eval(prog)));
}
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index 83ac51bdc09..94b30257e64 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -22,86 +22,86 @@ const TensorEngine &infer_engine(const std::initializer_list<Value::CREF> &value
//-----------------------------------------------------------------------------
+void
+Inject::push_children(std::vector<Child::CREF> &) const
+{
+}
+
const Value &
Inject::eval(const LazyParams &params, Stash &stash) const
{
- return params.resolve(tensor_id, stash);
+ return params.resolve(_param_idx, stash);
}
+//-----------------------------------------------------------------------------
+
void
-Inject::push_children(std::vector<Child::CREF> &) const
+Reduce::push_children(std::vector<Child::CREF> &children) const
{
+ children.emplace_back(_child);
}
-//-----------------------------------------------------------------------------
-
const Value &
Reduce::eval(const LazyParams &params, Stash &stash) const
{
- const Value &a = tensor.get().eval(params, stash);
+ const Value &a = child().eval(params, stash);
const TensorEngine &engine = infer_engine({a});
- return engine.reduce(a, aggr, dimensions, stash);
+ return engine.reduce(a, _aggr, _dimensions, stash);
}
+//-----------------------------------------------------------------------------
+
void
-Reduce::push_children(std::vector<Child::CREF> &children) const
+Map::push_children(std::vector<Child::CREF> &children) const
{
- children.emplace_back(tensor);
+ children.emplace_back(_child);
}
-//-----------------------------------------------------------------------------
-
const Value &
Map::eval(const LazyParams &params, Stash &stash) const
{
- const Value &a = tensor.get().eval(params, stash);
+ const Value &a = child().eval(params, stash);
const TensorEngine &engine = infer_engine({a});
- return engine.map(a, function, stash);
+ return engine.map(a, _function, stash);
}
+//-----------------------------------------------------------------------------
+
void
-Map::push_children(std::vector<Child::CREF> &children) const
+Join::push_children(std::vector<Child::CREF> &children) const
{
- children.emplace_back(tensor);
+ children.emplace_back(_lhs);
+ children.emplace_back(_rhs);
}
-//-----------------------------------------------------------------------------
-
const Value &
Join::eval(const LazyParams &params, Stash &stash) const
{
- const Value &a = lhs_tensor.get().eval(params, stash);
- const Value &b = rhs_tensor.get().eval(params, stash);
+ const Value &a = lhs().eval(params, stash);
+ const Value &b = rhs().eval(params, stash);
const TensorEngine &engine = infer_engine({a,b});
- return engine.join(a, b, function, stash);
-}
-
-void
-Join::push_children(std::vector<Child::CREF> &children) const
-{
- children.emplace_back(lhs_tensor);
- children.emplace_back(rhs_tensor);
+ return engine.join(a, b, _function, stash);
}
//-----------------------------------------------------------------------------
-const Node &inject(const ValueType &type, size_t tensor_id, Stash &stash) {
- return stash.create<Inject>(type, tensor_id);
+const Node &inject(const ValueType &type, size_t param_idx, Stash &stash) {
+ return stash.create<Inject>(type, param_idx);
}
-const Node &reduce(const Node &tensor, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) {
- ValueType result_type = tensor.result_type.reduce(dimensions);
- return stash.create<Reduce>(result_type, tensor, aggr, dimensions);
+const Node &reduce(const Node &child, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) {
+ ValueType result_type = child.result_type().reduce(dimensions);
+ return stash.create<Reduce>(result_type, child, aggr, dimensions);
}
-const Node &map(const Node &tensor, map_fun_t function, Stash &stash) {
- ValueType result_type = tensor.result_type;
- return stash.create<Map>(result_type, tensor, function);
+const Node &map(const Node &child, map_fun_t function, Stash &stash) {
+ ValueType result_type = child.result_type();
+ return stash.create<Map>(result_type, child, function);
}
-const Node &join(const Node &lhs_tensor, const Node &rhs_tensor, join_fun_t function, Stash &stash) {
- ValueType result_type = ValueType::join(lhs_tensor.result_type, rhs_tensor.result_type);
- return stash.create<Join>(result_type, lhs_tensor, rhs_tensor, function);
+const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash) {
+ ValueType result_type = ValueType::join(lhs.result_type(), rhs.result_type());
+ return stash.create<Join>(result_type, lhs, rhs, function);
}
} // namespace vespalib::eval::tensor_function
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index f2075e5d4ea..a93cf95dce4 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -22,21 +22,62 @@ class Tensor;
//-----------------------------------------------------------------------------
/**
- * A tensor function that can be evaluated. A TensorFunction will
- * typically be produced by an implementation-specific compile step
- * that takes an implementation-independent intermediate
- * representation of the tensor function as input (tree of
- * tensor_function::Node objects).
+ * Interface used to describe a tensor function as a tree of nodes
+ * with information about operation sequencing and intermediate result
+ * types. Each node in the tree describes a single tensor
+ * operation. This is the intermediate representation of a tensor
+ * function.
+ *
+ * A tensor function will initially be created based on a Function
+ * (expression AST) and associated type-resolving. In this tree, each
+ * node will directly represent a single call to the tensor engine
+ * immediate API.
+ *
+ * The generic implementation-independent tree will then be optimized
+ * (in-place, bottom-up) where sub-expressions may be replaced with
+ * optimized implementation-specific alternatives.
+ *
+ * This leaves us with a mixed-mode tree with some generic and some
+ * specialized nodes, that may be evaluated recursively.
**/
struct TensorFunction
{
+ TensorFunction(const TensorFunction &) = delete;
+ TensorFunction &operator=(const TensorFunction &) = delete;
+ TensorFunction(TensorFunction &&) = delete;
+ TensorFunction &operator=(TensorFunction &&) = delete;
+ TensorFunction() {}
+
+ /**
+ * Reference to a sub-tree. References are replaceable to enable
+ * in-place bottom-up optimization.
+ **/
+ class Child {
+ private:
+ mutable const TensorFunction *ptr;
+ public:
+ using CREF = std::reference_wrapper<const Child>;
+ Child(const TensorFunction &child) : ptr(&child) {}
+ const TensorFunction &get() const { return *ptr; }
+ void set(const TensorFunction &child) const { ptr = &child; }
+ };
+ virtual const ValueType &result_type() const = 0;
+
+ /**
+ * Push references to all children (NB: implementation must use
+ * Child class for all sub-expression references) on the given
+ * vector. This is needed to enable optimization of trees where
+ * the core algorithm does not need to know concrete node types.
+ *
+ * @params children where to put your children references
+ **/
+ virtual void push_children(std::vector<Child::CREF> &children) const = 0;
+
/**
* Evaluate this tensor function based on the given
* parameters. The given stash can be used to store temporary
* objects that need to be kept alive for the return value to be
- * valid. The return value must conform to the result type
- * indicated by the intermediate representation describing this
- * tensor function.
+ * valid. The return value must conform to 'result_type'.
*
* @return result of evaluating this tensor function
* @param params external values needed to evaluate this function
@@ -59,101 +100,87 @@ namespace tensor_function {
using map_fun_t = double (*)(double);
using join_fun_t = double (*)(double, double);
-/**
- * Interface used to describe a tensor function as a tree of nodes
- * with information about operation sequencing and intermediate result
- * types. Each node in the tree will describe a single tensor
- * operation. This is the intermediate representation of a tensor
- * function.
- *
- * The intermediate representation of a tensor function can also be
- * used to evaluate the tensor function it represents directly. This
- * will invoke the immediate API on the tensor engine associated with
- * the input tensors. In other words, the intermediate representation
- * 'compiles to itself'.
- *
- * The reason for using the top-level TensorFunction interface when
- * referencing downwards in the tree is to enable mixed-mode execution
- * resulting from partial optimization where the intermediate
- * representation is partially replaced by implementation-specific
- * tensor functions, which may or may not rely on lower-level tensor
- * functions that may in turn be mixed-mode.
- **/
-struct Node : public TensorFunction
+class Node : public TensorFunction
{
- /**
- * Reference to a sub-tree. References are replaceable to enable
- * in-place bottom-up optimization during compilation.
- **/
- class Child {
- private:
- mutable const TensorFunction *ptr;
- public:
- using CREF = std::reference_wrapper<const Child>;
- Child(const TensorFunction &child) : ptr(&child) {}
- const TensorFunction &get() const { return *ptr; }
- void set(const TensorFunction &child) const { ptr = &child; }
- };
- const ValueType result_type;
- Node(const ValueType &result_type_in) : result_type(result_type_in) {}
- Node(const Node &) = delete;
- Node &operator=(const Node &) = delete;
- Node(Node &&) = delete;
- Node &operator=(Node &&) = delete;
- virtual void push_children(std::vector<Child::CREF> &children) const = 0;
+private:
+ ValueType _result_type;
+public:
+ Node(const ValueType &result_type_in) : _result_type(result_type_in) {}
+ const ValueType &result_type() const override { return _result_type; }
};
-struct Inject : Node {
- const size_t tensor_id;
+class Inject : public Node
+{
+private:
+ size_t _param_idx;
+public:
Inject(const ValueType &result_type_in,
- size_t tensor_id_in)
- : Node(result_type_in), tensor_id(tensor_id_in) {}
- const Value &eval(const LazyParams &params, Stash &) const override;
+ size_t param_idx_in)
+ : Node(result_type_in), _param_idx(param_idx_in) {}
+ size_t param_idx() const { return _param_idx; }
void push_children(std::vector<Child::CREF> &children) const override;
+ const Value &eval(const LazyParams &params, Stash &) const override;
};
-struct Reduce : Node {
- Child tensor;
- const Aggr aggr;
- const std::vector<vespalib::string> dimensions;
+class Reduce : public Node
+{
+private:
+ Child _child;
+ Aggr _aggr;
+ std::vector<vespalib::string> _dimensions;
+public:
Reduce(const ValueType &result_type_in,
- const TensorFunction &tensor_in,
+ const TensorFunction &child_in,
Aggr aggr_in,
const std::vector<vespalib::string> &dimensions_in)
- : Node(result_type_in), tensor(tensor_in), aggr(aggr_in), dimensions(dimensions_in) {}
- const Value &eval(const LazyParams &params, Stash &stash) const override;
+ : Node(result_type_in), _child(child_in), _aggr(aggr_in), _dimensions(dimensions_in) {}
+ const TensorFunction &child() const { return _child.get(); }
+ Aggr aggr() const { return _aggr; }
+ const std::vector<vespalib::string> dimensions() const { return _dimensions; }
void push_children(std::vector<Child::CREF> &children) const override;
+ const Value &eval(const LazyParams &params, Stash &stash) const override;
};
-struct Map : Node {
- Child tensor;
- const map_fun_t function;
+class Map : public Node
+{
+private:
+ Child _child;
+ map_fun_t _function;
+public:
Map(const ValueType &result_type_in,
- const TensorFunction &tensor_in,
+ const TensorFunction &child_in,
map_fun_t function_in)
- : Node(result_type_in), tensor(tensor_in), function(function_in) {}
- const Value &eval(const LazyParams &params, Stash &stash) const override;
+ : Node(result_type_in), _child(child_in), _function(function_in) {}
+ const TensorFunction &child() const { return _child.get(); }
+ map_fun_t function() const { return _function; }
void push_children(std::vector<Child::CREF> &children) const override;
+ const Value &eval(const LazyParams &params, Stash &stash) const override;
};
-struct Join : Node {
- Child lhs_tensor;
- Child rhs_tensor;
- const join_fun_t function;
+class Join : public Node
+{
+private:
+ Child _lhs;
+ Child _rhs;
+ join_fun_t _function;
+public:
Join(const ValueType &result_type_in,
- const TensorFunction &lhs_tensor_in,
- const TensorFunction &rhs_tensor_in,
+ const TensorFunction &lhs_in,
+ const TensorFunction &rhs_in,
join_fun_t function_in)
- : Node(result_type_in), lhs_tensor(lhs_tensor_in),
- rhs_tensor(rhs_tensor_in), function(function_in) {}
- const Value &eval(const LazyParams &params, Stash &stash) const override;
+ : Node(result_type_in), _lhs(lhs_in),
+ _rhs(rhs_in), _function(function_in) {}
+ const TensorFunction &lhs() const { return _lhs.get(); }
+ const TensorFunction &rhs() const { return _rhs.get(); }
+ join_fun_t function() const { return _function; }
void push_children(std::vector<Child::CREF> &children) const override;
+ const Value &eval(const LazyParams &params, Stash &stash) const override;
};
-const Node &inject(const ValueType &type, size_t tensor_id, Stash &stash);
-const Node &reduce(const Node &tensor, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash);
-const Node &map(const Node &tensor, map_fun_t function, Stash &stash);
-const Node &join(const Node &lhs_tensor, const Node &rhs_tensor, join_fun_t function, Stash &stash);
+const Node &inject(const ValueType &type, size_t param_idx, Stash &stash);
+const Node &reduce(const Node &child, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash);
+const Node &map(const Node &child, map_fun_t function, Stash &stash);
+const Node &join(const Node &lhs, const Node &rhs, join_fun_t function, Stash &stash);
} // namespace vespalib::eval::tensor_function
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index 05acc0912ef..6fa2fc2574d 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -275,7 +275,7 @@ struct RetainedReduce : Eval {
Stash stash;
auto a_type = ValueType::from_spec(a.type());
const auto &ir = tensor_function::reduce(tensor_function::inject(a_type, tensor_id_a, stash), aggr, dimensions, stash);
- ValueType expect_type = ir.result_type;
+ ValueType expect_type = ir.result_type();
const auto &fun = engine.compile(ir, stash);
Input input(engine.from_spec(a));
return Result(engine, check_type(fun.eval(input.get(), stash), expect_type));
@@ -290,7 +290,7 @@ struct RetainedMap : Eval {
Stash stash;
auto a_type = ValueType::from_spec(a.type());
const auto &ir = tensor_function::map(tensor_function::inject(a_type, tensor_id_a, stash), function, stash);
- ValueType expect_type = ir.result_type;
+ ValueType expect_type = ir.result_type();
const auto &fun = engine.compile(ir, stash);
Input input(engine.from_spec(a));
return Result(engine, check_type(fun.eval(input.get(), stash), expect_type));
@@ -308,7 +308,7 @@ struct RetainedJoin : Eval {
const auto &ir = tensor_function::join(tensor_function::inject(a_type, tensor_id_a, stash),
tensor_function::inject(b_type, tensor_id_b, stash),
function, stash);
- ValueType expect_type = ir.result_type;
+ ValueType expect_type = ir.result_type();
const auto &fun = engine.compile(ir, stash);
Input input(engine.from_spec(a), engine.from_spec(b));
return Result(engine, check_type(fun.eval(input.get(), stash), expect_type));
diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h
index 08ca9792739..f14034968be 100644
--- a/eval/src/vespa/eval/eval/value.h
+++ b/eval/src/vespa/eval/eval/value.h
@@ -51,6 +51,7 @@ public:
bool is_double() const override { return true; }
double as_double() const override { return _value; }
const ValueType &type() const override { return _type; }
+ static const ValueType &double_type() { return _type; }
};
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
index 20c293444f3..1bca0ce4c8d 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.h
@@ -13,8 +13,6 @@ namespace vespalib::tensor {
class DenseDotProductFunction : public eval::TensorFunction
{
private:
- using InjectUP = std::unique_ptr<eval::tensor_function::Inject>;
-
size_t _lhsTensorId;
size_t _rhsTensorId;
hwaccelrated::IAccelrated::UP _hwAccelerator;
@@ -23,6 +21,8 @@ public:
DenseDotProductFunction(size_t lhsTensorId_, size_t rhsTensorId_);
size_t lhsTensorId() const { return _lhsTensorId; }
size_t rhsTensorId() const { return _rhsTensorId; }
+ const eval::ValueType &result_type() const override { return eval::DoubleValue::double_type(); }
+ void push_children(std::vector<Child::CREF> &) const override {}
const eval::Value &eval(const eval::LazyParams &params, Stash &stash) const override;
};
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp
index bd57db009b9..7e167a92819 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp
@@ -47,9 +47,9 @@ bool isDenseXWProduct(const ValueType &res, const ValueType &vec, const ValueTyp
}
const TensorFunction &createDenseXWProduct(const ValueType &res, const Inject &vec, const Inject &mat, Stash &stash) {
- bool common_is_inner = (mat.result_type.dimension_index(vec.result_type.dimensions()[0].name) == 1);
- return stash.create<DenseXWProductFunction>(res, vec.tensor_id, mat.tensor_id,
- vec.result_type.dimensions()[0].size,
+ bool common_is_inner = (mat.result_type().dimension_index(vec.result_type().dimensions()[0].name) == 1);
+ return stash.create<DenseXWProductFunction>(res, vec.param_idx(), mat.param_idx(),
+ vec.result_type().dimensions()[0].size,
res.dimensions()[0].size,
common_is_inner);
}
@@ -58,20 +58,20 @@ struct InnerProductFunctionOptimizer
{
static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash) {
const Reduce *reduce = as<Reduce>(expr);
- if (reduce && (reduce->aggr == Aggr::SUM)) {
- const ValueType &result_type = reduce->result_type;
- const Join *join = as<Join>(reduce->tensor.get());
- if (join && (join->function == Mul::f)) {
- const Inject *lhs = as<Inject>(join->lhs_tensor.get());
- const Inject *rhs = as<Inject>(join->rhs_tensor.get());
+ if (reduce && (reduce->aggr() == Aggr::SUM)) {
+ const ValueType &result_type = reduce->result_type();
+ const Join *join = as<Join>(reduce->child());
+ if (join && (join->function() == Mul::f)) {
+ const Inject *lhs = as<Inject>(join->lhs());
+ const Inject *rhs = as<Inject>(join->rhs());
if (lhs && rhs) {
- if (isDenseDotProduct(result_type, lhs->result_type, rhs->result_type)) {
- return stash.create<DenseDotProductFunction>(lhs->tensor_id, rhs->tensor_id);
+ if (isDenseDotProduct(result_type, lhs->result_type(), rhs->result_type())) {
+ return stash.create<DenseDotProductFunction>(lhs->param_idx(), rhs->param_idx());
}
- if (isDenseXWProduct(result_type, lhs->result_type, rhs->result_type)) {
+ if (isDenseXWProduct(result_type, lhs->result_type(), rhs->result_type())) {
return createDenseXWProduct(result_type, *lhs, *rhs, stash);
}
- if (isDenseXWProduct(result_type, rhs->result_type, lhs->result_type)) {
+ if (isDenseXWProduct(result_type, rhs->result_type(), lhs->result_type())) {
return createDenseXWProduct(result_type, *rhs, *lhs, stash);
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
index 6e8104fff44..bc0a63bc79e 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.h
@@ -45,6 +45,8 @@ public:
bool matrixHasCommonDimensionInnermost() const { return _commonDimensionInnermost; }
+ const eval::ValueType &result_type() const override { return _resultType; }
+ void push_children(std::vector<Child::CREF> &) const override {}
const eval::Value &eval(const eval::LazyParams &params, Stash &stash) const override;
};