diff options
author | HÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com> | 2020-08-28 11:49:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-28 11:49:53 +0200 |
commit | e9b01926c98277e489a7ecfb544aaf819b72fe4e (patch) | |
tree | 7a711cbc3c4a6b161e7f0ef548030f9ba0e7e194 /searchlib/src | |
parent | d704dc37471f26eb9838766f730983001a4703ac (diff) | |
parent | 492c9098b5daa718eadccc3320ca37a7d252f7d2 (diff) |
Merge pull request #14187 from vespa-engine/havardpe/infer-unknown-onnx-dimension-sizes
infer unknown onnx dimension sizes
Diffstat (limited to 'searchlib/src')
3 files changed, 40 insertions, 34 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp index cc6b8e0ce29..7a200a46ab2 100644 --- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -25,7 +25,8 @@ std::string get_source_dir() { } std::string source_dir = get_source_dir(); std::string vespa_dir = source_dir + "/" + "../../../../.."; -std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx"; +std::string simple_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/simple.onnx"; +std::string dynamic_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/dynamic.onnx"; uint32_t default_docid = 1; @@ -97,4 +98,16 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) { EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0)); } +TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) { + add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); + add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); + add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]"); + add_onnx("dynamic", dynamic_model); + compile(onnx_feature("dynamic")); + EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0)); + EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0)); +} + GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp index f6d5c37b61d..7433021b9b6 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -3,7 +3,6 @@ #include "onnx_feature.h" #include <vespa/searchlib/fef/properties.h> #include <vespa/searchlib/fef/featureexecutor.h> -#include <vespa/eval/tensor/dense/onnx_wrapper.h> #include <vespa/eval/tensor/dense/dense_tensor_view.h> #include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h> #include <vespa/vespalib/util/stringfmt.h> @@ -23,7 +22,7 @@ using vespalib::eval::ValueType; using vespalib::make_string_short::fmt; using vespalib::tensor::DenseTensorView; using vespalib::tensor::MutableDenseTensorView; -using vespalib::tensor::OnnxWrapper; +using vespalib::tensor::Onnx; namespace search::features { @@ -33,37 +32,28 @@ namespace search::features { class OnnxFeatureExecutor : public FeatureExecutor { private: - const OnnxWrapper &_model; - OnnxWrapper::Params _params; - OnnxWrapper::Result _result; - std::vector<MutableDenseTensorView> _views; - + Onnx::EvalContext _eval_context; public: - OnnxFeatureExecutor(const OnnxWrapper &model) - : _model(model), _params(), _result(OnnxWrapper::Result::make_empty()), _views() - { - _views.reserve(_model.outputs().size()); - for (const auto &output: _model.outputs()) { - _views.emplace_back(output.make_compatible_type()); - } - } + OnnxFeatureExecutor(const Onnx &model, const Onnx::WireInfo &wire_info) + : _eval_context(model, wire_info) {} bool isPure() override { return true; } - void execute(uint32_t) override { - _params = OnnxWrapper::Params(); - for (size_t i = 0; i < _model.inputs().size(); ++i) { - _params.bind(i, static_cast<const DenseTensorView&>(inputs().get_object(i).get())); + void handle_bind_outputs(vespalib::ArrayRef<fef::NumberOrObject>) override { + for (size_t i = 0; i < _eval_context.num_results(); ++i) { + outputs().set_object(i, _eval_context.get_result(i)); } - _result = _model.eval(_params); - for (size_t i = 0; i < _model.outputs().size(); ++i) { - _result.get(i, _views[i]); - outputs().set_object(i, _views[i]); + } + void execute(uint32_t) override { + for (size_t i = 0; i < _eval_context.num_params(); ++i) { + _eval_context.bind_param(i, inputs().get_object(i).get()); } + _eval_context.eval(); } }; OnnxBlueprint::OnnxBlueprint() : Blueprint("onnxModel"), - _model(nullptr) + _model(nullptr), + _wire_info() { } @@ -74,24 +64,25 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, const ParameterList ¶ms) { auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) - ? OnnxWrapper::Optimize::DISABLE - : OnnxWrapper::Optimize::ENABLE; + ? Onnx::Optimize::DISABLE + : Onnx::Optimize::ENABLE; // Note: Using the fileref property with the model name as // fallback to get a file name. This needs to be replaced with an // actual file reference obtained through config when available. vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue()); try { - _model = std::make_unique<OnnxWrapper>(file_name, optimize); + _model = std::make_unique<Onnx>(file_name, optimize); } catch (std::exception &ex) { return fail("Model setup failed: %s", ex.what()); } + Onnx::WirePlanner planner; for (size_t i = 0; i < _model->inputs().size(); ++i) { const auto &model_input = _model->inputs()[i]; if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", model_input.name.c_str()), AcceptInput::OBJECT)) { const FeatureType &feature_input = maybe_input.value(); assert(feature_input.is_object()); - if (!model_input.is_compatible(feature_input.type())) { + if (!planner.bind_input_type(feature_input.type(), model_input)) { return fail("incompatible type for input '%s': %s -> %s", model_input.name.c_str(), feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str()); } @@ -99,13 +90,14 @@ OnnxBlueprint::setup(const IIndexEnvironment &env, } for (size_t i = 0; i < _model->outputs().size(); ++i) { const auto &model_output = _model->outputs()[i]; - ValueType output_type = model_output.make_compatible_type(); + ValueType output_type = planner.make_output_type(model_output); if (output_type.is_error()) { return fail("unable to make compatible type for output '%s': %s -> error", model_output.name.c_str(), model_output.type_as_string().c_str()); } describeOutput(model_output.name, "output from onnx model", FeatureType::object(output_type)); } + _wire_info = planner.get_wire_info(*_model); return true; } @@ -113,7 +105,7 @@ FeatureExecutor & OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const { assert(_model); - return stash.create<OnnxFeatureExecutor>(*_model); + return stash.create<OnnxFeatureExecutor>(*_model, _wire_info); } } diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h index eb6e368ffbd..19c6338d2ee 100644 --- a/searchlib/src/vespa/searchlib/features/onnx_feature.h +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h @@ -3,8 +3,7 @@ #pragma once #include <vespa/searchlib/fef/blueprint.h> - -namespace vespalib::tensor { class OnnxWrapper; } +#include <vespa/eval/tensor/dense/onnx_wrapper.h> namespace search::features { @@ -13,7 +12,9 @@ namespace search::features { **/ class OnnxBlueprint : public fef::Blueprint { private: - std::unique_ptr<vespalib::tensor::OnnxWrapper> _model; + using Onnx = vespalib::tensor::Onnx; + std::unique_ptr<Onnx> _model; + Onnx::WireInfo _wire_info; public: OnnxBlueprint(); ~OnnxBlueprint() override; |