aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2020-08-20 12:09:57 +0000
committerHåvard Pettersen <havardpe@oath.com>2020-08-21 11:03:42 +0000
commit7e4d00f04c5cd5ca229a5559f05a069908f6144a (patch)
treeaa6de03f758780ff331bf3bb1658eebd85601aa8
parent67e528443cca68cc527e50c2714ad1717563c458 (diff)
onnx ranking feature
-rw-r--r--eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp45
-rw-r--r--eval/src/vespa/eval/eval/tensor_spec.cpp15
-rw-r--r--eval/src/vespa/eval/eval/tensor_spec.h3
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h1
-rw-r--r--eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h1
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp32
-rw-r--r--eval/src/vespa/eval/tensor/dense/onnx_wrapper.h4
-rw-r--r--eval/src/vespa/eval/tensor/dense/typed_cells.h3
-rw-r--r--searchlib/CMakeLists.txt1
-rw-r--r--searchlib/src/tests/features/onnx_feature/CMakeLists.txt9
-rw-r--r--searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp100
-rw-r--r--searchlib/src/vespa/searchlib/features/CMakeLists.txt1
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.cpp119
-rw-r--r--searchlib/src/vespa/searchlib/features/onnx_feature.h31
-rw-r--r--searchlib/src/vespa/searchlib/features/setup.cpp5
15 files changed, 328 insertions, 42 deletions
diff --git a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
index 07e065f9e39..28a4a34b2e4 100644
--- a/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
+++ b/eval/src/tests/tensor/onnx_wrapper/onnx_wrapper_test.cpp
@@ -19,33 +19,10 @@ std::string source_dir = get_source_dir();
std::string vespa_dir = source_dir + "/" + "../../../../..";
std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx";
-vespalib::string to_str(const std::vector<size_t> &dim_sizes) {
- vespalib::string res;
- for (size_t dim_size: dim_sizes) {
- if (dim_size == 0) {
- res += "[]";
- } else {
- res += fmt("[%zu]", dim_size);
- }
- }
- return res;
-}
-
-vespalib::string to_str(OnnxWrapper::TensorInfo::ElementType element_type) {
- if (element_type == OnnxWrapper::TensorInfo::ElementType::FLOAT) {
- return "float";
- }
- if (element_type == OnnxWrapper::TensorInfo::ElementType::DOUBLE) {
- return "double";
- }
- return "???";
-}
-
void dump_info(const char *ctx, const std::vector<OnnxWrapper::TensorInfo> &info) {
fprintf(stderr, "%s:\n", ctx);
for (size_t i = 0; i < info.size(); ++i) {
- fprintf(stderr, " %s[%zu]: '%s' %s%s\n", ctx, i, info[i].name.c_str(),
- to_str(info[i].elements).c_str(),to_str(info[i].dimensions).c_str());
+ fprintf(stderr, " %s[%zu]: '%s' %s\n", ctx, i, info[i].name.c_str(), info[i].type_as_string().c_str());
}
}
@@ -57,21 +34,17 @@ TEST(OnnxWrapperTest, onnx_model_can_be_inspected)
ASSERT_EQ(wrapper.inputs().size(), 3);
ASSERT_EQ(wrapper.outputs().size(), 1);
//-------------------------------------------------------------------------
- EXPECT_EQ( wrapper.inputs()[0].name, "query_tensor");
- EXPECT_EQ(to_str(wrapper.inputs()[0].dimensions), "[1][4]");
- EXPECT_EQ(to_str(wrapper.inputs()[0].elements), "float");
+ EXPECT_EQ(wrapper.inputs()[0].name, "query_tensor");
+ EXPECT_EQ(wrapper.inputs()[0].type_as_string(), "float[1][4]");
//-------------------------------------------------------------------------
- EXPECT_EQ( wrapper.inputs()[1].name, "attribute_tensor");
- EXPECT_EQ(to_str(wrapper.inputs()[1].dimensions), "[4][1]");
- EXPECT_EQ(to_str(wrapper.inputs()[1].elements), "float");
+ EXPECT_EQ(wrapper.inputs()[1].name, "attribute_tensor");
+ EXPECT_EQ(wrapper.inputs()[1].type_as_string(), "float[4][1]");
//-------------------------------------------------------------------------
- EXPECT_EQ( wrapper.inputs()[2].name, "bias_tensor");
- EXPECT_EQ(to_str(wrapper.inputs()[2].dimensions), "[1][1]");
- EXPECT_EQ(to_str(wrapper.inputs()[2].elements), "float");
+ EXPECT_EQ(wrapper.inputs()[2].name, "bias_tensor");
+ EXPECT_EQ(wrapper.inputs()[2].type_as_string(), "float[1][1]");
//-------------------------------------------------------------------------
- EXPECT_EQ( wrapper.outputs()[0].name, "output");
- EXPECT_EQ(to_str(wrapper.outputs()[0].dimensions), "[1][1]");
- EXPECT_EQ(to_str(wrapper.outputs()[0].elements), "float");
+ EXPECT_EQ(wrapper.outputs()[0].name, "output");
+ EXPECT_EQ(wrapper.outputs()[0].type_as_string(), "float[1][1]");
}
TEST(OnnxWrapperTest, onnx_model_can_be_evaluated)
diff --git a/eval/src/vespa/eval/eval/tensor_spec.cpp b/eval/src/vespa/eval/eval/tensor_spec.cpp
index a24b45dc3f5..b4b2e3d3afc 100644
--- a/eval/src/vespa/eval/eval/tensor_spec.cpp
+++ b/eval/src/vespa/eval/eval/tensor_spec.cpp
@@ -1,6 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "tensor_spec.h"
+#include "value.h"
+#include "tensor.h"
+#include "tensor_engine.h"
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/data/slime/slime.h>
#include <ostream>
@@ -94,6 +97,18 @@ TensorSpec::from_slime(const slime::Inspector &tensor)
return spec;
}
+TensorSpec
+TensorSpec::from_value(const eval::Value &value)
+{
+ if (const eval::Tensor *tensor = value.as_tensor()) {
+ return tensor->engine().to_spec(*tensor);
+ }
+ if (value.is_double()) {
+ return TensorSpec("double").add({}, value.as_double());
+ }
+ return TensorSpec("error");
+}
+
bool
operator==(const TensorSpec &lhs, const TensorSpec &rhs)
{
diff --git a/eval/src/vespa/eval/eval/tensor_spec.h b/eval/src/vespa/eval/eval/tensor_spec.h
index 22aa47f5ddb..f4f116454d5 100644
--- a/eval/src/vespa/eval/eval/tensor_spec.h
+++ b/eval/src/vespa/eval/eval/tensor_spec.h
@@ -18,6 +18,8 @@ struct Inspector;
namespace eval {
+class Value;
+
/**
* An implementation-independent specification of the type and
* contents of a tensor.
@@ -78,6 +80,7 @@ public:
vespalib::string to_string() const;
void to_slime(slime::Cursor &tensor) const;
static TensorSpec from_slime(const slime::Inspector &tensor);
+ static TensorSpec from_value(const eval::Value &value);
};
bool operator==(const TensorSpec &lhs, const TensorSpec &rhs);
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index 33183d267c1..93dd2dbedeb 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -55,6 +55,7 @@ protected:
: _typeRef(type_in),
_cellsRef()
{}
+ DenseTensorView(const DenseTensorView &rhs) : DenseTensorView(rhs._typeRef, rhs._cellsRef) {}
void initCellsRef(TypedCells cells_in) {
assert(_typeRef.cell_type() == cells_in.type);
diff --git a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
index db241ef6a2b..5e4a48462d7 100644
--- a/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/mutable_dense_tensor_view.h
@@ -17,6 +17,7 @@ private:
public:
MutableDenseTensorView(eval::ValueType type_in);
+ MutableDenseTensorView(MutableDenseTensorView &&) = default;
void setCells(TypedCells cells_in) {
initCellsRef(cells_in);
}
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
index fa0379473c9..125095ff23e 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.cpp
@@ -18,6 +18,16 @@ namespace vespalib::tensor {
namespace {
+vespalib::string to_str(OnnxWrapper::TensorInfo::ElementType element_type) {
+ if (element_type == OnnxWrapper::TensorInfo::ElementType::FLOAT) {
+ return "float";
+ }
+ if (element_type == OnnxWrapper::TensorInfo::ElementType::DOUBLE) {
+ return "double";
+ }
+ return "???";
+}
+
ValueType::CellType as_cell_type(OnnxWrapper::TensorInfo::ElementType type) {
if (type == OnnxWrapper::TensorInfo::ElementType::FLOAT) {
return ValueType::CellType::FLOAT;
@@ -134,6 +144,20 @@ OnnxWrapper::TensorInfo::make_compatible_type() const
return ValueType::tensor_type(std::move(dim_list), as_cell_type(elements));
}
+vespalib::string
+OnnxWrapper::TensorInfo::type_as_string() const
+{
+ vespalib::string res = to_str(elements);
+ for (size_t dim_size: dimensions) {
+ if (dim_size == 0) {
+ res += "[]";
+ } else {
+ res += fmt("[%zu]", dim_size);
+ }
+ }
+ return res;
+}
+
OnnxWrapper::TensorInfo::~TensorInfo() = default;
OnnxWrapper::Shared::Shared()
@@ -222,12 +246,14 @@ OnnxWrapper::OnnxWrapper(const vespalib::string &model_file, Optimize optimize)
OnnxWrapper::~OnnxWrapper() = default;
OnnxWrapper::Result
-OnnxWrapper::eval(const Params &params)
+OnnxWrapper::eval(const Params &params) const
{
assert(params.values.size() == _inputs.size());
Ort::RunOptions run_opts(nullptr);
- return Result(_session.Run(run_opts, _input_name_refs.data(), params.values.data(), _inputs.size(),
- _output_name_refs.data(), _outputs.size()));
+ // NB: Run requires non-const session
+ Ort::Session &session = const_cast<Ort::Session&>(_session);
+ return Result(session.Run(run_opts, _input_name_refs.data(), params.values.data(), _inputs.size(),
+ _output_name_refs.data(), _outputs.size()));
}
}
diff --git a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
index 67a64f2d318..abe1da252c7 100644
--- a/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
+++ b/eval/src/vespa/eval/tensor/dense/onnx_wrapper.h
@@ -28,6 +28,7 @@ public:
ElementType elements;
bool is_compatible(const eval::ValueType &type) const;
eval::ValueType make_compatible_type() const;
+ vespalib::string type_as_string() const;
~TensorInfo();
};
@@ -48,6 +49,7 @@ public:
std::vector<Ort::Value> values;
Result(std::vector<Ort::Value> values_in) : values(std::move(values_in)) {}
public:
+ static Result make_empty() { return Result({}); }
size_t num_values() const { return values.size(); }
void get(size_t idx, MutableDenseTensorView &dst);
};
@@ -78,7 +80,7 @@ public:
~OnnxWrapper();
const std::vector<TensorInfo> &inputs() const { return _inputs; }
const std::vector<TensorInfo> &outputs() const { return _outputs; }
- Result eval(const Params &params); // NB: Run requires non-const session
+ Result eval(const Params &params) const;
};
}
diff --git a/eval/src/vespa/eval/tensor/dense/typed_cells.h b/eval/src/vespa/eval/tensor/dense/typed_cells.h
index d1d2baa535e..6ea2b40689e 100644
--- a/eval/src/vespa/eval/tensor/dense/typed_cells.h
+++ b/eval/src/vespa/eval/tensor/dense/typed_cells.h
@@ -44,6 +44,9 @@ struct TypedCells {
abort();
}
+ TypedCells(TypedCells &&other) = default;
+ TypedCells(const TypedCells &other) = default;
+ TypedCells & operator= (TypedCells &&other) = default;
TypedCells & operator= (const TypedCells &other) = default;
};
diff --git a/searchlib/CMakeLists.txt b/searchlib/CMakeLists.txt
index 1db9018bd21..4e8d93e3f81 100644
--- a/searchlib/CMakeLists.txt
+++ b/searchlib/CMakeLists.txt
@@ -139,6 +139,7 @@ vespa_define_module(
src/tests/features/native_dot_product
src/tests/features/nns_closeness
src/tests/features/nns_distance
+ src/tests/features/onnx_feature
src/tests/features/ranking_expression
src/tests/features/raw_score
src/tests/features/subqueries
diff --git a/searchlib/src/tests/features/onnx_feature/CMakeLists.txt b/searchlib/src/tests/features/onnx_feature/CMakeLists.txt
new file mode 100644
index 00000000000..8657d8987da
--- /dev/null
+++ b/searchlib/src/tests/features/onnx_feature/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(searchlib_onnx_feature_test_app TEST
+ SOURCES
+ onnx_feature_test.cpp
+ DEPENDS
+ searchlib
+ GTest::GTest
+)
+vespa_add_test(NAME searchlib_onnx_feature_test_app COMMAND searchlib_onnx_feature_test_app)
diff --git a/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
new file mode 100644
index 00000000000..cc6b8e0ce29
--- /dev/null
+++ b/searchlib/src/tests/features/onnx_feature/onnx_feature_test.cpp
@@ -0,0 +1,100 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/stllike/string.h>
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/searchlib/features/rankingexpressionfeature.h>
+#include <vespa/searchlib/features/onnx_feature.h>
+#include <vespa/searchlib/fef/blueprintfactory.h>
+#include <vespa/searchlib/fef/indexproperties.h>
+#include <vespa/searchlib/fef/matchdatalayout.h>
+#include <vespa/searchlib/fef/test/indexenvironment.h>
+#include <vespa/searchlib/fef/test/queryenvironment.h>
+#include <vespa/searchlib/fef/rank_program.h>
+#include <vespa/searchlib/fef/test/test_features.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace search::fef;
+using namespace search::fef::test;
+using namespace search::features;
+using vespalib::make_string_short::fmt;
+using vespalib::eval::TensorSpec;
+
+std::string get_source_dir() {
+ const char *dir = getenv("SOURCE_DIRECTORY");
+ return (dir ? dir : ".");
+}
+std::string source_dir = get_source_dir();
+std::string vespa_dir = source_dir + "/" + "../../../../..";
+std::string simple_model = vespa_dir + "/" + "model-integration/src/test/models/onnx/simple/simple.onnx";
+
+uint32_t default_docid = 1;
+
+vespalib::string expr_feature(const vespalib::string &name) {
+ return fmt("rankingExpression(%s)", name.c_str());
+}
+
+vespalib::string onnx_feature(const vespalib::string &name) {
+ return fmt("onnxModel(%s)", name.c_str());
+}
+
+struct OnnxFeatureTest : ::testing::Test {
+ BlueprintFactory factory;
+ IndexEnvironment indexEnv;
+ BlueprintResolver::SP resolver;
+ Properties overrides;
+ MatchData::UP match_data;
+ RankProgram program;
+ OnnxFeatureTest() : factory(), indexEnv(), resolver(new BlueprintResolver(factory, indexEnv)),
+ overrides(), match_data(), program(resolver)
+ {
+ factory.addPrototype(std::make_shared<DocidBlueprint>());
+ factory.addPrototype(std::make_shared<RankingExpressionBlueprint>());
+ factory.addPrototype(std::make_shared<OnnxBlueprint>());
+ }
+ void add_expr(const vespalib::string &name, const vespalib::string &expr) {
+ vespalib::string feature_name = expr_feature(name);
+ vespalib::string expr_name = feature_name + ".rankingScript";
+ indexEnv.getProperties().add(expr_name, expr);
+ }
+ void add_onnx(const vespalib::string &name, const vespalib::string &file) {
+ vespalib::string feature_name = onnx_feature(name);
+ vespalib::string file_name = feature_name + ".fileref";
+ indexEnv.getProperties().add(file_name, file);
+ }
+ void compile(const vespalib::string &seed) {
+ resolver->addSeed(seed);
+ ASSERT_TRUE(resolver->compile());
+ MatchDataLayout mdl;
+ QueryEnvironment queryEnv(&indexEnv);
+ match_data = mdl.createMatchData();
+ program.setup(*match_data, queryEnv, overrides);
+ }
+ TensorSpec get(const vespalib::string &feature, uint32_t docid) {
+ auto result = program.get_all_features(false);
+ for (size_t i = 0; i < result.num_features(); ++i) {
+ if (result.name_of(i) == feature) {
+ return TensorSpec::from_value(result.resolve(i).as_object(docid));
+ }
+ }
+ return TensorSpec("error");
+ }
+ TensorSpec get(uint32_t docid) {
+ auto result = program.get_seeds(false);
+ EXPECT_EQ(1u, result.num_features());
+ return TensorSpec::from_value(result.resolve(0).as_object(docid));
+ }
+};
+
+TEST_F(OnnxFeatureTest, simple_onnx_model_can_be_calculated) {
+ add_expr("query_tensor", "tensor<float>(a[1],b[4]):[[docid,2,3,4]]");
+ add_expr("attribute_tensor", "tensor<float>(a[4],b[1]):[[5],[6],[7],[8]]");
+ add_expr("bias_tensor", "tensor<float>(a[1],b[1]):[[9]]");
+ add_onnx("simple", simple_model);
+ compile(onnx_feature("simple"));
+ EXPECT_EQ(get(1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get("onnxModel(simple).output", 1), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 79.0));
+ EXPECT_EQ(get(2), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 84.0));
+ EXPECT_EQ(get(3), TensorSpec("tensor<float>(d0[1],d1[1])").add({{"d0",0},{"d1",0}}, 89.0));
+}
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/searchlib/src/vespa/searchlib/features/CMakeLists.txt b/searchlib/src/vespa/searchlib/features/CMakeLists.txt
index 215b6ade9fd..93fead713f4 100644
--- a/searchlib/src/vespa/searchlib/features/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/features/CMakeLists.txt
@@ -40,6 +40,7 @@ vespa_add_library(searchlib_features OBJECT
nativeproximityfeature.cpp
nativerankfeature.cpp
nowfeature.cpp
+ onnx_feature.cpp
proximityfeature.cpp
querycompletenessfeature.cpp
queryfeature.cpp
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.cpp b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
new file mode 100644
index 00000000000..f6d5c37b61d
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.cpp
@@ -0,0 +1,119 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "onnx_feature.h"
+#include <vespa/searchlib/fef/properties.h>
+#include <vespa/searchlib/fef/featureexecutor.h>
+#include <vespa/eval/tensor/dense/onnx_wrapper.h>
+#include <vespa/eval/tensor/dense/dense_tensor_view.h>
+#include <vespa/eval/tensor/dense/mutable_dense_tensor_view.h>
+#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/stash.h>
+
+#include <vespa/log/log.h>
+LOG_SETUP(".features.onnx_feature");
+
+using search::fef::Blueprint;
+using search::fef::FeatureExecutor;
+using search::fef::FeatureType;
+using search::fef::IIndexEnvironment;
+using search::fef::IQueryEnvironment;
+using search::fef::ParameterList;
+using vespalib::Stash;
+using vespalib::eval::ValueType;
+using vespalib::make_string_short::fmt;
+using vespalib::tensor::DenseTensorView;
+using vespalib::tensor::MutableDenseTensorView;
+using vespalib::tensor::OnnxWrapper;
+
+namespace search::features {
+
+/**
+ * Feature executor that evaluates an onnx model
+ */
+class OnnxFeatureExecutor : public FeatureExecutor
+{
+private:
+ const OnnxWrapper &_model;
+ OnnxWrapper::Params _params;
+ OnnxWrapper::Result _result;
+ std::vector<MutableDenseTensorView> _views;
+
+public:
+ OnnxFeatureExecutor(const OnnxWrapper &model)
+ : _model(model), _params(), _result(OnnxWrapper::Result::make_empty()), _views()
+ {
+ _views.reserve(_model.outputs().size());
+ for (const auto &output: _model.outputs()) {
+ _views.emplace_back(output.make_compatible_type());
+ }
+ }
+ bool isPure() override { return true; }
+ void execute(uint32_t) override {
+ _params = OnnxWrapper::Params();
+ for (size_t i = 0; i < _model.inputs().size(); ++i) {
+ _params.bind(i, static_cast<const DenseTensorView&>(inputs().get_object(i).get()));
+ }
+ _result = _model.eval(_params);
+ for (size_t i = 0; i < _model.outputs().size(); ++i) {
+ _result.get(i, _views[i]);
+ outputs().set_object(i, _views[i]);
+ }
+ }
+};
+
+OnnxBlueprint::OnnxBlueprint()
+ : Blueprint("onnxModel"),
+ _model(nullptr)
+{
+}
+
+OnnxBlueprint::~OnnxBlueprint() = default;
+
+bool
+OnnxBlueprint::setup(const IIndexEnvironment &env,
+ const ParameterList &params)
+{
+ auto optimize = (env.getFeatureMotivation() == env.FeatureMotivation::VERIFY_SETUP)
+ ? OnnxWrapper::Optimize::DISABLE
+ : OnnxWrapper::Optimize::ENABLE;
+
+ // Note: Using the fileref property with the model name as
+ // fallback to get a file name. This needs to be replaced with an
+ // actual file reference obtained through config when available.
+ vespalib::string file_name = env.getProperties().lookup(getName(), "fileref").get(params[0].getValue());
+ try {
+ _model = std::make_unique<OnnxWrapper>(file_name, optimize);
+ } catch (std::exception &ex) {
+ return fail("Model setup failed: %s", ex.what());
+ }
+ for (size_t i = 0; i < _model->inputs().size(); ++i) {
+ const auto &model_input = _model->inputs()[i];
+ if (auto maybe_input = defineInput(fmt("rankingExpression(\"%s\")", model_input.name.c_str()), AcceptInput::OBJECT)) {
+ const FeatureType &feature_input = maybe_input.value();
+ assert(feature_input.is_object());
+ if (!model_input.is_compatible(feature_input.type())) {
+ return fail("incompatible type for input '%s': %s -> %s", model_input.name.c_str(),
+ feature_input.type().to_spec().c_str(), model_input.type_as_string().c_str());
+ }
+ }
+ }
+ for (size_t i = 0; i < _model->outputs().size(); ++i) {
+ const auto &model_output = _model->outputs()[i];
+ ValueType output_type = model_output.make_compatible_type();
+ if (output_type.is_error()) {
+ return fail("unable to make compatible type for output '%s': %s -> error",
+ model_output.name.c_str(), model_output.type_as_string().c_str());
+ }
+ describeOutput(model_output.name, "output from onnx model", FeatureType::object(output_type));
+ }
+ return true;
+}
+
+FeatureExecutor &
+OnnxBlueprint::createExecutor(const IQueryEnvironment &, Stash &stash) const
+{
+ assert(_model);
+ return stash.create<OnnxFeatureExecutor>(*_model);
+}
+
+}
diff --git a/searchlib/src/vespa/searchlib/features/onnx_feature.h b/searchlib/src/vespa/searchlib/features/onnx_feature.h
new file mode 100644
index 00000000000..eb6e368ffbd
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/features/onnx_feature.h
@@ -0,0 +1,31 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/searchlib/fef/blueprint.h>
+
+namespace vespalib::tensor { class OnnxWrapper; }
+
+namespace search::features {
+
+/**
+ * Blueprint for the ranking feature used to evaluate an onnx model.
+ **/
+class OnnxBlueprint : public fef::Blueprint {
+private:
+ std::unique_ptr<vespalib::tensor::OnnxWrapper> _model;
+public:
+ OnnxBlueprint();
+ ~OnnxBlueprint() override;
+ void visitDumpFeatures(const fef::IIndexEnvironment &, fef::IDumpFeatureVisitor &) const override {}
+ fef::Blueprint::UP createInstance() const override {
+ return Blueprint::UP(new OnnxBlueprint());
+ }
+ fef::ParameterDescriptions getDescriptions() const override {
+ return fef::ParameterDescriptions().desc().string();
+ }
+ bool setup(const fef::IIndexEnvironment &env, const fef::ParameterList &params) override;
+ fef::FeatureExecutor &createExecutor(const fef::IQueryEnvironment &env, vespalib::Stash &stash) const override;
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/features/setup.cpp b/searchlib/src/vespa/searchlib/features/setup.cpp
index bd79f1d4fb5..c97d2d68812 100644
--- a/searchlib/src/vespa/searchlib/features/setup.cpp
+++ b/searchlib/src/vespa/searchlib/features/setup.cpp
@@ -23,6 +23,7 @@
#include "flow_completeness_feature.h"
#include "foreachfeature.h"
#include "freshnessfeature.h"
+#include "global_sequence_feature.h"
#include "item_raw_score_feature.h"
#include "jarowinklerdistancefeature.h"
#include "matchcountfeature.h"
@@ -34,6 +35,7 @@
#include "nativeproximityfeature.h"
#include "nativerankfeature.h"
#include "nowfeature.h"
+#include "onnx_feature.h"
#include "proximityfeature.h"
#include "querycompletenessfeature.h"
#include "queryfeature.h"
@@ -53,7 +55,6 @@
#include "termfeature.h"
#include "terminfofeature.h"
#include "text_similarity_feature.h"
-#include "global_sequence_feature.h"
#include "valuefeature.h"
#include "max_reduce_prod_join_replacer.h"
@@ -123,7 +124,7 @@ void setup_search_features(fef::IBlueprintRegistry & registry)
registry.addPrototype(std::make_shared<TermFieldMdBlueprint>());
registry.addPrototype(std::make_shared<ConstantBlueprint>());
registry.addPrototype(std::make_shared<GlobalSequenceBlueprint>());
-
+ registry.addPrototype(std::make_shared<OnnxBlueprint>());
// Ranking Expression
auto replacers = std::make_unique<ListExpressionReplacer>();