aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-02-25 08:38:50 +0000
committerArne Juul <arnej@yahooinc.com>2023-02-26 16:02:10 +0000
commit5c3e942429a1333718e0e4b94ae10d066abbd626 (patch)
treea90cbd24b1c4460a8bba88b44006521c1f0dbb4d /model-evaluation
parenta8659be0f80b80e875222baa259f938bde151023 (diff)
pick up declared type of query/attribute features,
and handle globalphase not wrapped in rankExpression().
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java8
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java25
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java18
-rw-r--r--model-evaluation/src/test/resources/config/dotproduct/onnx-models.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/dotproduct/rank-profiles.cfg9
-rw-r--r--model-evaluation/src/test/resources/config/dotproduct/ranking-constants.cfg0
-rw-r--r--model-evaluation/src/test/resources/config/dotproduct/ranking-expressions.cfg0
7 files changed, 57 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index c317cdc5922..84c8e2b1e38 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -49,6 +49,7 @@ public class Model implements AutoCloseable {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
Map.of(),
+ Map.of(),
List.of(),
List.of());
}
@@ -56,6 +57,7 @@ public class Model implements AutoCloseable {
Model(String name,
Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
+ Map<String, TensorType> declaredTypes,
List<Constant> constants,
List<OnnxModel> onnxModels) {
this.name = name;
@@ -85,8 +87,10 @@ public class Model implements AutoCloseable {
}
else {
// External functions have type info (when not scalar) - add argument types
- if (function.getValue().getArgumentType(argument) == null)
- functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ if (function.getValue().getArgumentType(argument) == null) {
+ TensorType type = declaredTypes.getOrDefault(argument, TensorType.empty);
+ functions.put(function.getKey(), function.getValue().withArgument(argument, type));
+ }
}
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 098e6e7a1f6..6148287a536 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -90,11 +90,14 @@ public class RankProfilesConfigImporter {
SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
ExpressionFunction firstPhase = null;
ExpressionFunction secondPhase = null;
+ ExpressionFunction globalPhase = null;
+ Map<String, TensorType> declaredTypes = new LinkedHashMap<>();
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
Optional<FunctionReference> externalReference = FunctionReference.fromExternalSerial(property.name());
Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
+ Optional<String> typeDeclaredFeature = fromTypeDeclarationSerial(property.name());
if (externalReference.isPresent()) {
RankingExpression expression = largeExpressions.get(property.value());
ExpressionFunction function = new ExpressionFunction(externalReference.get().functionName(),
@@ -140,6 +143,13 @@ public class RankProfilesConfigImporter {
secondPhase = new ExpressionFunction("secondphase", new ArrayList<>(),
new RankingExpression("second-phase", property.value()));
}
+ else if (property.name().equals("vespa.rank.globalphase")) { // Include in addition to functions
+ globalPhase = new ExpressionFunction("globalphase", new ArrayList<>(),
+ new RankingExpression("global-phase", property.value()));
+ }
+ else if (typeDeclaredFeature.isPresent()) {
+ declaredTypes.put(typeDeclaredFeature.get(), TensorType.fromSpec(property.value()));
+ }
else {
smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value());
}
@@ -148,11 +158,13 @@ public class RankProfilesConfigImporter {
functions.put(FunctionReference.fromName("firstphase"), firstPhase);
if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body
functions.put(FunctionReference.fromName("secondphase"), secondPhase);
+ if (functionByName("globalphase", functions.values()) == null && globalPhase != null) // may be already included, depending on body
+ functions.put(FunctionReference.fromName("globalphase"), globalPhase);
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels);
+ return new Model(profile.name(), functions, referencedFunctions, declaredTypes, constants, onnxModels);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
@@ -299,4 +311,15 @@ public class RankProfilesConfigImporter {
}
+ private static final Pattern typeDeclarationPattern =
+ Pattern.compile("vespa[.]type[.]([a-zA-Z0-9]+)[.](.+)");
+
+ static Optional<String> fromTypeDeclarationSerial(String serialForm) {
+ Matcher expressionMatcher = typeDeclarationPattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+ String name = expressionMatcher.group(1);
+ String argument = expressionMatcher.group(2);
+ return Optional.of(name + "(" + argument + ")");
+ }
+
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
index 3fdbb370a5c..1a6f6925caf 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
@@ -31,4 +31,22 @@ public class RankProfileImportingTest {
"4 * (match + rankBoost)", macros);
}
+ @Test
+ public void testImportingSimpleGlobalPhase() {
+ ModelTester tester = new ModelTester("src/test/resources/config/dotproduct/");
+ assertEquals(1, tester.models().size());
+ Model m = tester.models().get("default");
+ assertEquals("default", m.name());
+ assertEquals(1, m.functions().size());
+ tester.assertFunction("globalphase", "reduce(attribute(aa) * query(zz), sum)", m);
+ var f = m.functions().get(0);
+ assertEquals("globalphase", f.getName());
+ assertEquals(2, f.arguments().size());
+ assertEquals("tensor(d0[3])", f.getArgumentType("query(zz)").toString());
+ assertEquals("tensor(d0[3])", f.getArgumentType("attribute(aa)").toString());
+ var rt = f.returnType();
+ assertEquals(true, rt.isPresent());
+ assertEquals("tensor()", rt.get().toString());
+ }
+
}
diff --git a/model-evaluation/src/test/resources/config/dotproduct/onnx-models.cfg b/model-evaluation/src/test/resources/config/dotproduct/onnx-models.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/dotproduct/onnx-models.cfg
diff --git a/model-evaluation/src/test/resources/config/dotproduct/rank-profiles.cfg b/model-evaluation/src/test/resources/config/dotproduct/rank-profiles.cfg
new file mode 100644
index 00000000000..ae1e6791f3e
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/dotproduct/rank-profiles.cfg
@@ -0,0 +1,9 @@
+rankprofile[0].name "default"
+rankprofile[0].fef.property[0].name "vespa.rank.globalphase"
+rankprofile[0].fef.property[0].value "sum(attribute(aa) * query(zz))"
+rankprofile[0].fef.property[1].name "vespa.match.feature"
+rankprofile[0].fef.property[1].value "attribute(aa)"
+rankprofile[0].fef.property[2].name "vespa.type.attribute.aa"
+rankprofile[0].fef.property[2].value "tensor(d0[3])"
+rankprofile[0].fef.property[3].name "vespa.type.query.zz"
+rankprofile[0].fef.property[3].value "tensor(d0[3])"
diff --git a/model-evaluation/src/test/resources/config/dotproduct/ranking-constants.cfg b/model-evaluation/src/test/resources/config/dotproduct/ranking-constants.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/dotproduct/ranking-constants.cfg
diff --git a/model-evaluation/src/test/resources/config/dotproduct/ranking-expressions.cfg b/model-evaluation/src/test/resources/config/dotproduct/ranking-expressions.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/dotproduct/ranking-expressions.cfg