diff options
author | Arne Juul <arnej@vespa.ai> | 2024-02-21 14:31:27 +0000 |
---|---|---|
committer | Arne Juul <arnej@vespa.ai> | 2024-02-22 10:36:19 +0000 |
commit | 97eff1f15dcafe4a8d0229fa36b3e0448de2f4d6 (patch) | |
tree | 7bfa620f51263b3e2c6cc1bd319b922a0cb33a09 /container-search/src | |
parent | 4f0947b6617f5c9ff0a133a1f12bb4a5b7d57bb6 (diff) |
allow inputs { query(foo) string }
Diffstat (limited to 'container-search/src')
6 files changed, 56 insertions, 31 deletions
diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java index 82e1409c4eb..ea2120dae3f 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/RankProfileInputProperties.java @@ -6,6 +6,8 @@ import com.yahoo.language.process.Embedder; import com.yahoo.processing.IllegalInputException; import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; +import com.yahoo.search.schema.RankProfile; +import com.yahoo.search.schema.RankProfile.InputType; import com.yahoo.search.schema.SchemaInfo; import com.yahoo.search.schema.internal.TensorConverter; import com.yahoo.search.query.Properties; @@ -39,9 +41,16 @@ public class RankProfileInputProperties extends Properties { public void set(CompoundName name, Object value, Map<String, String> context) { if (RankFeatures.isFeatureName(name.toString())) { try { - TensorType expectedType = typeOf(name); - if (expectedType != null) { - value = tensorConverter.convertTo(expectedType, + var expectedType = typeOf(name); + System.err.println("setting rank feature '" + name + "' -> " + value + " :: " + value.getClass()); + if (expectedType != null && expectedType.declaredString()) { + System.err.println("expected type: declared string"); + var e = new IllegalArgumentException("foo"); + e.fillInStackTrace(); + e.printStackTrace(); + } + if (expectedType != null && ! expectedType.declaredString()) { + value = tensorConverter.convertTo(expectedType.tensorType(), name.last(), value, query.getModel().getLanguage(), @@ -59,14 +68,14 @@ public class RankProfileInputProperties extends Properties { @Override public void requireSettable(CompoundName name, Object value, Map<String, String> context) { if (RankFeatures.isFeatureName(name.toString())) { - TensorType expectedType = typeOf(name); - if (expectedType != null) - verifyType(name, value, expectedType); + var expectedType = typeOf(name); + if (expectedType != null && ! expectedType.declaredString()) + verifyType(name, value, expectedType.tensorType()); } super.requireSettable(name, value, context); } - private TensorType typeOf(CompoundName name) { + private RankProfile.InputType typeOf(CompoundName name) { // Session is lazily resolved because order matters: // model.sources+restrict must be set in the query before this is done if (session == null) diff --git a/container-search/src/main/java/com/yahoo/search/schema/RankProfile.java b/container-search/src/main/java/com/yahoo/search/schema/RankProfile.java index d681062451c..a5b8d328a7a 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/RankProfile.java +++ b/container-search/src/main/java/com/yahoo/search/schema/RankProfile.java @@ -18,10 +18,25 @@ import java.util.Objects; @Beta public class RankProfile { + public record InputType(TensorType tensorType, boolean declaredString) { + public String toString() { + return declaredString ? "string" : tensorType.toString(); + } + public static InputType fromSpec(String spec) { + if ("string".equals(spec)) { + return new InputType(TensorType.empty, true); + } + if ("double".equals(spec)) { + return new InputType(TensorType.empty, false); + } + return new InputType(TensorType.fromSpec(spec), false); + } + } + private final String name; private final boolean hasSummaryFeatures; private final boolean hasRankFeatures; - private final Map<String, TensorType> inputs; + private final Map<String, InputType> inputs; // Assigned when this is added to a schema private Schema schema = null; @@ -52,7 +67,7 @@ public class RankProfile { public boolean hasRankFeatures() { return hasRankFeatures; } /** Returns the inputs explicitly declared in this rank profile. */ - public Map<String, TensorType> inputs() { return inputs; } + public Map<String, InputType> inputs() { return inputs; } @Override public boolean equals(Object o) { @@ -80,7 +95,7 @@ public class RankProfile { private final String name; private boolean hasSummaryFeatures = true; private boolean hasRankFeatures = true; - private final Map<String, TensorType> inputs = new LinkedHashMap<>(); + private final Map<String, InputType> inputs = new LinkedHashMap<>(); public Builder(String name) { this.name = Objects.requireNonNull(name); @@ -96,7 +111,7 @@ public class RankProfile { return this; } - public Builder addInput(String name, TensorType type) { + public Builder addInput(String name, InputType type) { inputs.put(name, type); return this; } diff --git a/container-search/src/main/java/com/yahoo/search/schema/SchemaInfo.java b/container-search/src/main/java/com/yahoo/search/schema/SchemaInfo.java index 263fa4058c7..df7a5cab3c1 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/SchemaInfo.java +++ b/container-search/src/main/java/com/yahoo/search/schema/SchemaInfo.java @@ -6,7 +6,6 @@ import com.yahoo.component.annotation.Inject; import com.yahoo.container.QrSearchersConfig; import com.yahoo.search.Query; import com.yahoo.search.config.SchemaInfoConfig; -import com.yahoo.tensor.TensorType; import java.util.Collection; import java.util.Collections; @@ -184,16 +183,16 @@ public class SchemaInfo { * feature is declared in this rank profile in multiple schemas * of this session with conflicting types */ - public TensorType rankProfileInput(String rankFeature, String rankProfile) { + public RankProfile.InputType rankProfileInput(String rankFeature, String rankProfile) { if (schemas.isEmpty()) return null; // no matching schemas - validated elsewhere List<RankProfile> profiles = profilesNamed(rankProfile); if (profiles.isEmpty()) throw new IllegalArgumentException("No profile named '" + rankProfile + "' exists in schemas [" + schemas.stream().map(Schema::name).collect(Collectors.joining(", ")) + "]"); - TensorType foundType = null; + RankProfile.InputType foundType = null; RankProfile declaringProfile = null; for (RankProfile profile : profiles) { - TensorType newlyFoundType = profile.inputs().get(rankFeature); + RankProfile.InputType newlyFoundType = profile.inputs().get(rankFeature); if (newlyFoundType == null) continue; if (foundType != null && ! newlyFoundType.equals(foundType)) throw new IllegalArgumentException("Conflicting input type declarations for '" + rankFeature + "': " + diff --git a/container-search/src/main/java/com/yahoo/search/schema/SchemaInfoConfigurer.java b/container-search/src/main/java/com/yahoo/search/schema/SchemaInfoConfigurer.java index b70f5145e56..b576260b85d 100644 --- a/container-search/src/main/java/com/yahoo/search/schema/SchemaInfoConfigurer.java +++ b/container-search/src/main/java/com/yahoo/search/schema/SchemaInfoConfigurer.java @@ -27,7 +27,7 @@ class SchemaInfoConfigurer { profileBuilder.setHasSummaryFeatures(profileConfig.hasSummaryFeatures()); profileBuilder.setHasRankFeatures(profileConfig.hasRankFeatures()); for (var inputConfig : profileConfig.input()) - profileBuilder.addInput(inputConfig.name(), TensorType.fromSpec(inputConfig.type())); + profileBuilder.addInput(inputConfig.name(), RankProfile.InputType.fromSpec(inputConfig.type())); builder.add(profileBuilder.build()); } diff --git a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java index 2feac135b03..3cdfc95ea0e 100644 --- a/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java +++ b/container-search/src/test/java/com/yahoo/search/query/RankProfileInputTest.java @@ -7,6 +7,7 @@ import com.yahoo.language.process.Embedder; import com.yahoo.search.Query; import com.yahoo.search.schema.Cluster; import com.yahoo.search.schema.RankProfile; +import com.yahoo.search.schema.RankProfile.InputType; import com.yahoo.search.schema.Schema; import com.yahoo.search.schema.SchemaInfo; import com.yahoo.search.query.profile.QueryProfile; @@ -298,23 +299,23 @@ public class RankProfileInputTest { private SchemaInfo createSchemaInfo() { List<Schema> schemas = new ArrayList<>(); RankProfile.Builder common = new RankProfile.Builder("commonProfile") - .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) - .addInput("query(myTensor2)", TensorType.fromSpec("tensor(x[2],y[2])")) - .addInput("query(myTensor3)", TensorType.fromSpec("tensor(x[2],y[2])")) - .addInput("query(myTensor4)", TensorType.fromSpec("tensor<float>(x[5])")); + .addInput("query(myTensor1)", InputType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor2)", InputType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor3)", InputType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor4)", InputType.fromSpec("tensor<float>(x[5])")); schemas.add(new Schema.Builder("a") .add(common.build()) .add(new RankProfile.Builder("inconsistent") - .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor1)", InputType.fromSpec("tensor(a{},b{})")) .build()) .build()); schemas.add(new Schema.Builder("b") .add(common.build()) .add(new RankProfile.Builder("inconsistent") - .addInput("query(myTensor1)", TensorType.fromSpec("tensor(x[10])")) + .addInput("query(myTensor1)", InputType.fromSpec("tensor(x[10])")) .build()) .add(new RankProfile.Builder("bOnly") - .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor1)", InputType.fromSpec("tensor(a{},b{})")) .build()) .build()); List<Cluster> clusters = new ArrayList<>(); diff --git a/container-search/src/test/java/com/yahoo/search/schema/SchemaInfoTester.java b/container-search/src/test/java/com/yahoo/search/schema/SchemaInfoTester.java index c27fe9ff15d..553f71d91b2 100644 --- a/container-search/src/test/java/com/yahoo/search/schema/SchemaInfoTester.java +++ b/container-search/src/test/java/com/yahoo/search/schema/SchemaInfoTester.java @@ -6,6 +6,7 @@ import com.yahoo.search.Query; import com.yahoo.search.config.IndexInfoConfig; import com.yahoo.search.config.SchemaInfoConfig; import com.yahoo.search.schema.RankProfile; +import com.yahoo.search.schema.RankProfile.InputType; import com.yahoo.search.schema.Schema; import com.yahoo.search.schema.SchemaInfo; import com.yahoo.tensor.TensorType; @@ -43,7 +44,7 @@ public class SchemaInfoTester { void assertInput(TensorType expectedType, String sources, String restrict, String rankProfile, String feature) { assertEquals(expectedType, - schemaInfo.newSession(query(sources, restrict)).rankProfileInput(feature, rankProfile)); + schemaInfo.newSession(query(sources, restrict)).rankProfileInput(feature, rankProfile).tensorType()); } void assertInputConflict(TensorType expectedType, String sources, String restrict, String rankProfile, String feature) { @@ -59,14 +60,14 @@ public class SchemaInfoTester { static SchemaInfo createSchemaInfo() { List<Schema> schemas = new ArrayList<>(); RankProfile.Builder common = new RankProfile.Builder("commonProfile") - .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) - .addInput("query(myTensor2)", TensorType.fromSpec("tensor(x[2],y[2])")) - .addInput("query(myTensor3)", TensorType.fromSpec("tensor(x[2],y[2])")) - .addInput("query(myTensor4)", TensorType.fromSpec("tensor<float>(x[5])")); + .addInput("query(myTensor1)", InputType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor2)", InputType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor3)", InputType.fromSpec("tensor(x[2],y[2])")) + .addInput("query(myTensor4)", InputType.fromSpec("tensor<float>(x[5])")); schemas.add(new Schema.Builder("a") .add(common.build()) .add(new RankProfile.Builder("inconsistent") - .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor1)", InputType.fromSpec("tensor(a{},b{})")) .build()) .add(new DocumentSummary.Builder("testSummary") .add(new DocumentSummary.Field("field1", "string")) @@ -77,10 +78,10 @@ public class SchemaInfoTester { schemas.add(new Schema.Builder("b") .add(common.build()) .add(new RankProfile.Builder("inconsistent") - .addInput("query(myTensor1)", TensorType.fromSpec("tensor(x[10])")) + .addInput("query(myTensor1)", InputType.fromSpec("tensor(x[10])")) .build()) .add(new RankProfile.Builder("bOnly") - .addInput("query(myTensor1)", TensorType.fromSpec("tensor(a{},b{})")) + .addInput("query(myTensor1)", InputType.fromSpec("tensor(a{},b{})")) .build()) .build()); List<Cluster> clusters = new ArrayList<>(); |