summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-15 14:48:22 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-15 14:48:22 +0100
commitd6434be601768c2fd1f8a726101b340e48565daa (patch)
tree5323d794925dfa5522ad6e32f1b4df1ffcef4fe7
parentf3aaa08db00c9df1758fb1ab863ebba13ca043d3 (diff)
Use Path. Save constants in models.generated/
-rw-r--r--config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java6
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java34
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java4
3 files changed, 32 insertions, 12 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
index aca7b595249..83d12718b6a 100644
--- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
+++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java
@@ -53,7 +53,11 @@ public interface ApplicationPackage {
String DOCPROCCHAINS_DIR = "docproc/chains";
String PROCESSORCHAINS_DIR = "processor/chains";
String ROUTINGTABLES_DIR = "routing/tables";
- String MODELS_DIR = "models";
+
+ /** Machine-learned models - only present in user-uploaded package instances */
+ Path MODELS_DIR = Path.fromString("models");
+ /** Files generated from machine-learned models - distributed to config servers over file distribution */
+ Path MODELS_GENERATED_DIR = Path.fromString("models.generated");
// NOTE: this directory is created in serverdb during deploy, and should not exist in the original user application
/** Do not use */
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 32f8f4871df..606ae6b43e0 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms;
import com.google.common.base.Joiner;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
@@ -43,7 +44,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<String, TensorFlowModel> importedModels = new HashMap<>();
+ private final Map<Path, TensorFlowModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -63,8 +64,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
"the tensorflow model directory under [application]/models");
- String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0));
- TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath));
+ Path modelPath = Path.fromString(asString(feature.getArguments().expressions().get(0)));
+ TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> importModel(modelPath));
// Find the specified expression
TensorFlowModel.Signature signature = chooseSignature(result,
@@ -85,6 +86,17 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
+ private TensorFlowModel importModel(Path modelPath) {
+ try {
+ return tensorFlowImporter.importModel(new File(ApplicationPackage.MODELS_DIR.append(modelPath)
+ .getRelative())
+ .getCanonicalPath());
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
/**
* Returns the specified, existing signature, or the only signature if none is specified.
* Throws IllegalArgumentException in all other cases.
@@ -138,17 +150,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
- private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) {
+ private void transformConstant(Path modelPath, RankProfile profile, String constantName, Tensor constantValue) {
try {
if (profile.getSearch().getRankingConstants().containsKey(constantName)) return;
- File constantFilePath = new File(modelPath, "converted_variables").getCanonicalFile();
- if (!constantFilePath.exists()) {
- if (!constantFilePath.mkdir())
+ System.out.println("modelPath is " + modelPath);
+ File constantFilePath = new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath)
+ .append("constants")
+ .getRelative())
+ .getCanonicalFile();
+ System.out.println("constant file path is " + constantFilePath);
+ if ( ! constantFilePath.exists())
+ if ( ! constantFilePath.mkdir())
throw new IOException("Could not create directory " + constantFilePath);
- }
- // "tbf" ending for "typed binary format" - recognized by the nodes reciving the file:
+ // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file:
File constantFile = new File(constantFilePath, constantName + ".tbf");
IOUtils.writeFile(constantFile, TypedBinaryFormat.encode(constantValue));
profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFile.getPath()));
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 31f7511155b..aa47b0b3b81 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -33,7 +33,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@After
public void removeGeneratedConstantTensorFiles() {
- IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "converted_variables"));
+ IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "constants"));
}
@Test
@@ -126,7 +126,7 @@ public class RankingExpressionWithTensorFlowTestCase {
try {
TensorValue constant = (TensorValue)search.rankProfile("my_profile").getConstants().get(name); // Old way. TODO: Remove
if (constant == null) { // New way
- File constantFile = new File(modelDirectory.substring(3) + "/converted_variables", name + ".tbf");
+ File constantFile = new File(modelDirectory.substring(3) + "/constants", name + ".tbf");
RankingConstant rankingConstant = search.search().getRankingConstants().get(name);
assertEquals(name, rankingConstant.getName());
assertEquals(constantFile.getAbsolutePath(), rankingConstant.getFileName());