summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-25 20:07:56 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-25 20:07:56 +0100
commit1d88554bd513783715425120e76fc5f2a86f439f (patch)
tree166c86107d3620014cc7e26d85118c311e1b8cf0 /config-model
parenta01bc21d9bcbc417a9fb2591079561f59f76865e (diff)
Java type only interface between imported-models and config models
This avoids class incompatibility problems when passing an imported model across bundle boundaries to a config model. Tensor string parsing has been sped up as this relies on it more.
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java51
1 files changed, 36 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 59aa5b3ba53..259ac5227ae 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -202,8 +202,9 @@ public class ConvertedModel {
// Add expressions
Map<String, ExpressionFunction> expressions = new HashMap<>();
- for (Pair<String, ExpressionFunction> output : model.outputExpressions()) {
- addExpression(output.getSecond(), output.getFirst(),
+ for (ImportedModel.ImportedFunction outputFunction : model.outputExpressions()) {
+ ExpressionFunction expression = asExpressionFunction(outputFunction);
+ addExpression(expression, expression.getName(),
constantsReplacedByFunctions,
model, store, profile, queryProfiles,
expressions);
@@ -218,6 +219,23 @@ public class ConvertedModel {
return expressions;
}
+ private static ExpressionFunction asExpressionFunction(ImportedModel.ImportedFunction function) {
+ try {
+ Map<String, TensorType> argumentTypes = new HashMap<>();
+ for (Map.Entry<String, String> entry : function.argumentTypes().entrySet())
+ argumentTypes.put(entry.getKey(), TensorType.fromSpec(entry.getValue()));
+
+ return new ExpressionFunction(function.name(),
+ function.arguments(),
+ new RankingExpression(function.expression()),
+ argumentTypes,
+ function.returnType().map(TensorType::fromSpec));
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException("Gor an illegal argument from importing " + function.name(), e);
+ }
+ }
+
private static void addExpression(ExpressionFunction expression,
String expressionName,
Set<String> constantsReplacedByFunctions,
@@ -248,7 +266,9 @@ public class ConvertedModel {
return store.readExpressions();
}
- private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
+ private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName,
+ String constantValueString) {
+ Tensor constantValue = Tensor.from(constantValueString);
store.writeSmallConstant(constantName, constantValue);
profile.addConstant(constantName, asValue(constantValue));
}
@@ -258,7 +278,8 @@ public class ConvertedModel {
QueryProfileRegistry queryProfiles,
Set<String> constantsReplacedByFunctions,
String constantName,
- Tensor constantValue) {
+ String constantValueString) {
+ Tensor constantValue = Tensor.from(constantValueString);
RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName);
if (rankingExpressionFunctionOverridingConstant != null) {
TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles));
@@ -306,14 +327,14 @@ public class ConvertedModel {
Set<String> functionNames = new HashSet<>();
addFunctionNamesIn(expression.getRoot(), functionNames, model);
for (String functionName : functionNames) {
- TensorType requiredType = model.inputs().get(functionName);
- if (requiredType == null) continue; // Not a required function
+ Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
+ if ( ! requiredType.isPresent()) continue; // Not a required function
RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
if (rankingExpressionFunction == null)
throw new IllegalArgumentException("Model refers input '" + functionName +
- "' of type " + requiredType + " but this function is not present in " +
- profile);
+ "' of type " + requiredType.get() +
+ " but this function is not present in " + profile);
// TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second
// phase and summary features), as it may only resolve correctly given those bindings
// Or, probably better, annotate the functions with type constraints here and verify during general
@@ -321,12 +342,12 @@ public class ConvertedModel {
TensorType actualType = rankingExpressionFunction.function().getBody().getRoot().type(profile.typeContext(queryProfiles));
if ( actualType == null)
throw new IllegalArgumentException("Model refers input '" + functionName +
- "' of type " + requiredType +
+ "' of type " + requiredType.get() +
" which must be produced by a function in the rank profile, but " +
"this function references a feature which is not declared");
- if ( ! actualType.isAssignableTo(requiredType))
+ if ( ! actualType.isAssignableTo(requiredType.get()))
throw new IllegalArgumentException("Model refers input '" + functionName + "'. " +
- typeMismatchExplanation(requiredType, actualType));
+ typeMismatchExplanation(requiredType.get(), actualType));
}
}
@@ -339,7 +360,7 @@ public class ConvertedModel {
/** Add the generated functions to the rank profile */
private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) {
- model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, v.copy()));
+ model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, RankingExpression.from(v)));
}
/**
@@ -383,7 +404,7 @@ public class ConvertedModel {
List<ExpressionNode> children = ((TensorFunctionNode)node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.inputs().containsKey(referenceNode.getName())) {
+ if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
@@ -391,7 +412,7 @@ public class ConvertedModel {
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.inputs().containsKey(referenceNode.getName())) {
+ if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
@@ -487,7 +508,7 @@ public class ConvertedModel {
if (referenceNode.getOutput() == null) { // function references cannot specify outputs
names.add(referenceNode.getName());
if (model.functions().containsKey(referenceNode.getName())) {
- addFunctionNamesIn(model.functions().get(referenceNode.getName()).getRoot(), names, model);
+ addFunctionNamesIn(RankingExpression.from(model.functions().get(referenceNode.getName())).getRoot(), names, model);
}
}
}