diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-17 13:05:27 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-17 13:05:27 +0100 |
commit | c84b8f952ef5857aa44fad479551eda1f3a4e106 (patch) | |
tree | e7bf28337efaa9bc02e7c13c2cd14777a46135b1 /config-model/src/main/java/com/yahoo | |
parent | 66b3a3ca7c14097f9a277431c19c169e3681a4de (diff) |
Persist constant info in ZooKeeper
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 64 |
1 files changed, 47 insertions, 17 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 0324b9852df..0dd5b4166ef 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -2,13 +2,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; +import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; @@ -20,13 +20,16 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.File; import java.io.IOException; import java.io.StringReader; import java.io.UncheckedIOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -37,13 +40,10 @@ import java.util.Optional; * * @author bratseth */ -// TODO: - Verify types of macros -// - Avoid name conflicts across models for constants +// TODO: Verify types of macros +// TODO: Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - // TODO: Make system test work with this set to true, then remove the "true" path - private static final boolean constantsInConfig = true; - private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ @@ -68,14 +68,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if (store.hasTensorFlowModels()) return transformFromTensorFlowModel(store, context.rankProfile()); else // is should have previously stored model information instead - return store.readConverted().getRoot(); + return transformFromStoredModel(store, context.rankProfile()); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); } } - private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile rankProfile) { + private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile profile) { TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), k -> tensorFlowImporter.importModel(store.tensorFlowModelDir())); @@ -85,15 +85,18 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil RankingExpression expression = model.expressions().get(output); store.writeConverted(expression); - // Add all constants (after finding outputs to fail faster when the output is not found) TODO: Remove the first path - if (constantsInConfig) - model.constants().forEach((k, v) -> rankProfile.addConstantTensor(k, new TensorValue(v))); - else // correct way, disabled for now - model.constants().forEach((k, v) -> transformConstant(store, rankProfile, k, v)); - + model.constants().forEach((k, v) -> transformConstant(store, profile, k, v)); return expression.getRoot(); } + private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (RankingConstant constant : store.readRankingConstants()) { + if (!profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); + } + return store.readConverted().getRoot(); + } + /** * Returns the specified, existing signature, or the only signature if none is specified. * Throws IllegalArgumentException in all other cases. @@ -216,6 +219,24 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } /** + * Reads the information about all the constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + public List<RankingConstant> readRankingConstants() { + try { + List<RankingConstant> constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.rankingConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** * Adds this constant to the application package as a file, * such that it can be distributed using file distribution. * @@ -223,11 +244,16 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil */ public Path writeConstant(String name, Tensor constant) { Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); - createIfNeeded(constantsPath); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); - // Write explicitly as a file on the file system as this is distributed using file distribution + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.rankingConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPath)); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); return constantPath; } @@ -267,8 +293,12 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil public Optional<String> signature() { return signature; } public Optional<String> output() { return output; } + public Path rankingConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_DIR + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR .append(modelPath).append("expressions").append(expressionFileName()); } |