diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-25 20:07:56 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-25 20:07:56 +0100 |
commit | 1d88554bd513783715425120e76fc5f2a86f439f (patch) | |
tree | 166c86107d3620014cc7e26d85118c311e1b8cf0 /config-model | |
parent | a01bc21d9bcbc417a9fb2591079561f59f76865e (diff) |
Java type only interface between imported-models and config models
This avoids class incompatibility problems when passing an
imported model across bundle boundaries to a config model.
Tensor string parsing has been sped up as this relies on it more.
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 51 |
1 files changed, 36 insertions, 15 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 59aa5b3ba53..259ac5227ae 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -202,8 +202,9 @@ public class ConvertedModel { // Add expressions Map<String, ExpressionFunction> expressions = new HashMap<>(); - for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { - addExpression(output.getSecond(), output.getFirst(), + for (ImportedModel.ImportedFunction outputFunction : model.outputExpressions()) { + ExpressionFunction expression = asExpressionFunction(outputFunction); + addExpression(expression, expression.getName(), constantsReplacedByFunctions, model, store, profile, queryProfiles, expressions); @@ -218,6 +219,23 @@ public class ConvertedModel { return expressions; } + private static ExpressionFunction asExpressionFunction(ImportedModel.ImportedFunction function) { + try { + Map<String, TensorType> argumentTypes = new HashMap<>(); + for (Map.Entry<String, String> entry : function.argumentTypes().entrySet()) + argumentTypes.put(entry.getKey(), TensorType.fromSpec(entry.getValue())); + + return new ExpressionFunction(function.name(), + function.arguments(), + new RankingExpression(function.expression()), + argumentTypes, + function.returnType().map(TensorType::fromSpec)); + } + catch (ParseException e) { + throw new IllegalArgumentException("Gor an illegal argument from importing " + function.name(), e); + } + } + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, @@ -248,7 +266,9 @@ public class ConvertedModel { return store.readExpressions(); } - private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { + private static void transformSmallConstant(ModelStore store, RankProfile profile, String constantName, + String constantValueString) { + Tensor constantValue = Tensor.from(constantValueString); store.writeSmallConstant(constantName, constantValue); profile.addConstant(constantName, asValue(constantValue)); } @@ -258,7 +278,8 @@ public class ConvertedModel { QueryProfileRegistry queryProfiles, Set<String> constantsReplacedByFunctions, String constantName, - Tensor constantValue) { + String constantValueString) { + Tensor constantValue = Tensor.from(constantValueString); RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName); if (rankingExpressionFunctionOverridingConstant != null) { TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles)); @@ -306,14 +327,14 @@ public class ConvertedModel { Set<String> functionNames = new HashSet<>(); addFunctionNamesIn(expression.getRoot(), functionNames, model); for (String functionName : functionNames) { - TensorType requiredType = model.inputs().get(functionName); - if (requiredType == null) continue; // Not a required function + Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); + if ( ! requiredType.isPresent()) continue; // Not a required function RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); if (rankingExpressionFunction == null) throw new IllegalArgumentException("Model refers input '" + functionName + - "' of type " + requiredType + " but this function is not present in " + - profile); + "' of type " + requiredType.get() + + " but this function is not present in " + profile); // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second // phase and summary features), as it may only resolve correctly given those bindings // Or, probably better, annotate the functions with type constraints here and verify during general @@ -321,12 +342,12 @@ public class ConvertedModel { TensorType actualType = rankingExpressionFunction.function().getBody().getRoot().type(profile.typeContext(queryProfiles)); if ( actualType == null) throw new IllegalArgumentException("Model refers input '" + functionName + - "' of type " + requiredType + + "' of type " + requiredType.get() + " which must be produced by a function in the rank profile, but " + "this function references a feature which is not declared"); - if ( ! actualType.isAssignableTo(requiredType)) + if ( ! actualType.isAssignableTo(requiredType.get())) throw new IllegalArgumentException("Model refers input '" + functionName + "'. " + - typeMismatchExplanation(requiredType, actualType)); + typeMismatchExplanation(requiredType.get(), actualType)); } } @@ -339,7 +360,7 @@ public class ConvertedModel { /** Add the generated functions to the rank profile */ private static void addGeneratedFunctions(ImportedModel model, RankProfile profile) { - model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, v.copy())); + model.functions().forEach((k, v) -> addGeneratedFunctionToProfile(profile, k, RankingExpression.from(v))); } /** @@ -383,7 +404,7 @@ public class ConvertedModel { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.inputs().containsKey(referenceNode.getName())) { + if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { return reduceBatchDimensionExpression(tensorFunction, typeContext); } } @@ -391,7 +412,7 @@ public class ConvertedModel { } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (model.inputs().containsKey(referenceNode.getName())) { + if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } @@ -487,7 +508,7 @@ public class ConvertedModel { if (referenceNode.getOutput() == null) { // function references cannot specify outputs names.add(referenceNode.getName()); if (model.functions().containsKey(referenceNode.getName())) { - addFunctionNamesIn(model.functions().get(referenceNode.getName()).getRoot(), names, model); + addFunctionNamesIn(RankingExpression.from(model.functions().get(referenceNode.getName())).getRoot(), names, model); } } } |