diff options
Diffstat (limited to 'container-search')
3 files changed, 220 insertions, 17 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/MatchFeatureData.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/MatchFeatureData.java new file mode 100644 index 00000000000..3a7a1796f04 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/MatchFeatureData.java @@ -0,0 +1,98 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.dispatch.rpc; + +import com.yahoo.collections.Hashlet; + +import com.yahoo.data.access.ArrayTraverser; +import com.yahoo.data.access.Inspector; +import com.yahoo.data.access.ObjectTraverser; +import com.yahoo.data.access.Type; +import com.yahoo.data.access.simple.Value; + +import java.util.ArrayList; +import java.util.AbstractMap.SimpleEntry; +import java.util.List; +import java.util.Map; + +/** + * MatchFeatureData helps pack match features for hits into + * inspectable HitValue objects, all sharing the same Hashlet + * for the field names. + * @author arnej + */ +public class MatchFeatureData { + + private final Hashlet<String,Integer> hashlet; + + // package-private: + MatchFeatureData(List<String> keys) { + this.hashlet = new Hashlet<>(); + hashlet.reserve(keys.size()); + int i = 0; + for (String key : keys) { + hashlet.put(key, i++); + } + } + + static class HitValue extends Value { + private final Hashlet<String,Integer> hashlet; + private final byte[][] dataValues; + private final double[] doubleValues; + private int index = 0; + + public Type type() { return Type.OBJECT; } + public boolean valid() { return index == doubleValues.length; } + public int fieldCount() { return hashlet.size(); } + public void traverse(ObjectTraverser ot) { + for (int i = 0; i < hashlet.size(); i++) { + String fn = hashlet.key(i); + int offset = hashlet.value(i); + ot.field(fn, valueAt(offset)); + } + } + public Inspector field(String name) { + int offset = hashlet.getIndexOfKey(name); + if (offset < 0 || ! valid()) { + return invalid(); + } + return valueAt(offset); + } + public Iterable<Map.Entry<String,Inspector>> fields() { + if (! valid()) { return List.of(); } + var list = new ArrayList<Map.Entry<String,Inspector>>(hashlet.size()); + for (int i = 0; i < hashlet.size(); i++) { + String fn = hashlet.key(i); + int offset = hashlet.value(i); + list.add(new SimpleEntry<String,Inspector>(fn, valueAt(offset))); + } + return list; + } + + // use from enclosing class only + HitValue(Hashlet<String,Integer> hashlet) { + this.hashlet = hashlet; + this.dataValues = new byte[hashlet.size()][]; + this.doubleValues = new double[hashlet.size()]; + } + + // package-private: + void add(byte[] data) { + dataValues[index++] = data; + } + void add(double value) { + doubleValues[index++] = value; + } + + private Inspector valueAt(int index) { + if (dataValues[index] != null) { + return new Value.DataValue(dataValues[index]); + } + return new Value.DoubleValue(doubleValues[index]); + } + } + + HitValue addHit() { + return new HitValue(hashlet); + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java index ac41321f639..8f8bfb63447 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/rpc/ProtobufSerialization.java @@ -197,7 +197,7 @@ public class ProtobufSerialization { } static InvokerResult convertToResult(Query query, SearchProtocol.SearchReply protobuf, - DocumentDatabase documentDatabase, int partId, int distKey) + DocumentDatabase documentDatabase, int partId, int distKey) { InvokerResult result = new InvokerResult(query, protobuf.getHitsCount()); @@ -206,7 +206,8 @@ public class ProtobufSerialization { convertSearchReplyErrors(result.getResult(), protobuf.getErrorsList()); List<String> featureNames = protobuf.getMatchFeatureNamesList(); - + var haveMatchFeatures = ! featureNames.isEmpty(); + MatchFeatureData matchFeatures = haveMatchFeatures ? new MatchFeatureData(featureNames) : null; var haveGrouping = ! protobuf.getGroupingBlob().isEmpty(); if (haveGrouping) { BufferSerializer buf = new BufferSerializer(new GrowableByteBuffer(protobuf.getGroupingBlob().asReadOnlyByteBuffer())); @@ -221,27 +222,27 @@ public class ProtobufSerialization { hit.setQuery(query); result.getResult().hits().add(hit); } - + int hitIndex = 0; for (var replyHit : protobuf.getHitsList()) { LeanHit hit = (replyHit.getSortData().isEmpty()) ? new LeanHit(replyHit.getGlobalId().toByteArray(), partId, distKey, replyHit.getRelevance()) : new LeanHit(replyHit.getGlobalId().toByteArray(), partId, distKey, replyHit.getRelevance(), replyHit.getSortData().toByteArray()); - if (! featureNames.isEmpty()) { - List<SearchProtocol.Feature> featureValues = replyHit.getMatchFeaturesList(); - var object = new Value.ObjectValue(); - var nameIter = featureNames.iterator(); - var valueIter = featureValues.iterator(); - while (nameIter.hasNext() && valueIter.hasNext()) { - String name = nameIter.next(); - SearchProtocol.Feature value = valueIter.next(); - ByteString tensorBlob = value.getTensor(); - if (tensorBlob.isEmpty()) { - object.put(name, value.getNumber()); - } else { - object.put(name, new Value.DataValue(tensorBlob.toByteArray())); + if (haveMatchFeatures) { + var hitFeatures = matchFeatures.addHit(); + var featureList = replyHit.getMatchFeaturesList(); + if (featureList.size() == featureNames.size()) { + for (SearchProtocol.Feature value : featureList) { + ByteString tensorBlob = value.getTensor(); + if (tensorBlob.isEmpty()) { + hitFeatures.add(value.getNumber()); + } else { + hitFeatures.add(tensorBlob.toByteArray()); + } } + hit.addMatchFeatures(hitFeatures); + } else { + result.getResult().hits().addError(ErrorMessage.createBackendCommunicationError("mismatch in match feature sizes")); } - hit.addMatchFeatures(object); } result.getLeanHits().add(hit); } diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MatchFeatureDataTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MatchFeatureDataTest.java new file mode 100644 index 00000000000..0ef6fbae062 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/dispatch/rpc/MatchFeatureDataTest.java @@ -0,0 +1,104 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.search.dispatch.rpc; + +import com.yahoo.data.access.ArrayTraverser; +import com.yahoo.data.access.Inspector; +import com.yahoo.data.access.ObjectTraverser; +import com.yahoo.data.access.Type; +import org.junit.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author arnej + */ +public class MatchFeatureDataTest { + + @Test + public void testHitValueAPI() { + List<String> names = List.of("foo", "bar", "baz", "qux", "quux"); + var mf = new MatchFeatureData(names); + var hit = mf.addHit(); + assertEquals(hit.type(), Type.OBJECT); + assertFalse(hit.valid()); + hit.add(1.0); + hit.add(2.0); + byte[] somebytes = { 42, 0, 17 }; + hit.add(somebytes); + hit.add(4.0); + assertFalse(hit.valid()); + hit.add(5.0); + assertTrue(hit.valid()); + assertEquals(0, hit.entryCount()); + assertEquals(5, hit.fieldCount()); + var f0 = hit.field("not"); + assertFalse(f0.valid()); + + var f1 = hit.field("foo"); + assertTrue(f1.valid()); + assertEquals(f1.type(), Type.DOUBLE); + assertEquals(f1.asDouble(), 1.0, 0.0); + + var f2 = hit.field("bar"); + assertTrue(f2.valid()); + assertEquals(f2.type(), Type.DOUBLE); + assertEquals(f2.asDouble(), 2.0, 0.0); + + var f3 = hit.field("baz"); + assertTrue(f3.valid()); + assertEquals(f3.type(), Type.DATA); + var gotbytes = f3.asData(); + assertEquals(3, gotbytes.length); + assertEquals(42, gotbytes[0]); + assertEquals(0, gotbytes[1]); + assertEquals(17, gotbytes[2]); + + var f5 = hit.field("quux"); + assertTrue(f5.valid()); + assertEquals(f5.type(), Type.DOUBLE); + assertEquals(f5.asDouble(), 5.0, 0.0); + + var fields = hit.fields().iterator(); + assertTrue(fields.hasNext()); + Map.Entry<String,Inspector> entry = fields.next(); + assertEquals("foo", entry.getKey()); + assertEquals(f1.type(), entry.getValue().type()); + assertEquals(f1.asDouble(), entry.getValue().asDouble(), 0.0); + + assertTrue(fields.hasNext()); + entry = fields.next(); + assertEquals("bar", entry.getKey()); + + assertTrue(fields.hasNext()); + entry = fields.next(); + assertEquals("baz", entry.getKey()); + assertEquals(f3.type(), entry.getValue().type()); + assertEquals(f3.asData(), entry.getValue().asData()); + + assertTrue(fields.hasNext()); + entry = fields.next(); + assertEquals("qux", entry.getKey()); + var f4 = entry.getValue(); + assertTrue(f4.valid()); + assertEquals(f4.type(), Type.DOUBLE); + assertEquals(f4.asDouble(), 4.0, 0.0); + + assertTrue(fields.hasNext()); + entry = fields.next(); + assertEquals("quux", entry.getKey()); + assertEquals(f5.type(), entry.getValue().type()); + assertEquals(f5.asDouble(), entry.getValue().asDouble(), 0.0); + + assertFalse(fields.hasNext()); + + assertEquals("{\"foo\":1.0,\"bar\":2.0,\"baz\":\"0x2A0011\",\"qux\":4.0,\"quux\":5.0}", + hit.toString()); + } + +} |