summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-12-03 09:41:34 +0000
committerArne Juul <arnej@verizonmedia.com>2020-12-03 18:22:03 +0000
commitc4854a1ec33e7dddf20b93a33faff70f1a9c1041 (patch)
tree49a34edf6183b391b9763dfc43df2f59dd5bef23 /eval
parentf88b7e59d00a59b623221d453e24c74e5b26c677 (diff)
no more getting engine from tensor
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/tensor.cpp6
-rw-r--r--eval/src/vespa/eval/eval/tensor.h2
-rw-r--r--eval/src/vespa/eval/eval/tensor_spec.cpp6
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp21
4 files changed, 8 insertions, 27 deletions
diff --git a/eval/src/vespa/eval/eval/tensor.cpp b/eval/src/vespa/eval/eval/tensor.cpp
index 645208ba8fb..9de812b46f0 100644
--- a/eval/src/vespa/eval/eval/tensor.cpp
+++ b/eval/src/vespa/eval/eval/tensor.cpp
@@ -10,15 +10,15 @@ namespace eval {
bool
operator==(const Tensor &lhs, const Tensor &rhs)
{
- auto lhs_spec = lhs.engine().to_spec(lhs);
- auto rhs_spec = rhs.engine().to_spec(rhs);
+ auto lhs_spec = TensorSpec::from_value(lhs);
+ auto rhs_spec = TensorSpec::from_value(rhs);
return (lhs_spec == rhs_spec);
}
std::ostream &
operator<<(std::ostream &out, const Tensor &tensor)
{
- out << tensor.engine().to_spec(tensor).to_string();
+ out << TensorSpec::from_value(tensor).to_string();
return out;
}
diff --git a/eval/src/vespa/eval/eval/tensor.h b/eval/src/vespa/eval/eval/tensor.h
index ddc341ed910..4f0ccd1114e 100644
--- a/eval/src/vespa/eval/eval/tensor.h
+++ b/eval/src/vespa/eval/eval/tensor.h
@@ -23,6 +23,7 @@ class Tensor : public Value
{
private:
const TensorEngine &_engine;
+ const TensorEngine &engine() const { return _engine; }
protected:
explicit Tensor(const TensorEngine &engine_in)
: _engine(engine_in) {}
@@ -33,7 +34,6 @@ public:
Tensor &operator=(Tensor &&) = delete;
bool is_tensor() const override { return true; }
const Tensor *as_tensor() const override { return this; }
- const TensorEngine &engine() const { return _engine; }
virtual ~Tensor() {}
};
diff --git a/eval/src/vespa/eval/eval/tensor_spec.cpp b/eval/src/vespa/eval/eval/tensor_spec.cpp
index ec94602781c..e082d3bc2ba 100644
--- a/eval/src/vespa/eval/eval/tensor_spec.cpp
+++ b/eval/src/vespa/eval/eval/tensor_spec.cpp
@@ -195,11 +195,7 @@ TensorSpec::from_slime(const slime::Inspector &tensor)
TensorSpec
TensorSpec::from_value(const eval::Value &value)
{
- if (auto tensor = dynamic_cast<const vespalib::eval::Tensor *>(&value)) {
- return tensor->engine().to_spec(value);
- } else {
- return spec_from_value(value);
- }
+ return spec_from_value(value);
}
TensorSpec
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 3ac342217ac..49b5118f777 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -27,7 +27,7 @@
#include "dense/dense_tensor_create_function.h"
#include <vespa/eval/instruction/dense_tensor_peek_function.h>
#include <vespa/eval/eval/value.h>
-#include <vespa/eval/eval/engine_or_factory.h>
+#include <vespa/eval/eval/value_codec.h>
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/simple_value.h>
#include <vespa/eval/eval/operation.h>
@@ -71,7 +71,7 @@ const Value &to_simple(const Value &value, Stash &stash) {
return wrapped->unwrap();
}
nbostream data;
- tensor->engine().encode(*tensor, data);
+ encode_value(*tensor, data);
return *stash.create<Value::UP>(simple_engine().decode(data));
}
return value;
@@ -164,15 +164,7 @@ const DefaultTensorEngine DefaultTensorEngine::_engine;
TensorSpec
DefaultTensorEngine::to_spec(const Value &value) const
{
- if (value.is_double()) {
- return TensorSpec("double").add({}, value.as_double());
- } else if (auto tensor = value.as_tensor()) {
- assert(&tensor->engine() == this);
- const tensor::Tensor &my_tensor = static_cast<const tensor::Tensor &>(*tensor);
- return my_tensor.toSpec();
- } else {
- return TensorSpec("error");
- }
+ return TensorSpec::from_value(value);
}
struct CallDenseTensorBuilder {
@@ -324,7 +316,6 @@ const Value &
DefaultTensorEngine::map(const Value &a, map_fun_t function, Stash &stash) const
{
if (auto tensor = a.as_tensor()) {
- assert(&tensor->engine() == this);
const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor);
if (!tensor::Tensor::supported({my_a.type()})) {
return to_default(simple_engine().map(to_simple(a, stash), function, stash), stash);
@@ -340,10 +331,8 @@ const Value &
DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const
{
if (auto tensor_a = a.as_tensor()) {
- assert(&tensor_a->engine() == this);
const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor_a);
if (auto tensor_b = b.as_tensor()) {
- assert(&tensor_b->engine() == this);
const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b);
if (!tensor::Tensor::supported({my_a.type(), my_b.type()})) {
return fallback_join(a, b, function, stash);
@@ -358,7 +347,6 @@ DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, S
}
} else {
if (auto tensor_b = b.as_tensor()) {
- assert(&tensor_b->engine() == this);
const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b);
if (!tensor::Tensor::supported({my_b.type()})) {
return fallback_join(a, b, function, stash);
@@ -377,8 +365,6 @@ DefaultTensorEngine::merge(const Value &a, const Value &b, join_fun_t function,
if (auto tensor_a = a.as_tensor()) {
auto tensor_b = b.as_tensor();
assert(tensor_b);
- assert(&tensor_a->engine() == this);
- assert(&tensor_b->engine() == this);
const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor_a);
const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b);
if (!tensor::Tensor::supported({my_a.type(), my_b.type()})) {
@@ -394,7 +380,6 @@ const Value &
DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const
{
if (auto tensor = a.as_tensor()) {
- assert(&tensor->engine() == this);
const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor);
if (!tensor::Tensor::supported({my_a.type()})) {
return fallback_reduce(a, aggr, dimensions, stash);