summaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
authortmartins <thigm85@gmail.com>2020-06-11 11:17:48 +0200
committertmartins <thigm85@gmail.com>2020-06-11 11:17:48 +0200
commit4a162fc061f7bbdbbd9ccabe285d5135eab2cabd (patch)
treed853f3e7aa64e893544c9067579e55b577a5c672 /python
parent17a51be5d6fc8a2ea11ecb0eb14bbbf3bd384fa2 (diff)
adapt collect_training_data unit test to use VespaResult
Diffstat (limited to 'python')
-rw-r--r--python/vespa/vespa/test_application.py16
1 files changed, 12 insertions, 4 deletions
diff --git a/python/vespa/vespa/test_application.py b/python/vespa/vespa/test_application.py
index 45e2b30727a..84bd1c0a6ad 100644
--- a/python/vespa/vespa/test_application.py
+++ b/python/vespa/vespa/test_application.py
@@ -6,7 +6,7 @@ from pandas import DataFrame
from pandas.testing import assert_frame_equal
from vespa.application import Vespa
-from vespa.query import Query, OR, RankProfile
+from vespa.query import Query, OR, RankProfile, VespaResult
class TestVespa(unittest.TestCase):
@@ -27,7 +27,9 @@ class TestVespaQuery(unittest.TestCase):
app = Vespa(url="http://localhost", port=8080)
body = {"yql": "select * from sources * where test"}
- self.assertDictEqual(app.query(body=body, debug_request=True).request_body, body)
+ self.assertDictEqual(
+ app.query(body=body, debug_request=True).request_body, body
+ )
self.assertDictEqual(
app.query(
@@ -149,7 +151,10 @@ class TestVespaCollectData(unittest.TestCase):
def test_collect_training_data_point(self):
self.app.query = Mock(
- side_effect=[self.raw_vespa_result_recall, self.raw_vespa_result_additional]
+ side_effect=[
+ VespaResult(self.raw_vespa_result_recall),
+ VespaResult(self.raw_vespa_result_additional),
+ ]
)
query_model = Query(rank_profile=RankProfile(list_features=True))
data = self.app.collect_training_data_point(
@@ -204,7 +209,10 @@ class TestVespaCollectData(unittest.TestCase):
}
}
self.app.query = Mock(
- side_effect=[self.raw_vespa_result_recall, self.raw_vespa_result_additional]
+ side_effect=[
+ VespaResult(self.raw_vespa_result_recall),
+ VespaResult(self.raw_vespa_result_additional),
+ ]
)
query_model = Query(rank_profile=RankProfile(list_features=True))
data = self.app.collect_training_data_point(