summaryrefslogtreecommitdiffstats
path: root/config-model/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2020-04-26 13:45:54 +0200
committerGitHub <noreply@github.com>2020-04-26 13:45:54 +0200
commit1ea6410fd642bb92a03b5798387fa977286af167 (patch)
tree2a4d3fa648d581c472e2ceb83484a24edf0b739a /config-model/src
parent84e4a15ed3545f59f8f38f4fe8c4770098cce642 (diff)
parentd95c4ceb6c4db635e93b605d719f01db9d5f2e6f (diff)
Merge pull request #13067 from vespa-engine/lesters/bert-searchlib-and-config-model
Lesters/bert searchlib and config model
Diffstat (limited to 'config-model/src')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java34
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java5
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java34
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java47
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java4
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java16
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java4
8 files changed, 98 insertions, 54 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 6de7c985326..4011ce43841 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -41,6 +41,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
private final Map<Reference, TensorType> resolvedTypes = new HashMap<>();
+ /** To avoid re-resolving diamond-shaped dependencies */
+ private final Map<Reference, TensorType> globallyResolvedTypes;
+
/** For invocation loop detection */
private final Deque<Reference> currentResolutionCallStack;
@@ -53,6 +56,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
this.currentResolutionCallStack = new ArrayDeque<>();
this.queryFeaturesNotDeclared = new TreeSet<>();
tensorsAreUsed = false;
+ globallyResolvedTypes = new HashMap<>();
}
private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions,
@@ -60,12 +64,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
Map<Reference, TensorType> featureTypes,
Deque<Reference> currentResolutionCallStack,
SortedSet<Reference> queryFeaturesNotDeclared,
- boolean tensorsAreUsed) {
+ boolean tensorsAreUsed,
+ Map<Reference, TensorType> globallyResolvedTypes) {
super(functions, bindings);
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = currentResolutionCallStack;
this.queryFeaturesNotDeclared = queryFeaturesNotDeclared;
this.tensorsAreUsed = tensorsAreUsed;
+ this.globallyResolvedTypes = globallyResolvedTypes;
}
public void setType(Reference reference, TensorType type) {
@@ -82,11 +88,25 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
resolvedTypes.clear();
}
- @Override
+ private boolean referenceCanBeResolvedGlobally(Reference reference) {
+ Optional<ExpressionFunction> function = functionInvocation(reference);
+ return function.isPresent() && function.get().arguments().size() == 0;
+ // are there other cases we would like to resolve globally?
+ }
+
+ @Override
public TensorType getType(Reference reference) {
// computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
+
+ boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference);
+
TensorType resolvedType = resolvedTypes.get(reference);
- if (resolvedType != null) return resolvedType;
+ if (resolvedType == null && canBeResolvedGlobally) {
+ resolvedType = globallyResolvedTypes.get(reference);
+ }
+ if (resolvedType != null) {
+ return resolvedType;
+ }
resolvedType = resolveType(reference);
if (resolvedType == null)
@@ -94,6 +114,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
resolvedTypes.put(reference, resolvedType);
if (resolvedType.rank() > 0)
tensorsAreUsed = true;
+
+ if (canBeResolvedGlobally) {
+ globallyResolvedTypes.put(reference, resolvedType);
+ }
+
return resolvedType;
}
@@ -254,7 +279,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
featureTypes,
currentResolutionCallStack,
queryFeaturesNotDeclared,
- tensorsAreUsed);
+ tensorsAreUsed,
+ globallyResolvedTypes);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 23eb814de81..ea126123a25 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -680,11 +680,12 @@ public class RankProfile implements Cloneable {
Map<String, RankingExpressionFunction> inlineFunctions =
compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms);
+ firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
+ secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
+
// Function compiling second pass: compile all functions and insert previously compiled inline functions
functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms);
- firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
- secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
}
private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<String, Value> constants) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 1a22b98fd9f..c3c10139684 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -200,9 +200,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
if (functions.isEmpty()) return;
List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
-
Map<String, String> functionProperties = new LinkedHashMap<>();
- functionProperties.putAll(deriveFunctionProperties(functions, functionExpressions));
if (firstPhaseRanking != null) {
functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions));
@@ -211,20 +209,30 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions));
}
+ SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
+ replaceFunctionSummaryFeatures(context);
+
+ // First phase, second phase and summary features should add all required functions to the context.
+ // However, we need to add any functions not referenced in those anyway for model-evaluation.
+ deriveFunctionProperties(functions, functionExpressions, functionProperties);
+
for (Map.Entry<String, String> e : functionProperties.entrySet()) {
rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue()));
}
- SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
- replaceFunctionSummaryFeatures(context);
}
- private Map<String, String> deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions,
- List<ExpressionFunction> functionExpressions) {
- SerializationContext context = new SerializationContext(functionExpressions);
+ private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions,
+ List<ExpressionFunction> functionExpressions,
+ Map<String, String> functionProperties) {
+ SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties);
for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) {
+ String propertyName = RankingExpression.propertyName(e.getKey());
+ if (context.serializedFunctions().containsKey(propertyName)) {
+ continue;
+ }
String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
- context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString);
+ context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString);
for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue());
if (e.getValue().function().returnType().isPresent())
@@ -232,7 +240,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
// else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types
// throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved");
}
- return context.serializedFunctions();
+ functionProperties.putAll(context.serializedFunctions());
}
private void replaceFunctionSummaryFeatures(SerializationContext context) {
@@ -241,9 +249,11 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) {
ReferenceNode referenceNode = i.next();
// Is the feature a function?
- if (context.getFunction(referenceNode.getName()) != null) {
- context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()),
- referenceNode.toString(new StringBuilder(), context, null, null).toString());
+ ExpressionFunction function = context.getFunction(referenceNode.getName());
+ if (function != null) {
+ String propertyName = RankingExpression.propertyName(referenceNode.getName());
+ String expressionString = function.getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
+ context.addFunctionSerialization(propertyName, expressionString);
ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput());
functionSummaryFeatures.put(referenceNode.getName(), newReferenceNode);
i.remove(); // Will add the expanded one in next block
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index 5c1134f928c..e4ca83640e9 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -50,10 +50,10 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
new QueryProfileRegistry(),
new ImportedMlModels(),
new AttributeFields(s)).configProperties();
- assertEquals("(rankingExpression(sin).rankingScript,x * x)",
- testRankProperties.get(0).toString());
assertEquals("(rankingExpression(sin@).rankingScript,2 * 2)",
- censorBindingHash(testRankProperties.get(1).toString()));
+ censorBindingHash(testRankProperties.get(0).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript,x * x)",
+ testRankProperties.get(1).toString());
assertEquals("(vespa.rank.firstphase,rankingExpression(sin@))",
censorBindingHash(testRankProperties.get(2).toString()));
}
@@ -94,27 +94,26 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
new QueryProfileRegistry(),
new ImportedMlModels(),
new AttributeFields(s)).configProperties();
+ assertEquals("(rankingExpression(tan@).rankingScript,2 * 2)",
+ censorBindingHash(testRankProperties.get(0).toString()));
+ assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))",
+ censorBindingHash(testRankProperties.get(1).toString()));
+ assertEquals("(rankingExpression(sin@).rankingScript,rankingExpression(cos@))",
+ censorBindingHash(testRankProperties.get(2).toString()));
assertEquals("(rankingExpression(tan).rankingScript,x * x)",
- testRankProperties.get(0).toString());
+ testRankProperties.get(3).toString());
assertEquals("(rankingExpression(tan@).rankingScript,x * x)",
- censorBindingHash(testRankProperties.get(1).toString()));
- assertEquals("(rankingExpression(cos).rankingScript,rankingExpression(tan@))",
- censorBindingHash(testRankProperties.get(2).toString()));
- assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))",
- censorBindingHash(testRankProperties.get(3).toString()));
- assertEquals("(rankingExpression(sin).rankingScript,rankingExpression(cos@))",
censorBindingHash(testRankProperties.get(4).toString()));
- assertEquals("(rankingExpression(tan@).rankingScript,2 * 2)",
+ assertEquals("(rankingExpression(cos).rankingScript,rankingExpression(tan@))",
censorBindingHash(testRankProperties.get(5).toString()));
assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))",
- censorBindingHash(testRankProperties.get(6).toString()));
- assertEquals("(rankingExpression(sin@).rankingScript,rankingExpression(cos@))",
+ censorBindingHash(testRankProperties.get(6).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript,rankingExpression(cos@))",
censorBindingHash(testRankProperties.get(7).toString()));
assertEquals("(vespa.rank.firstphase,rankingExpression(sin@))",
censorBindingHash(testRankProperties.get(8).toString()));
}
-
@Test
public void testFunctionShadowingArguments() throws ParseException {
RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
@@ -144,12 +143,12 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
new QueryProfileRegistry(),
new ImportedMlModels(),
new AttributeFields(s)).configProperties();
- assertEquals("(rankingExpression(sin).rankingScript,x * x)",
- testRankProperties.get(0).toString());
assertEquals("(rankingExpression(sin@).rankingScript,4.0 * 4.0)",
- censorBindingHash(testRankProperties.get(1).toString()));
+ censorBindingHash(testRankProperties.get(0).toString()));
assertEquals("(rankingExpression(sin@).rankingScript,cos(5.0) * cos(5.0))",
- censorBindingHash(testRankProperties.get(2).toString()));
+ censorBindingHash(testRankProperties.get(1).toString()));
+ assertEquals("(rankingExpression(sin).rankingScript,x * x)",
+ testRankProperties.get(2).toString());
assertEquals("(vespa.rank.firstphase,rankingExpression(firstphase))",
censorBindingHash(testRankProperties.get(3).toString()));
assertEquals("(rankingExpression(firstphase).rankingScript,cos(rankingExpression(sin@)) + rankingExpression(sin@))",
@@ -208,17 +207,17 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase {
queryProfiles,
new ImportedMlModels(),
new AttributeFields(s)).configProperties();
- assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
- testRankProperties.get(0).toString());
assertEquals("(rankingExpression(relu@).rankingScript,max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))",
- censorBindingHash(testRankProperties.get(1).toString()));
+ censorBindingHash(testRankProperties.get(0).toString()));
assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))",
- censorBindingHash(testRankProperties.get(2).toString()));
+ censorBindingHash(testRankProperties.get(1).toString()));
assertEquals("(rankingExpression(hidden_layer).type,tensor(x[]))",
- censorBindingHash(testRankProperties.get(3).toString()));
+ censorBindingHash(testRankProperties.get(2).toString()));
assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))",
- testRankProperties.get(4).toString());
+ testRankProperties.get(3).toString());
assertEquals("(rankingExpression(final_layer).type,tensor(x[]))",
+ testRankProperties.get(4).toString());
+ assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))",
testRankProperties.get(5).toString());
assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))",
testRankProperties.get(6).toString());
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 0cd6674751e..e6616ce0dd1 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -41,6 +41,14 @@ class RankProfileSearchFixture {
private Search search;
private Map<String, RankProfile> compiledRankProfiles = new HashMap<>();
+ public RankProfileRegistry getRankProfileRegistry() {
+ return rankProfileRegistry;
+ }
+
+ public QueryProfileRegistry getQueryProfileRegistry() {
+ return queryProfileRegistry;
+ }
+
RankProfileSearchFixture(String rankProfiles) throws ParseException {
this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles);
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index a64a964727c..680f2dd9659 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -319,7 +319,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testFunctionGeneration() {
final String name = "mnist_saved";
- final String expression = "join(join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))";
final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))";
final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))";
@@ -349,7 +349,7 @@ public class RankingExpressionWithTensorFlowTestCase {
" rank-profile my_profile_child inherits my_profile {\n" +
" }";
- final String expression = "join(join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))";
final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))";
final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))";
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
index b3eda9b7e13..1567a4c3b5e 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java
@@ -44,20 +44,20 @@ public class RankingExpressionsTestCase extends SchemaTestCase {
new AttributeFields(search)).configProperties();
assertEquals(6, rankProperties.size());
- assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(0).getFirst());
- assertEquals("var1 * var2 + 890", rankProperties.get(0).getSecond());
+ assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(2).getFirst());
+ assertEquals("var1 * var2 + 890", rankProperties.get(2).getSecond());
- assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(1).getFirst());
- assertEquals("78 + closeness(distance)", rankProperties.get(1).getSecond());
+ assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(3).getFirst());
+ assertEquals("78 + closeness(distance)", rankProperties.get(3).getSecond());
assertEquals("rankingExpression(firstphase).rankingScript", rankProperties.get(5).getFirst());
assertEquals("0.8 + 0.2 * rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c) + 0.8 * rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6) * closeness(distance)", rankProperties.get(5).getSecond());
- assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(3).getFirst());
- assertEquals("7 * 8 + 890", rankProperties.get(3).getSecond());
+ assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(1).getFirst());
+ assertEquals("7 * 8 + 890", rankProperties.get(1).getSecond());
- assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(2).getFirst());
- assertEquals("4 * 5 + 890", rankProperties.get(2).getSecond());
+ assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(0).getFirst());
+ assertEquals("4 * 5 + 890", rankProperties.get(0).getSecond());
}
@Test(expected = IllegalArgumentException.class)
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
index ca84eb5eed7..57dbf132883 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
@@ -64,10 +64,10 @@ public class MlModelsTest {
private final String testProfile =
"rankingExpression(input).rankingScript: attribute(argument)\n" +
"rankingExpression(input).type: tensor<float>(d0[],d1[784])\n" +
- "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" +
- "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" +
"rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(rankingExpression(input), (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
"rankingExpression(mnist_tensorflow).rankingScript: join(reduce(join(map(join(reduce(join(join(join(0.009999999776482582, rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.0507009873554805 * if (a >= 0, a, 1.6732632423543772 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" +
+ "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" +
+ "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" +
"rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" +
"rankingExpression(mnist_softmax_onnx).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))\n" +
"rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (!(f56 >= -0.242398), 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (!(f60 >= -0.482947), if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" +