summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-04-23 13:03:19 +0200
committerGitHub <noreply@github.com>2021-04-23 13:03:19 +0200
commit69d032ee48c4c28fb874020220990392903480d0 (patch)
treeb81c5b646134122b4030f4d76af06a0ca2e92f90 /config-model
parentdbf15114b4505e0d4ebe6ad5263685d64619f0b8 (diff)
parent0f20f60145524b13b11453fa0c92f33be0732707 (diff)
Merge pull request #17560 from vespa-engine/arnej/add-input-params-in-rank-profile
Arnej/add input params in rank profile
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java24
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java42
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java10
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java1
4 files changed, 63 insertions, 14 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 8bef4c39ba1..b460752d7bd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -103,6 +103,8 @@ public class RankProfile implements Cloneable {
private Map<String, RankingExpressionFunction> functions = new LinkedHashMap<>();
+ private Map<Reference, TensorType> inputFeatures = new LinkedHashMap<>();
+
private Set<String> filterFields = new HashSet<>();
private final RankProfileRegistry rankProfileRegistry;
@@ -578,6 +580,23 @@ public class RankProfile implements Cloneable {
return rankingExpressionFunction;
}
+ /**
+ * Use for rank profiles representing a model evaluation; it will assume
+ * that a input is provided with the declared type (for the purpose of
+ * type resolving).
+ **/
+ public void addInputFeature(String name, TensorType declaredType) {
+ Reference ref = Reference.fromIdentifier(name);
+ if (inputFeatures.containsKey(ref)) {
+ TensorType hadType = inputFeatures.get(ref);
+ if (! declaredType.equals(hadType)) {
+ throw new IllegalArgumentException("Tried to replace input feature "+name+" with different type: "+
+ hadType+" -> "+declaredType);
+ }
+ }
+ inputFeatures.put(ref, declaredType);
+ }
+
public RankingExpressionFunction findFunction(String name) {
RankingExpressionFunction function = functions.get(name);
return ((function == null) && (getInherited() != null))
@@ -677,6 +696,7 @@ public class RankProfile implements Cloneable {
clone.summaryFeatures = summaryFeatures != null ? new LinkedHashSet<>(this.summaryFeatures) : null;
clone.rankFeatures = rankFeatures != null ? new LinkedHashSet<>(this.rankFeatures) : null;
clone.rankProperties = new LinkedHashMap<>(this.rankProperties);
+ clone.inputFeatures = new LinkedHashMap<>(this.inputFeatures);
clone.functions = new LinkedHashMap<>(this.functions);
clone.filterFields = new HashSet<>(this.filterFields);
clone.constants = new HashMap<>(this.constants);
@@ -790,8 +810,12 @@ public class RankProfile implements Cloneable {
return typeContext(queryProfiles, collectFeatureTypes());
}
+ public MapEvaluationTypeContext typeContext() { return typeContext(new QueryProfileRegistry()); }
+
private Map<Reference, TensorType> collectFeatureTypes() {
Map<Reference, TensorType> featureTypes = new HashMap<>();
+ // Add input features
+ inputFeatures.forEach((k, v) -> featureTypes.put(k, v));
// Add attributes
allFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes));
allImportedFields().forEach(field -> addAttributeFeatureTypes(field, featureTypes));
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index b757259102b..9086ca9f40e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -210,6 +210,9 @@ public class ConvertedModel {
Map<String, ExpressionFunction> expressions = new HashMap<>();
for (ImportedMlFunction outputFunction : model.outputExpressions()) {
ExpressionFunction expression = asExpressionFunction(outputFunction);
+ for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) {
+ profile.addInputFeature(input.getKey(), input.getValue());
+ }
addExpression(expression, expression.getName(),
constantsReplacedByFunctions,
model, store, profile, queryProfiles,
@@ -251,13 +254,20 @@ public class ConvertedModel {
QueryProfileRegistry queryProfiles,
Map<String, ExpressionFunction> expressions) {
expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions));
+ if (expression.returnType().isEmpty()) {
+ TensorType type = expression.getBody().type(profile.typeContext(queryProfiles));
+ if (type != null) {
+ expression = expression.withReturnType(type);
+ }
+ }
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) {
- for (Pair<String, Tensor> constant : store.readSmallConstants())
+ for (Pair<String, Tensor> constant : store.readSmallConstants()) {
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
+ }
for (RankingConstant constant : store.readLargeConstants()) {
if ( ! profile.rankingConstants().asMap().containsKey(constant.getName())) {
@@ -269,7 +279,20 @@ public class ConvertedModel {
addGeneratedFunctionToProfile(profile, function.getFirst(), function.getSecond());
}
- return store.readExpressions();
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
+ for (Pair<String, ExpressionFunction> output : store.readExpressions()) {
+ String name = output.getFirst();
+ ExpressionFunction expression = output.getSecond();
+ for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) {
+ profile.addInputFeature(input.getKey(), input.getValue());
+ }
+ TensorType type = expression.getBody().type(profile.typeContext());
+ if (type != null) {
+ expression = expression.withReturnType(type);
+ }
+ expressions.put(name, expression);
+ }
+ return expressions;
}
private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName,
@@ -321,8 +344,9 @@ public class ConvertedModel {
"\nwant to add " + expression + "\n");
return;
}
- var fun = new ExpressionFunction(functionName, expression);
- profile.addFunction(fun, false); // TODO: Inline if only used once
+ ExpressionFunction function = new ExpressionFunction(functionName, expression);
+ // XXX should we resolve type here?
+ profile.addFunction(function, false); // TODO: Inline if only used once
}
/**
@@ -465,14 +489,14 @@ public class ConvertedModel {
application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString()));
}
- Map<String, ExpressionFunction> readExpressions() {
- Map<String, ExpressionFunction> expressions = new HashMap<>();
+ List<Pair<String, ExpressionFunction>> readExpressions() {
+ List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
- if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
+ if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyList();
for (ApplicationFile expressionFile : expressionPath.listFiles()) {
- try (BufferedReader reader = new BufferedReader(expressionFile.createReader())){
+ try (BufferedReader reader = new BufferedReader(expressionFile.createReader())) {
String name = expressionFile.getPath().getName();
- expressions.put(name, readExpression(name, reader));
+ expressions.add(new Pair<>(name, readExpression(name, reader)));
}
catch (IOException e) {
throw new UncheckedIOException("Failed reading " + expressionFile.getPath(), e);
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 8fe4a8fb022..d665b7f20f0 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -159,7 +159,7 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
public void testNeuralNetworkSetup() throws ParseException {
// Note: the type assigned to query profile and constant tensors here is not the correct type
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
- QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[1])");
+ QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(input[1])");
SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles);
builder.importString(
"search test {\n" +
@@ -184,19 +184,19 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
" }\n" +
" }\n" +
" constant W_hidden {\n" +
- " type: tensor(x[1])\n" +
+ " type: tensor(hidden[1])\n" +
" file: ignored.json\n" +
" }\n" +
" constant b_input {\n" +
- " type: tensor(x[1])\n" +
+ " type: tensor(hidden[1])\n" +
" file: ignored.json\n" +
" }\n" +
" constant W_final {\n" +
- " type: tensor(x[1])\n" +
+ " type: tensor(final[1])\n" +
" file: ignored.json\n" +
" }\n" +
" constant b_final {\n" +
- " type: tensor(x[1])\n" +
+ " type: tensor(final[1])\n" +
" file: ignored.json\n" +
" }\n" +
"}\n");
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index 4c1c24c9790..1aaa1669377 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -145,6 +145,7 @@ public class ModelEvaluationTest {
private final String profile =
"rankingExpression(imported_ml_function_small_constants_and_functions_exp_output).rankingScript: map(input, f(a)(exp(a)))\n" +
+ "rankingExpression(imported_ml_function_small_constants_and_functions_exp_output).type: tensor<float>(d0[3])\n" +
"rankingExpression(default.output).rankingScript: join(rankingExpression(imported_ml_function_small_constants_and_functions_exp_output), reduce(join(join(reduce(rankingExpression(imported_ml_function_small_constants_and_functions_exp_output), sum, d0), tensor<float>(d0[1])(1.0), f(a,b)(a * b)), 9.999999974752427E-7, f(a,b)(a + b)), sum, d0), f(a,b)(a / b))\n" +
"rankingExpression(default.output).input.type: tensor<float>(d0[3])\n" +
"rankingExpression(default.output).type: tensor<float>(d0[3])\n";