diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-09 16:07:43 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-09 16:07:43 +0100 |
commit | dc0f70fac9167acf487453daf565636c675934df (patch) | |
tree | aaccfae7aaf4a48e35655a66c75ea57412ede6a6 /config-model/src/main/java/com | |
parent | fa9fe82c82d6a562e3ae02b9577f536a16c72c92 (diff) |
Basic TensorFlow integration
This wil replace any occurrence of tensorflow(...)
in ranking expressions with the corresponding translated expression.
It is functional but these tings are outstanding
- Propagate warnings
- Import a model just once even if referred multiple times
- Add constants as tensor files rather than config
Diffstat (limited to 'config-model/src/main/java/com')
13 files changed, 189 insertions, 27 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplier.java b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplier.java index cdd089d9bf7..f9d71f03972 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplier.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplier.java @@ -6,9 +6,10 @@ import com.yahoo.searchdefinition.document.SDDocumentType; import com.yahoo.searchdefinition.document.SDField; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class FieldOperationApplier { + public void process(SDDocumentType sdoc) { if (!sdoc.isStruct()) { apply(sdoc); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForSearch.java b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForSearch.java index addaa4bc632..1019b794cdd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForSearch.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForSearch.java @@ -5,7 +5,7 @@ import com.yahoo.document.Field; import com.yahoo.searchdefinition.document.SDDocumentType; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class FieldOperationApplierForSearch extends FieldOperationApplier { @Override diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java index 30fdcd01dd4..04b0fc6e331 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FieldOperationApplierForStructs.java @@ -11,9 +11,10 @@ import java.util.Iterator; import java.util.List; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class FieldOperationApplierForStructs extends FieldOperationApplier { + @Override public void process(SDDocumentType sdoc) { for (SDDocumentType type : sdoc.getAllTypes()) { @@ -45,4 +46,5 @@ public class FieldOperationApplierForStructs extends FieldOperationApplier { } } } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 1021227b0e6..cf92d1f979b 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -1,7 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.search.query.ranking.Diversity; +import com.yahoo.searchdefinition.expressiontransforms.ExpressionTransforms; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.FeatureList; @@ -9,13 +11,22 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.SetMembershipNode; -import com.yahoo.searchlib.rankingexpression.transform.ConstantDereferencer; -import com.yahoo.searchlib.rankingexpression.transform.Simplifier; -import com.yahoo.config.application.api.ApplicationPackage; -import java.io.*; -import java.util.*; +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.io.Serializable; +import java.io.StringReader; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; /** * Represents a rank profile - a named set of ranking settings @@ -40,7 +51,7 @@ public class RankProfile implements Serializable, Cloneable { protected Set<RankSetting> rankSettings = new java.util.LinkedHashSet<>(); /** The ranking expression to be used for first phase */ - private RankingExpression firstPhaseRanking= null; + private RankingExpression firstPhaseRanking= null; /** The ranking expression to be used for second phase */ private RankingExpression secondPhaseRanking = null; @@ -485,7 +496,7 @@ public class RankProfile implements Serializable, Cloneable { /** * Returns the string form of the second phase ranking expression. - * + * * @return string form of second phase ranking expression */ public String getSecondPhaseRankingString() { @@ -702,12 +713,7 @@ public class RankProfile implements Serializable, Cloneable { Map<String, Macro> inlineMacros) { if (expression == null) return null; Map<String, String> rankPropertiesOutput = new HashMap<>(); - expression = new ConstantDereferencer(constants).transform(expression); - expression = new ConstantTensorTransformer(constants, rankPropertiesOutput).transform(expression); - expression = new MacroInliner(inlineMacros).transform(expression); - expression = new MacroShadower(getMacros()).transform(expression); - expression = new TensorTransformer(this).transform(expression); - expression = new Simplifier().transform(expression); + expression = new ExpressionTransforms().transform(expression, this, constants, inlineMacros, rankPropertiesOutput); for (Map.Entry<String, String> rankProperty : rankPropertiesOutput.entrySet()) { addRankProperty(rankProperty.getKey(), rankProperty.getValue()); } @@ -975,7 +981,7 @@ public class RankProfile implements Serializable, Cloneable { throw new IllegalArgumentException("match-phase did not set max-hits > 0"); } } - + } public static class TypeSettings { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/UnprocessingSearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/UnprocessingSearchBuilder.java index 1b292007ef3..b448005c6a5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/UnprocessingSearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/UnprocessingSearchBuilder.java @@ -10,7 +10,6 @@ import java.io.IOException; /** * A SearchBuilder that does not run the processing chain for searches - * */ public class UnprocessingSearchBuilder extends SearchBuilder { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/UnrankedRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/UnrankedRankProfile.java index d8f7e56539e..b58f696cbdf 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/UnrankedRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/UnrankedRankProfile.java @@ -25,4 +25,5 @@ public class UnrankedRankProfile extends RankProfile { this.setKeepRankCount(0); this.setRerankCount(0); } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java index c75864f81b7..e061ead465e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/ConstantTensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ConstantTensorTransformer.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; +package com.yahoo.searchdefinition.expressiontransforms; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -19,7 +19,7 @@ import java.util.Map; * * @author geirst */ -class ConstantTensorTransformer extends ExpressionTransformer { +public class ConstantTensorTransformer extends ExpressionTransformer { public static final String CONSTANT = "constant"; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java new file mode 100644 index 00000000000..ee5cccccb29 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/ExpressionTransforms.java @@ -0,0 +1,34 @@ +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.transform.ConstantDereferencer; +import com.yahoo.searchlib.rankingexpression.transform.Simplifier; + +import java.util.Map; + +/** + * The transformations done on ranking expressions done at config time before passing them on to the Vespa + * engine for execution. + * + * @author bratseth + */ +public class ExpressionTransforms { + + public RankingExpression transform(RankingExpression expression, + RankProfile rankProfile, + Map<String, Value> constants, + Map<String, RankProfile.Macro> inlineMacros, + Map<String, String> rankPropertiesOutput) { + expression = new TensorFlowFeatureConverter(rankProfile).transform(expression); + expression = new ConstantDereferencer(constants).transform(expression); + expression = new ConstantTensorTransformer(constants, rankPropertiesOutput).transform(expression); + expression = new MacroInliner(inlineMacros).transform(expression); + expression = new MacroShadower(rankProfile.getMacros()).transform(expression); + expression = new TensorTransformer(rankProfile).transform(expression); + expression = new Simplifier().transform(expression); + return expression; + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MacroInliner.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java index 4702fac30a8..a3933e6f8e2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MacroInliner.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroInliner.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; +package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; @@ -13,7 +14,7 @@ import java.util.Map; * * @author bratseth */ -class MacroInliner extends ExpressionTransformer { +public class MacroInliner extends ExpressionTransformer { private final Map<String, RankProfile.Macro> macros; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MacroShadower.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java index edf0ce69819..1d9769d0d78 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MacroShadower.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/MacroShadower.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; +package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.*; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; @@ -20,7 +21,7 @@ import java.util.Map; * * @author lesters */ -class MacroShadower extends ExpressionTransformer { +public class MacroShadower extends ExpressionTransformer { private final Map<String, RankProfile.Macro> macros; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java new file mode 100644 index 00000000000..e5886030d44 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -0,0 +1,114 @@ +package com.yahoo.searchdefinition.expressiontransforms; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + +import java.util.Map; +import java.util.Optional; + +/** + * Replaces instances of the tensorflow(model-path, signature, output) + * pseudofeature with the native Vespa ranking expression implementing + * the same computation. + * + * @author bratseth + */ +public class TensorFlowFeatureConverter extends ExpressionTransformer { + + private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); + private final RankProfile profile; + + public TensorFlowFeatureConverter(RankProfile profile) { + this.profile = profile; + } + + @Override + public ExpressionNode transform(ExpressionNode node) { + if (node instanceof ReferenceNode) + return transformFeature((ReferenceNode) node); + else if (node instanceof CompositeNode) + return super.transformChildren((CompositeNode) node); + else + return node; + } + + private ExpressionNode transformFeature(ReferenceNode feature) { + try { + if ( ! feature.getName().equals("tensorflow")) return feature; + + if (feature.getArguments().isEmpty()) + throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + + "the tensorflow model directory under [application]/models"); + + // Find the specified expression + ImportResult result = tensorFlowImporter.importModel(asString(feature.getArguments().expressions().get(0))); + ImportResult.Signature signature = chooseOrDefault("signatures", result.signatures(), + optionalArgument(1, feature.getArguments())); + String output = chooseOrDefault("outputs", signature.outputs(), + optionalArgument(2, feature.getArguments())); + + // Add all constants + result.constants().forEach((k, v) -> profile.addConstantTensor(k, new TensorValue(v))); + + return result.expressions().get(output).getRoot(); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import tensorflow model from " + feature, e); + } + } + + /** + * Returns the specified, existing map value, or the only map value if no key is specified. + * Throws IllegalArgumentException in all other cases. + */ + private <T> T chooseOrDefault(String valueDescription, Map<String, T> map, Optional<String> key) { + if ( ! key.isPresent()) { + if (map.size() == 0) + throw new IllegalArgumentException("No " + valueDescription + " are present"); + if (map.size() > 1) + throw new IllegalArgumentException("Model has multiple " + valueDescription + ", but no " + + valueDescription + " argument is specified"); + return map.values().stream().findFirst().get(); + } + else { + T value = map.get(key.get()); + if (value == null) + throw new IllegalArgumentException("Model does not have the specified " + + valueDescription + " '" + key.get() + "'"); + return value; + } + } + + private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + return Optional.empty(); + return Optional.of(asString(arguments.expressions().get(argumentIndex))); + } + + private String asString(ExpressionNode node) { + if ( ! (node instanceof ConstantNode)) + throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node); + return stripQuotes(((ConstantNode)node).sourceString()); + } + + private String stripQuotes(String s) { + if ( ! isQuoteSign(s.codePointAt(0))) return s; + if ( ! isQuoteSign(s.codePointAt(s.length() - 1 ))) + throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote"); + return s.substring(1, s.length()-1); + } + + private boolean isQuoteSign(int c) { + return c == '\'' || c == '"'; + } + + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java index 65176006a2a..70a7372dbe9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java @@ -1,6 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; +package com.yahoo.searchdefinition.expressiontransforms; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.Search; import com.yahoo.searchdefinition.document.Attribute; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/PredicateProcessor.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/PredicateProcessor.java index 450c24d8e3e..4b9b090cdc5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/PredicateProcessor.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/PredicateProcessor.java @@ -24,7 +24,7 @@ import java.util.List; /** * Validates the predicate fields. * - * @author <a href="mailto:lesters@yahoo-inc.com">Lester Solbakken</a> + * @author Lester Solbakken */ public class PredicateProcessor extends Processor { |