summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-02 10:01:36 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-02 10:01:36 +0200
commit8909fd9728591d8e00e7babc601c600b26d5acf4 (patch)
tree53231c4abb7857b8345c5125bb8539519f0d776e /searchlib
parent55236fc050998712ad6dc136e2b5e45c9d41538f (diff)
Be truthful about generated functions
Diffstat (limited to 'searchlib')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java16
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java4
2 files changed, 17 insertions, 3 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 787b857839d..674571ff73e 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -98,12 +98,24 @@ public class ExpressionFunction {
return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType));
}
- /** Returns a copy of this with the given argument and argument type added */
- public ExpressionFunction withArgument(String argument, TensorType type) {
+ /** Returns a copy of this with the given argument added (if not already present) */
+ public ExpressionFunction withArgument(String argument) {
+ if (arguments.contains(argument)) return this;
+
List<String> arguments = new ArrayList<>(this.arguments);
arguments.add(argument);
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
+ }
+
+ /** Returns a copy of this with the given argument (if not present) and argument type added */
+ public ExpressionFunction withArgument(String argument, TensorType type) {
+ List<String> arguments = new ArrayList<>(this.arguments);
+ if ( ! arguments.contains(argument))
+ arguments.add(argument);
+
Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes);
argumentTypes.put(argument, type);
+
return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index 6235756d4e1..481b7f9397a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -196,7 +196,9 @@ public abstract class ModelImporter {
if (operation.rankingExpressionFunction().isPresent()) {
TensorFunction function = operation.rankingExpressionFunction().get();
try {
- model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString()));
+ model.function(operation.rankingExpressionFunctionName(),
+ new RankingExpression(operation.rankingExpressionFunctionName(),
+ function.toString()));
}
catch (ParseException e) {
throw new RuntimeException("Tensorflow function " + function +