diff options
author | Arne H Juul <arnej@yahooinc.com> | 2021-10-28 14:48:25 +0000 |
---|---|---|
committer | Arne H Juul <arnej@yahooinc.com> | 2021-10-28 14:48:25 +0000 |
commit | 8f489ecf4630406f58492a74f99b5fb1285c1884 (patch) | |
tree | 82fe6984a407f1b96a123efed68fceaa9d2b63cb /container-search | |
parent | 4e06f925e20788427e10431e03be5be522a4285f (diff) |
add skeleton for match features
Diffstat (limited to 'container-search')
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java | 10 | ||||
-rw-r--r-- | container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java | 20 |
2 files changed, 30 insertions, 0 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java b/container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java index ae9d62fca41..b948eab266a 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/LeanHit.java @@ -1,8 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.dispatch; +import com.yahoo.tensor.Tensor; import java.util.Arrays; +import java.util.Map; +import java.util.HashMap; /** * @author baldersheim @@ -14,6 +17,7 @@ public class LeanHit implements Comparable<LeanHit> { private final byte [] sortData; private final int partId; private final int distributionKey; + private final Map<String, Tensor> matchFeatures; public LeanHit(byte [] gid, int partId, int distributionKey, double relevance) { this(gid, partId, distributionKey, relevance, null); @@ -24,6 +28,7 @@ public class LeanHit implements Comparable<LeanHit> { this.sortData = sortData; this.partId = partId; this.distributionKey = distributionKey; + this.matchFeatures = new HashMap<>(); } public double getRelevance() { return relevance; } @@ -32,6 +37,11 @@ public class LeanHit implements Comparable<LeanHit> { public boolean hasSortData() { return sortData != null; } public int getPartId() { return partId; } public int getDistributionKey() { return distributionKey; } + public final Map<String, Tensor> getMatchFeatures() { return matchFeatures; } + + public void addMatchFeature(String name, Tensor value) { + matchFeatures.put(name, value); + } @Override public int compareTo(LeanHit o) { diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java index a57fa6164d7..98172d4a96e 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java @@ -30,9 +30,13 @@ import com.yahoo.searchlib.aggregation.Grouping; import com.yahoo.slime.BinaryFormat; import com.yahoo.vespa.objects.BufferSerializer; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import com.yahoo.tensor.Tensor; + import java.nio.ByteBuffer; import java.util.ArrayList; import java.util.List; +import java.util.Optional; import java.util.function.Consumer; public class ProtobufSerialization { @@ -204,6 +208,7 @@ public class ProtobufSerialization { result.getResult().setCoverage(convertToCoverage(protobuf)); convertSearchReplyErrors(result.getResult(), protobuf.getErrorsList()); + List<String> featureNames = protobuf.getMatchFeatureNamesList(); var haveGrouping = ! protobuf.getGroupingBlob().isEmpty(); if (haveGrouping) { @@ -224,6 +229,21 @@ public class ProtobufSerialization { LeanHit hit = (replyHit.getSortData().isEmpty()) ? new LeanHit(replyHit.getGlobalId().toByteArray(), partId, distKey, replyHit.getRelevance()) : new LeanHit(replyHit.getGlobalId().toByteArray(), partId, distKey, replyHit.getRelevance(), replyHit.getSortData().toByteArray()); + if (! featureNames.isEmpty()) { + List<SearchProtocol.Feature> featureValues = replyHit.getFeaturesList(); + var nameIter = featureNames.iterator(); + var valueIter = featureValues.iterator(); + while (nameIter.hasNext() && valueIter.hasNext()) { + String name = nameIter.next(); + SearchProtocol.Feature value = valueIter.next(); + ByteString tensorBlob = value.getTensorBlob(); + Tensor tensor = tensorBlob.isEmpty() + ? Tensor.from(value.getValue()) + : TypedBinaryFormat.decode(Optional.empty(), + GrowableByteBuffer.wrap(tensorBlob.toByteArray())); + hit.addMatchFeature(name, tensor); + } + } result.getLeanHits().add(hit); } |