summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-10-27 14:33:35 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-10-27 14:33:35 +0000
commit3f5670a1c21cb6ce6479cc85bb332b38ed6af0e0 (patch)
treeb9ae9f71ba9fd9b3daa941bbd922c92b64b9da85
parent5fcbb66f52d44b286f0898ab318f7e6269330f4e (diff)
simple optimizations for double map/join
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.cpp38
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.cpp10
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.h1
-rw-r--r--eval/src/vespa/eval/eval/tensor.h1
-rw-r--r--eval/src/vespa/eval/eval/value.cpp6
-rw-r--r--eval/src/vespa/eval/eval/value.h1
-rw-r--r--eval/src/vespa/eval/tensor/tensor.h1
7 files changed, 54 insertions, 4 deletions
diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp
index 09cc2d57e80..82aa71893e1 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.cpp
+++ b/eval/src/vespa/eval/eval/interpreted_function.cpp
@@ -98,6 +98,24 @@ void op_not_member(State &state, uint64_t) {
//-----------------------------------------------------------------------------
+void op_double_map(State &state, uint64_t param) {
+ state.replace(1, state.stash.create<DoubleValue>(to_map_fun(param)(state.peek(0).as_double())));
+}
+
+void op_double_mul(State &state, uint64_t) {
+ state.replace(2, state.stash.create<DoubleValue>(state.peek(1).as_double() * state.peek(0).as_double()));
+}
+
+void op_double_add(State &state, uint64_t) {
+ state.replace(2, state.stash.create<DoubleValue>(state.peek(1).as_double() + state.peek(0).as_double()));
+}
+
+void op_double_join(State &state, uint64_t param) {
+ state.replace(2, state.stash.create<DoubleValue>(to_join_fun(param)(state.peek(1).as_double(), state.peek(0).as_double())));
+}
+
+//-----------------------------------------------------------------------------
+
void op_tensor_map(State &state, uint64_t param) {
state.replace(1, state.engine.map(state.peek(0), to_map_fun(param), state.stash));
}
@@ -217,13 +235,25 @@ struct ProgramBuilder : public NodeVisitor, public NodeTraverser {
}
void make_map_op(const Node &node, map_fun_t function) {
- (void) node;
- program.emplace_back(op_tensor_map, to_param(function));
+ if (types.get_type(node).is_double()) {
+ program.emplace_back(op_double_map, to_param(function));
+ } else {
+ program.emplace_back(op_tensor_map, to_param(function));
+ }
}
void make_join_op(const Node &node, join_fun_t function) {
- (void) node;
- program.emplace_back(op_tensor_join, to_param(function));
+ if (types.get_type(node).is_double()) {
+ if (function == operation::Mul::f) {
+ program.emplace_back(op_double_mul);
+ } else if (function == operation::Add::f) {
+ program.emplace_back(op_double_add);
+ } else {
+ program.emplace_back(op_double_join, to_param(function));
+ }
+ } else {
+ program.emplace_back(op_tensor_join, to_param(function));
+ }
}
//-------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp
index 37d2e8747ef..75c170d48ba 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor.cpp
@@ -542,6 +542,16 @@ SimpleTensor::SimpleTensor(const ValueType &type_in, Cells cells_in)
[](const auto &a, const auto &b){ return (a.address < b.address); });
}
+double
+SimpleTensor::as_double() const
+{
+ double sum = 0.0;
+ for (auto &cell: _cells) {
+ sum += cell.value;
+ }
+ return sum;
+}
+
std::unique_ptr<SimpleTensor>
SimpleTensor::map(map_fun_t function) const
{
diff --git a/eval/src/vespa/eval/eval/simple_tensor.h b/eval/src/vespa/eval/eval/simple_tensor.h
index 3de80483fb3..ec154ff969a 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.h
+++ b/eval/src/vespa/eval/eval/simple_tensor.h
@@ -81,6 +81,7 @@ public:
SimpleTensor();
explicit SimpleTensor(double value);
SimpleTensor(const ValueType &type_in, Cells cells_in);
+ double as_double() const final override;
const ValueType &type() const { return _type; }
const Cells &cells() const { return _cells; }
std::unique_ptr<SimpleTensor> map(map_fun_t function) const;
diff --git a/eval/src/vespa/eval/eval/tensor.h b/eval/src/vespa/eval/eval/tensor.h
index ed17a47f775..57cd9abe1f5 100644
--- a/eval/src/vespa/eval/eval/tensor.h
+++ b/eval/src/vespa/eval/eval/tensor.h
@@ -30,6 +30,7 @@ public:
Tensor(Tensor &&) = delete;
Tensor &operator=(const Tensor &) = delete;
Tensor &operator=(Tensor &&) = delete;
+ virtual double as_double() const = 0;
const TensorEngine &engine() const { return _engine; }
virtual ~Tensor() {}
};
diff --git a/eval/src/vespa/eval/eval/value.cpp b/eval/src/vespa/eval/eval/value.cpp
index 2a0a80b8547..0118d95e5cb 100644
--- a/eval/src/vespa/eval/eval/value.cpp
+++ b/eval/src/vespa/eval/eval/value.cpp
@@ -8,6 +8,12 @@ namespace eval {
ErrorValue ErrorValue::instance;
+double
+TensorValue::as_double() const
+{
+ return _tensor->as_double();
+}
+
bool
TensorValue::equal(const Value &rhs) const
{
diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h
index 42dddeef188..0d727db6b91 100644
--- a/eval/src/vespa/eval/eval/value.h
+++ b/eval/src/vespa/eval/eval/value.h
@@ -64,6 +64,7 @@ public:
TensorValue(const Tensor &value) : _tensor(&value), _stored() {}
TensorValue(std::unique_ptr<Tensor> value) : _tensor(value.get()), _stored(std::move(value)) {}
bool is_tensor() const override { return true; }
+ double as_double() const override;
const Tensor *as_tensor() const override { return _tensor; }
bool equal(const Value &rhs) const override;
ValueType type() const override;
diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h
index 1f4d599ac18..3b3d7ce4a70 100644
--- a/eval/src/vespa/eval/tensor/tensor.h
+++ b/eval/src/vespa/eval/tensor/tensor.h
@@ -32,6 +32,7 @@ struct Tensor : public eval::Tensor
virtual ~Tensor() {}
virtual const eval::ValueType &getType() const = 0;
virtual double sum() const = 0;
+ virtual double as_double() const final override { return sum(); }
virtual Tensor::UP add(const Tensor &arg) const = 0;
virtual Tensor::UP subtract(const Tensor &arg) const = 0;
virtual Tensor::UP multiply(const Tensor &arg) const = 0;