summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-17 13:05:27 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-17 13:05:27 +0100
commitc84b8f952ef5857aa44fad479551eda1f3a4e106 (patch)
treee7bf28337efaa9bc02e7c13c2cd14777a46135b1 /config-model/src/main/java/com/yahoo
parent66b3a3ca7c14097f9a277431c19c169e3681a4de (diff)
Persist constant info in ZooKeeper
Diffstat (limited to 'config-model/src/main/java/com/yahoo')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java64
1 files changed, 47 insertions, 17 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 0324b9852df..0dd5b4166ef 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -2,13 +2,13 @@
package com.yahoo.searchdefinition.expressiontransforms;
import com.google.common.base.Joiner;
+import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature;
@@ -20,13 +20,16 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.File;
import java.io.IOException;
import java.io.StringReader;
import java.io.UncheckedIOException;
+import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -37,13 +40,10 @@ import java.util.Optional;
*
* @author bratseth
*/
-// TODO: - Verify types of macros
-// - Avoid name conflicts across models for constants
+// TODO: Verify types of macros
+// TODO: Avoid name conflicts across models for constants
public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> {
- // TODO: Make system test work with this set to true, then remove the "true" path
- private static final boolean constantsInConfig = true;
-
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
@@ -68,14 +68,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
if (store.hasTensorFlowModels())
return transformFromTensorFlowModel(store, context.rankProfile());
else // is should have previously stored model information instead
- return store.readConverted().getRoot();
+ return transformFromStoredModel(store, context.rankProfile());
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
}
}
- private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile rankProfile) {
+ private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile profile) {
TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
k -> tensorFlowImporter.importModel(store.tensorFlowModelDir()));
@@ -85,15 +85,18 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
RankingExpression expression = model.expressions().get(output);
store.writeConverted(expression);
- // Add all constants (after finding outputs to fail faster when the output is not found) TODO: Remove the first path
- if (constantsInConfig)
- model.constants().forEach((k, v) -> rankProfile.addConstantTensor(k, new TensorValue(v)));
- else // correct way, disabled for now
- model.constants().forEach((k, v) -> transformConstant(store, rankProfile, k, v));
-
+ model.constants().forEach((k, v) -> transformConstant(store, profile, k, v));
return expression.getRoot();
}
+ private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) {
+ for (RankingConstant constant : store.readRankingConstants()) {
+ if (!profile.getSearch().getRankingConstants().containsKey(constant.getName()))
+ profile.getSearch().addRankingConstant(constant);
+ }
+ return store.readConverted().getRoot();
+ }
+
/**
* Returns the specified, existing signature, or the only signature if none is specified.
* Throws IllegalArgumentException in all other cases.
@@ -216,6 +219,24 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
/**
+ * Reads the information about all the constants stored in the application package
+ * (the constant value itself is replicated with file distribution).
+ */
+ public List<RankingConstant> readRankingConstants() {
+ try {
+ List<RankingConstant> constants = new ArrayList<>();
+ for (ApplicationFile constantFile : application.getFile(arguments.rankingConstantsPath()).listFiles()) {
+ String[] parts = IOUtils.readAll(constantFile.createReader()).split(":");
+ constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2]));
+ }
+ return constants;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /**
* Adds this constant to the application package as a file,
* such that it can be distributed using file distribution.
*
@@ -223,11 +244,16 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
*/
public Path writeConstant(String name, Tensor constant) {
Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants");
- createIfNeeded(constantsPath);
// "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
Path constantPath = constantsPath.append(name + ".tbf");
- // Write explicitly as a file on the file system as this is distributed using file distribution
+
+ // Remember the constant in a file we replicate in ZooKeeper
+ application.getFile(arguments.rankingConstantsPath().append(name + ".constant"))
+ .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPath));
+
+ // Write content explicitly as a file on the file system as this is distributed using file distribution
+ createIfNeeded(constantsPath);
IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant));
return constantPath;
}
@@ -267,8 +293,12 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
public Optional<String> signature() { return signature; }
public Optional<String> output() { return output; }
+ public Path rankingConstantsPath() {
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
+ }
+
public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_DIR
+ return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
.append(modelPath).append("expressions").append(expressionFileName());
}