diff options
Diffstat (limited to 'config-model')
6 files changed, 169 insertions, 94 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index bacff94d776..3b68e0199a9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -2,15 +2,26 @@ package com.yahoo.searchdefinition; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.io.reader.NamedReader; +import com.yahoo.processing.request.CompoundName; +import com.yahoo.search.query.profile.QueryProfile; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.config.QueryProfileXMLReader; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.TensorFieldType; import com.yahoo.search.query.ranking.Diversity; +import com.yahoo.searchdefinition.document.SDField; import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.FeatureList; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TypeMapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.io.File; import java.io.IOException; @@ -276,12 +287,10 @@ public class RankProfile implements Serializable, Cloneable { addConstant(name, value); } - /** - * Returns an unmodifiable view of the constants to use in this. - */ + /** Returns an unmodifiable view of the constants available in this */ public Map<String, Value> getConstants() { if (constants.isEmpty()) - return getInherited() != null ? getInherited().getConstants() : Collections.<String,Value>emptyMap(); + return getInherited() != null ? getInherited().getConstants() : Collections.emptyMap(); if (getInherited() == null || getInherited().getConstants().isEmpty()) return Collections.unmodifiableMap(constants); @@ -433,8 +442,9 @@ public class RankProfile implements Serializable, Cloneable { properties.add(rankProperty); } + @Override public String toString() { - return "rank profile " + getName(); + return "rank profile '" + getName() + "'"; } public int getRerankCount() { @@ -508,6 +518,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns the string form of the first phase ranking expression. + * * @return string form of first phase ranking expression */ public String getFirstPhaseRankingString() { @@ -727,6 +738,64 @@ public class RankProfile implements Serializable, Cloneable { } /** + * Creates a context containing the type information of all constants, attributes and query profiles + * referable from this rank profile. + */ + public TypeContext typeContext() { + TypeMapContext context = new TypeMapContext(); + + // Add constants + getConstants().forEach((k, v) -> context.setType(asConstantFeature(k), v.type())); + + // Add attributes + for (SDField field : getSearch().allConcreteFields()) + field.getAttributes().forEach((k, a) -> context.setType(asAttributeFeature(k), a.tensorType().orElse(TensorType.empty))); + + // Add query features from rank profile types reached from the "default" profile + QueryProfile profile = queryProfilesOf(getSearch().sourceApplication()).getComponent("default"); + if (profile != null && profile.getType() != null) { + profile.listTypes(CompoundName.empty, Collections.emptyMap()).forEach((prefix, queryProfileType) -> { + for (FieldDescription field : queryProfileType.declaredFields().values()) { + TensorType type = TensorType.empty; // assume the empty (aka double) type by default + if (field.getType() instanceof TensorFieldType) + type = ((TensorFieldType)field.getType()).type().get(); + + String feature = asQueryFeature(prefix.append(field.getName()).toString()); + context.setType(feature, type); + } + }); + } + + return context; + } + + private QueryProfileRegistry queryProfilesOf(ApplicationPackage applicationPackage) { + List<NamedReader> queryProfileFiles = null; + List<NamedReader> queryProfileTypeFiles = null; + try { + queryProfileFiles = applicationPackage.getQueryProfileFiles(); + queryProfileTypeFiles = applicationPackage.getQueryProfileTypeFiles(); + return new QueryProfileXMLReader().read(queryProfileFiles, queryProfileTypeFiles); + } + finally { + NamedReader.closeAll(queryProfileFiles); + NamedReader.closeAll(queryProfileTypeFiles); + } + } + + private String asConstantFeature(String constantName) { + return "constant(\"" + constantName + "\")"; + } + + private String asAttributeFeature(String constantName) { + return "attribute(\"" + constantName + "\")"; + } + + private String asQueryFeature(String constantName) { + return "query(\"" + constantName + "\")"; + } + + /** * A rank setting. The identity of a rank setting is its field name and type (not value). * A rank setting is immutable. */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index df5697de0d5..f4a0365e36e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -247,7 +247,7 @@ public class Search implements Serializable, ImmutableSearch { /** * Returns a list of all the fields of this search definition, that is all fields in all documents, in the documents * they inherit, and all extra fields. The caller receives ownership to the list - subsequent changes to it will not - * impact this Search + * impact this * * @return the list of fields in this searchdefinition */ @@ -546,7 +546,7 @@ public class Search implements Serializable, ImmutableSearch { } /** - * Returns the first occurence of an attribute having this name, or null if none + * Returns the first occurrence of an attribute having this name, or null if none * * @param name Name of attribute * @return The Attribute with given name. diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java index ce3ff7cc447..72ba6de7022 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/AttributeFields.java @@ -31,24 +31,19 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce private Map<String, Attribute> attributes = new java.util.LinkedHashMap<>(); private Map<String, Attribute> importedAttributes = new java.util.LinkedHashMap<>(); - /** - * Flag indicating if a position-attribute has been found - */ + /** Whether this has any position attribute */ private boolean hasPosition = false; public AttributeFields(Search search) { derive(search); } - /** - * Derives everything from a field - */ + /** Derives everything from a field */ @Override protected void derive(ImmutableSDField field, Search search) { if (field.usesStructOrMap() && !field.getDataType().equals(PositionDataType.INSTANCE) && - !field.getDataType().equals(DataType.getArray(PositionDataType.INSTANCE))) - { + !field.getDataType().equals(DataType.getArray(PositionDataType.INSTANCE))) { return; // Ignore struct fields for indexed search (only implemented for streaming search) } if (field.isImportedField()) { @@ -58,9 +53,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } - /** - * Return an attribute by name, or null if it doesn't exist - */ + /** Returns an attribute by name, or null if it doesn't exist */ public Attribute getAttribute(String attributeName) { return attributes.get(attributeName); } @@ -69,9 +62,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce return getAttribute(attributeName) != null; } - /** - * Derives one attribute. TODO: Support non-default named attributes - */ + /** Derives one attribute. TODO: Support non-default named attributes */ private void deriveAttributes(ImmutableSDField field) { for (Attribute fieldAttribute : field.getAttributes().values()) { deriveAttribute(field, fieldAttribute); @@ -107,9 +98,7 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } - /** - * Returns a read only attribute iterator - */ + /** Returns a read only attribute iterator */ public Iterator attributeIterator() { return attributes().iterator(); } @@ -201,4 +190,5 @@ public class AttributeFields extends Derived implements AttributesConfig.Produce } } } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 01d3449573c..d0f705d4c9a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -41,7 +41,6 @@ import java.util.Optional; * * @author bratseth */ -// TODO: Verify types of macros // TODO: Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { @@ -84,6 +83,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); + verifyRequiredMacros(expression, model.requiredMacros(), profile); store.writeConverted(expression); model.constants().forEach((k, v) -> transformConstant(store, profile, k, v)); @@ -168,6 +168,44 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } /** + * Verify that the macros referred in the given expression exists in the given rank profile, + * and return tensors of the types specified in requiredMacros. + */ + private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros, + RankProfile profile) { + List<String> macroNames = new ArrayList<>(); + addMacroNamesIn(expression.getRoot(), macroNames); + for (String macroName : macroNames) { + TensorType requiredType = requiredMacros.get(macroName); + if (requiredType == null) continue; // Not a required macro + + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + " but this macro is not present in " + + profile); + TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext()); + if ( ! actualType.isAssignableTo(requiredType)) + throw new IllegalArgumentException("Model refers Placeholder '" + macroName + + "' of type " + requiredType + + " which must be produced by a macro in the rank profile, but " + + "this macro produces type " + actualType + " in " + profile); + } + } + + private void addMacroNamesIn(ExpressionNode node, List<String> names) { + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode)node; + if (referenceNode.getOutput() == null) // macro references cannot specify outputs + names.add(referenceNode.getName()); + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) + addMacroNamesIn(child, names); + } + } + + /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ private static class ModelStore { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java index d08092003f1..b85cb88bf2e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/QueryProfilesBuilder.java @@ -5,6 +5,7 @@ import com.yahoo.io.reader.NamedReader; import com.yahoo.search.query.profile.config.QueryProfileXMLReader; import com.yahoo.config.application.api.ApplicationPackage; +import java.util.Collections; import java.util.List; /** @@ -13,17 +14,16 @@ import java.util.List; * * @author bratseth */ -// TODO: Move into QueryProfiles public class QueryProfilesBuilder { /** Build the set of query profiles for an application package */ public QueryProfiles build(ApplicationPackage applicationPackage) { - List<NamedReader> queryProfileTypeFiles=null; - List<NamedReader> queryProfileFiles=null; + List<NamedReader> queryProfileTypeFiles = null; + List<NamedReader> queryProfileFiles = null; try { - queryProfileTypeFiles=applicationPackage.getQueryProfileTypeFiles(); - queryProfileFiles=applicationPackage.getQueryProfileFiles(); - return new QueryProfiles(new QueryProfileXMLReader().read(queryProfileTypeFiles,queryProfileFiles)); + queryProfileTypeFiles = applicationPackage.getQueryProfileTypeFiles(); + queryProfileFiles = applicationPackage.getQueryProfileFiles(); + return new QueryProfiles(new QueryProfileXMLReader().read(queryProfileTypeFiles, queryProfileFiles)); } finally { NamedReader.closeAll(queryProfileTypeFiles); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 89f1a9f785c..c24e886c83d 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -49,14 +49,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMinimalTensorFlowReference() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -64,14 +58,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testNestedTensorFlowReference() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: 5 + sum(tensorflow('mnist_softmax/saved'))" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -79,41 +67,23 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'y')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default', 'y')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { try { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_defaultz')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_defaultz')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -128,14 +98,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { try { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default', 'x')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved', 'serving_default', 'x')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); fail("Expecting exception"); } @@ -149,14 +113,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressions() throws ParseException, IOException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = new RankProfileSearchFixture( - application, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); assertConstant("Variable_1", search, Optional.of(10L)); assertConstant("Variable", search, Optional.of(7840L)); @@ -168,13 +126,9 @@ public class RankingExpressionWithTensorFlowTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = new RankProfileSearchFixture( - storedApplication, - " rank-profile my_profile {\n" + - " first-phase {\n" + - " expression: tensorflow('mnist_softmax/saved', 'serving_default')" + - " }\n" + - " }"); + RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')", + storedApplication); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test @@ -212,6 +166,30 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression) { + return fixtureWith(placeholderExpression, firstPhaseExpression, new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, + String firstPhaseExpression, + StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture( + application, + " rank-profile my_profile {\n" + + " macro Placeholder() {\n" + + " expression: " + placeholderExpression + + " }\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }"); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + private static class StoringApplicationPackage extends MockApplicationPackage { private final File root; |