diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-10-02 10:01:36 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-10-02 10:01:36 +0200 |
commit | 8909fd9728591d8e00e7babc601c600b26d5acf4 (patch) | |
tree | 53231c4abb7857b8345c5125bb8539519f0d776e /searchlib | |
parent | 55236fc050998712ad6dc136e2b5e45c9d41538f (diff) |
Be truthful about generated functions
Diffstat (limited to 'searchlib')
2 files changed, 17 insertions, 3 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 787b857839d..674571ff73e 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -98,12 +98,24 @@ public class ExpressionFunction { return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType)); } - /** Returns a copy of this with the given argument and argument type added */ - public ExpressionFunction withArgument(String argument, TensorType type) { + /** Returns a copy of this with the given argument added (if not already present) */ + public ExpressionFunction withArgument(String argument) { + if (arguments.contains(argument)) return this; + List<String> arguments = new ArrayList<>(this.arguments); arguments.add(argument); + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); + } + + /** Returns a copy of this with the given argument (if not present) and argument type added */ + public ExpressionFunction withArgument(String argument, TensorType type) { + List<String> arguments = new ArrayList<>(this.arguments); + if ( ! arguments.contains(argument)) + arguments.add(argument); + Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes); argumentTypes.put(argument, type); + return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 6235756d4e1..481b7f9397a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -196,7 +196,9 @@ public abstract class ModelImporter { if (operation.rankingExpressionFunction().isPresent()) { TensorFunction function = operation.rankingExpressionFunction().get(); try { - model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString())); + model.function(operation.rankingExpressionFunctionName(), + new RankingExpression(operation.rankingExpressionFunctionName(), + function.toString())); } catch (ParseException e) { throw new RuntimeException("Tensorflow function " + function + |