diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2021-09-14 09:16:46 +0200 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2021-09-14 09:17:14 +0200 |
commit | 7b3242d789ba40b854130706a010f26af125328f (patch) | |
tree | 758fdec6285944b0ab2453358856075b92ef00da | |
parent | a72175250295f12c1b7d8c77ce8174096ca6b551 (diff) |
Make the LargeConstants usable concurrently from many threads
6 files changed, 28 insertions, 27 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java index 8ccc0ef429e..3d19cba78b6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java @@ -6,8 +6,9 @@ import com.yahoo.vespa.model.AbstractService; import java.util.Collection; import java.util.Collections; -import java.util.HashMap; import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.Function; /** * Constant values for ranking/model execution tied to a search definition, or globally to an application @@ -17,7 +18,7 @@ import java.util.Map; */ public class RankingConstants { - private final Map<String, RankingConstant> constants = new HashMap<>(); + private final Map<String, RankingConstant> constants = new ConcurrentHashMap<>(); private final FileRegistry fileRegistry; public RankingConstants(FileRegistry fileRegistry) { @@ -28,9 +29,23 @@ public class RankingConstants { constant.validate(); constant.register(fileRegistry); String name = constant.getName(); - if (constants.containsKey(name)) + RankingConstant prev = constants.putIfAbsent(name, constant); + if ( prev != null ) throw new IllegalArgumentException("Ranking constant '" + name + "' defined twice"); - constants.put(name, constant); + } + public void putIfAbsent(RankingConstant constant) { + constant.validate(); + constant.register(fileRegistry); + String name = constant.getName(); + constants.putIfAbsent(name, constant); + } + public void computeIfAbsent(String name, Function<? super String, ? extends RankingConstant> createConstant) { + constants.computeIfAbsent(name, key -> { + RankingConstant constant = createConstant.apply(key); + constant.validate(); + constant.register(fileRegistry); + return constant; + }); } /** Returns the ranking constant with the given name, or null if not present */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 4a315420b0a..3bf4bdb8e01 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -26,6 +26,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>(); + public TensorFlowFeatureConverter() {} + @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { if (node instanceof ReferenceNode) diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java index 032341297bf..e57d67abb15 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java @@ -205,7 +205,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform return new TensorType.Builder(TensorType.Value.FLOAT).indexed("d0", 1).indexed("d1", length).build(); } catch (NumberFormatException ex) { throw new IllegalArgumentException("Invalid argument to " + featureName + ": the first argument must be " + - "the length to the token sequence to generate. Got " + argument.toString()); + "the length to the token sequence to generate. Got " + argument); } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index c6c2fea5900..64e606dd7d3 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -265,8 +265,6 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri /** Returns the global ranking constants of this */ public RankingConstants rankingConstants() { return rankingConstants; } - public LargeRankExpressions rankExpressionFiles() { return largeRankExpressions; } - /** Creates a mutable model with no services instantiated */ public static VespaModel createIncomplete(DeployState deployState) throws IOException, SAXException { return new VespaModel(new NullConfigModelRegistry(), deployState, false); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 62f911c9f1a..ac97dbdbcee 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -219,10 +219,8 @@ public class ConvertedModel { for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) { profile.addInputFeature(input.getKey(), input.getValue()); } - addExpression(expression, expression.getName(), - constantsReplacedByFunctions, - model, store, profile, queryProfiles, - expressions); + addExpression(expression, expression.getName(), constantsReplacedByFunctions, + store, profile, queryProfiles, expressions); } // Transform and save function - must come after reading expressions due to optimization transforms @@ -254,7 +252,6 @@ public class ConvertedModel { private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, - ImportedMlModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, @@ -276,9 +273,7 @@ public class ConvertedModel { } for (RankingConstant constant : store.readLargeConstants()) { - if ( ! profile.rankingConstants().asMap().containsKey(constant.getName())) { - profile.rankingConstants().add(constant); - } + profile.rankingConstants().putIfAbsent(constant); } for (Pair<String, RankingExpression> function : store.readFunctions()) { @@ -325,10 +320,7 @@ public class ConvertedModel { } else { Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.rankingConstants().asMap().containsKey(constantName)) { - profile.rankingConstants().add(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); - } + profile.rankingConstants().computeIfAbsent(constantName, name -> new RankingConstant(name, constantValue.type(), constantPath.toString())); } } @@ -365,7 +357,7 @@ public class ConvertedModel { addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); - if ( ! requiredType.isPresent()) continue; // Not a required function + if ( requiredType.isEmpty()) continue; // Not a required function RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); if (rankingExpressionFunction == null) @@ -634,7 +626,7 @@ public class ConvertedModel { // Secret file format for remembering constants: application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" + constant.type().toString() + "\t" + - constant.toString() + "\n"); + constant + "\n"); } /** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */ diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 96bf2c64485..f77374651ff 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -2,13 +2,7 @@ package ai.vespa.rankingexpression.importer.onnx; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.Context; -import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import ai.vespa.rankingexpression.importer.ImportedModel; -import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; |