summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com>2020-08-28 11:49:53 +0200
committerGitHub <noreply@github.com>2020-08-28 11:49:53 +0200
commite9b01926c98277e489a7ecfb544aaf819b72fe4e (patch)
tree7a711cbc3c4a6b161e7f0ef548030f9ba0e7e194 /searchlib/src
parentd704dc37471f26eb9838766f730983001a4703ac (diff)
parent492c9098b5daa718eadccc3320ca37a7d252f7d2 (diff)
Merge pull request #14187 from vespa-engine/havardpe/infer-unknown-onnx-dimension-sizes
infer unknown onnx dimension sizes
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp15
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp52
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.h7
3 files changed, 40 insertions, 34 deletions
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
index cc6b8e0ce29..7a200a46ab2 100644
--- a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -25,7 +25,8 @@ std::string get_source_dir() {
}
std::string source_dir = get_source_dir();
std::string vespa_dir = source_dir + "/" + "../../../../..";
-std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx";
+std::string simple_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/simple.onnx";
+std::string dynamic_model = vespa_dir + "/" + "eval/src/tests/tensor/onnx_wrapper/dynamic.onnx";
uint32_t default_docid = 1;
@@ -97,4 +98,16 @@ TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) {
EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0));
}
+TEST_F(OnnxFeatureTest, dynamic_onnx_model_can_be_calculated) {
+ add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
+ add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
+ add_expr("bias_tensor", "tensor<float>(a[1],b[2]):[[4,5]]");
+ add_onnx("dynamic", dynamic_model);
+ compile(onnx_feature("dynamic"));
+ EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get("onnxModel(dynamic).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0));
+ EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0));
+}
+
GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
index f6d5c37b61d..7433021b9b6 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -3,7 +3,6 @@
#include "onnx_feature.h"
#include <vespa/searchlib/fef/properties.h>
#include <vespa/searchlib/fef/featureexecutor.h>
-#include <vespa/eval/tensor/dense/onnx_wrapper.h>
#include <vespa/eval/tensor/dense/dense_tensor_view.h>
#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h>
#include <vespa/vespalib/util/stringfmt.h>
@@ -23,7 +22,7 @@ using vespalib::eval::ValueType;
using vespalib::make_string_short::fmt;
using vespalib::tensor::DenseTensorView;
using vespalib::tensor::MutableDenseTensorView;
-using vespalib::tensor::OnnxWrapper;
+using vespalib::tensor::Onnx;
namespace search::features {
@@ -33,37 +32,28 @@ namespace search::features {
class OnnxFeatureExecutor : public FeatureExecutor
{
private:
- const OnnxWrapper &_model;
- OnnxWrapper::Params _params;
- OnnxWrapper::Result _result;
- std::vector<MutableDenseTensorView> _views;
-
+ Onnx::EvalContext _eval_context;
public:
- OnnxFeatureExecutor(const OnnxWrapper &model)
- : _model(model), _params(), _result(OnnxWrapper::Result::make_empty()), _views()
- {
- _views.reserve(_model.outputs().size());
- for (const auto &output: _model.outputs()) {
- _views.emplace_back(output.make_compatible_type());
- }
- }
+ OnnxFeatureExecutor(const Onnx &model, const Onnx::WireInfo &wire_info)
+ : _eval_context(model, wire_info) {}
bool isPure() override { return true; }
- void execute(uint32_t) override {
- _params = OnnxWrapper::Params();
- for (size_t i = 0; i < _model.inputs().size(); ++i) {
- _params.bind(i, static_cast<const DenseTensorView&>(inputs().get_object(i).get()));
+ void handle_bind_outputs(vespalib::ArrayRef<fef::NumberOrObject>) override {
+ for (size_t i = 0; i < _eval_context.num_results(); ++i) {
+ outputs().set_object(i, _eval_context.get_result(i));
}
- _result = _model.eval(_params);
- for (size_t i = 0; i < _model.outputs().size(); ++i) {
- _result.get(i, _views[i]);
- outputs().set_object(i, _views[i]);
+ }
+ void execute(uint32_t) override {
+ for (size_t i = 0; i < _eval_context.num_params(); ++i) {
+ _eval_context.bind_param(i, inputs().get_object(i).get());
}
+ _eval_context.eval();
}
};
OnnxBlueprint::OnnxBlueprint()
: Blueprint("onnxModel"),
- _model(nullptr)
+ _model(nullptr),
+ _wire_info()
{
}
@@ -74,24 +64,25 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
const ParameterList &params)
{
auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP)
- ? OnnxWrapper::Optimize::DISABLE
- : OnnxWrapper::Optimize::ENABLE;
+ ? Onnx::Optimize::DISABLE
+ : Onnx::Optimize::ENABLE;
// Note: Using the fileref property with the model name as
// fallback to get a file name. This needs to be replaced with an
// actual file reference obtained through config when available.
vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue());
try {
- _model = std::make_unique<OnnxWrapper>(file_name, optimize);
+ _model = std::make_unique<Onnx>(file_name, optimize);
} catch (std::exception &ex) {
return fail("Model setup failed: %s", ex.what());
}
+ Onnx::WirePlanner planner;
for (size_t i = 0; i < _model->inputs().size(); ++i) {
const auto &model_input = _model->inputs()[i];
if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", model_input.name.c_str()), AcceptInput::OBJECT)) {
const FeatureType &feature_input = maybe_input.value();
assert(feature_input.is_object());
- if (!model_input.is_compatible(feature_input.type())) {
+ if (!planner.bind_input_type(feature_input.type(), model_input)) {
return fail("incompatible type for input '%s': %s -> %s", model_input.name.c_str(),
feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str());
}
@@ -99,13 +90,14 @@ OnnxBlueprint::setup(const IIndexEnvironment &env,
}
for (size_t i = 0; i < _model->outputs().size(); ++i) {
const auto &model_output = _model->outputs()[i];
- ValueType output_type = model_output.make_compatible_type();
+ ValueType output_type = planner.make_output_type(model_output);
if (output_type.is_error()) {
return fail("unable to make compatible type for output '%s': %s -> error",
model_output.name.c_str(), model_output.type_as_string().c_str());
}
describeOutput(model_output.name, "output from onnx model", FeatureType::object(output_type));
}
+ _wire_info = planner.get_wire_info(*_model);
return true;
}
@@ -113,7 +105,7 @@ FeatureExecutor &
OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const
{
assert(_model);
- return stash.create<OnnxFeatureExecutor>(*_model);
+ return stash.create<OnnxFeatureExecutor>(*_model, _wire_info);
}
}
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h
index eb6e368ffbd..19c6338d2ee 100644
--- a/searchlib/src/vespa/searchlib/features/onnx_feature.h
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h
@@ -3,8 +3,7 @@
#pragma once
#include <vespa/searchlib/fef/blueprint.h>
-
-namespace vespalib::tensor { class OnnxWrapper; }
+#include <vespa/eval/tensor/dense/onnx_wrapper.h>
namespace search::features {
@@ -13,7 +12,9 @@ namespace search::features {
**/
class OnnxBlueprint : public fef::Blueprint {
private:
- std::unique_ptr<vespalib::tensor::OnnxWrapper> _model;
+ using Onnx = vespalib::tensor::Onnx;
+ std::unique_ptr<Onnx> _model;
+ Onnx::WireInfo _wire_info;
public:
OnnxBlueprint();
~OnnxBlueprint() override;