diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-15 14:48:22 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-15 14:48:22 +0100 |
commit | d6434be601768c2fd1f8a726101b340e48565daa (patch) | |
tree | 5323d794925dfa5522ad6e32f1b4df1ffcef4fe7 | |
parent | f3aaa08db00c9df1758fb1ab863ebba13ca043d3 (diff) |
Use Path. Save constants in models.generated/
3 files changed, 32 insertions, 12 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index aca7b595249..83d12718b6a 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -53,7 +53,11 @@ public interface ApplicationPackage { String DOCPROCCHAINS_DIR = "docproc/chains"; String PROCESSORCHAINS_DIR = "processor/chains"; String ROUTINGTABLES_DIR = "routing/tables"; - String MODELS_DIR = "models"; + + /** Machine-learned models - only present in user-uploaded package instances */ + Path MODELS_DIR = Path.fromString("models"); + /** Files generated from machine-learned models - distributed to config servers over file distribution */ + Path MODELS_GENERATED_DIR = Path.fromString("models.generated"); // NOTE: this directory is created in serverdb during deploy, and should not exist in the original user application /** Do not use */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 32f8f4871df..606ae6b43e0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -4,6 +4,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; @@ -43,7 +44,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<String, TensorFlowModel> importedModels = new HashMap<>(); + private final Map<Path, TensorFlowModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -63,8 +64,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + "the tensorflow model directory under [application]/models"); - String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0)); - TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath)); + Path modelPath = Path.fromString(asString(feature.getArguments().expressions().get(0))); + TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> importModel(modelPath)); // Find the specified expression TensorFlowModel.Signature signature = chooseSignature(result, @@ -85,6 +86,17 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } + private TensorFlowModel importModel(Path modelPath) { + try { + return tensorFlowImporter.importModel(new File(ApplicationPackage.MODELS_DIR.append(modelPath) + .getRelative()) + .getCanonicalPath()); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + /** * Returns the specified, existing signature, or the only signature if none is specified. * Throws IllegalArgumentException in all other cases. @@ -138,17 +150,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) { + private void transformConstant(Path modelPath, RankProfile profile, String constantName, Tensor constantValue) { try { if (profile.getSearch().getRankingConstants().containsKey(constantName)) return; - File constantFilePath = new File(modelPath, "converted_variables").getCanonicalFile(); - if (!constantFilePath.exists()) { - if (!constantFilePath.mkdir()) + System.out.println("modelPath is " + modelPath); + File constantFilePath = new File(ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath) + .append("constants") + .getRelative()) + .getCanonicalFile(); + System.out.println("constant file path is " + constantFilePath); + if ( ! constantFilePath.exists()) + if ( ! constantFilePath.mkdir()) throw new IOException("Could not create directory " + constantFilePath); - } - // "tbf" ending for "typed binary format" - recognized by the nodes reciving the file: + // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: File constantFile = new File(constantFilePath, constantName + ".tbf"); IOUtils.writeFile(constantFile, TypedBinaryFormat.encode(constantValue)); profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFile.getPath())); diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 31f7511155b..aa47b0b3b81 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -33,7 +33,7 @@ public class RankingExpressionWithTensorFlowTestCase { @After public void removeGeneratedConstantTensorFiles() { - IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "converted_variables")); + IOUtils.recursiveDeleteDir(new File(modelDirectory.substring(3), "constants")); } @Test @@ -126,7 +126,7 @@ public class RankingExpressionWithTensorFlowTestCase { try { TensorValue constant = (TensorValue)search.rankProfile("my_profile").getConstants().get(name); // Old way. TODO: Remove if (constant == null) { // New way - File constantFile = new File(modelDirectory.substring(3) + "/converted_variables", name + ".tbf"); + File constantFile = new File(modelDirectory.substring(3) + "/constants", name + ".tbf"); RankingConstant rankingConstant = search.search().getRankingConstants().get(name); assertEquals(name, rankingConstant.getName()); assertEquals(constantFile.getAbsolutePath(), rankingConstant.getFileName()); |