summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java18
1 files changed, 14 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index fa3d70a4ddf..1a12c7a6370 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -72,13 +72,23 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
@Override
- public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() {
+ return boundGenerator != null && boundGenerator.asTensorFunction().isPresent()
+ ? List.of(boundGenerator.asTensorFunction().get())
+ : List.of();
+ }
@Override
public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if ( arguments.size() != 0)
- throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size());
- return this;
+ if ( arguments.size() > 1)
+ throw new IllegalArgumentException("Generate must have 0 or 1 arguments, got " + arguments.size());
+ if (arguments.isEmpty()) return this;
+
+ if (arguments.get(0).asScalarFunction().isEmpty())
+ throw new IllegalArgumentException("The argument to generate must be convertible to a tensor function, " +
+ "but got " + arguments.get(0));
+
+ return new Generate<>(type, null, arguments.get(0).asScalarFunction().get());
}
@Override