summaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
authortmartins <thigm85@gmail.com>2020-06-11 11:06:21 +0200
committertmartins <thigm85@gmail.com>2020-06-11 11:06:21 +0200
commit57005775391adc8f8e08288b9663678eb91fead2 (patch)
treee04d946f4f3d5c70a0e7add3c4c81dcc30a67382 /python
parent20fa992576a80dceba5cee4a50c50f620362b5a2 (diff)
query now returns VespaResult
Diffstat (limited to 'python')
-rw-r--r--python/vespa/vespa/application.py8
-rw-r--r--python/vespa/vespa/test_application.py6
2 files changed, 7 insertions, 7 deletions
diff --git a/python/vespa/vespa/application.py b/python/vespa/vespa/application.py
index d875998f4d0..62839de92de 100644
--- a/python/vespa/vespa/application.py
+++ b/python/vespa/vespa/application.py
@@ -4,7 +4,7 @@ from typing import Optional, Dict, Tuple, List
from requests import post
from pandas import DataFrame
-from vespa.query import Query
+from vespa.query import Query, VespaResult
from vespa.evaluation import EvalMetric
@@ -37,7 +37,7 @@ class Vespa(object):
debug_request: bool = False,
recall: Optional[Tuple] = None,
**kwargs
- ) -> Dict:
+ ) -> VespaResult:
"""
Send a query request to the Vespa application.
@@ -71,10 +71,10 @@ class Vespa(object):
body.update(kwargs)
if debug_request:
- return body
+ return VespaResult(vespa_result={}, request_body=body)
else:
r = post(self.search_end_point, json=body)
- return r.json()
+ return VespaResult(vespa_result=r.json())
def collect_training_data_point(
self,
diff --git a/python/vespa/vespa/test_application.py b/python/vespa/vespa/test_application.py
index 57d7d784bde..45e2b30727a 100644
--- a/python/vespa/vespa/test_application.py
+++ b/python/vespa/vespa/test_application.py
@@ -27,7 +27,7 @@ class TestVespaQuery(unittest.TestCase):
app = Vespa(url="http://localhost", port=8080)
body = {"yql": "select * from sources * where test"}
- self.assertDictEqual(app.query(body=body, debug_request=True), body)
+ self.assertDictEqual(app.query(body=body, debug_request=True).request_body, body)
self.assertDictEqual(
app.query(
@@ -35,7 +35,7 @@ class TestVespaQuery(unittest.TestCase):
query_model=Query(match_phase=OR(), rank_profile=RankProfile()),
debug_request=True,
hits=10,
- ),
+ ).request_body,
{
"yql": 'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));',
"ranking": {"profile": "default", "listFeatures": "false"},
@@ -50,7 +50,7 @@ class TestVespaQuery(unittest.TestCase):
debug_request=True,
hits=10,
recall=("id", [1, 5]),
- ),
+ ).request_body,
{
"yql": 'select * from sources * where ([{"grammar": "any"}]userInput("this is a test"));',
"ranking": {"profile": "default", "listFeatures": "false"},