diff options
author | Arne Juul <arnej@verizonmedia.com> | 2020-10-19 14:03:44 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2020-10-20 10:49:03 +0000 |
commit | a994e37fc60b0040da974d32f084a035b75d00a1 (patch) | |
tree | a19a40a8ddec51d330a1f93c424dbbe883c1c202 | |
parent | 46f5dd7d8eeb1393635ebcd5e5f5f08358b3cc1b (diff) |
wrap SimpleValue and its engine
* as DefaultTensorEngine fallback,
instead of SimpleTensor and SimpleTensorEngine
-rw-r--r-- | eval/src/vespa/eval/tensor/default_tensor_engine.cpp | 50 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp | 23 |
2 files changed, 52 insertions, 21 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 962e0360598..3ad53a61dd0 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -2,7 +2,7 @@ #include "default_tensor_engine.h" #include "tensor.h" -#include "wrapped_simple_tensor.h" +#include "wrapped_simple_value.h" #include "serialization/typed_binary_format.h" #include "sparse/sparse_tensor_address_builder.h" #include "sparse/direct_sparse_tensor_builder.h" @@ -28,8 +28,7 @@ #include "dense/dense_tensor_peek_function.h" #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/tensor_spec.h> -#include <vespa/eval/eval/tensor_spec.h> -#include <vespa/eval/eval/simple_tensor_engine.h> +#include <vespa/eval/eval/simple_value.h> #include <vespa/eval/eval/operation.h> #include <vespa/vespalib/objects/nbostream.h> #include <vespa/vespalib/util/exceptions.h> @@ -58,19 +57,22 @@ namespace { constexpr size_t UNDEFINED_IDX = std::numeric_limits<size_t>::max(); -const eval::TensorEngine &simple_engine() { return eval::SimpleTensorEngine::ref(); } +const eval::EngineOrFactory &simple_engine() { + static eval::EngineOrFactory engine(eval::SimpleValueBuilderFactory::get()); + return engine; +} const eval::TensorEngine &default_engine() { return DefaultTensorEngine::ref(); } // map tensors to simple tensors before fall-back evaluation const Value &to_simple(const Value &value, Stash &stash) { if (auto tensor = value.as_tensor()) { - if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(tensor)) { - return wrapped->get(); + if (auto wrapped = dynamic_cast<const WrappedSimpleValue *>(tensor)) { + return wrapped->unwrap(); } nbostream data; tensor->engine().encode(*tensor, data); - return *stash.create<Value::UP>(eval::SimpleTensor::decode(data)); + return *stash.create<Value::UP>(simple_engine().decode(data)); } return value; } @@ -78,17 +80,39 @@ const Value &to_simple(const Value &value, Stash &stash) { // map tensors to default tensors after fall-back evaluation const Value &to_default(const Value &value, Stash &stash) { + // case 1 : a tensor with an engine if (auto tensor = value.as_tensor()) { - if (auto simple = dynamic_cast<const eval::SimpleTensor *>(tensor)) { - if (!Tensor::supported({simple->type()})) { - return stash.create<WrappedSimpleTensor>(*simple); - } + // case [1A]: it's already one of "our" tensors + if (&tensor->engine() == &default_engine()) { + return value; } + // case [1B]: it belongs to some other engine nbostream data; tensor->engine().encode(*tensor, data); return *stash.create<Value::UP>(default_engine().decode(data)); } - return value; + // case 2 : some kind of double (possibly in a SimpleValue or DenseTensor) + if (value.type().is_double()) { + // case [2A]: already OK + if (dynamic_cast<const DoubleValue *>(&value)) { + return value; + } + // case [2B]: simplify to DoubleValue + return stash.create<DoubleValue>(value.as_double()); + } + // case 3 : it's a (possibly mixed) SimpleValue + if (auto simple = dynamic_cast<const eval::SimpleValue *>(&value)) { + // case [3A]: not one of our supported types, just wrap it + if (!Tensor::supported({simple->type()})) { + return stash.create<WrappedSimpleValue>(*simple); + } + // case [3B]: we should convert to one of our supported types + } + // case [4]: some other kind of Value, convert to one of our + // supported types or make a WrappedSimpleValue + nbostream data; + simple_engine().encode(value, data); + return *stash.create<Value::UP>(default_engine().decode(data)); } const Value &to_value(std::unique_ptr<Tensor> tensor, Stash &stash) { @@ -229,7 +253,7 @@ DefaultTensorEngine::from_spec(const TensorSpec &spec) const } else if (type.is_sparse()) { return typify_invoke<1,MyTypify,CallSparseTensorBuilder>(type.cell_type(), type, spec); } - return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::create(spec)); + return std::make_unique<WrappedSimpleValue>(simple_engine().from_spec(spec)); } struct CellFunctionFunAdapter : tensor::CellFunction { diff --git a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp index eac5d0aa26c..758ceb43ab4 100644 --- a/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp +++ b/eval/src/vespa/eval/tensor/serialization/typed_binary_format.cpp @@ -6,8 +6,10 @@ #include <vespa/vespalib/objects/nbostream.h> #include <vespa/eval/tensor/tensor.h> #include <vespa/eval/tensor/dense/dense_tensor.h> -#include <vespa/eval/eval/simple_tensor.h> -#include <vespa/eval/tensor/wrapped_simple_tensor.h> +#include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/tensor/wrapped_simple_value.h> +#include <vespa/eval/eval/value_codec.h> +#include <vespa/eval/eval/engine_or_factory.h> #include <vespa/log/log.h> #include <vespa/vespalib/util/stringfmt.h> @@ -23,6 +25,11 @@ namespace vespalib::tensor { namespace { +const eval::EngineOrFactory &simple_engine() { + static eval::EngineOrFactory engine(eval::SimpleValueBuilderFactory::get()); + return engine; +} + constexpr uint32_t SPARSE_BINARY_FORMAT_TYPE = 1u; constexpr uint32_t DENSE_BINARY_FORMAT_TYPE = 2u; constexpr uint32_t MIXED_BINARY_FORMAT_TYPE = 3u; @@ -56,15 +63,15 @@ encoding_to_cell_type(uint32_t cell_encoding) { } std::unique_ptr<Tensor> -wrap_simple_tensor(std::unique_ptr<eval::SimpleTensor> simple) +wrap_simple_value(std::unique_ptr<eval::Value> simple) { if (Tensor::supported({simple->type()})) { nbostream data; - eval::SimpleTensor::encode(*simple, data); + simple_engine().encode(*simple, data); // note: some danger of infinite recursion here return TypedBinaryFormat::deserialize(data); } - return std::make_unique<WrappedSimpleTensor>(std::move(simple)); + return std::make_unique<WrappedSimpleValue>(std::move(simple)); } } // namespace <unnamed> @@ -82,8 +89,8 @@ TypedBinaryFormat::serialize(nbostream &stream, const Tensor &tensor) stream.putInt1_4Bytes(cell_type_to_encoding(cell_type)); } DenseBinaryFormat::serialize(stream, *denseTensor); - } else if (auto wrapped = dynamic_cast<const WrappedSimpleTensor *>(&tensor)) { - eval::SimpleTensor::encode(wrapped->get(), stream); + } else if (dynamic_cast<const WrappedSimpleValue *>(&tensor)) { + eval::encode_value(tensor, stream); } else { if (default_cell_type) { stream.putInt1_4Bytes(SPARSE_BINARY_FORMAT_TYPE); @@ -116,7 +123,7 @@ TypedBinaryFormat::deserialize(nbostream &stream) case MIXED_BINARY_FORMAT_TYPE: case MIXED_BINARY_FORMAT_WITH_CELLTYPE: stream.adjustReadPos(read_pos - stream.rp()); - return wrap_simple_tensor(eval::SimpleTensor::decode(stream)); + return wrap_simple_value(simple_engine().decode(stream)); default: throw IllegalArgumentException(make_string("Received unknown tensor format type = %du.", formatId)); } |