aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-19 14:03:44 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-20 10:49:03 +0000
commita994e37fc60b0040da974d32f084a035b75d00a1 (patch)
treea19a40a8ddec51d330a1f93c424dbbe883c1c202 /eval
parent46f5dd7d8eeb1393635ebcd5e5f5f08358b3cc1b (diff)
wrap SimpleValue and its engine
* as DefaultTensorEngine fallback, instead of SimpleTensor and SimpleTensorEngine
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp50
-rw-r--r--eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp23
2 files changed, 52 insertions, 21 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 962e0360598..3ad53a61dd0 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -2,7 +2,7 @@
#include "default_tensor_engine.h"
#include "tensor.h"
-#include "wrapped_simple_tensor.h"
+#include "wrapped_simple_value.h"
#include "serialization/typed_binary_format.h"
#include "sparse/sparse_tensor_address_builder.h"
#include "sparse/direct_sparse_tensor_builder.h"
@@ -28,8 +28,7 @@
#include "dense/dense_tensor_peek_function.h"
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/tensor_spec.h>
-#include <vespa/eval/eval/tensor_spec.h>
-#include <vespa/eval/eval/simple_tensor_engine.h>
+#include <vespa/eval/eval/simple_value.h>
#include <vespa/eval/eval/operation.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/exceptions.h>
@@ -58,19 +57,22 @@ namespace {
constexpr size_t UNDEFINED_IDX = std::numeric_limits<size_t>::max();
-const eval::TensorEngine &simple_engine() { return eval::SimpleTensorEngine::ref(); }
+const eval::EngineOrFactory &simple_engine() {
+ static eval::EngineOrFactory engine(eval::SimpleValueBuilderFactory::get());
+ return engine;
+}
const eval::TensorEngine &default_engine() { return DefaultTensorEngine::ref(); }
// map tensors to simple tensors before fall-back evaluation
const Value &to_simple(const Value &value, Stash &stash) {
if (auto tensor = value.as_tensor()) {
- if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(tensor)) {
- return wrapped->get();
+ if (auto wrapped = dynamic_cast<const WrappedSimpleValue *>(tensor)) {
+ return wrapped->unwrap();
}
nbostream data;
tensor->engine().encode(*tensor, data);
- return *stash.create<Value::UP>(eval::SimpleTensor::decode(data));
+ return *stash.create<Value::UP>(simple_engine().decode(data));
}
return value;
}
@@ -78,17 +80,39 @@ const Value &to_simple(const Value &value, Stash &stash) {
// map tensors to default tensors after fall-back evaluation
const Value &to_default(const Value &value, Stash &stash) {
+ // case 1 : a tensor with an engine
if (auto tensor = value.as_tensor()) {
- if (auto simple = dynamic_cast<const eval::SimpleTensor *>(tensor)) {
- if (!Tensor::supported({simple->type()})) {
- return stash.create<WrappedSimpleTensor>(*simple);
- }
+ // case [1A]: it's already one of "our" tensors
+ if (&tensor->engine() == &default_engine()) {
+ return value;
}
+ // case [1B]: it belongs to some other engine
nbostream data;
tensor->engine().encode(*tensor, data);
return *stash.create<Value::UP>(default_engine().decode(data));
}
- return value;
+ // case 2 : some kind of double (possibly in a SimpleValue or DenseTensor)
+ if (value.type().is_double()) {
+ // case [2A]: already OK
+ if (dynamic_cast<const DoubleValue *>(&value)) {
+ return value;
+ }
+ // case [2B]: simplify to DoubleValue
+ return stash.create<DoubleValue>(value.as_double());
+ }
+ // case 3 : it's a (possibly mixed) SimpleValue
+ if (auto simple = dynamic_cast<const eval::SimpleValue *>(&value)) {
+ // case [3A]: not one of our supported types, just wrap it
+ if (!Tensor::supported({simple->type()})) {
+ return stash.create<WrappedSimpleValue>(*simple);
+ }
+ // case [3B]: we should convert to one of our supported types
+ }
+ // case [4]: some other kind of Value, convert to one of our
+ // supported types or make a WrappedSimpleValue
+ nbostream data;
+ simple_engine().encode(value, data);
+ return *stash.create<Value::UP>(default_engine().decode(data));
}
const Value &to_value(std::unique_ptr<Tensor> tensor, Stash &stash) {
@@ -229,7 +253,7 @@ DefaultTensorEngine::from_spec(const TensorSpec &spec) const
} else if (type.is_sparse()) {
return typify_invoke<1,MyTypify,CallSparseTensorBuilder>(type.cell_type(), type, spec);
}
- return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::create(spec));
+ return std::make_unique<WrappedSimpleValue>(simple_engine().from_spec(spec));
}
struct CellFunctionFunAdapter : tensor::CellFunction {
diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
index eac5d0aa26c..758ceb43ab4 100644
--- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
+++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp
@@ -6,8 +6,10 @@
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
-#include <vespa/eval/eval/simple_tensor.h>
-#include <vespa/eval/tensor/wrapped_simple_tensor.h>
+#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/tensor/wrapped_simple_value.h>
+#include <vespa/eval/eval/value_codec.h>
+#include <vespa/eval/eval/engine_or_factory.h>
#include <vespa/log/log.h>
#include <vespa/vespalib/util/stringfmt.h>
@@ -23,6 +25,11 @@ namespace vespalib::tensor {
namespace {
+const eval::EngineOrFactory &simple_engine() {
+ static eval::EngineOrFactory engine(eval::SimpleValueBuilderFactory::get());
+ return engine;
+}
+
constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u;
constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u;
constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u;
@@ -56,15 +63,15 @@ encoding_to_cell_type(uint32_t cell_encoding) {
}
std::unique_ptr<Tensor>
-wrap_simple_tensor(std::unique_ptr<eval::SimpleTensor> simple)
+wrap_simple_value(std::unique_ptr<eval::Value> simple)
{
if (Tensor::supported({simple->type()})) {
nbostream data;
- eval::SimpleTensor::encode(*simple, data);
+ simple_engine().encode(*simple, data);
// note: some danger of infinite recursion here
return TypedBinaryFormat::deserialize(data);
}
- return std::make_unique<WrappedSimpleTensor>(std::move(simple));
+ return std::make_unique<WrappedSimpleValue>(std::move(simple));
}
} // namespace <unnamed>
@@ -82,8 +89,8 @@ TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor)
stream.putInt1_4Bytes(cell_type_to_encoding(cell_type));
}
DenseBinaryFormat::serialize(stream, *denseTensor);
- } else if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(&tensor)) {
- eval::SimpleTensor::encode(wrapped->get(), stream);
+ } else if (dynamic_cast<const WrappedSimpleValue *>(&tensor)) {
+ eval::encode_value(tensor, stream);
} else {
if (default_cell_type) {
stream.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE);
@@ -116,7 +123,7 @@ TypedBinaryFormat::deserialize(nbostream &stream)
case MIXED_BINARY_FORMAT_TYPE:
case MIXED_BINARY_FORMAT_WITH_CELLTYPE:
stream.adjustReadPos(read_pos - stream.rp());
- return wrap_simple_tensor(eval::SimpleTensor::decode(stream));
+ return wrap_simple_value(simple_engine().decode(stream));
default:
throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId));
}