aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2021-09-14 09:16:46 +0200
committerHenning Baldersheim <balder@yahoo-inc.com>2021-09-14 09:17:14 +0200
commit7b3242d789ba40b854130706a010f26af125328f (patch)
tree758fdec6285944b0ab2453358856075b92ef00da
parenta72175250295f12c1b7d8c77ce8174096ca6b551 (diff)
Make the LargeConstants usable concurrently from many threads
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java23
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java20
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java6
6 files changed, 28 insertions, 27 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
index 8ccc0ef429e..3d19cba78b6 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java
@@ -6,8 +6,9 @@ import com.yahoo.vespa.model.AbstractService;
import java.util.Collection;
import java.util.Collections;
-import java.util.HashMap;
import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.Function;
/**
* Constant values for ranking/model execution tied to a search definition, or globally to an application
@@ -17,7 +18,7 @@ import java.util.Map;
*/
public class RankingConstants {
- private final Map<String, RankingConstant> constants = new HashMap<>();
+ private final Map<String, RankingConstant> constants = new ConcurrentHashMap<>();
private final FileRegistry fileRegistry;
public RankingConstants(FileRegistry fileRegistry) {
@@ -28,9 +29,23 @@ public class RankingConstants {
constant.validate();
constant.register(fileRegistry);
String name = constant.getName();
- if (constants.containsKey(name))
+ RankingConstant prev = constants.putIfAbsent(name, constant);
+ if ( prev != null )
throw new IllegalArgumentException("Ranking constant '" + name + "' defined twice");
- constants.put(name, constant);
+ }
+ public void putIfAbsent(RankingConstant constant) {
+ constant.validate();
+ constant.register(fileRegistry);
+ String name = constant.getName();
+ constants.putIfAbsent(name, constant);
+ }
+ public void computeIfAbsent(String name, Function<? super String, ? extends RankingConstant> createConstant) {
+ constants.computeIfAbsent(name, key -> {
+ RankingConstant constant = createConstant.apply(key);
+ constant.validate();
+ constant.register(fileRegistry);
+ return constant;
+ });
}
/** Returns the ranking constant with the given name, or null if not present */
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 4a315420b0a..3bf4bdb8e01 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -26,6 +26,8 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
private final Map<Path, ConvertedModel> convertedTensorFlowModels = new HashMap<>();
+ public TensorFlowFeatureConverter() {}
+
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
if (node instanceof ReferenceNode)
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
index 032341297bf..e57d67abb15 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TokenTransformer.java
@@ -205,7 +205,7 @@ public class TokenTransformer extends ExpressionTransformer<RankProfileTransform
return new TensorType.Builder(TensorType.Value.FLOAT).indexed("d0", 1).indexed("d1", length).build();
} catch (NumberFormatException ex) {
throw new IllegalArgumentException("Invalid argument to " + featureName + ": the first argument must be " +
- "the length to the token sequence to generate. Got " + argument.toString());
+ "the length to the token sequence to generate. Got " + argument);
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index c6c2fea5900..64e606dd7d3 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -265,8 +265,6 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
/** Returns the global ranking constants of this */
public RankingConstants rankingConstants() { return rankingConstants; }
- public LargeRankExpressions rankExpressionFiles() { return largeRankExpressions; }
-
/** Creates a mutable model with no services instantiated */
public static VespaModel createIncomplete(DeployState deployState) throws IOException, SAXException {
return new VespaModel(new NullConfigModelRegistry(), deployState, false);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 62f911c9f1a..ac97dbdbcee 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -219,10 +219,8 @@ public class ConvertedModel {
for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) {
profile.addInputFeature(input.getKey(), input.getValue());
}
- addExpression(expression, expression.getName(),
- constantsReplacedByFunctions,
- model, store, profile, queryProfiles,
- expressions);
+ addExpression(expression, expression.getName(), constantsReplacedByFunctions,
+ store, profile, queryProfiles, expressions);
}
// Transform and save function - must come after reading expressions due to optimization transforms
@@ -254,7 +252,6 @@ public class ConvertedModel {
private static void addExpression(ExpressionFunction expression,
String expressionName,
Set<String> constantsReplacedByFunctions,
- ImportedMlModel model,
ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
@@ -276,9 +273,7 @@ public class ConvertedModel {
}
for (RankingConstant constant : store.readLargeConstants()) {
- if ( ! profile.rankingConstants().asMap().containsKey(constant.getName())) {
- profile.rankingConstants().add(constant);
- }
+ profile.rankingConstants().putIfAbsent(constant);
}
for (Pair<String, RankingExpression> function : store.readFunctions()) {
@@ -325,10 +320,7 @@ public class ConvertedModel {
}
else {
Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.rankingConstants().asMap().containsKey(constantName)) {
- profile.rankingConstants().add(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
- }
+ profile.rankingConstants().computeIfAbsent(constantName, name -> new RankingConstant(name, constantValue.type(), constantPath.toString()));
}
}
@@ -365,7 +357,7 @@ public class ConvertedModel {
addFunctionNamesIn(expression.getRoot(), functionNames, model);
for (String functionName : functionNames) {
Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
- if ( ! requiredType.isPresent()) continue; // Not a required function
+ if ( requiredType.isEmpty()) continue; // Not a required function
RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
if (rankingExpressionFunction == null)
@@ -634,7 +626,7 @@ public class ConvertedModel {
// Secret file format for remembering constants:
application.getFile(modelFiles.smallConstantsPath()).appendFile(name + "\t" +
constant.type().toString() + "\t" +
- constant.toString() + "\n");
+ constant + "\n");
}
/** Workaround for being constructed with the .preprocessed dir as root while later being used outside it */
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 96bf2c64485..f77374651ff 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -2,13 +2,7 @@
package ai.vespa.rankingexpression.importer.onnx;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.Context;
-import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import ai.vespa.rankingexpression.importer.ImportedModel;
-import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;