diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2018-02-05 22:56:23 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-05 22:56:23 +0100 |
commit | 6e6e9c71e11268a7badd2297341a0937cbad2d1f (patch) | |
tree | 4fc17d3e36f507efea78adc856228eec5f144019 | |
parent | 3632387ab3bf56688d54c0714bcefe6f0f6d999f (diff) | |
parent | 30a2d3e88529bc5a86ad6c53c8de35e4a71fbac3 (diff) |
Merge pull request #4924 from vespa-engine/bratseth/support-small-constants
Support small constants
11 files changed, 170 insertions, 33 deletions
diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java index 60524fbca8d..a8e1256e032 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/FilesApplicationFile.java @@ -111,8 +111,8 @@ public class FilesApplicationFile extends ApplicationFile { file.getParentFile().mkdirs(); } try { - String data = com.yahoo.io.IOUtils.readAll(input); String status = file.exists() ? ApplicationFile.ContentStatusChanged : ApplicationFile.ContentStatusNew; + String data = com.yahoo.io.IOUtils.readAll(input); IOUtils.writeFile(file, data, false); writeMetaFile(data, status); } catch (IOException e) { @@ -122,6 +122,21 @@ public class FilesApplicationFile extends ApplicationFile { } @Override + public ApplicationFile appendFile(String value) { + if (file.getParentFile() != null) { + file.getParentFile().mkdirs(); + } + try { + String status = file.exists() ? ApplicationFile.ContentStatusChanged : ApplicationFile.ContentStatusNew; + IOUtils.writeFile(file, value, true); + writeMetaFile(value, status); + } catch (IOException e) { + throw new RuntimeException(e); + } + return this; + } + + @Override public List<ApplicationFile> listFiles(final PathFilter filter) { List<ApplicationFile> files = new ArrayList<>(); if (!file.isDirectory()) { diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java index 0384a5c7a1c..33b7807aac5 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationFile.java @@ -75,6 +75,13 @@ public abstract class ApplicationFile implements Comparable<ApplicationFile> { public abstract ApplicationFile writeFile(Reader input); /** + * Appends the given string to this text file. + * + * @return this + */ + public abstract ApplicationFile appendFile(String value); + + /** * List the files under this directory. If this is file, an empty list is returned. * Only immediate files/subdirectories are returned. * diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 2f28d9adb8b..bcbc7cc99e2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -750,7 +750,9 @@ public class RankProfile implements Serializable, Cloneable { public TypeContext typeContext(QueryProfileRegistry queryProfiles) { TypeMapContext context = new TypeMapContext(); - // Add constants + // Add small constants + getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type())); + // Add large constants getSearch().getRankingConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType())); // Add attributes diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 495ca7dd14a..2b997aa25f2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -2,6 +2,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; +import com.yahoo.collections.Pair; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.application.provider.FilesApplicationPackage; @@ -11,6 +12,9 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel.Signature; @@ -25,11 +29,13 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; +import java.io.BufferedReader; import java.io.File; import java.io.IOException; import java.io.StringReader; import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -92,16 +98,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles); store.writeConverted(expression); - model.constants().forEach((k, v) -> transformConstant(store, profile, k, v)); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v)); return expression.getRoot(); } private ExpressionNode transformFromStoredModel(ModelStore store, RankProfile profile) { - for (RankingConstant constant : store.readRankingConstants()) { + for (Pair<String, Tensor> constant : store.readSmallConstants()) + profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); + + for (RankingConstant constant : store.readLargeConstants()) { if ( ! profile.getSearch().getRankingConstants().containsKey(constant.getName())) profile.getSearch().addRankingConstant(constant); } + return store.readConverted().getRoot(); } @@ -158,8 +169,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private void transformConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - Path constantPath = store.writeConstant(constantName, constantValue); + private void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + store.writeSmallConstant(constantName, constantValue); + profile.addConstant(constantName, asValue(constantValue)); + } + + private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + Path constantPath = store.writeLargeConstant(constantName, constantValue); if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); @@ -218,6 +234,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } + private Value asValue(Tensor tensor) { + if (tensor.type().rank() == 0) + return new DoubleValue(tensor.asDouble()); // the backend gets offended by dimensionless tensors + else + return new TensorValue(tensor); + } + /** * Provides read/write access to the correct directories of the application package given by the feature arguments */ @@ -272,13 +295,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } /** - * Reads the information about all the constants stored in the application package + * Reads the information about all the large (aka ranking) constants stored in the application package * (the constant value itself is replicated with file distribution). */ - public List<RankingConstant> readRankingConstants() { + public List<RankingConstant> readLargeConstants() { try { List<RankingConstant> constants = new ArrayList<>(); - for (ApplicationFile constantFile : application.getFile(arguments.rankingConstantsPath()).listFiles()) { + for (ApplicationFile constantFile : application.getFile(arguments.largeConstantsPath()).listFiles()) { String[] parts = IOUtils.readAll(constantFile.createReader()).split(":"); constants.add(new RankingConstant(parts[0], TensorType.fromSpec(parts[1]), parts[2])); } @@ -295,25 +318,63 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil * * @return the path to the stored constant, relative to the application package root */ - public Path writeConstant(String name, Tensor constant) { + public Path writeLargeConstant(String name, Tensor constant) { Path constantsPath = ApplicationPackage.MODELS_GENERATED_DIR.append(arguments.modelPath).append("constants"); // "tbf" ending for "typed binary format" - recognized by the nodes receiving the file: Path constantPath = constantsPath.append(name + ".tbf"); - Path constantPathCorrected = constantPath; - if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) - && ! constantPath.elements().contains(FilesApplicationPackage.preprocessed)) { - constantPathCorrected = Path.fromString(FilesApplicationPackage.preprocessed).append(constantPath); - } // Remember the constant in a file we replicate in ZooKeeper - application.getFile(arguments.rankingConstantsPath().append(name + ".constant")) - .writeFile(new StringReader(name + ":" + constant.type() + ":" + constantPathCorrected)); + application.getFile(arguments.largeConstantsPath().append(name + ".constant")) + .writeFile(new StringReader(name + ":" + constant.type() + ":" + correct(constantPath))); // Write content explicitly as a file on the file system as this is distributed using file distribution createIfNeeded(constantsPath); IOUtils.writeFile(application.getFileReference(constantPath), TypedBinaryFormat.encode(constant)); - return constantPathCorrected; + return correct(constantPath); + } + + private List<Pair<String, Tensor>> readSmallConstants() { + try { + ApplicationFile file = application.getFile(arguments.smallConstantsPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, Tensor>> constants = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + TensorType type = TensorType.fromSpec(parts[1]); + Tensor tensor = Tensor.from(type, parts[2]); + constants.add(new Pair<>(name, tensor)); + } + return constants; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + /** + * Append this constant to the single file used for small constants distributed as config + */ + public void writeSmallConstant(String name, Tensor constant) { + // Secret file format for remembering constants: + application.getFile(arguments.smallConstantsPath()).appendFile(name + "\t" + + constant.type().toString() + "\t" + + constant.toString() + "\n"); + } + + /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ + private Path correct(Path path) { + if (application.getFileReference(Path.fromString("")).getAbsolutePath().endsWith(FilesApplicationPackage.preprocessed) + && ! path.elements().contains(FilesApplicationPackage.preprocessed)) { + return Path.fromString(FilesApplicationPackage.preprocessed).append(path); + } + else { + return path; + } } private void createIfNeeded(Path path) { @@ -351,7 +412,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil public Optional<String> signature() { return signature; } public Optional<String> output() { return output; } - public Path rankingConstantsPath() { + /** Path to the small constants file */ + public Path smallConstantsPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("constants.txt"); + } + + /** Path to the large (ranking) constants directory */ + public Path largeConstantsPath() { return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java index 5255cdaeba1..0334012e8d9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorTransformer.java @@ -183,7 +183,7 @@ public class TensorTransformer extends ExpressionTransformer<RankProfileTransfor } private void addIfConstant(ReferenceNode node, Context context, RankProfile profile) { - if (!node.getName().equals(ConstantTensorTransformer.CONSTANT)) { + if ( ! node.getName().equals(ConstantTensorTransformer.CONSTANT)) { return; } if (node.children().size() != 1) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 5203e686681..7246b22b0f8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -399,6 +399,17 @@ public class RankingExpressionWithTensorFlowTestCase { } @Override + public ApplicationFile appendFile(String value) { + try { + IOUtils.writeFile(file, value, true); + return this; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + @Override public List<ApplicationFile> listFiles(PathFilter filter) { if ( ! isDirectory()) return Collections.emptyList(); return Arrays.stream(file.listFiles()).filter(f -> filter.accept(Path.fromString(f.toString()))) diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java index 717fb88e5dc..affc2e03e2b 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/zookeeper/ZKApplicationFile.java @@ -95,7 +95,6 @@ class ZKApplicationFile extends ApplicationFile { @Override public ApplicationFile writeFile(Reader input) { - // foo/bar/baz.txt String zkPath = getZKPath(path); try { String data = IOUtils.readAll(input); @@ -112,6 +111,21 @@ class ZKApplicationFile extends ApplicationFile { } @Override + public ApplicationFile appendFile(String value) { + String zkPath = getZKPath(path); + String status = ContentStatusNew; + if (zkApp.exists(zkPath)) { + status = ContentStatusChanged; + } + String existingData = zkApp.getData(zkPath); + if (existingData == null) + existingData = ""; + zkApp.putData(zkPath, existingData + value); + writeMetaFile(value, status); + return this; + } + + @Override public List<ApplicationFile> listFiles(PathFilter filter) { String userPath = getZKPath(path); List<ApplicationFile> ret = new ArrayList<>(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index b8f8e288257..55782c36d18 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -219,7 +219,7 @@ class OperationMapper { private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) { String name = toVespaName(params.node().getInput(0)); Tensor defaultValue = getConstantTensor(params, params.node().getInput(0)); - params.result().constant(name, defaultValue); + params.result().largeConstant(name, defaultValue); params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); // The default value will be provided by the macro. Users can override macro to change value. TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); @@ -544,7 +544,11 @@ class OperationMapper { private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) { String name = toVespaName(params.node().getName()); - params.result().constant(name, constant); + if (constant.type().rank() == 0 || constant.size() <= 1) { + params.result().smallConstant(name, constant); + } else { + params.result().largeConstant(name, constant); + } TypedTensorFunction output = new TypedTensorFunction(constant.type(), new TensorFunctionNode.TensorFunctionExpressionNode( new ReferenceNode("constant(\"" + name + "\")"))); @@ -553,8 +557,11 @@ class OperationMapper { private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) { String vespaName = toVespaName(name); - if (params.result().constants().containsKey(vespaName)) { - return params.result().constants().get(vespaName); + if (params.result().smallConstants().containsKey(vespaName)) { + return params.result().smallConstants().get(vespaName); + } + if (params.result().largeConstants().containsKey(vespaName)) { + return params.result().largeConstants().get(vespaName); } Session.Runner fetched = params.model().session().runner().fetch(name); List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 530f4793b62..351aa417f9c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -24,13 +24,15 @@ public class TensorFlowModel { private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); - private final Map<String, Tensor> constants = new HashMap<>(); + private final Map<String, Tensor> smallConstants = new HashMap<>(); + private final Map<String, Tensor> largeConstants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); private final Map<String, RankingExpression> macros = new HashMap<>(); private final Map<String, TensorType> requiredMacros = new HashMap<>(); void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } - void constant(String name, Tensor constant) { constants.put(name, constant); } + void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } void macro(String name, RankingExpression expression) { macros.put(name, expression); } void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } @@ -43,8 +45,19 @@ public class TensorFlowModel { /** Returns an immutable map of the arguments ("Placeholders") of this */ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } - /** Returns an immutable map of the constants of this */ - public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); } + /** + * Returns an immutable map of the small constants of this. + * These should have sizes up to a few kb at most, and correspond to constant + * values given in the TensorFlow source. + */ + public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } + + /** + * Returns an immutable map of the large constants of this. + * These can have sizes in gigabytes and must be distributed to nodes separately from configuration, + * and correspond to Variable files stored separately in TensorFlow. + */ + public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } /** * Returns an immutable map of the expressions of this - corresponding to TensorFlow nodes diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 01dd15d5fa0..ad5abd4c03d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -20,15 +20,15 @@ public class MnistSoftmaxImportTestCase { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/mnist_softmax/saved"); // Check constants - assertEquals(2, model.get().constants().size()); + assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().constants().get("Variable"); + Tensor constant0 = model.get().largeConstants().get("Variable"); assertNotNull(constant0); assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().constants().get("Variable_1"); + Tensor constant1 = model.get().largeConstants().get("Variable_1"); assertNotNull(constant1); assertEquals(new TensorType.Builder().indexed("d0", 10).build(), constant1.type()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 2c621fd2e92..ae7714b271a 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -57,7 +57,8 @@ public class TestableTensorFlowModel { private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); return context; } |