summaryrefslogtreecommitdiffstats
path: root/container-search
diff options
context:
space:
mode:
authorArne H Juul <arnej@yahooinc.com>2021-10-28 14:48:25 +0000
committerArne H Juul <arnej@yahooinc.com>2021-10-28 14:48:25 +0000
commit8f489ecf4630406f58492a74f99b5fb1285c1884 (patch)
tree82fe6984a407f1b96a123efed68fceaa9d2b63cb /container-search
parent4e06f925e20788427e10431e03be5be522a4285f (diff)
add skeleton for match features
Diffstat (limited to 'container-search')
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java10
-rw-r--r--container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java20
2 files changed, 30 insertions, 0 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java b/container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java
index ae9d62fca41..b948eab266a 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java
@@ -1,8 +1,11 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.dispatch;
+import com.yahoo.tensor.Tensor;
import java.util.Arrays;
+import java.util.Map;
+import java.util.HashMap;
/**
* @author baldersheim
@@ -14,6 +17,7 @@ public class LeanHit implements Comparable<LeanHit> {
private final byte [] sortData;
private final int partId;
private final int distributionKey;
+ private final Map<String, Tensor> matchFeatures;
public LeanHit(byte [] gid, int partId, int distributionKey, double relevance) {
this(gid, partId, distributionKey, relevance, null);
@@ -24,6 +28,7 @@ public class LeanHit implements Comparable<LeanHit> {
this.sortData = sortData;
this.partId = partId;
this.distributionKey = distributionKey;
+ this.matchFeatures = new HashMap<>();
}
public double getRelevance() { return relevance; }
@@ -32,6 +37,11 @@ public class LeanHit implements Comparable<LeanHit> {
public boolean hasSortData() { return sortData != null; }
public int getPartId() { return partId; }
public int getDistributionKey() { return distributionKey; }
+ public final Map<String, Tensor> getMatchFeatures() { return matchFeatures; }
+
+ public void addMatchFeature(String name, Tensor value) {
+ matchFeatures.put(name, value);
+ }
@Override
public int compareTo(LeanHit o) {
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java
index a57fa6164d7..98172d4a96e 100644
--- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java
+++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java
@@ -30,9 +30,13 @@ import com.yahoo.searchlib.aggregation.Grouping;
import com.yahoo.slime.BinaryFormat;
import com.yahoo.vespa.objects.BufferSerializer;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
+import com.yahoo.tensor.Tensor;
+
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
+import java.util.Optional;
import java.util.function.Consumer;
public class ProtobufSerialization {
@@ -204,6 +208,7 @@ public class ProtobufSerialization {
result.getResult().setCoverage(convertToCoverage(protobuf));
convertSearchReplyErrors(result.getResult(), protobuf.getErrorsList());
+ List<String> featureNames = protobuf.getMatchFeatureNamesList();
var haveGrouping = ! protobuf.getGroupingBlob().isEmpty();
if (haveGrouping) {
@@ -224,6 +229,21 @@ public class ProtobufSerialization {
LeanHit hit = (replyHit.getSortData().isEmpty())
? new LeanHit(replyHit.getGlobalId().toByteArray(), partId, distKey, replyHit.getRelevance())
: new LeanHit(replyHit.getGlobalId().toByteArray(), partId, distKey, replyHit.getRelevance(), replyHit.getSortData().toByteArray());
+ if (! featureNames.isEmpty()) {
+ List<SearchProtocol.Feature> featureValues = replyHit.getFeaturesList();
+ var nameIter = featureNames.iterator();
+ var valueIter = featureValues.iterator();
+ while (nameIter.hasNext() && valueIter.hasNext()) {
+ String name = nameIter.next();
+ SearchProtocol.Feature value = valueIter.next();
+ ByteString tensorBlob = value.getTensorBlob();
+ Tensor tensor = tensorBlob.isEmpty()
+ ? Tensor.from(value.getValue())
+ : TypedBinaryFormat.decode(Optional.empty(),
+ GrowableByteBuffer.wrap(tensorBlob.toByteArray()));
+ hit.addMatchFeature(name, tensor);
+ }
+ }
result.getLeanHits().add(hit);
}