From 5c3e942429a1333718e0e4b94ae10d066abbd626 Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Sat, 25 Feb 2023 08:38:50 +0000 Subject: pick up declared type of query/attribute features, and handle globalphase not wrapped in rankExpression(). --- .../java/ai/vespa/models/evaluation/Model.java | 8 +++++-- .../evaluation/RankProfilesConfigImporter.java | 25 +++++++++++++++++++++- .../evaluation/RankProfileImportingTest.java | 18 ++++++++++++++++ .../resources/config/dotproduct/onnx-models.cfg | 0 .../resources/config/dotproduct/rank-profiles.cfg | 9 ++++++++ .../config/dotproduct/ranking-constants.cfg | 0 .../config/dotproduct/ranking-expressions.cfg | 0 7 files changed, 57 insertions(+), 3 deletions(-) create mode 100644 model-evaluation/src/test/resources/config/dotproduct/onnx-models.cfg create mode 100644 model-evaluation/src/test/resources/config/dotproduct/rank-profiles.cfg create mode 100644 model-evaluation/src/test/resources/config/dotproduct/ranking-constants.cfg create mode 100644 model-evaluation/src/test/resources/config/dotproduct/ranking-expressions.cfg (limited to 'model-evaluation') diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index c317cdc5922..84c8e2b1e38 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -49,6 +49,7 @@ public class Model implements AutoCloseable { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), Map.of(), + Map.of(), List.of(), List.of()); } @@ -56,6 +57,7 @@ public class Model implements AutoCloseable { Model(String name, Map functions, Map referencedFunctions, + Map declaredTypes, List constants, List onnxModels) { this.name = name; @@ -85,8 +87,10 @@ public class Model implements AutoCloseable { } else { // External functions have type info (when not scalar) - add argument types - if (function.getValue().getArgumentType(argument) == null) - functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + if (function.getValue().getArgumentType(argument) == null) { + TensorType type = declaredTypes.getOrDefault(argument, TensorType.empty); + functions.put(function.getKey(), function.getValue().withArgument(argument, type)); + } } } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 098e6e7a1f6..6148287a536 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -90,11 +90,14 @@ public class RankProfilesConfigImporter { SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo(); ExpressionFunction firstPhase = null; ExpressionFunction secondPhase = null; + ExpressionFunction globalPhase = null; + Map declaredTypes = new LinkedHashMap<>(); for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional reference = FunctionReference.fromSerial(property.name()); Optional externalReference = FunctionReference.fromExternalSerial(property.name()); Optional> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); Optional returnType = FunctionReference.fromReturnTypeSerial(property.name()); + Optional typeDeclaredFeature = fromTypeDeclarationSerial(property.name()); if (externalReference.isPresent()) { RankingExpression expression = largeExpressions.get(property.value()); ExpressionFunction function = new ExpressionFunction(externalReference.get().functionName(), @@ -140,6 +143,13 @@ public class RankProfilesConfigImporter { secondPhase = new ExpressionFunction("secondphase", new ArrayList<>(), new RankingExpression("second-phase", property.value())); } + else if (property.name().equals("vespa.rank.globalphase")) { // Include in addition to functions + globalPhase = new ExpressionFunction("globalphase", new ArrayList<>(), + new RankingExpression("global-phase", property.value())); + } + else if (typeDeclaredFeature.isPresent()) { + declaredTypes.put(typeDeclaredFeature.get(), TensorType.fromSpec(property.value())); + } else { smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value()); } @@ -148,11 +158,13 @@ public class RankProfilesConfigImporter { functions.put(FunctionReference.fromName("firstphase"), firstPhase); if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body functions.put(FunctionReference.fromName("secondphase"), secondPhase); + if (functionByName("globalphase", functions.values()) == null && globalPhase != null) // may be already included, depending on body + functions.put(FunctionReference.fromName("globalphase"), globalPhase); constants.addAll(smallConstantsInfo.asConstants()); try { - return new Model(profile.name(), functions, referencedFunctions, constants, onnxModels); + return new Model(profile.name(), functions, referencedFunctions, declaredTypes, constants, onnxModels); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -299,4 +311,15 @@ public class RankProfilesConfigImporter { } + private static final Pattern typeDeclarationPattern = + Pattern.compile("vespa[.]type[.]([a-zA-Z0-9]+)[.](.+)"); + + static Optional fromTypeDeclarationSerial(String serialForm) { + Matcher expressionMatcher = typeDeclarationPattern.matcher(serialForm); + if ( ! expressionMatcher.matches()) return Optional.empty(); + String name = expressionMatcher.group(1); + String argument = expressionMatcher.group(2); + return Optional.of(name + "(" + argument + ")"); + } + } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java index 3fdbb370a5c..1a6f6925caf 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java @@ -31,4 +31,22 @@ public class RankProfileImportingTest { "4 * (match + rankBoost)", macros); } + @Test + public void testImportingSimpleGlobalPhase() { + ModelTester tester = new ModelTester("src/test/resources/config/dotproduct/"); + assertEquals(1, tester.models().size()); + Model m = tester.models().get("default"); + assertEquals("default", m.name()); + assertEquals(1, m.functions().size()); + tester.assertFunction("globalphase", "reduce(attribute(aa) * query(zz), sum)", m); + var f = m.functions().get(0); + assertEquals("globalphase", f.getName()); + assertEquals(2, f.arguments().size()); + assertEquals("tensor(d0[3])", f.getArgumentType("query(zz)").toString()); + assertEquals("tensor(d0[3])", f.getArgumentType("attribute(aa)").toString()); + var rt = f.returnType(); + assertEquals(true, rt.isPresent()); + assertEquals("tensor()", rt.get().toString()); + } + } diff --git a/model-evaluation/src/test/resources/config/dotproduct/onnx-models.cfg b/model-evaluation/src/test/resources/config/dotproduct/onnx-models.cfg new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model-evaluation/src/test/resources/config/dotproduct/rank-profiles.cfg b/model-evaluation/src/test/resources/config/dotproduct/rank-profiles.cfg new file mode 100644 index 00000000000..ae1e6791f3e --- /dev/null +++ b/model-evaluation/src/test/resources/config/dotproduct/rank-profiles.cfg @@ -0,0 +1,9 @@ +rankprofile[0].name "default" +rankprofile[0].fef.property[0].name "vespa.rank.globalphase" +rankprofile[0].fef.property[0].value "sum(attribute(aa) * query(zz))" +rankprofile[0].fef.property[1].name "vespa.match.feature" +rankprofile[0].fef.property[1].value "attribute(aa)" +rankprofile[0].fef.property[2].name "vespa.type.attribute.aa" +rankprofile[0].fef.property[2].value "tensor(d0[3])" +rankprofile[0].fef.property[3].name "vespa.type.query.zz" +rankprofile[0].fef.property[3].value "tensor(d0[3])" diff --git a/model-evaluation/src/test/resources/config/dotproduct/ranking-constants.cfg b/model-evaluation/src/test/resources/config/dotproduct/ranking-constants.cfg new file mode 100644 index 00000000000..e69de29bb2d diff --git a/model-evaluation/src/test/resources/config/dotproduct/ranking-expressions.cfg b/model-evaluation/src/test/resources/config/dotproduct/ranking-expressions.cfg new file mode 100644 index 00000000000..e69de29bb2d -- cgit v1.2.3