diff options
author | tmartins <thigm85@gmail.com> | 2020-06-11 11:17:48 +0200 |
---|---|---|
committer | tmartins <thigm85@gmail.com> | 2020-06-11 11:17:48 +0200 |
commit | 4a162fc061f7bbdbbd9ccabe285d5135eab2cabd (patch) | |
tree | d853f3e7aa64e893544c9067579e55b577a5c672 /python | |
parent | 17a51be5d6fc8a2ea11ecb0eb14bbbf3bd384fa2 (diff) |
adapt collect_training_data unit test to use VespaResult
Diffstat (limited to 'python')
-rw-r--r-- | python/vespa/vespa/test_application.py | 16 |
1 files changed, 12 insertions, 4 deletions
diff --git a/python/vespa/vespa/test_application.py b/python/vespa/vespa/test_application.py index 45e2b30727a..84bd1c0a6ad 100644 --- a/python/vespa/vespa/test_application.py +++ b/python/vespa/vespa/test_application.py @@ -6,7 +6,7 @@ from pandas import DataFrame from pandas.testing import assert_frame_equal from vespa.application import Vespa -from vespa.query import Query, OR, RankProfile +from vespa.query import Query, OR, RankProfile, VespaResult class TestVespa(unittest.TestCase): @@ -27,7 +27,9 @@ class TestVespaQuery(unittest.TestCase): app = Vespa(url="http://localhost", port=8080) body = {"yql": "select * from sources * where test"} - self.assertDictEqual(app.query(body=body, debug_request=True).request_body, body) + self.assertDictEqual( + app.query(body=body, debug_request=True).request_body, body + ) self.assertDictEqual( app.query( @@ -149,7 +151,10 @@ class TestVespaCollectData(unittest.TestCase): def test_collect_training_data_point(self): self.app.query = Mock( - side_effect=[self.raw_vespa_result_recall, self.raw_vespa_result_additional] + side_effect=[ + VespaResult(self.raw_vespa_result_recall), + VespaResult(self.raw_vespa_result_additional), + ] ) query_model = Query(rank_profile=RankProfile(list_features=True)) data = self.app.collect_training_data_point( @@ -204,7 +209,10 @@ class TestVespaCollectData(unittest.TestCase): } } self.app.query = Mock( - side_effect=[self.raw_vespa_result_recall, self.raw_vespa_result_additional] + side_effect=[ + VespaResult(self.raw_vespa_result_recall), + VespaResult(self.raw_vespa_result_additional), + ] ) query_model = Query(rank_profile=RankProfile(list_features=True)) data = self.app.collect_training_data_point( |