diff options
Diffstat (limited to 'searchlib')
2 files changed, 17 insertions, 3 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 787b857839d..674571ff73e 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -98,12 +98,24 @@ public class ExpressionFunction { return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType)); } - /** Returns a copy of this with the given argument and argument type added */ - public ExpressionFunction withArgument(String argument, TensorType type) { + /** Returns a copy of this with the given argument added (if not already present) */ + public ExpressionFunction withArgument(String argument) { + if (arguments.contains(argument)) return this; + List<String> arguments = new ArrayList<>(this.arguments); arguments.add(argument); + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); + } + + /** Returns a copy of this with the given argument (if not present) and argument type added */ + public ExpressionFunction withArgument(String argument, TensorType type) { + List<String> arguments = new ArrayList<>(this.arguments); + if ( ! arguments.contains(argument)) + arguments.add(argument); + Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes); argumentTypes.put(argument, type); + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 6235756d4e1..481b7f9397a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -196,7 +196,9 @@ public abstract class ModelImporter { if (operation.rankingExpressionFunction().isPresent()) { TensorFunction function = operation.rankingExpressionFunction().get(); try { - model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString())); + model.function(operation.rankingExpressionFunctionName(), + new RankingExpression(operation.rankingExpressionFunctionName(), + function.toString())); } catch (ParseException e) { throw new RuntimeException("Tensorflow function " + function + |