diff options
author | Jon Bratseth <bratseth@oath.com> | 2020-04-26 13:45:54 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-04-26 13:45:54 +0200 |
commit | 1ea6410fd642bb92a03b5798387fa977286af167 (patch) | |
tree | 2a4d3fa648d581c472e2ceb83484a24edf0b739a /config-model/src | |
parent | 84e4a15ed3545f59f8f38f4fe8c4770098cce642 (diff) | |
parent | d95c4ceb6c4db635e93b605d719f01db9d5f2e6f (diff) |
Merge pull request #13067 from vespa-engine/lesters/bert-searchlib-and-config-model
Lesters/bert searchlib and config model
Diffstat (limited to 'config-model/src')
8 files changed, 98 insertions, 54 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 6de7c985326..4011ce43841 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -41,6 +41,9 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement private final Map<Reference, TensorType> resolvedTypes = new HashMap<>(); + /** To avoid re-resolving diamond-shaped dependencies */ + private final Map<Reference, TensorType> globallyResolvedTypes; + /** For invocation loop detection */ private final Deque<Reference> currentResolutionCallStack; @@ -53,6 +56,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement this.currentResolutionCallStack = new ArrayDeque<>(); this.queryFeaturesNotDeclared = new TreeSet<>(); tensorsAreUsed = false; + globallyResolvedTypes = new HashMap<>(); } private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, @@ -60,12 +64,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement Map<Reference, TensorType> featureTypes, Deque<Reference> currentResolutionCallStack, SortedSet<Reference> queryFeaturesNotDeclared, - boolean tensorsAreUsed) { + boolean tensorsAreUsed, + Map<Reference, TensorType> globallyResolvedTypes) { super(functions, bindings); this.featureTypes.putAll(featureTypes); this.currentResolutionCallStack = currentResolutionCallStack; this.queryFeaturesNotDeclared = queryFeaturesNotDeclared; this.tensorsAreUsed = tensorsAreUsed; + this.globallyResolvedTypes = globallyResolvedTypes; } public void setType(Reference reference, TensorType type) { @@ -82,11 +88,25 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement resolvedTypes.clear(); } - @Override + private boolean referenceCanBeResolvedGlobally(Reference reference) { + Optional<ExpressionFunction> function = functionInvocation(reference); + return function.isPresent() && function.get().arguments().size() == 0; + // are there other cases we would like to resolve globally? + } + + @Override public TensorType getType(Reference reference) { // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: + + boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference); + TensorType resolvedType = resolvedTypes.get(reference); - if (resolvedType != null) return resolvedType; + if (resolvedType == null && canBeResolvedGlobally) { + resolvedType = globallyResolvedTypes.get(reference); + } + if (resolvedType != null) { + return resolvedType; + } resolvedType = resolveType(reference); if (resolvedType == null) @@ -94,6 +114,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement resolvedTypes.put(reference, resolvedType); if (resolvedType.rank() > 0) tensorsAreUsed = true; + + if (canBeResolvedGlobally) { + globallyResolvedTypes.put(reference, resolvedType); + } + return resolvedType; } @@ -254,7 +279,8 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement featureTypes, currentResolutionCallStack, queryFeaturesNotDeclared, - tensorsAreUsed); + tensorsAreUsed, + globallyResolvedTypes); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 23eb814de81..ea126123a25 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -680,11 +680,12 @@ public class RankProfile implements Cloneable { Map<String, RankingExpressionFunction> inlineFunctions = compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms); + firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + // Function compiling second pass: compile all functions and insert previously compiled inline functions functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms); - firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); - secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); } private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<String, Value> constants) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 1a22b98fd9f..c3c10139684 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -200,9 +200,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (functions.isEmpty()) return; List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); - Map<String, String> functionProperties = new LinkedHashMap<>(); - functionProperties.putAll(deriveFunctionProperties(functions, functionExpressions)); if (firstPhaseRanking != null) { functionProperties.putAll(firstPhaseRanking.getRankProperties(functionExpressions)); @@ -211,20 +209,30 @@ public class RawRankProfile implements RankProfilesConfig.Producer { functionProperties.putAll(secondPhaseRanking.getRankProperties(functionExpressions)); } + SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); + replaceFunctionSummaryFeatures(context); + + // First phase, second phase and summary features should add all required functions to the context. + // However, we need to add any functions not referenced in those anyway for model-evaluation. + deriveFunctionProperties(functions, functionExpressions, functionProperties); + for (Map.Entry<String, String> e : functionProperties.entrySet()) { rankProperties.add(new RankProfile.RankProperty(e.getKey(), e.getValue())); } - SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); - replaceFunctionSummaryFeatures(context); } - private Map<String, String> deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, - List<ExpressionFunction> functionExpressions) { - SerializationContext context = new SerializationContext(functionExpressions); + private void deriveFunctionProperties(Map<String, RankProfile.RankingExpressionFunction> functions, + List<ExpressionFunction> functionExpressions, + Map<String, String> functionProperties) { + SerializationContext context = new SerializationContext(functionExpressions, null, functionProperties); for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) { + String propertyName = RankingExpression.propertyName(e.getKey()); + if (context.serializedFunctions().containsKey(propertyName)) { + continue; + } String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); - context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); + context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet()) context.addArgumentTypeSerialization(e.getKey(), argumentType.getKey(), argumentType.getValue()); if (e.getValue().function().returnType().isPresent()) @@ -232,7 +240,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { // else if (e.getValue().function().arguments().isEmpty()) TODO: Enable this check when we resolve all types // throw new IllegalStateException("Type of function '" + e.getKey() + "' is not resolved"); } - return context.serializedFunctions(); + functionProperties.putAll(context.serializedFunctions()); } private void replaceFunctionSummaryFeatures(SerializationContext context) { @@ -241,9 +249,11 @@ public class RawRankProfile implements RankProfilesConfig.Producer { for (Iterator<ReferenceNode> i = summaryFeatures.iterator(); i.hasNext(); ) { ReferenceNode referenceNode = i.next(); // Is the feature a function? - if (context.getFunction(referenceNode.getName()) != null) { - context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()), - referenceNode.toString(new StringBuilder(), context, null, null).toString()); + ExpressionFunction function = context.getFunction(referenceNode.getName()); + if (function != null) { + String propertyName = RankingExpression.propertyName(referenceNode.getName()); + String expressionString = function.getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + context.addFunctionSerialization(propertyName, expressionString); ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput()); functionSummaryFeatures.put(referenceNode.getName(), newReferenceNode); i.remove(); // Will add the expanded one in next block diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 5c1134f928c..e4ca83640e9 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -50,10 +50,10 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(sin).rankingScript,x * x)", - testRankProperties.get(0).toString()); assertEquals("(rankingExpression(sin@).rankingScript,2 * 2)", - censorBindingHash(testRankProperties.get(1).toString())); + censorBindingHash(testRankProperties.get(0).toString())); + assertEquals("(rankingExpression(sin).rankingScript,x * x)", + testRankProperties.get(1).toString()); assertEquals("(vespa.rank.firstphase,rankingExpression(sin@))", censorBindingHash(testRankProperties.get(2).toString())); } @@ -94,27 +94,26 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(s)).configProperties(); + assertEquals("(rankingExpression(tan@).rankingScript,2 * 2)", + censorBindingHash(testRankProperties.get(0).toString())); + assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))", + censorBindingHash(testRankProperties.get(1).toString())); + assertEquals("(rankingExpression(sin@).rankingScript,rankingExpression(cos@))", + censorBindingHash(testRankProperties.get(2).toString())); assertEquals("(rankingExpression(tan).rankingScript,x * x)", - testRankProperties.get(0).toString()); + testRankProperties.get(3).toString()); assertEquals("(rankingExpression(tan@).rankingScript,x * x)", - censorBindingHash(testRankProperties.get(1).toString())); - assertEquals("(rankingExpression(cos).rankingScript,rankingExpression(tan@))", - censorBindingHash(testRankProperties.get(2).toString())); - assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))", - censorBindingHash(testRankProperties.get(3).toString())); - assertEquals("(rankingExpression(sin).rankingScript,rankingExpression(cos@))", censorBindingHash(testRankProperties.get(4).toString())); - assertEquals("(rankingExpression(tan@).rankingScript,2 * 2)", + assertEquals("(rankingExpression(cos).rankingScript,rankingExpression(tan@))", censorBindingHash(testRankProperties.get(5).toString())); assertEquals("(rankingExpression(cos@).rankingScript,rankingExpression(tan@))", - censorBindingHash(testRankProperties.get(6).toString())); - assertEquals("(rankingExpression(sin@).rankingScript,rankingExpression(cos@))", + censorBindingHash(testRankProperties.get(6).toString())); + assertEquals("(rankingExpression(sin).rankingScript,rankingExpression(cos@))", censorBindingHash(testRankProperties.get(7).toString())); assertEquals("(vespa.rank.firstphase,rankingExpression(sin@))", censorBindingHash(testRankProperties.get(8).toString())); } - @Test public void testFunctionShadowingArguments() throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); @@ -144,12 +143,12 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { new QueryProfileRegistry(), new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(sin).rankingScript,x * x)", - testRankProperties.get(0).toString()); assertEquals("(rankingExpression(sin@).rankingScript,4.0 * 4.0)", - censorBindingHash(testRankProperties.get(1).toString())); + censorBindingHash(testRankProperties.get(0).toString())); assertEquals("(rankingExpression(sin@).rankingScript,cos(5.0) * cos(5.0))", - censorBindingHash(testRankProperties.get(2).toString())); + censorBindingHash(testRankProperties.get(1).toString())); + assertEquals("(rankingExpression(sin).rankingScript,x * x)", + testRankProperties.get(2).toString()); assertEquals("(vespa.rank.firstphase,rankingExpression(firstphase))", censorBindingHash(testRankProperties.get(3).toString())); assertEquals("(rankingExpression(firstphase).rankingScript,cos(rankingExpression(sin@)) + rankingExpression(sin@))", @@ -208,17 +207,17 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { queryProfiles, new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", - testRankProperties.get(0).toString()); assertEquals("(rankingExpression(relu@).rankingScript,max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))", - censorBindingHash(testRankProperties.get(1).toString())); + censorBindingHash(testRankProperties.get(0).toString())); assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))", - censorBindingHash(testRankProperties.get(2).toString())); + censorBindingHash(testRankProperties.get(1).toString())); assertEquals("(rankingExpression(hidden_layer).type,tensor(x[]))", - censorBindingHash(testRankProperties.get(3).toString())); + censorBindingHash(testRankProperties.get(2).toString())); assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", - testRankProperties.get(4).toString()); + testRankProperties.get(3).toString()); assertEquals("(rankingExpression(final_layer).type,tensor(x[]))", + testRankProperties.get(4).toString()); + assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(5).toString()); assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))", testRankProperties.get(6).toString()); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 0cd6674751e..e6616ce0dd1 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -41,6 +41,14 @@ class RankProfileSearchFixture { private Search search; private Map<String, RankProfile> compiledRankProfiles = new HashMap<>(); + public RankProfileRegistry getRankProfileRegistry() { + return rankProfileRegistry; + } + + public QueryProfileRegistry getQueryProfileRegistry() { + return queryProfileRegistry; + } + RankProfileSearchFixture(String rankProfiles) throws ParseException { this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index a64a964727c..680f2dd9659 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -319,7 +319,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testFunctionGeneration() { final String name = "mnist_saved"; - final String expression = "join(join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))"; final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))"; @@ -349,7 +349,7 @@ public class RankingExpressionWithTensorFlowTestCase { " rank-profile my_profile_child inherits my_profile {\n" + " }"; - final String expression = "join(join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))"; final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))"; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index b3eda9b7e13..1567a4c3b5e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -44,20 +44,20 @@ public class RankingExpressionsTestCase extends SchemaTestCase { new AttributeFields(search)).configProperties(); assertEquals(6, rankProperties.size()); - assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(0).getFirst()); - assertEquals("var1 * var2 + 890", rankProperties.get(0).getSecond()); + assertEquals("rankingExpression(titlematch$).rankingScript", rankProperties.get(2).getFirst()); + assertEquals("var1 * var2 + 890", rankProperties.get(2).getSecond()); - assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(1).getFirst()); - assertEquals("78 + closeness(distance)", rankProperties.get(1).getSecond()); + assertEquals("rankingExpression(artistmatch).rankingScript", rankProperties.get(3).getFirst()); + assertEquals("78 + closeness(distance)", rankProperties.get(3).getSecond()); assertEquals("rankingExpression(firstphase).rankingScript", rankProperties.get(5).getFirst()); assertEquals("0.8 + 0.2 * rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c) + 0.8 * rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6) * closeness(distance)", rankProperties.get(5).getSecond()); - assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(3).getFirst()); - assertEquals("7 * 8 + 890", rankProperties.get(3).getSecond()); + assertEquals("rankingExpression(titlematch$@c7e4c2d0e6d9f2a1.1d4ed08e56cce2e6).rankingScript", rankProperties.get(1).getFirst()); + assertEquals("7 * 8 + 890", rankProperties.get(1).getSecond()); - assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(2).getFirst()); - assertEquals("4 * 5 + 890", rankProperties.get(2).getSecond()); + assertEquals("rankingExpression(titlematch$@126063073eb2deb.ab95cd69909927c).rankingScript", rankProperties.get(0).getFirst()); + assertEquals("4 * 5 + 890", rankProperties.get(0).getSecond()); } @Test(expected = IllegalArgumentException.class) diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java index ca84eb5eed7..57dbf132883 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java @@ -64,10 +64,10 @@ public class MlModelsTest { private final String testProfile = "rankingExpression(input).rankingScript: attribute(argument)\n" + "rankingExpression(input).type: tensor<float>(d0[],d1[784])\n" + - "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" + - "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" + "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(rankingExpression(input), (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" + "rankingExpression(mnist_tensorflow).rankingScript: join(reduce(join(map(join(reduce(join(join(join(0.009999999776482582, rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.0507009873554805 * if (a >= 0, a, 1.6732632423543772 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" + + "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" + + "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" + "rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" + "rankingExpression(mnist_softmax_onnx).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))\n" + "rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (!(f56 >= -0.242398), 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (!(f60 >= -0.482947), if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" + |