diff options
30 files changed, 556 insertions, 74 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java index 2f41b172ab6..1e133d0b8f4 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java @@ -34,6 +34,12 @@ public class FeatureNames { return name.equals("attribute") || name.equals("constant") || name.equals("query"); } + /** Returns true if this is a constant */ + public static boolean isConstantFeature(Reference reference) { + if ( ! isSimpleFeature(reference)) return false; + return reference.name().equals("constant"); + } + /** * Returns the single argument of the given feature name, without any quotes, * or empty if it is not a valid query, attribute or constant feature name diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index b3853b36aa5..d738929f721 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -756,7 +756,7 @@ public class RankProfile implements Serializable, Cloneable { for (FieldDescription field : queryProfileType.declaredFields().values()) { TensorType type = field.getType().asTensorType(); Optional<Reference> feature = Reference.simple(field.getName()); - if ( ! feature.isPresent() || ! feature.get().name().equals("query")) continue; + if ( feature.isEmpty() || ! feature.get().name().equals("query")) continue; TensorType existingType = context.getType(feature.get()); if ( ! Objects.equals(existingType, context.defaultTypeOf(feature.get()))) diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java index 9804b0b6329..a84db895b02 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java @@ -8,7 +8,9 @@ import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.document.DocumentTypeManager; import com.yahoo.io.IOUtils; import com.yahoo.io.reader.NamedReader; +import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.config.QueryProfileXMLReader; import com.yahoo.searchdefinition.derived.SearchOrderer; import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.parser.ParseException; @@ -394,14 +396,13 @@ public class SearchBuilder { } public static SearchBuilder createFromDirectory(String dir) throws IOException, ParseException { - return createFromDirectory(dir, new RankProfileRegistry(), new QueryProfileRegistry()); + return createFromDirectory(dir, new RankProfileRegistry()); } public static SearchBuilder createFromDirectory(String dir, - RankProfileRegistry rankProfileRegistry, - QueryProfileRegistry queryProfileRegistry) throws IOException, ParseException { + RankProfileRegistry rankProfileRegistry) throws IOException, ParseException { SearchBuilder builder = new SearchBuilder(MockApplicationPackage.fromSearchDefinitionDirectory(dir), rankProfileRegistry, - queryProfileRegistry); + createQueryProfileRegistryFromDirectory(dir)); for (Iterator<Path> i = Files.list(new File(dir).toPath()).filter(p -> p.getFileName().toString().endsWith(".sd")).iterator(); i.hasNext(); ) { builder.importFile(i.next()); } @@ -409,6 +410,12 @@ public class SearchBuilder { return builder; } + private static QueryProfileRegistry createQueryProfileRegistryFromDirectory(String dir) { + File queryProfilesDir = new File(dir, "query-profiles"); + if ( ! queryProfilesDir.exists()) return new QueryProfileRegistry(); + return new QueryProfileXMLReader().read(queryProfilesDir.toString()); + } + // TODO: The build methods below just call the create methods above - remove /** diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java index caf5f0442eb..6991e2b978b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java @@ -49,13 +49,13 @@ public class ConstantTensorTransformer extends ExpressionTransformer<RankProfile } private ExpressionNode transformConstantReference(ReferenceNode node, RankProfileTransformContext context) { - Reference constantReference = FeatureNames.asConstantFeature(node.getName()); + Reference constantReference = node.reference(); + if ( ! FeatureNames.isConstantFeature(constantReference) && constantReference.isIdentifier()) + constantReference = FeatureNames.asConstantFeature(node.getName()); + Value value = context.constants().get(node.getName()); - if (value == null || value.type().rank() == 0) { - if (context.rankProfile().rankingConstants().get(node.getName()) != null) // Large constants: Transform reference but don't add value - return new ReferenceNode(constantReference); - return node; - } + if (value == null || value.type().rank() == 0) return node; + TensorValue tensorValue = (TensorValue)value; String tensorType = tensorValue.asTensor().type().toString(); context.rankProperties().put(constantReference.toString() + ".value", tensorValue.toString()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java index cbabfffb7a1..6fdf448a39b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -6,6 +6,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.transform.ConstantDereferencer; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.searchlib.rankingexpression.transform.Simplifier; +import com.yahoo.searchlib.rankingexpression.transform.TensorMaxMinTransformer; import java.util.List; @@ -30,7 +31,7 @@ public class ExpressionTransforms { new ConstantTensorTransformer(), new FunctionInliner(), new FunctionShadower(), - new TensorTransformer(), + new TensorMaxMinTransformer(), new Simplifier()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java index 2c0e1eaa56a..630c8644eb1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/RankProfileTransformContext.java @@ -24,19 +24,17 @@ public class RankProfileTransformContext extends TransformContext { private final ImportedMlModels importedModels; private final Map<String, RankProfile.RankingExpressionFunction> inlineFunctions; private final Map<String, String> rankProperties = new HashMap<>(); - private final MapEvaluationTypeContext types; public RankProfileTransformContext(RankProfile rankProfile, QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, Map<String, Value> constants, Map<String, RankProfile.RankingExpressionFunction> inlineFunctions) { - super(constants); + super(constants, rankProfile.typeContext(queryProfiles)); this.rankProfile = rankProfile; this.queryProfiles = queryProfiles; this.importedModels = importedModels; this.inlineFunctions = inlineFunctions; - this.types = rankProfile.typeContext(queryProfiles); } public RankProfile rankProfile() { return rankProfile; } @@ -45,10 +43,4 @@ public class RankProfileTransformContext extends TransformContext { public Map<String, RankProfile.RankingExpressionFunction> inlineFunctions() { return inlineFunctions; } public Map<String, String> rankProperties() { return rankProperties; } - /** - * Returns the types known in this context. We may have type information for references - * for which no value is available - */ - public MapEvaluationTypeContext types() { return types; } - } diff --git a/config-model/src/test/derived/neuralnet/neuralnet.sd b/config-model/src/test/derived/neuralnet/neuralnet.sd new file mode 100644 index 00000000000..f916b35cb75 --- /dev/null +++ b/config-model/src/test/derived/neuralnet/neuralnet.sd @@ -0,0 +1,238 @@ +search neuralnet { + + document neuralnet { + + field pinned type int { + indexing: attribute + } + + field createdAt type long { + indexing: attribute + } + + field updatedAt type long { + indexing: attribute + } + + field uvCount type int { + indexing: attribute + } + + field dvCount type int { + indexing: attribute + } + + field aVoteCount type int { + indexing: attribute + } + + field rCount type int { + indexing: attribute + } + + field uniqueRACount type int { + indexing: attribute + } + + field rTo type string { + indexing: attribute + } + + field markedAsAAt type long { + indexing: attribute + } + + field normalizedTextScore type float { + indexing: attribute + } + + field t type float { + indexing: attribute + } + + field relevance type float { + indexing: attribute + } + + field normalizedCS type float { + indexing: attribute + } + + field laAt type long { + indexing: attribute + } + + field hsScore type double { + indexing: attribute + } + + } + + rank-profile defaultRankProfile inherits default { + + constants { + maxSignedSixtyFourBitInteger: 9223372036854775807 + } + + macro log10_1p(x) { + expression: log10(x+1) + } + + macro textScoreToUse() { + expression: if(isNan(attribute(normalizedTextScore)) == 1, 0, attribute(normalizedTextScore)) + } + + macro rCountToUse() { + expression: if(isNan(attribute(rCount)) == 1, 0, if(attribute(rCount) < 0, 0, attribute(rCount))) + } + + macro uniqueRCountToUse() { + expression: if(isNan(attribute(uniqueRCount)) == 1, 0, if(attribute(uniqueRACount) < 0, 0, attribute(uniqueRACount))) + } + + macro uvCountToUse() { + expression: if(isNan(attribute(uvCount)) == 1, 0, if(attribute(uvCount) < 0, 0, attribute(uvCount))) + } + + macro dvCountToUse() { + expression: if(isNan(attribute(dvCount)) == 1, 0, if(attribute(dvCount) < 0, 0, attribute(dvCount))) + } + + macro aVoteCountToUse() { + expression: if(isNan(attribute(aVoteCount)) == 1, 0, if(attribute(aVoteCount) < 0, 0, attribute(aVoteCount))) + } + + macro totalPR() { + expression: uniqueRCountToUse + query(voteToRRatio) * (uvCountToUse - dvCountToUse) - aVoteCountToUse + } + + macro totalvote() { + expression: query(reportaweight) * aVoteCountToUse + dvCountToUse + query(rweight) * uniqueRCountToUse + uvCountToUse + } + + macro phat() { + expression: if (totalvote == 0, 0, ( query(rweight) * uniqueRCountToUse + uvCountToUse) / totalvote) + } + + macro nCScoreToUse() { + expression: if (totalPR > 0, log10(totalPR), 0) + } + + macro hsScoreToUse() { + expression: attribute(hsScore) + } + + macro tScoreToUse() { + expression: if (isNan(attribute(t)) == 1, 0.6, attribute(t)) + } + + macro relevanceScoreToUse() { + expression: if (isNan(attribute(relevance)) == 1, 0.254, attribute(relevance)) + } + + macro freshnessToUse() { + expression: if (freshness(createdAt).logscale < 0.01, 0.01, freshness(createdAt).logscale) + } + + macro rankedAt() { + expression: now + } + + macro createdAtToUse() { + expression: if(isNan(attribute(createdAt)) == 1, rankedAt, attribute(createdAt)) + } + + macro laAtToUse() { + expression: if(isNan(attribute(laAt)) == 1, attribute(createdAt), attribute(laAt)) + } + + macro markedAsAAtToUse() { + expression: if(isNan(attribute(markedAsAAt)) == 1, maxSignedSixtyFourBitInteger, attribute(markedAsAAt)) + } + + macro tdToUse() { + expression: pow(2, 0 - ((rankedAt - createdAtToUse) / query(decay))) + } + + macro commentOverallScore() { + expression: query(textweight) * textScoreToUse + query(communityweight) * nCScoreToUse + } + + macro pinScore() { + expression: if(isNan(attribute(pinned)) == 1, 0, query(pinweight) * attribute(pinned)) + } + + macro freshnessRank() { + expression: nativeRank + freshness(createdAt) + } + + first-phase { + expression: nativeRank + } + + } + + rank-profile neuralNetworkProfile inherits defaultRankProfile { + macro nn_input() { + expression { + concat(log10_1p(aVoteCountToUse), + concat(log10_1p(dvCountToUse), + concat(log10_1p(uniqueRCountToUse), + concat(log10_1p(uvCountToUse), + concat(phat, + concat(log10_1p(totalvote), + concat(hsScoreToUse, + concat(tdToUse, + tScoreToUse, x), x), x), x), x), x), x), x) + } + } + + macro get_model_weights(field) { + expression: if(query(field) == 0, constant(field), query(field)) + } + + macro layer_0() { + expression: elu(xw_plus_b(nn_input, get_model_weights(W_0), get_model_weights(b_0), x)) + } + macro layer_1() { + expression: elu(xw_plus_b(layer_0, get_model_weights(W_1), get_model_weights(b_1), hidden)) + } + macro layer_out() { + expression: sum(xw_plus_b(layer_1, get_model_weights(W_out), get_model_weights(b_out), out)) + } + first-phase { + expression: freshnessRank + } + second-phase { + expression: layer_out + rerank-count: 2000 + } + + } + + constant W_0 { + file: neural-network-201805/W_0.json + type: tensor(x[9],hidden[9]) + } + constant b_0 { + file: neural-network-201805/b_0.json + type: tensor(hidden[9]) + } + constant W_1 { + file: neural-network-201805/W_1.json + type: tensor(hidden[9],out[9]) + } + constant b_1 { + file: neural-network-201805/b_1.json + type: tensor(out[9]) + } + constant W_out { + file: neural-network-201805/W_out.json + type: tensor(out[9]) + } + constant b_out { + file: neural-network-201805/b_out.json + type: tensor(out[1]) + } + +}
\ No newline at end of file diff --git a/config-model/src/test/derived/neuralnet/query-profiles/default.xml b/config-model/src/test/derived/neuralnet/query-profiles/default.xml new file mode 100644 index 00000000000..eef1aaa7f53 --- /dev/null +++ b/config-model/src/test/derived/neuralnet/query-profiles/default.xml @@ -0,0 +1,2 @@ +<query-profile id="default" type="DefaultQueryProfileType"> +</query-profile> diff --git a/config-model/src/test/derived/neuralnet/query-profiles/types/DefaultQueryProfileType.xml b/config-model/src/test/derived/neuralnet/query-profiles/types/DefaultQueryProfileType.xml new file mode 100644 index 00000000000..e1659479135 --- /dev/null +++ b/config-model/src/test/derived/neuralnet/query-profiles/types/DefaultQueryProfileType.xml @@ -0,0 +1,8 @@ +<query-profile-type id="DefaultQueryProfileType"> + <field name="ranking.features.query(W_0)" type="tensor(x[9],hidden[9])" /> + <field name="ranking.features.query(b_0)" type="tensor(hidden[9])" /> + <field name="ranking.features.query(W_1)" type="tensor(hidden[9],out[9])" /> + <field name="ranking.features.query(b_1)" type="tensor(out[9])" /> + <field name="ranking.features.query(W_out)" type="tensor(out[9])" /> + <field name="ranking.features.query(b_out)" type="tensor(out[1])" /> +</query-profile-type> diff --git a/config-model/src/test/derived/neuralnet/rank-profiles.cfg b/config-model/src/test/derived/neuralnet/rank-profiles.cfg new file mode 100644 index 00000000000..4530bff2e20 --- /dev/null +++ b/config-model/src/test/derived/neuralnet/rank-profiles.cfg @@ -0,0 +1,198 @@ +rankprofile[].name "default" +rankprofile[].fef.property[].name "vespa.type.query.b_out" +rankprofile[].fef.property[].value "tensor(out[1])" +rankprofile[].fef.property[].name "vespa.type.query.W_out" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "vespa.type.query.b_0" +rankprofile[].fef.property[].value "tensor(hidden[9])" +rankprofile[].fef.property[].name "vespa.type.query.b_1" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "vespa.type.query.W_1" +rankprofile[].fef.property[].value "tensor(hidden[9],out[9])" +rankprofile[].fef.property[].name "vespa.type.query.W_0" +rankprofile[].fef.property[].value "tensor(hidden[9],x[9])" +rankprofile[].name "unranked" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "value(0)" +rankprofile[].fef.property[].name "vespa.hitcollector.heapsize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.hitcollector.arraysize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures" +rankprofile[].fef.property[].value "true" +rankprofile[].fef.property[].name "vespa.type.query.b_out" +rankprofile[].fef.property[].value "tensor(out[1])" +rankprofile[].fef.property[].name "vespa.type.query.W_out" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "vespa.type.query.b_0" +rankprofile[].fef.property[].value "tensor(hidden[9])" +rankprofile[].fef.property[].name "vespa.type.query.b_1" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "vespa.type.query.W_1" +rankprofile[].fef.property[].value "tensor(hidden[9],out[9])" +rankprofile[].fef.property[].name "vespa.type.query.W_0" +rankprofile[].fef.property[].value "tensor(hidden[9],x[9])" +rankprofile[].name "defaultRankProfile" +rankprofile[].fef.property[].name "rankingExpression(log10_1p).rankingScript" +rankprofile[].fef.property[].value "log10(x + 1)" +rankprofile[].fef.property[].name "rankingExpression(textScoreToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(normalizedTextScore)) == 1, 0, attribute(normalizedTextScore))" +rankprofile[].fef.property[].name "rankingExpression(rCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(rCount)) == 1, 0, if (attribute(rCount) < 0, 0, attribute(rCount)))" +rankprofile[].fef.property[].name "rankingExpression(uniqueRCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(uniqueRCount)) == 1, 0, if (attribute(uniqueRACount) < 0, 0, attribute(uniqueRACount)))" +rankprofile[].fef.property[].name "rankingExpression(uvCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(uvCount)) == 1, 0, if (attribute(uvCount) < 0, 0, attribute(uvCount)))" +rankprofile[].fef.property[].name "rankingExpression(dvCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(dvCount)) == 1, 0, if (attribute(dvCount) < 0, 0, attribute(dvCount)))" +rankprofile[].fef.property[].name "rankingExpression(aVoteCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(aVoteCount)) == 1, 0, if (attribute(aVoteCount) < 0, 0, attribute(aVoteCount)))" +rankprofile[].fef.property[].name "rankingExpression(totalPR).rankingScript" +rankprofile[].fef.property[].value "rankingExpression(uniqueRCountToUse) + query(voteToRRatio) * (rankingExpression(uvCountToUse) - rankingExpression(dvCountToUse)) - rankingExpression(aVoteCountToUse)" +rankprofile[].fef.property[].name "rankingExpression(totalvote).rankingScript" +rankprofile[].fef.property[].value "query(reportaweight) * rankingExpression(aVoteCountToUse) + rankingExpression(dvCountToUse) + query(rweight) * rankingExpression(uniqueRCountToUse) + rankingExpression(uvCountToUse)" +rankprofile[].fef.property[].name "rankingExpression(phat).rankingScript" +rankprofile[].fef.property[].value "if (rankingExpression(totalvote) == 0, 0, (query(rweight) * rankingExpression(uniqueRCountToUse) + rankingExpression(uvCountToUse)) / rankingExpression(totalvote))" +rankprofile[].fef.property[].name "rankingExpression(nCScoreToUse).rankingScript" +rankprofile[].fef.property[].value "if (rankingExpression(totalPR) > 0, log10(rankingExpression(totalPR)), 0)" +rankprofile[].fef.property[].name "rankingExpression(hsScoreToUse).rankingScript" +rankprofile[].fef.property[].value "attribute(hsScore)" +rankprofile[].fef.property[].name "rankingExpression(tScoreToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(t)) == 1, 0.6, attribute(t))" +rankprofile[].fef.property[].name "rankingExpression(relevanceScoreToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(relevance)) == 1, 0.254, attribute(relevance))" +rankprofile[].fef.property[].name "rankingExpression(freshnessToUse).rankingScript" +rankprofile[].fef.property[].value "if (freshness(createdAt).logscale < 0.01, 0.01, freshness(createdAt).logscale)" +rankprofile[].fef.property[].name "rankingExpression(rankedAt).rankingScript" +rankprofile[].fef.property[].value "now" +rankprofile[].fef.property[].name "rankingExpression(createdAtToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(createdAt)) == 1, rankingExpression(rankedAt), attribute(createdAt))" +rankprofile[].fef.property[].name "rankingExpression(laAtToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(laAt)) == 1, attribute(createdAt), attribute(laAt))" +rankprofile[].fef.property[].name "rankingExpression(markedAsAAtToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(markedAsAAt)) == 1, 9.223372036854776E18, attribute(markedAsAAt))" +rankprofile[].fef.property[].name "rankingExpression(tdToUse).rankingScript" +rankprofile[].fef.property[].value "pow(2,0 - ((rankingExpression(rankedAt) - rankingExpression(createdAtToUse)) / query(decay)))" +rankprofile[].fef.property[].name "rankingExpression(commentOverallScore).rankingScript" +rankprofile[].fef.property[].value "query(textweight) * rankingExpression(textScoreToUse) + query(communityweight) * rankingExpression(nCScoreToUse)" +rankprofile[].fef.property[].name "rankingExpression(pinScore).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(pinned)) == 1, 0, query(pinweight) * attribute(pinned))" +rankprofile[].fef.property[].name "rankingExpression(freshnessRank).rankingScript" +rankprofile[].fef.property[].value "nativeRank + freshness(createdAt)" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "nativeRank" +rankprofile[].fef.property[].name "vespa.type.query.b_out" +rankprofile[].fef.property[].value "tensor(out[1])" +rankprofile[].fef.property[].name "vespa.type.query.W_out" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "vespa.type.query.b_0" +rankprofile[].fef.property[].value "tensor(hidden[9])" +rankprofile[].fef.property[].name "vespa.type.query.b_1" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "vespa.type.query.W_1" +rankprofile[].fef.property[].value "tensor(hidden[9],out[9])" +rankprofile[].fef.property[].name "vespa.type.query.W_0" +rankprofile[].fef.property[].value "tensor(hidden[9],x[9])" +rankprofile[].name "neuralNetworkProfile" +rankprofile[].fef.property[].name "rankingExpression(log10_1p).rankingScript" +rankprofile[].fef.property[].value "log10(x + 1)" +rankprofile[].fef.property[].name "rankingExpression(textScoreToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(normalizedTextScore)) == 1, 0, attribute(normalizedTextScore))" +rankprofile[].fef.property[].name "rankingExpression(rCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(rCount)) == 1, 0, if (attribute(rCount) < 0, 0, attribute(rCount)))" +rankprofile[].fef.property[].name "rankingExpression(uniqueRCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(uniqueRCount)) == 1, 0, if (attribute(uniqueRACount) < 0, 0, attribute(uniqueRACount)))" +rankprofile[].fef.property[].name "rankingExpression(uvCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(uvCount)) == 1, 0, if (attribute(uvCount) < 0, 0, attribute(uvCount)))" +rankprofile[].fef.property[].name "rankingExpression(dvCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(dvCount)) == 1, 0, if (attribute(dvCount) < 0, 0, attribute(dvCount)))" +rankprofile[].fef.property[].name "rankingExpression(aVoteCountToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(aVoteCount)) == 1, 0, if (attribute(aVoteCount) < 0, 0, attribute(aVoteCount)))" +rankprofile[].fef.property[].name "rankingExpression(totalPR).rankingScript" +rankprofile[].fef.property[].value "rankingExpression(uniqueRCountToUse) + query(voteToRRatio) * (rankingExpression(uvCountToUse) - rankingExpression(dvCountToUse)) - rankingExpression(aVoteCountToUse)" +rankprofile[].fef.property[].name "rankingExpression(totalvote).rankingScript" +rankprofile[].fef.property[].value "query(reportaweight) * rankingExpression(aVoteCountToUse) + rankingExpression(dvCountToUse) + query(rweight) * rankingExpression(uniqueRCountToUse) + rankingExpression(uvCountToUse)" +rankprofile[].fef.property[].name "rankingExpression(phat).rankingScript" +rankprofile[].fef.property[].value "if (rankingExpression(totalvote) == 0, 0, (query(rweight) * rankingExpression(uniqueRCountToUse) + rankingExpression(uvCountToUse)) / rankingExpression(totalvote))" +rankprofile[].fef.property[].name "rankingExpression(nCScoreToUse).rankingScript" +rankprofile[].fef.property[].value "if (rankingExpression(totalPR) > 0, log10(rankingExpression(totalPR)), 0)" +rankprofile[].fef.property[].name "rankingExpression(hsScoreToUse).rankingScript" +rankprofile[].fef.property[].value "attribute(hsScore)" +rankprofile[].fef.property[].name "rankingExpression(tScoreToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(t)) == 1, 0.6, attribute(t))" +rankprofile[].fef.property[].name "rankingExpression(relevanceScoreToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(relevance)) == 1, 0.254, attribute(relevance))" +rankprofile[].fef.property[].name "rankingExpression(freshnessToUse).rankingScript" +rankprofile[].fef.property[].value "if (freshness(createdAt).logscale < 0.01, 0.01, freshness(createdAt).logscale)" +rankprofile[].fef.property[].name "rankingExpression(rankedAt).rankingScript" +rankprofile[].fef.property[].value "now" +rankprofile[].fef.property[].name "rankingExpression(createdAtToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(createdAt)) == 1, rankingExpression(rankedAt), attribute(createdAt))" +rankprofile[].fef.property[].name "rankingExpression(laAtToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(laAt)) == 1, attribute(createdAt), attribute(laAt))" +rankprofile[].fef.property[].name "rankingExpression(markedAsAAtToUse).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(markedAsAAt)) == 1, 9.223372036854776E18, attribute(markedAsAAt))" +rankprofile[].fef.property[].name "rankingExpression(tdToUse).rankingScript" +rankprofile[].fef.property[].value "pow(2,0 - ((rankingExpression(rankedAt) - rankingExpression(createdAtToUse)) / query(decay)))" +rankprofile[].fef.property[].name "rankingExpression(commentOverallScore).rankingScript" +rankprofile[].fef.property[].value "query(textweight) * rankingExpression(textScoreToUse) + query(communityweight) * rankingExpression(nCScoreToUse)" +rankprofile[].fef.property[].name "rankingExpression(pinScore).rankingScript" +rankprofile[].fef.property[].value "if (isNan(attribute(pinned)) == 1, 0, query(pinweight) * attribute(pinned))" +rankprofile[].fef.property[].name "rankingExpression(freshnessRank).rankingScript" +rankprofile[].fef.property[].value "nativeRank + freshness(createdAt)" +rankprofile[].fef.property[].name "rankingExpression(log10_1p@af9a8c53ba738798).rankingScript" +rankprofile[].fef.property[].value "log10(rankingExpression(aVoteCountToUse) + 1)" +rankprofile[].fef.property[].name "rankingExpression(log10_1p@6ad21b437fe95dd9).rankingScript" +rankprofile[].fef.property[].value "log10(rankingExpression(dvCountToUse) + 1)" +rankprofile[].fef.property[].name "rankingExpression(log10_1p@c05478688f81fe20).rankingScript" +rankprofile[].fef.property[].value "log10(rankingExpression(uniqueRCountToUse) + 1)" +rankprofile[].fef.property[].name "rankingExpression(log10_1p@53f0a2c000e82f4).rankingScript" +rankprofile[].fef.property[].value "log10(rankingExpression(uvCountToUse) + 1)" +rankprofile[].fef.property[].name "rankingExpression(log10_1p@d7da61ad34902e89).rankingScript" +rankprofile[].fef.property[].value "log10(rankingExpression(totalvote) + 1)" +rankprofile[].fef.property[].name "rankingExpression(nn_input).rankingScript" +rankprofile[].fef.property[].value "concat(rankingExpression(log10_1p@af9a8c53ba738798), concat(rankingExpression(log10_1p@6ad21b437fe95dd9), concat(rankingExpression(log10_1p@c05478688f81fe20), concat(rankingExpression(log10_1p@53f0a2c000e82f4), concat(rankingExpression(phat), concat(rankingExpression(log10_1p@d7da61ad34902e89), concat(rankingExpression(hsScoreToUse), concat(rankingExpression(tdToUse), rankingExpression(tScoreToUse), x), x), x), x), x), x), x), x)" +rankprofile[].fef.property[].name "rankingExpression(nn_input).type" +rankprofile[].fef.property[].value "tensor(x[9])" +rankprofile[].fef.property[].name "rankingExpression(get_model_weights).rankingScript" +rankprofile[].fef.property[].value "if (query(field) == 0, constant(field), query(field))" +rankprofile[].fef.property[].name "rankingExpression(get_model_weights@1f2b4afc2c45fbee).rankingScript" +rankprofile[].fef.property[].value "if (query(W_0) == 0, constant(W_0), query(W_0))" +rankprofile[].fef.property[].name "rankingExpression(get_model_weights@e752cecc7900ff3e).rankingScript" +rankprofile[].fef.property[].value "if (query(b_0) == 0, constant(b_0), query(b_0))" +rankprofile[].fef.property[].name "rankingExpression(layer_0).rankingScript" +rankprofile[].fef.property[].value "elu(join(reduce(join(rankingExpression(nn_input), rankingExpression(get_model_weights@1f2b4afc2c45fbee), f(a,b)(a * b)), sum, x), rankingExpression(get_model_weights@e752cecc7900ff3e), f(a,b)(a + b)))" +rankprofile[].fef.property[].name "rankingExpression(layer_0).type" +rankprofile[].fef.property[].value "tensor(hidden[9])" +rankprofile[].fef.property[].name "rankingExpression(get_model_weights@eac265fa16b752cf).rankingScript" +rankprofile[].fef.property[].value "if (query(W_1) == 0, constant(W_1), query(W_1))" +rankprofile[].fef.property[].name "rankingExpression(get_model_weights@b953c19adb7d2154).rankingScript" +rankprofile[].fef.property[].value "if (query(b_1) == 0, constant(b_1), query(b_1))" +rankprofile[].fef.property[].name "rankingExpression(layer_1).rankingScript" +rankprofile[].fef.property[].value "elu(join(reduce(join(rankingExpression(layer_0), rankingExpression(get_model_weights@eac265fa16b752cf), f(a,b)(a * b)), sum, hidden), rankingExpression(get_model_weights@b953c19adb7d2154), f(a,b)(a + b)))" +rankprofile[].fef.property[].name "rankingExpression(layer_1).type" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "rankingExpression(get_model_weights@418462473aa32b7d).rankingScript" +rankprofile[].fef.property[].value "if (query(W_out) == 0, constant(W_out), query(W_out))" +rankprofile[].fef.property[].name "rankingExpression(get_model_weights@23f46853cab72961).rankingScript" +rankprofile[].fef.property[].value "if (query(b_out) == 0, constant(b_out), query(b_out))" +rankprofile[].fef.property[].name "rankingExpression(layer_out).rankingScript" +rankprofile[].fef.property[].value "reduce(join(reduce(join(rankingExpression(layer_1), rankingExpression(get_model_weights@418462473aa32b7d), f(a,b)(a * b)), sum, out), rankingExpression(get_model_weights@23f46853cab72961), f(a,b)(a + b)), sum)" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(freshnessRank)" +rankprofile[].fef.property[].name "vespa.rank.secondphase" +rankprofile[].fef.property[].value "rankingExpression(layer_out)" +rankprofile[].fef.property[].name "vespa.hitcollector.heapsize" +rankprofile[].fef.property[].value "2000" +rankprofile[].fef.property[].name "vespa.type.query.b_out" +rankprofile[].fef.property[].value "tensor(out[1])" +rankprofile[].fef.property[].name "vespa.type.query.W_out" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "vespa.type.query.b_0" +rankprofile[].fef.property[].value "tensor(hidden[9])" +rankprofile[].fef.property[].name "vespa.type.query.b_1" +rankprofile[].fef.property[].value "tensor(out[9])" +rankprofile[].fef.property[].name "vespa.type.query.W_1" +rankprofile[].fef.property[].value "tensor(hidden[9],out[9])" +rankprofile[].fef.property[].name "vespa.type.query.W_0" +rankprofile[].fef.property[].value "tensor(hidden[9],x[9])" diff --git a/config-model/src/test/integration/vespa/models/example.model b/config-model/src/test/integration/vespa/models/example.model index 9579be4e44c..e9725d14923 100644 --- a/config-model/src/test/integration/vespa/models/example.model +++ b/config-model/src/test/integration/vespa/models/example.model @@ -19,7 +19,7 @@ model example { } function foo2() { - expression: max(sum(input1 * input2, name) * constant1asLarge, x) * constant2 + expression: max(sum(input1 * input2, name) * constant(constant1asLarge), x) * constant2 } }
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java index 0ff8a5cc7ca..9a0dcc7dd07 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionLoopDetectionTestCase.java @@ -40,7 +40,7 @@ public class RankingExpressionLoopDetectionTestCase { fail("Excepted exception"); } catch (IllegalArgumentException e) { - assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: foo -> foo", + assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: foo is invalid: Invocation loop: foo -> foo", Exceptions.toMessageString(e)); } } @@ -75,7 +75,7 @@ public class RankingExpressionLoopDetectionTestCase { fail("Excepted exception"); } catch (IllegalArgumentException e) { - assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: arg(5) -> foo -> arg(5)", + assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: arg(5) is invalid: foo is invalid: arg(5) is invalid: Invocation loop: arg(5) -> foo -> arg(5)", Exceptions.toMessageString(e)); } } @@ -110,7 +110,7 @@ public class RankingExpressionLoopDetectionTestCase { fail("Excepted exception"); } catch (IllegalArgumentException e) { - assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: Invocation loop: arg(foo) -> foo -> arg(foo)", + assertEquals("In search definition 'test', rank profile 'test': The function 'foo' is invalid: arg(foo) is invalid: a1 is invalid: foo is invalid: arg(foo) is invalid: Invocation loop: arg(foo) -> foo -> arg(foo)", Exceptions.toMessageString(e)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java index ef99ec28686..7fbca88cb61 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/AbstractExportingTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.searchdefinition.derived; import com.yahoo.document.DocumenttypesConfig; import com.yahoo.document.config.DocumentmanagerConfig; +import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.SearchBuilder; import com.yahoo.searchdefinition.SearchDefinitionTestCase; @@ -29,11 +30,10 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase deleteContent(toDir); SearchBuilder builder = SearchBuilder.createFromDirectory(searchDefRoot + dirName + "/"); - //SearchBuilder builder = SearchBuilder.createFromFile(searchDefDir + name + ".sd"); return derive(dirName, searchDefinitionName, builder); } - protected DerivedConfiguration derive(String dirName, String searchDefinitionName, SearchBuilder builder) throws IOException { + private DerivedConfiguration derive(String dirName, String searchDefinitionName, SearchBuilder builder) throws IOException { DerivedConfiguration config = new DerivedConfiguration(builder.getSearch(searchDefinitionName), builder.getRankProfileRegistry(), builder.getQueryProfileRegistry(), @@ -85,14 +85,14 @@ public abstract class AbstractExportingTestCase extends SearchDefinitionTestCase * Asserts config is correctly derived given a builder. * This will fail if the builder contains multiple search definitions. */ - protected DerivedConfiguration assertCorrectDeriving(SearchBuilder builder, String dirName) throws IOException, ParseException { + protected DerivedConfiguration assertCorrectDeriving(SearchBuilder builder, String dirName) throws IOException { builder.build(); DerivedConfiguration derived = derive(dirName, null, builder); assertCorrectConfigFiles(dirName); return derived; } - protected DerivedConfiguration assertCorrectDeriving(SearchBuilder builder, Search search, String name) throws IOException, ParseException { + protected DerivedConfiguration assertCorrectDeriving(SearchBuilder builder, Search search, String name) throws IOException { DerivedConfiguration derived = derive(name, builder, search); assertCorrectConfigFiles(name); return derived; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/derived/NeuralNetTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/derived/NeuralNetTestCase.java new file mode 100644 index 00000000000..b299c7fa299 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/derived/NeuralNetTestCase.java @@ -0,0 +1,16 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.derived; + +import com.yahoo.searchdefinition.parser.ParseException; +import org.junit.Test; + +import java.io.IOException; + +public class NeuralNetTestCase extends AbstractExportingTestCase { + + @Test + public void testNeuralNet() throws IOException, ParseException { + assertCorrectDeriving("neuralnet"); + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java index 1b917b6f3a3..3b3ce712387 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolverTestCase.java @@ -103,7 +103,9 @@ public class RankingExpressionTypeResolverTestCase { fail("Expected exception"); } catch (IllegalArgumentException expected) { - assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])", + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])" + + "\n'true' branch: attribute(a)" + + "\n'false' branch: attribute(b)", Exceptions.toMessageString(expected)); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java index d4fcd09e249..1a7eb96483e 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionsTestCase.java @@ -23,8 +23,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { public void testFunctions() throws IOException, ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressionfunction", - rankProfileRegistry, - new QueryProfileRegistry()).getSearch(); + rankProfileRegistry).getSearch(); RankProfile functionsRankProfile = rankProfileRegistry.get(search, "macros"); Map<String, RankProfile.RankingExpressionFunction> functions = functionsRankProfile.getFunctions(); assertEquals(2, functions.get("titlematch$").function().arguments().size()); @@ -62,9 +61,7 @@ public class RankingExpressionsTestCase extends SearchDefinitionTestCase { @Test(expected = IllegalArgumentException.class) public void testThatIncludingFileInSubdirFails() throws IOException, ParseException { RankProfileRegistry registry = new RankProfileRegistry(); - Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressioninfile", - registry, - new QueryProfileRegistry()).getSearch(); + Search search = SearchBuilder.createFromDirectory("src/test/examples/rankingexpressioninfile", registry).getSearch(); new DerivedConfiguration(search, registry, new QueryProfileRegistry(), new ImportedMlModels()); // rank profile parsing happens during deriving } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index fe150b51961..15c1d24ce33 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -58,8 +58,8 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { "max(attribute(tensor_field_1),x)"); assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)", "1 + max(attribute(tensor_field_1),x)"); - assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),attribute(tensor_field_1))", - "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),attribute(tensor_field_1))"); + assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),reduce(attribute(tensor_field_1),sum,x))", + "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),reduce(attribute(tensor_field_1), sum, x))"); assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)", "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)", diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java index eb4a0ad6be4..210b4899c58 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/config/QueryProfileXMLReader.java @@ -23,7 +23,7 @@ import java.util.Collections; import java.util.List; /** - * A class which imports query profiles and types from XML files + * Importer of query profiles and types from XML files * * @author bratseth */ diff --git a/documentapi/src/main/java/com/yahoo/documentapi/VisitorParameters.java b/documentapi/src/main/java/com/yahoo/documentapi/VisitorParameters.java index 82e7a87e95b..df0e0f0abdd 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/VisitorParameters.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/VisitorParameters.java @@ -36,7 +36,7 @@ public class VisitorParameters extends Parameters { private String remoteDataHandler = null; private VisitorDataHandler localDataHandler; private VisitorControlHandler controlHandler; - private Map<String, byte []> libraryParameters = new TreeMap<String, byte []>(); + private Map<String, byte []> libraryParameters = new TreeMap<>(); private Route visitRoute = null; private float weight = 1; private long maxFirstPassHits = -1; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java index 25a24792432..c7210e6710a 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java @@ -40,7 +40,7 @@ public class VespaImportTestCase { assertEquals(2, model.expressions().size()); assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2", model.expressions().get("foo1").getRoot().toString()); - assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1asLarge, max, x) * constant2", + assertEquals("reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * constant2", model.expressions().get("foo2").getRoot().toString()); List<ImportedMlFunction> functions = model.outputExpressions(); diff --git a/model-integration/src/test/models/vespa/example.model b/model-integration/src/test/models/vespa/example.model index 6d660732db9..269ed83b695 100644 --- a/model-integration/src/test/models/vespa/example.model +++ b/model-integration/src/test/models/vespa/example.model @@ -19,7 +19,7 @@ model example { } function foo2() { - expression: reduce(sum(input1 * input2, name) * constant1asLarge, max, x) * constant2 + expression: reduce(sum(input1 * input2, name) * constant(constant1asLarge), max, x) * constant2 } }
\ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index c4f3a75f2f8..2aedec2109b 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.transform.TensorMaxMinTransformer; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Join; @@ -67,6 +68,11 @@ public final class FunctionNode extends CompositeNode { @Override public TensorType type(TypeContext<Reference> context) { + // Check if this node should be interpreted as tensor reduce, as this impacts the type + ExpressionNode thisTransformed = TensorMaxMinTransformer.transformFunctionNode(this, context); + if (thisTransformed != this) + return thisTransformed.type(context); + if (arguments.expressions().size() == 0) return TensorType.empty; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 28dc623be72..92c6d6f8638 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -85,7 +85,9 @@ public final class IfNode extends CompositeNode { return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() -> new IllegalArgumentException("An if expression must produce compatible types in both " + "alternatives, but the 'true' type is " + trueType + " while the " + - "'false' type is " + falseType) + "'false' type is " + falseType + + "\n'true' branch: " + trueExpression + + "\n'false' branch: " + falseExpression) ); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index eb8d2229a6d..e15ce158e83 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -95,7 +95,13 @@ public final class ReferenceNode extends CompositeNode { @Override public TensorType type(TypeContext<Reference> context) { - TensorType type = context.getType(reference); + TensorType type = null; + try { + type = context.getType(reference); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException(reference + " is invalid", e); + } if (type == null) throw new IllegalArgumentException("Unknown feature '" + toString() + "'"); return type; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java index 22d314bcb28..31567ba120b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ExpressionTransformer.java @@ -10,7 +10,7 @@ import java.util.List; /** * Superclass of expression transformers. The scope (lifetime) of a transformer instance is a single compilation - * of alle the expressions in one rank profile. + * of all the expressions in one rank profile. * * @author bratseth */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java index 5d03c323803..979c5b0f88c 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -1,54 +1,40 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition.expressiontransforms; +package com.yahoo.searchlib.rankingexpression.transform; -import com.yahoo.searchdefinition.RankProfile; -import com.yahoo.searchdefinition.RankingConstant; -import com.yahoo.searchdefinition.document.Attribute; import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.NameNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; -import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Reduce; -import java.util.List; import java.util.Optional; /** - * Transforms and simplifies tensor expressions. - * - * Currently transforms min(tensor,dim) and max(tensor,dim) to + * Transforms min(tensor,dim) and max(tensor,dim) to * reduce(tensor,min/max,dim). This is necessary as the backend does * not recognize these forms of min and max. * * @author lesters */ -public class TensorTransformer extends ExpressionTransformer<RankProfileTransformContext> { +public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends ExpressionTransformer<CONTEXT> { @Override - public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { + public ExpressionNode transform(ExpressionNode node, CONTEXT context) { if (node instanceof CompositeNode) { node = transformChildren((CompositeNode) node, context); } if (node instanceof FunctionNode) { - node = transformFunctionNode((FunctionNode) node, context); + node = transformFunctionNode((FunctionNode) node, context.types()); } return node; } - private ExpressionNode transformFunctionNode(FunctionNode node, RankProfileTransformContext context) { + public static ExpressionNode transformFunctionNode(FunctionNode node, TypeContext<Reference> context) { switch (node.getFunction()) { case min: case max: @@ -62,14 +48,14 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor * argument returns a tensor type and the second argument is a valid * dimension in the tensor. */ - private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, RankProfileTransformContext context) { + private static ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node, TypeContext<Reference> context) { if (node.children().size() != 2) { return node; } ExpressionNode arg1 = node.children().get(0); Optional<String> dimension = dimensionName(node.children().get(1)); if (dimension.isPresent()) { - TensorType type = arg1.type(context.types()); + TensorType type = arg1.type(context); if (type.dimension(dimension.get()).isPresent()) { return replaceMaxAndMinFunction(node); } @@ -77,7 +63,7 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor return node; } - private Optional<String> dimensionName(ExpressionNode node) { + private static Optional<String> dimensionName(ExpressionNode node) { if (node instanceof ReferenceNode) { Reference reference = ((ReferenceNode)node).reference(); if (reference.isIdentifier()) @@ -93,7 +79,7 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor } } - private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { + private static ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { ExpressionNode arg1 = node.children().get(0); ExpressionNode arg2 = node.children().get(1); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java index 7485ce69f98..0113a650277 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TransformContext.java @@ -1,7 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Map; @@ -13,11 +15,19 @@ import java.util.Map; public class TransformContext { private final Map<String, Value> constants; + private final TypeContext<Reference> types; - public TransformContext(Map<String, Value> constants) { + public TransformContext(Map<String, Value> constants, TypeContext<Reference> types) { this.constants = constants; + this.types = types; } public Map<String, Value> constants() { return constants; } + /** + * Returns the types known in this context. We may have type information for references + * for which no value is available + */ + public TypeContext<Reference> types() { return types; } + } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java index a08d510eec4..88838b5aed0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java @@ -53,7 +53,9 @@ public class TypeResolutionTestCase { } catch (IllegalArgumentException expected) { assertEquals("An if expression must produce compatible types in both alternatives, " + - "but the 'true' type is tensor(x[]) while the 'false' type is tensor(y[])", + "but the 'true' type is tensor(x[]) while the 'false' type is tensor(y[])" + + "\n'true' branch: query(x1)" + + "\n'false' branch: query(y1)", expected.getMessage()); } catch (ParseException e) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java index 1f28f0b0129..a41fb02f784 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.transform; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import org.junit.Test; @@ -24,7 +25,7 @@ public class ConstantDereferencerTestCase { constants.put("a", Value.parse("1.0")); constants.put("b", Value.parse("2")); constants.put("c", Value.parse("3.5")); - TransformContext context = new TransformContext(constants); + TransformContext context = new TransformContext(constants, new MapTypeContext()); assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString()); assertEquals("myFunction(1.0,2.0)", c.transform(new RankingExpression("myFunction(a, b)"), context).toString()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java index 8fac3395ac0..f4b1b0ceee2 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java @@ -1,9 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; +import com.yahoo.log.event.Collection; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.MapTypeContext; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import org.junit.Test; @@ -20,7 +22,7 @@ public class SimplifierTestCase { @Test public void testSimplify() throws ParseException { Simplifier s = new Simplifier(); - TransformContext c = new TransformContext(Collections.emptyMap()); + TransformContext c = new TransformContext(Collections.emptyMap(), new MapTypeContext()); assertEquals("a + b", s.transform(new RankingExpression("a + b"), c).toString()); assertEquals("6.5", s.transform(new RankingExpression("1.0 + 2.0 + 3.5"), c).toString()); assertEquals("6.5", s.transform(new RankingExpression("1.0 + ( 2.0 + 3.5 )"), c).toString()); @@ -45,7 +47,7 @@ public class SimplifierTestCase { @Test public void testSimplifyComplexExpression() throws ParseException { RankingExpression initial = new RankingExpression("sqrt(if (if (INFERRED * 0.9 < INFERRED, GMP, (1 + 1.1) * INFERRED) < INFERRED * INFERRED - INFERRED, if (GMP < 85.80799542793133 * GMP, INFERRED, if (GMP < GMP, tanh(INFERRED), log(76.89956221113943))), tanh(tanh(INFERRED))) * sqrt(sqrt(GMP + INFERRED)) * GMP ) + 13.5 * (1 - GMP) * pow(GMP * 0.1, 2 + 1.1 * 0)"); - TransformContext c = new TransformContext(Collections.emptyMap()); + TransformContext c = new TransformContext(Collections.emptyMap(), new MapTypeContext()); RankingExpression simplified = new Simplifier().transform(initial, c); Context context = new MapContext(); @@ -70,7 +72,7 @@ public class SimplifierTestCase { @Test public void testParenthesisPreservation() throws ParseException { Simplifier s = new Simplifier(); - TransformContext c = new TransformContext(Collections.emptyMap()); + TransformContext c = new TransformContext(Collections.emptyMap(), new MapTypeContext()); CompositeNode transformed = (CompositeNode)s.transform(new RankingExpression("a + (b + c) / 100000000.0"), c).getRoot(); assertEquals("a + (b + c) / 100000000.0", transformed.toString()); } |