diff options
-rw-r--r-- | eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp | 45 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_spec.cpp | 15 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/tensor_spec.h | 3 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/dense_tensor_view.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp | 32 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/onnx_wrapper.h | 4 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/dense/typed_cells.h | 3 | ||||
-rw-r--r-- | searchlib/CMakeLists.txt | 1 | ||||
-rw-r--r-- | searchlib/src/tests/features/onnx_feature/CMakeLists.txt | 9 | ||||
-rw-r--r-- | searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp | 100 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/CMakeLists.txt | 1 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/onnx_feature.cpp | 119 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/onnx_feature.h | 31 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/setup.cpp | 5 |
15 files changed, 328 insertions, 42 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp index 07e065f9e39..28a4a34b2e4 100644 --- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp +++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp @@ -19,33 +19,10 @@ std::string source_dir = get_source_dir(); std::string vespa_dir = source_dir + "/" + "../../../../.."; std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx"; -vespalib::string to_str(const std::vector<size_t> &dim_sizes) { - vespalib::string res; - for (size_t dim_size: dim_sizes) { - if (dim_size == 0) { - res += "[]"; - } else { - res += fmt("[%zu]", dim_size); - } - } - return res; -} - -vespalib::string to_str(OnnxWrapper::TensorInfo::ElementType element_type) { - if (element_type == OnnxWrapper::TensorInfo::ElementType::FLOAT) { - return "float"; - } - if (element_type == OnnxWrapper::TensorInfo::ElementType::DOUBLE) { - return "double"; - } - return "???"; -} - void dump_info(const char *ctx, const std::vector<OnnxWrapper::TensorInfo> &info) { fprintf(stderr, "%s:\n", ctx); for (size_t i = 0; i < info.size(); ++i) { - fprintf(stderr, " %s[%zu]: '%s' %s%s\n", ctx, i, info[i].name.c_str(), - to_str(info[i].elements).c_str(),to_str(info[i].dimensions).c_str()); + fprintf(stderr, " %s[%zu]: '%s' %s\n", ctx, i, info[i].name.c_str(), info[i].type_as_string().c_str()); } } @@ -57,21 +34,17 @@ TEST(OnnxWrapperTest, onnx_model_can_be_inspected) ASSERT_EQ(wrapper.inputs().size(), 3); ASSERT_EQ(wrapper.outputs().size(), 1); //------------------------------------------------------------------------- - EXPECT_EQ( wrapper.inputs()[0].name, "query_tensor"); - EXPECT_EQ(to_str(wrapper.inputs()[0].dimensions), "[1][4]"); - EXPECT_EQ(to_str(wrapper.inputs()[0].elements), "float"); + EXPECT_EQ(wrapper.inputs()[0].name, "query_tensor"); + EXPECT_EQ(wrapper.inputs()[0].type_as_string(), "float[1][4]"); //------------------------------------------------------------------------- - EXPECT_EQ( wrapper.inputs()[1].name, "attribute_tensor"); - EXPECT_EQ(to_str(wrapper.inputs()[1].dimensions), "[4][1]"); - EXPECT_EQ(to_str(wrapper.inputs()[1].elements), "float"); + EXPECT_EQ(wrapper.inputs()[1].name, "attribute_tensor"); + EXPECT_EQ(wrapper.inputs()[1].type_as_string(), "float[4][1]"); //------------------------------------------------------------------------- - EXPECT_EQ( wrapper.inputs()[2].name, "bias_tensor"); - EXPECT_EQ(to_str(wrapper.inputs()[2].dimensions), "[1][1]"); - EXPECT_EQ(to_str(wrapper.inputs()[2].elements), "float"); + EXPECT_EQ(wrapper.inputs()[2].name, "bias_tensor"); + EXPECT_EQ(wrapper.inputs()[2].type_as_string(), "float[1][1]"); //------------------------------------------------------------------------- - EXPECT_EQ( wrapper.outputs()[0].name, "output"); - EXPECT_EQ(to_str(wrapper.outputs()[0].dimensions), "[1][1]"); - EXPECT_EQ(to_str(wrapper.outputs()[0].elements), "float"); + EXPECT_EQ(wrapper.outputs()[0].name, "output"); + EXPECT_EQ(wrapper.outputs()[0].type_as_string(), "float[1][1]"); } TEST(OnnxWrapperTest, onnx_model_can_be_evaluated) diff --git a/eval/src/vespa/eval/eval/tensor_spec.cpp b/eval/src/vespa/eval/eval/tensor_spec.cpp index a24b45dc3f5..b4b2e3d3afc 100644 --- a/eval/src/vespa/eval/eval/tensor_spec.cpp +++ b/eval/src/vespa/eval/eval/tensor_spec.cpp @@ -1,6 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "tensor_spec.h" +#include "value.h" +#include "tensor.h" +#include "tensor_engine.h" #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/data/slime/slime.h> #include <ostream> @@ -94,6 +97,18 @@ TensorSpec::from_slime(const slime::Inspector &tensor) return spec; } +TensorSpec +TensorSpec::from_value(const eval::Value &value) +{ + if (const eval::Tensor *tensor = value.as_tensor()) { + return tensor->engine().to_spec(*tensor); + } + if (value.is_double()) { + return TensorSpec("double").add({}, value.as_double()); + } + return TensorSpec("error"); +} + bool operator==(const TensorSpec &lhs, const TensorSpec &rhs) { diff --git a/eval/src/vespa/eval/eval/tensor_spec.h b/eval/src/vespa/eval/eval/tensor_spec.h index 22aa47f5ddb..f4f116454d5 100644 --- a/eval/src/vespa/eval/eval/tensor_spec.h +++ b/eval/src/vespa/eval/eval/tensor_spec.h @@ -18,6 +18,8 @@ struct Inspector; namespace eval { +class Value; + /** * An implementation-independent specification of the type and * contents of a tensor. @@ -78,6 +80,7 @@ public: vespalib::string to_string() const; void to_slime(slime::Cursor &tensor) const; static TensorSpec from_slime(const slime::Inspector &tensor); + static TensorSpec from_value(const eval::Value &value); }; bool operator==(const TensorSpec &lhs, const TensorSpec &rhs); diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index 33183d267c1..93dd2dbedeb 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -55,6 +55,7 @@ protected: : _typeRef(type_in), _cellsRef() {} + DenseTensorView(const DenseTensorView &rhs) : DenseTensorView(rhs._typeRef, rhs._cellsRef) {} void initCellsRef(TypedCells cells_in) { assert(_typeRef.cell_type() == cells_in.type); diff --git a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h index db241ef6a2b..5e4a48462d7 100644 --- a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h @@ -17,6 +17,7 @@ private: public: MutableDenseTensorView(eval::ValueType type_in); + MutableDenseTensorView(MutableDenseTensorView &&) = default; void setCells(TypedCells cells_in) { initCellsRef(cells_in); } diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp index fa0379473c9..125095ff23e 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp @@ -18,6 +18,16 @@ namespace vespalib::tensor { namespace { +vespalib::string to_str(OnnxWrapper::TensorInfo::ElementType element_type) { + if (element_type == OnnxWrapper::TensorInfo::ElementType::FLOAT) { + return "float"; + } + if (element_type == OnnxWrapper::TensorInfo::ElementType::DOUBLE) { + return "double"; + } + return "???"; +} + ValueType::CellType as_cell_type(OnnxWrapper::TensorInfo::ElementType type) { if (type == OnnxWrapper::TensorInfo::ElementType::FLOAT) { return ValueType::CellType::FLOAT; @@ -134,6 +144,20 @@ OnnxWrapper::TensorInfo::make_compatible_type() const return ValueType::tensor_type(std::move(dim_list), as_cell_type(elements)); } +vespalib::string +OnnxWrapper::TensorInfo::type_as_string() const +{ + vespalib::string res = to_str(elements); + for (size_t dim_size: dimensions) { + if (dim_size == 0) { + res += "[]"; + } else { + res += fmt("[%zu]", dim_size); + } + } + return res; +} + OnnxWrapper::TensorInfo::~TensorInfo() = default; OnnxWrapper::Shared::Shared() @@ -222,12 +246,14 @@ OnnxWrapper::OnnxWrapper(const vespalib::string &model_file, Optimize optimize) OnnxWrapper::~OnnxWrapper() = default; OnnxWrapper::Result -OnnxWrapper::eval(const Params ¶ms) +OnnxWrapper::eval(const Params ¶ms) const { assert(params.values.size() == _inputs.size()); Ort::RunOptions run_opts(nullptr); - return Result(_session.Run(run_opts, _input_name_refs.data(), params.values.data(), _inputs.size(), - _output_name_refs.data(), _outputs.size())); + // NB: Run requires non-const session + Ort::Session &session = const_cast<Ort::Session&>(_session); + return Result(session.Run(run_opts, _input_name_refs.data(), params.values.data(), _inputs.size(), + _output_name_refs.data(), _outputs.size())); } } diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h index 67a64f2d318..abe1da252c7 100644 --- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h +++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h @@ -28,6 +28,7 @@ public: ElementType elements; bool is_compatible(const eval::ValueType &type) const; eval::ValueType make_compatible_type() const; + vespalib::string type_as_string() const; ~TensorInfo(); }; @@ -48,6 +49,7 @@ public: std::vector<Ort::Value> values; Result(std::vector<Ort::Value> values_in) : values(std::move(values_in)) {} public: + static Result make_empty() { return Result({}); } size_t num_values() const { return values.size(); } void get(size_t idx, MutableDenseTensorView &dst); }; @@ -78,7 +80,7 @@ public: ~OnnxWrapper(); const std::vector<TensorInfo> &inputs() const { return _inputs; } const std::vector<TensorInfo> &outputs() const { return _outputs; } - Result eval(const Params ¶ms); // NB: Run requires non-const session + Result eval(const Params ¶ms) const; }; } diff --git a/eval/src/vespa/eval/tensor/dense/typed_cells.h b/eval/src/vespa/eval/tensor/dense/typed_cells.h index d1d2baa535e..6ea2b40689e 100644 --- a/eval/src/vespa/eval/tensor/dense/typed_cells.h +++ b/eval/src/vespa/eval/tensor/dense/typed_cells.h @@ -44,6 +44,9 @@ struct TypedCells { abort(); } + TypedCells(TypedCells &&other) = default; + TypedCells(const TypedCells &other) = default; + TypedCells & operator= (TypedCells &&other) = default; TypedCells & operator= (const TypedCells &other) = default; }; diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt index 1db9018bd21..4e8d93e3f81 100644 --- a/searchlib/CMakeLists.txt +++ b/searchlib/CMakeLists.txt @@ -139,6 +139,7 @@ vespa_define_module( src/tests/features/native_dot_product src/tests/features/nns_closeness src/tests/features/nns_distance + src/tests/features/onnx_feature src/tests/features/ranking_expression src/tests/features/raw_score src/tests/features/subqueries diff --git a/searchlib/src/tests/features/onnx_feature/CMakeLists.txt b/searchlib/src/tests/features/onnx_feature/CMakeLists.txt new file mode 100644 index 00000000000..8657d8987da --- /dev/null +++ b/searchlib/src/tests/features/onnx_feature/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(searchlib_onnx_feature_test_app TEST + SOURCES + onnx_feature_test.cpp + DEPENDS + searchlib + GTest::GTest +) +vespa_add_test(NAME searchlib_onnx_feature_test_app COMMAND searchlib_onnx_feature_test_app) diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp new file mode 100644 index 00000000000..cc6b8e0ce29 --- /dev/null +++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp @@ -0,0 +1,100 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/vespalib/stllike/string.h> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/searchlib/features/rankingexpressionfeature.h> +#include <vespa/searchlib/features/onnx_feature.h> +#include <vespa/searchlib/fef/blueprintfactory.h> +#include <vespa/searchlib/fef/indexproperties.h> +#include <vespa/searchlib/fef/matchdatalayout.h> +#include <vespa/searchlib/fef/test/indexenvironment.h> +#include <vespa/searchlib/fef/test/queryenvironment.h> +#include <vespa/searchlib/fef/rank_program.h> +#include <vespa/searchlib/fef/test/test_features.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace search::fef; +using namespace search::fef::test; +using namespace search::features; +using vespalib::make_string_short::fmt; +using vespalib::eval::TensorSpec; + +std::string get_source_dir() { + const char *dir = getenv("SOURCE_DIRECTORY"); + return (dir ? dir : "."); +} +std::string source_dir = get_source_dir(); +std::string vespa_dir = source_dir + "/" + "../../../../.."; +std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx"; + +uint32_t default_docid = 1; + +vespalib::string expr_feature(const vespalib::string &name) { + return fmt("rankingExpression(%s)", name.c_str()); +} + +vespalib::string onnx_feature(const vespalib::string &name) { + return fmt("onnxModel(%s)", name.c_str()); +} + +struct OnnxFeatureTest : ::testing::Test { + BlueprintFactory factory; + IndexEnvironment indexEnv; + BlueprintResolver::SP resolver; + Properties overrides; + MatchData::UP match_data; + RankProgram program; + OnnxFeatureTest() : factory(), indexEnv(), resolver(new BlueprintResolver(factory, indexEnv)), + overrides(), match_data(), program(resolver) + { + factory.addPrototype(std::make_shared<DocidBlueprint>()); + factory.addPrototype(std::make_shared<RankingExpressionBlueprint>()); + factory.addPrototype(std::make_shared<OnnxBlueprint>()); + } + void add_expr(const vespalib::string &name, const vespalib::string &expr) { + vespalib::string feature_name = expr_feature(name); + vespalib::string expr_name = feature_name + ".rankingScript"; + indexEnv.getProperties().add(expr_name, expr); + } + void add_onnx(const vespalib::string &name, const vespalib::string &file) { + vespalib::string feature_name = onnx_feature(name); + vespalib::string file_name = feature_name + ".fileref"; + indexEnv.getProperties().add(file_name, file); + } + void compile(const vespalib::string &seed) { + resolver->addSeed(seed); + ASSERT_TRUE(resolver->compile()); + MatchDataLayout mdl; + QueryEnvironment queryEnv(&indexEnv); + match_data = mdl.createMatchData(); + program.setup(*match_data, queryEnv, overrides); + } + TensorSpec get(const vespalib::string &feature, uint32_t docid) { + auto result = program.get_all_features(false); + for (size_t i = 0; i < result.num_features(); ++i) { + if (result.name_of(i) == feature) { + return TensorSpec::from_value(result.resolve(i).as_object(docid)); + } + } + return TensorSpec("error"); + } + TensorSpec get(uint32_t docid) { + auto result = program.get_seeds(false); + EXPECT_EQ(1u, result.num_features()); + return TensorSpec::from_value(result.resolve(0).as_object(docid)); + } +}; + +TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) { + add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]"); + add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]"); + add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]"); + add_onnx("simple", simple_model); + compile(onnx_feature("simple")); + EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0)); + EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0)); + EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0)); +} + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/searchlib/src/vespa/searchlib/features/CMakeLists.txt b/searchlib/src/vespa/searchlib/features/CMakeLists.txt index 215b6ade9fd..93fead713f4 100644 --- a/searchlib/src/vespa/searchlib/features/CMakeLists.txt +++ b/searchlib/src/vespa/searchlib/features/CMakeLists.txt @@ -40,6 +40,7 @@ vespa_add_library(searchlib_features OBJECT nativeproximityfeature.cpp nativerankfeature.cpp nowfeature.cpp + onnx_feature.cpp proximityfeature.cpp querycompletenessfeature.cpp queryfeature.cpp diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp new file mode 100644 index 00000000000..f6d5c37b61d --- /dev/null +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp @@ -0,0 +1,119 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "onnx_feature.h" +#include <vespa/searchlib/fef/properties.h> +#include <vespa/searchlib/fef/featureexecutor.h> +#include <vespa/eval/tensor/dense/onnx_wrapper.h> +#include <vespa/eval/tensor/dense/dense_tensor_view.h> +#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h> +#include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/stash.h> + +#include <vespa/log/log.h> +LOG_SETUP(".features.onnx_feature"); + +using search::fef::Blueprint; +using search::fef::FeatureExecutor; +using search::fef::FeatureType; +using search::fef::IIndexEnvironment; +using search::fef::IQueryEnvironment; +using search::fef::ParameterList; +using vespalib::Stash; +using vespalib::eval::ValueType; +using vespalib::make_string_short::fmt; +using vespalib::tensor::DenseTensorView; +using vespalib::tensor::MutableDenseTensorView; +using vespalib::tensor::OnnxWrapper; + +namespace search::features { + +/** + * Feature executor that evaluates an onnx model + */ +class OnnxFeatureExecutor : public FeatureExecutor +{ +private: + const OnnxWrapper &_model; + OnnxWrapper::Params _params; + OnnxWrapper::Result _result; + std::vector<MutableDenseTensorView> _views; + +public: + OnnxFeatureExecutor(const OnnxWrapper &model) + : _model(model), _params(), _result(OnnxWrapper::Result::make_empty()), _views() + { + _views.reserve(_model.outputs().size()); + for (const auto &output: _model.outputs()) { + _views.emplace_back(output.make_compatible_type()); + } + } + bool isPure() override { return true; } + void execute(uint32_t) override { + _params = OnnxWrapper::Params(); + for (size_t i = 0; i < _model.inputs().size(); ++i) { + _params.bind(i, static_cast<const DenseTensorView&>(inputs().get_object(i).get())); + } + _result = _model.eval(_params); + for (size_t i = 0; i < _model.outputs().size(); ++i) { + _result.get(i, _views[i]); + outputs().set_object(i, _views[i]); + } + } +}; + +OnnxBlueprint::OnnxBlueprint() + : Blueprint("onnxModel"), + _model(nullptr) +{ +} + +OnnxBlueprint::~OnnxBlueprint() = default; + +bool +OnnxBlueprint::setup(const IIndexEnvironment &env, + const ParameterList ¶ms) +{ + auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP) + ? OnnxWrapper::Optimize::DISABLE + : OnnxWrapper::Optimize::ENABLE; + + // Note: Using the fileref property with the model name as + // fallback to get a file name. This needs to be replaced with an + // actual file reference obtained through config when available. + vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue()); + try { + _model = std::make_unique<OnnxWrapper>(file_name, optimize); + } catch (std::exception &ex) { + return fail("Model setup failed: %s", ex.what()); + } + for (size_t i = 0; i < _model->inputs().size(); ++i) { + const auto &model_input = _model->inputs()[i]; + if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", model_input.name.c_str()), AcceptInput::OBJECT)) { + const FeatureType &feature_input = maybe_input.value(); + assert(feature_input.is_object()); + if (!model_input.is_compatible(feature_input.type())) { + return fail("incompatible type for input '%s': %s -> %s", model_input.name.c_str(), + feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str()); + } + } + } + for (size_t i = 0; i < _model->outputs().size(); ++i) { + const auto &model_output = _model->outputs()[i]; + ValueType output_type = model_output.make_compatible_type(); + if (output_type.is_error()) { + return fail("unable to make compatible type for output '%s': %s -> error", + model_output.name.c_str(), model_output.type_as_string().c_str()); + } + describeOutput(model_output.name, "output from onnx model", FeatureType::object(output_type)); + } + return true; +} + +FeatureExecutor & +OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const +{ + assert(_model); + return stash.create<OnnxFeatureExecutor>(*_model); +} + +} diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h new file mode 100644 index 00000000000..eb6e368ffbd --- /dev/null +++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h @@ -0,0 +1,31 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/searchlib/fef/blueprint.h> + +namespace vespalib::tensor { class OnnxWrapper; } + +namespace search::features { + +/** + * Blueprint for the ranking feature used to evaluate an onnx model. + **/ +class OnnxBlueprint : public fef::Blueprint { +private: + std::unique_ptr<vespalib::tensor::OnnxWrapper> _model; +public: + OnnxBlueprint(); + ~OnnxBlueprint() override; + void visitDumpFeatures(const fef::IIndexEnvironment &, fef::IDumpFeatureVisitor &) const override {} + fef::Blueprint::UP createInstance() const override { + return Blueprint::UP(new OnnxBlueprint()); + } + fef::ParameterDescriptions getDescriptions() const override { + return fef::ParameterDescriptions().desc().string(); + } + bool setup(const fef::IIndexEnvironment &env, const fef::ParameterList ¶ms) override; + fef::FeatureExecutor &createExecutor(const fef::IQueryEnvironment &env, vespalib::Stash &stash) const override; +}; + +} diff --git a/searchlib/src/vespa/searchlib/features/setup.cpp b/searchlib/src/vespa/searchlib/features/setup.cpp index bd79f1d4fb5..c97d2d68812 100644 --- a/searchlib/src/vespa/searchlib/features/setup.cpp +++ b/searchlib/src/vespa/searchlib/features/setup.cpp @@ -23,6 +23,7 @@ #include "flow_completeness_feature.h" #include "foreachfeature.h" #include "freshnessfeature.h" +#include "global_sequence_feature.h" #include "item_raw_score_feature.h" #include "jarowinklerdistancefeature.h" #include "matchcountfeature.h" @@ -34,6 +35,7 @@ #include "nativeproximityfeature.h" #include "nativerankfeature.h" #include "nowfeature.h" +#include "onnx_feature.h" #include "proximityfeature.h" #include "querycompletenessfeature.h" #include "queryfeature.h" @@ -53,7 +55,6 @@ #include "termfeature.h" #include "terminfofeature.h" #include "text_similarity_feature.h" -#include "global_sequence_feature.h" #include "valuefeature.h" #include "max_reduce_prod_join_replacer.h" @@ -123,7 +124,7 @@ void setup_search_features(fef::IBlueprintRegistry & registry) registry.addPrototype(std::make_shared<TermFieldMdBlueprint>()); registry.addPrototype(std::make_shared<ConstantBlueprint>()); registry.addPrototype(std::make_shared<GlobalSequenceBlueprint>()); - + registry.addPrototype(std::make_shared<OnnxBlueprint>()); // Ranking Expression auto replacers = std::make_unique<ListExpressionReplacer>(); |