summaryrefslogtreecommitdiffstats
path: root/python
diff options
context:
space:
mode:
authortmartins <thigm85@gmail.com>2020-06-11 11:33:49 +0200
committertmartins <thigm85@gmail.com>2020-06-11 11:33:49 +0200
commita07691defc45df452a4d220d55d3e2f5801b030f (patch)
tree10a309703029d3e9ab75483b4459ed09b4c707fb /python
parentab9a57477511d98706ef7b34e0215e8deba32b31 (diff)
use VespaResult in the evaluation code
Diffstat (limited to 'python')
-rw-r--r--python/vespa/vespa/evaluation.py30
-rw-r--r--python/vespa/vespa/test_evaluation.py87
2 files changed, 50 insertions, 67 deletions
diff --git a/python/vespa/vespa/evaluation.py b/python/vespa/vespa/evaluation.py
index 98365640fb3..4ca7a1d136b 100644
--- a/python/vespa/vespa/evaluation.py
+++ b/python/vespa/vespa/evaluation.py
@@ -1,8 +1,7 @@
# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
from typing import Dict, List
-
-# todo: When creating a VespaResult class use getters with appropriate defaults to avoid the try clauses here.
+from vespa.query import VespaResult
class EvalMetric(object):
@@ -25,7 +24,7 @@ class MatchRatio(EvalMetric):
def evaluate_query(
self,
- query_results: Dict,
+ query_results: VespaResult,
relevant_docs: List[Dict],
id_field: str,
default_score: int,
@@ -40,16 +39,11 @@ class MatchRatio(EvalMetric):
:return: Dict containing the number of retrieved docs (_retrieved_docs), the number of docs available in
the corpus (_docs_available) and the match ratio (_value).
"""
- try:
- retrieved_docs = query_results["root"]["fields"]["totalCount"]
- except KeyError:
- retrieved_docs = 0
- try:
- docs_available = query_results["root"]["coverage"]["documents"]
+ retrieved_docs = query_results.number_documents_retrieved
+ docs_available = query_results.number_documents_indexed
+ value = 0
+ if docs_available > 0:
value = retrieved_docs / docs_available
- except KeyError:
- docs_available = 0
- value = 0
return {
str(self.name) + "_retrieved_docs": retrieved_docs,
str(self.name) + "_docs_available": docs_available,
@@ -70,7 +64,7 @@ class Recall(EvalMetric):
def evaluate_query(
self,
- query_results: Dict,
+ query_results: VespaResult,
relevant_docs: List[Dict],
id_field: str,
default_score: int,
@@ -88,8 +82,7 @@ class Recall(EvalMetric):
relevant_ids = {str(doc["id"]) for doc in relevant_docs}
try:
retrieved_ids = {
- str(hit["fields"][id_field])
- for hit in query_results["root"]["children"][: self.at]
+ str(hit["fields"][id_field]) for hit in query_results.hits[: self.at]
}
except KeyError:
retrieved_ids = set()
@@ -113,7 +106,7 @@ class ReciprocalRank(EvalMetric):
def evaluate_query(
self,
- query_results: Dict,
+ query_results: VespaResult,
relevant_docs: List[Dict],
id_field: str,
default_score: int,
@@ -130,10 +123,7 @@ class ReciprocalRank(EvalMetric):
relevant_ids = {str(doc["id"]) for doc in relevant_docs}
rr = 0
- try:
- hits = query_results["root"]["children"][: self.at]
- except KeyError:
- hits = []
+ hits = query_results.hits[: self.at]
for index, hit in enumerate(hits):
if hit["fields"][id_field] in relevant_ids:
rr = 1 / (index + 1)
diff --git a/python/vespa/vespa/test_evaluation.py b/python/vespa/vespa/test_evaluation.py
index 5fa29eb3907..b6941985d94 100644
--- a/python/vespa/vespa/test_evaluation.py
+++ b/python/vespa/vespa/test_evaluation.py
@@ -2,6 +2,7 @@
import unittest
+from vespa.query import VespaResult
from vespa.evaluation import MatchRatio, Recall, ReciprocalRank
@@ -61,7 +62,7 @@ class TestEvalMetric(unittest.TestCase):
metric = MatchRatio()
evaluation = metric.evaluate_query(
- query_results=self.query_results,
+ query_results=VespaResult(self.query_results),
relevant_docs=self.labelled_data[0]["relevant_docs"],
id_field="vespa_id_field",
default_score=0,
@@ -77,20 +78,22 @@ class TestEvalMetric(unittest.TestCase):
)
evaluation = metric.evaluate_query(
- query_results={
- "root": {
- "id": "toplevel",
- "relevance": 1.0,
- "coverage": {
- "coverage": 100,
- "documents": 62529,
- "full": True,
- "nodes": 2,
- "results": 1,
- "resultsFull": 1,
- },
+ query_results=VespaResult(
+ {
+ "root": {
+ "id": "toplevel",
+ "relevance": 1.0,
+ "coverage": {
+ "coverage": 100,
+ "documents": 62529,
+ "full": True,
+ "nodes": 2,
+ "results": 1,
+ "resultsFull": 1,
+ },
+ }
}
- },
+ ),
relevant_docs=self.labelled_data[0]["relevant_docs"],
id_field="vespa_id_field",
default_score=0,
@@ -106,20 +109,22 @@ class TestEvalMetric(unittest.TestCase):
)
evaluation = metric.evaluate_query(
- query_results={
- "root": {
- "id": "toplevel",
- "relevance": 1.0,
- "fields": {"totalCount": 1083},
- "coverage": {
- "coverage": 100,
- "full": True,
- "nodes": 2,
- "results": 1,
- "resultsFull": 1,
- },
+ query_results=VespaResult(
+ {
+ "root": {
+ "id": "toplevel",
+ "relevance": 1.0,
+ "fields": {"totalCount": 1083},
+ "coverage": {
+ "coverage": 100,
+ "full": True,
+ "nodes": 2,
+ "results": 1,
+ "resultsFull": 1,
+ },
+ }
}
- },
+ ),
relevant_docs=self.labelled_data[0]["relevant_docs"],
id_field="vespa_id_field",
default_score=0,
@@ -137,57 +142,45 @@ class TestEvalMetric(unittest.TestCase):
def test_recall(self):
metric = Recall(at=2)
evaluation = metric.evaluate_query(
- query_results=self.query_results,
+ query_results=VespaResult(self.query_results),
relevant_docs=self.labelled_data[0]["relevant_docs"],
id_field="vespa_id_field",
default_score=0,
)
self.assertDictEqual(
- evaluation,
- {
- "recall_2_value": 0.5,
- },
+ evaluation, {"recall_2_value": 0.5,},
)
metric = Recall(at=1)
evaluation = metric.evaluate_query(
- query_results=self.query_results,
+ query_results=VespaResult(self.query_results),
relevant_docs=self.labelled_data[0]["relevant_docs"],
id_field="vespa_id_field",
default_score=0,
)
self.assertDictEqual(
- evaluation,
- {
- "recall_1_value": 0.0,
- },
+ evaluation, {"recall_1_value": 0.0,},
)
def test_reciprocal_rank(self):
metric = ReciprocalRank(at=2)
evaluation = metric.evaluate_query(
- query_results=self.query_results,
+ query_results=VespaResult(self.query_results),
relevant_docs=self.labelled_data[0]["relevant_docs"],
id_field="vespa_id_field",
default_score=0,
)
self.assertDictEqual(
- evaluation,
- {
- "reciprocal_rank_2_value": 0.5,
- },
+ evaluation, {"reciprocal_rank_2_value": 0.5,},
)
metric = ReciprocalRank(at=1)
evaluation = metric.evaluate_query(
- query_results=self.query_results,
+ query_results=VespaResult(self.query_results),
relevant_docs=self.labelled_data[0]["relevant_docs"],
id_field="vespa_id_field",
default_score=0,
)
self.assertDictEqual(
- evaluation,
- {
- "reciprocal_rank_1_value": 0.0,
- },
+ evaluation, {"reciprocal_rank_1_value": 0.0,},
)