diff options
4 files changed, 108 insertions, 52 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index 9a7e1960696..2e1c810675c 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -56,8 +56,10 @@ public interface ApplicationPackage { /** Machine-learned models - only present in user-uploaded package instances */ Path MODELS_DIR = Path.fromString("models"); - /** Files generated from machine-learned models - distributed to config servers over file distribution */ + /** Files generated from machine-learned models */ Path MODELS_GENERATED_DIR = Path.fromString("models.generated"); + /** Files generated from machine-learned models which should be replicated in ZooKeeper */ + Path MODELS_GENERATED_REPLICATED_DIR = MODELS_GENERATED_DIR.append("replicated"); // NOTE: this directory is created in serverdb during deploy, and should not exist in the original user application /** Do not use */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 0324b9852df..0dd5b4166ef 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -2,13 +2,13 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; +import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; @@ -20,13 +20,16 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.File; import java.io.IOException; import java.io.StringReader; import java.io.UncheckedIOException; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Optional; @@ -37,13 +40,10 @@ import java.util.Optional; * * @author bratseth */ -// TODO: - Verify types of macros -// - Avoid name conflicts across models for constants +// TODO: Verify types of macros +// TODO: Avoid name conflicts across models for constants public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfileTransformContext> { - // TODO: Make system test work with this set to true, then remove the "true" path - private static final boolean constantsInConfig = true; - private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ @@ -68,14 +68,14 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil if (store.hasTensorFlowModels()) return transformFromTensorFlowModel(store, context.rankProfile()); else // is should have previously stored model information instead - return store.readConverted().getRoot(); + return transformFromStoredModel(store, context.rankProfile()); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); } } - private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile rankProfile) { + private ExpressionNode transformFromTensorFlowModel(ModelStore store, RankProfile profile) { TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), k -> tensorFlowImporter.importModel(store.tensorFlowModelDir())); @@ -85,15 +85,18 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil RankingExpression expression = model.expressions().get(output); store.writeConverted(expression); - // Add all constants (after finding outputs to fail faster when the output is not found) TODO: Remove the first path - if (constantsInConfig) - model.constants().forEach((k, v) -> rankProfile.addConstantTensor(k, new TensorValue(v))); - else // correct way, disabled for now - model.constants().forEach((k, v) -> transformConstant(store, rankProfile, k, v)); - + model.constants().forEach((k, v) -> transformConstant(store, profile, k, v)); return expression.getRoot(); } + private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { + for (RankingConstant constant : store.readRankingConstants()) { + if (!profile.getSearch().getRankingConstants().containsKey(constant.getName())) + profile.getSearch().addRankingConstant(constant); + } + return store.readConverted().getRoot(); + } + /** * Returns the specified, existing signature, or the only signature if none is specified. * Throws IllegalArgumentException in all other cases. @@ -216,6 +219,24 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } /** + * Reads the information about all the constants stored in the application package + * (the constant value itself is replicated with file distribution). + */ + public List<RankingConstant> readRankingConstants() { + try { + List<RankingConstant> constants = new ArrayList<>(); + for (ApplicationFile constantFile : application.getFile(arguments.rankingConstantsPath()).listFiles()) { + String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); + constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** * Adds this constant to the application package as a file, * such that it can be distributed using file distribution. * @@ -223,11 +244,16 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil */ public Path writeConstant(String name, Tensor constant) { Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); - createIfNeeded(constantsPath); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); - // Write explicitly as a file on the file system as this is distributed using file distribution + + // Remember the constant in a file we replicate in ZooKeeper + application.getFile(arguments.rankingConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPath)); + + // Write content explicitly as a file on the file system as this is distributed using file distribution + createIfNeeded(constantsPath); IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); return constantPath; } @@ -267,8 +293,12 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil public Optional<String> signature() { return signature; } public Optional<String> output() { return output; } + public Path rankingConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); + } + public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_DIR + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR .append(modelPath).append("expressions").append(expressionFileName()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 79d679b43b1..0354173f365 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -9,7 +9,6 @@ import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.yolean.Exceptions; @@ -26,12 +25,12 @@ import java.io.Reader; import java.io.UncheckedIOException; import java.util.Arrays; import java.util.Collections; +import java.util.Iterator; import java.util.List; import java.util.Optional; import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -40,17 +39,17 @@ import static org.junit.Assert.fail; */ public class RankingExpressionWithTensorFlowTestCase { - private final Path applicationDirectory = Path.fromString("src/test/integration/tensorflow/"); + private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/"); private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(Variable_1), d0, d1), f(a,b)(a + b))"; @After public void removeGeneratedConstantTensorFiles() { - IOUtils.recursiveDeleteDir(applicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); } @Test public void testMinimalTensorFlowReference() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDirectory); + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( application, " rank-profile my_profile {\n" + @@ -59,13 +58,13 @@ public class RankingExpressionWithTensorFlowTestCase { " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertConstant(10, "Variable_1", search); - assertConstant(7840, "Variable", search); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); } @Test public void testNestedTensorFlowReference() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDirectory); + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( application, " rank-profile my_profile {\n" + @@ -74,13 +73,13 @@ public class RankingExpressionWithTensorFlowTestCase { " }\n" + " }"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertConstant(10, "Variable_1", search); - assertConstant(7840, "Variable", search); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); } @Test public void testTensorFlowReferenceSpecifyingSignature() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDirectory); + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( application, " rank-profile my_profile {\n" + @@ -93,7 +92,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingSignatureAndOutput() throws ParseException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDirectory); + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( application, " rank-profile my_profile {\n" + @@ -107,7 +106,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingSignature() throws ParseException { try { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDirectory); + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( application, " rank-profile my_profile {\n" + @@ -129,7 +128,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReferenceSpecifyingNonExistingOutput() throws ParseException { try { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDirectory); + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( application, " rank-profile my_profile {\n" + @@ -150,7 +149,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressions() throws ParseException, IOException { - StoringApplicationPackage application = new StoringApplicationPackage(applicationDirectory); + StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = new RankProfileSearchFixture( application, " rank-profile my_profile {\n" + @@ -159,12 +158,14 @@ public class RankingExpressionWithTensorFlowTestCase { " }\n" + " }"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); + assertConstant("Variable_1", search, Optional.of(10L)); + assertConstant("Variable", search, Optional.of(7840L)); // At this point the expression is stored - copy application to another location which do not have a models dir - Path storedApplicationDirectory = applicationDirectory.getParentPath().append("copy"); + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); try { storedApplicationDirectory.toFile().mkdirs(); - IOUtils.copyDirectory(applicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); RankProfileSearchFixture searchFromStored = new RankProfileSearchFixture( @@ -175,6 +176,10 @@ public class RankingExpressionWithTensorFlowTestCase { " }\n" + " }"); searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); + // Verify that the constants exists, but don't verify the content as we are not + // simulating file distribution in this test + assertConstant("Variable_1", searchFromStored, Optional.empty()); + assertConstant("Variable", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -182,24 +187,24 @@ public class RankingExpressionWithTensorFlowTestCase { } - private void assertConstant(int expectedSize, String name, RankProfileSearchFixture search) { + /** + * Verifies that the constant with the given name exists, and - only if an expected size is given - + * that the content of the constant is available and has the expected size. + */ + private void assertConstant(String name, RankProfileSearchFixture search, Optional<Long> expectedSize) { try { - TensorValue constant = (TensorValue)search.rankProfile("my_profile").getConstants().get(name); // Old way. TODO: Remove - if (constant == null) { // New way - Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf"); - RankingConstant rankingConstant = search.search().getRankingConstants().get(name); - assertEquals(name, rankingConstant.getName()); - assertEquals(constantApplicationPackagePath.toString(), rankingConstant.getFileName()); - - Path constantPath = applicationDirectory.append(constantApplicationPackagePath); + Path constantApplicationPackagePath = Path.fromString("models.generated/mnist_softmax/saved/constants").append(name + ".tbf"); + RankingConstant rankingConstant = search.search().getRankingConstants().get(name); + assertEquals(name, rankingConstant.getName()); + assertEquals(constantApplicationPackagePath.toString(), rankingConstant.getFileName()); + + if (expectedSize.isPresent()) { + Path constantPath = applicationDir.append(constantApplicationPackagePath); assertTrue("Constant file '" + constantPath + "' has been written", constantPath.toFile().exists()); Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile()))); - assertEquals(expectedSize, deserializedConstant.size()); - } else { // Old way. TODO: Remove - assertNotNull(name + " is imported", constant); - assertEquals(expectedSize, constant.asTensor().size()); + assertEquals(expectedSize.get().longValue(), deserializedConstant.size()); } } catch (IOException e) { @@ -300,8 +305,8 @@ public class RankingExpressionWithTensorFlowTestCase { public List<ApplicationFile> listFiles(PathFilter filter) { if ( ! isDirectory()) return Collections.emptyList(); return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString()))) - .map(f -> new StoringApplicationPackageFile(Path.fromString(f.toString()), - root)) + .map(f -> new StoringApplicationPackageFile(asApplicationRelativePath(f), + root)) .collect(Collectors.toList()); } @@ -320,6 +325,25 @@ public class RankingExpressionWithTensorFlowTestCase { public int compareTo(ApplicationFile other) { return this.getPath().getName().compareTo((other).getPath().getName()); } + + /** Strips the application package root path prefix from the path of the given file */ + private Path asApplicationRelativePath(File file) { + Path path = Path.fromString(file.toString()); + + Iterator<String> pathIterator = path.iterator(); + // Skip the path elements this shares with the root + for (Iterator<String> rootIterator = root.iterator(); rootIterator.hasNext(); ) { + String rootElement = rootIterator.next(); + String pathElement = pathIterator.next(); + if ( ! rootElement.equals(pathElement)) throw new RuntimeException("Assumption broken"); + } + // Build a path from the remaining + Path relative = Path.fromString(""); + while (pathIterator.hasNext()) + relative = relative.append(pathIterator.next()); + return relative; + } + } } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClient.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClient.java index 5e87c6c0f6b..c9c6ef4b428 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClient.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ZooKeeperClient.java @@ -202,8 +202,8 @@ public class ZooKeeperClient { writeDir(app.getFile(Path.fromString(ApplicationPackage.ROUTINGTABLES_DIR)), getZooKeeperAppPath(ConfigCurator.USERAPP_ZK_SUBPATH).append(ApplicationPackage.ROUTINGTABLES_DIR), xmlFilter, true); - writeDir(app.getFile(ApplicationPackage.MODELS_GENERATED_DIR), - getZooKeeperAppPath(ConfigCurator.USERAPP_ZK_SUBPATH).append(ApplicationPackage.MODELS_GENERATED_DIR), + writeDir(app.getFile(ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR), + getZooKeeperAppPath(ConfigCurator.USERAPP_ZK_SUBPATH).append(ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR), true); } |