diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-02-05 22:54:13 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-02-05 22:54:13 +0100 |
commit | 30a2d3e88529bc5a86ad6c53c8de35e4a71fbac3 (patch) | |
tree | 4fc17d3e36f507efea78adc856228eec5f144019 /config-model/src | |
parent | 62de95451cf663f3f43532d2c4746eaa1b678d95 (diff) |
Handle small constants
Diffstat (limited to 'config-model/src')
4 files changed, 99 insertions, 19 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 2f28d9adb8b..bcbc7cc99e2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -750,7 +750,9 @@ public class RankProfile implements Serializable, Cloneable { public TypeContext typeContext(QueryProfileRegistry queryProfiles) { TypeMapContext context = new TypeMapContext(); - // Add constants + // Add small constants + getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type())); + // Add large constants getSearch().getRankingConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType())); // Add attributes diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 495ca7dd14a..2b997aa25f2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -2,6 +2,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; +import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.application.provider.FilesApplicationPackage; @@ -11,6 +12,9 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; @@ -25,11 +29,13 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.StringReader; import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -92,16 +98,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles); store.writeConverted(expression); - model.constants().forEach((k, v) -> transformConstant(store, profile, k, v)); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v)); return expression.getRoot(); } private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { - for (RankingConstant constant : store.readRankingConstants()) { + for (Pair<String, Tensor> constant : store.readSmallConstants()) + profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + + for (RankingConstant constant : store.readLargeConstants()) { if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) profile.getSearch().addRankingConstant(constant); } + return store.readConverted().getRoot(); } @@ -158,8 +169,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private void transformConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - Path constantPath = store.writeConstant(constantName, constantValue); + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + store.writeSmallConstant(constantName, constantValue); + profile.addConstant(constantName, asValue(constantValue)); + } + + private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + Path constantPath = store.writeLargeConstant(constantName, constantValue); if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); @@ -218,6 +234,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } + private Value asValue(Tensor tensor) { + if (tensor.type().rank() == 0) + return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors + else + return new TensorValue(tensor); + } + /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ @@ -272,13 +295,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } /** - * Reads the information about all the constants stored in the application package + * Reads the information about all the large (aka ranking) constants stored in the application package * (the constant value itself is replicated with file distribution). */ - public List<RankingConstant> readRankingConstants() { + public List<RankingConstant> readLargeConstants() { try { List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.rankingConstantsPath()).listFiles()) { + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); } @@ -295,25 +318,63 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil * * @return the path to the stored constant, relative to the application package root */ - public Path writeConstant(String name, Tensor constant) { + public Path writeLargeConstant(String name, Tensor constant) { Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); - Path constantPathCorrected = constantPath; - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) - && ! constantPath.elements().contains(FilesApplicationPackage.preprocessed)) { - constantPathCorrected = Path.fromString(FilesApplicationPackage.preprocessed).append(constantPath); - } // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.rankingConstantsPath().append(name + ".constant")) - .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPathCorrected)); + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution createIfNeeded(constantsPath); IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); - return constantPathCorrected; + return correct(constantPath); + } + + private List<Pair<String, Tensor>> readSmallConstants() { + try { + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, Tensor>> constants = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + TensorType type = TensorType.fromSpec(parts[1]); + Tensor tensor = Tensor.from(type, parts[2]); + constants.add(new Pair<>(name, tensor)); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Append this constant to the single file used for small constants distributed as config + */ + public void writeSmallConstant(String name, Tensor constant) { + // Secret file format for remembering constants: + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + constant.type().toString() + "\t" + + constant.toString() + "\n"); + } + + /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ + private Path correct(Path path) { + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) + && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { + return Path.fromString(FilesApplicationPackage.preprocessed).append(path); + } + else { + return path; + } } private void createIfNeeded(Path path) { @@ -351,7 +412,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil public Optional<String> signature() { return signature; } public Optional<String> output() { return output; } - public Path rankingConstantsPath() { + /** Path to the small constants file */ + public Path smallConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsPath() { return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java index 5255cdaeba1..0334012e8d9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java @@ -183,7 +183,7 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor } private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) { - if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) { + if ( ! node.getName().equals(ConstantTensorTransformer.CONSTANT)) { return; } if (node.children().size() != 1) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 5203e686681..7246b22b0f8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -399,6 +399,17 @@ public class RankingExpressionWithTensorFlowTestCase { } @Override + public ApplicationFile appendFile(String value) { + try { + IOUtils.writeFile(file, value, true); + return this; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override public List<ApplicationFile> listFiles(PathFilter filter) { if ( ! isDirectory()) return Collections.emptyList(); return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString()))) |