From 27c83bea57f67d03042a60666f1ce7bcb6c04fe7 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Wed, 18 Oct 2023 07:22:03 +0000 Subject: add getAsTensor() API in RankProperties --- container-search/abi-spec.json | 1 + .../yahoo/search/query/ranking/RankProperties.java | 22 ++++++++++ .../com/yahoo/search/ranking/PreparedInput.java | 25 +++++------ .../query/ranking/RankPropertiesTestCase.java | 49 ++++++++++++++++++++++ 4 files changed, 82 insertions(+), 15 deletions(-) create mode 100644 container-search/src/test/java/com/yahoo/search/query/ranking/RankPropertiesTestCase.java diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index db5c52267aa..31b4dd2c920 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -7076,6 +7076,7 @@ "public void put(java.lang.String, java.lang.String)", "public void put(java.lang.String, java.lang.Object)", "public java.util.List get(java.lang.String)", + "public java.util.Optional getAsTensor(java.lang.String)", "public void remove(java.lang.String)", "public boolean isEmpty()", "public java.util.Map asMap()", diff --git a/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java b/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java index 544f26a7d89..4ac5375807b 100644 --- a/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/ranking/RankProperties.java @@ -3,6 +3,7 @@ package com.yahoo.search.query.ranking; import com.yahoo.fs4.GetDocSumsPacket; import com.yahoo.fs4.MapEncoder; +import com.yahoo.tensor.Tensor; import com.yahoo.text.JSON; import java.nio.ByteBuffer; @@ -11,6 +12,7 @@ import java.util.Collections; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; /** * Contains the properties of a query. @@ -61,6 +63,26 @@ public class RankProperties implements Cloneable { return Collections.unmodifiableList(stringValues); } + /** + * Returns a tensor (as moved from RankFeatures by prepare step) if present + * + * @throws IllegalArgumentException if the value is there but wrong type + */ + public Optional getAsTensor(String name) { + List values = properties.get(name); + if (values == null || values.isEmpty()) return Optional.empty(); + if (values.size() != 1) { + throw new IllegalArgumentException("unexpected multiple [" + values.size() + "] values for property '" + name + "'"); + } + Object feature = values.get(0); + if (feature == null) return Optional.empty(); + if (feature instanceof Tensor t) return Optional.of(t); + if (feature instanceof Double d) return Optional.of(Tensor.from(d)); + throw new IllegalArgumentException("Expected '" + name + "' to be a tensor or double, but it is '" + feature + + "', this usually means that '" + name + "' is not defined in the schema. " + + "See https://docs.vespa.ai/en/tensor-user-guide.html#querying-with-tensors"); + } + /** Removes all properties for a given name */ public void remove(String name) { properties.remove(name); diff --git a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java index 346acccd916..5491724cc08 100644 --- a/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java +++ b/container-search/src/main/java/com/yahoo/search/ranking/PreparedInput.java @@ -26,24 +26,19 @@ record PreparedInput(String name, Tensor value) { List result = new ArrayList<>(); var ranking = query.getRanking(); var rankFeatures = ranking.getFeatures(); - var rankProps = ranking.getProperties().asMap(); + var rankProps = ranking.getProperties(); for (String queryFeatureName : queryFeatures) { String needed = "query(" + queryFeatureName + ")"; - // searchers are recommended to place query features here: - var feature = rankFeatures.getTensor(needed); - if (feature.isPresent()) { - result.add(new PreparedInput(needed, feature.get())); - } else { - // but other ways of setting query features end up in the properties: - var objList = rankProps.get(queryFeatureName); - if (objList != null && objList.size() == 1 && objList.get(0) instanceof Tensor t) { - result.add(new PreparedInput(needed, t)); - } else if (objList != null && objList.size() == 1 && objList.get(0) instanceof Double d) { - result.add(new PreparedInput(needed, Tensor.from(d))); - } else { - throw new IllegalArgumentException("missing query feature: " + queryFeatureName); - } + // after prepare() the query tensor ends up here: + var feature = rankProps.getAsTensor(queryFeatureName); + if (feature.isEmpty()) { + // searchers are recommended to place query features here: + feature = rankFeatures.getTensor(needed); } + if (feature.isEmpty()) { + throw new IllegalArgumentException("missing query feature: " + queryFeatureName); + } + result.add(new PreparedInput(needed, feature.get())); } return result; } diff --git a/container-search/src/test/java/com/yahoo/search/query/ranking/RankPropertiesTestCase.java b/container-search/src/test/java/com/yahoo/search/query/ranking/RankPropertiesTestCase.java new file mode 100644 index 00000000000..81c657f6323 --- /dev/null +++ b/container-search/src/test/java/com/yahoo/search/query/ranking/RankPropertiesTestCase.java @@ -0,0 +1,49 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.query.ranking; + +import com.yahoo.search.Query; +import com.yahoo.search.query.Ranking; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * @author arnej + */ +public class RankPropertiesTestCase { + + @Test + void requireThatGetAsTensorCanGetDoublesAndTensors() { + TensorType ttype = new TensorType.Builder().mapped("cat").build(); + Tensor mappedTensor = Tensor.from(ttype, "{ {cat:foo}:2.5, {cat:bar}:1.25 }"); + RankFeatures f = new RankFeatures(new Ranking(new Query())); + f.put("query(myDouble)", 42.75); + f.put("query(myTensor)", mappedTensor); + RankProperties p = new RankProperties(); + f.prepare(p); + var optT = p.getAsTensor("myDouble"); + assertEquals(true, optT.isPresent()); + assertEquals(TensorType.empty, optT.get().type()); + assertEquals(42.75, optT.get().asDouble()); + optT = p.getAsTensor("myTensor"); + assertEquals(true, optT.isPresent()); + assertEquals(mappedTensor, optT.get()); + } + + @Test + void requireThatGetAsTensorFailsOnStrings() { + RankFeatures f = new RankFeatures(new Ranking(new Query())); + // common mistake: + f.put("query(myTensor)", "{ {cat:foo}:2.5, {cat:bar}:1.25 }"); + RankProperties p = new RankProperties(); + f.prepare(p); + var ex = assertThrows(IllegalArgumentException.class, () -> p.getAsTensor("myTensor")); + assertEquals("Expected 'myTensor' to be a tensor or double, " + + "but it is '{ {cat:foo}:2.5, {cat:bar}:1.25 }', " + + "this usually means that 'myTensor' is not defined in the schema. " + + "See https://docs.vespa.ai/en/tensor-user-guide.html#querying-with-tensors", ex.getMessage()); + } +} -- cgit v1.2.3