diff options
author | Jon Bratseth <bratseth@gmail.com> | 2020-06-29 11:07:22 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2020-06-29 11:07:22 +0200 |
commit | 4289be15756bd05e880f41b1dd3e81cf054950f8 (patch) | |
tree | 82cc456ea30cb67604c32519c36079f86ca3d940 /vespajlib | |
parent | 7dc5390309ccd905aec92e68d222c0b1783abcc5 (diff) |
Make tensor generate inspectable
Diffstat (limited to 'vespajlib')
3 files changed, 22 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index fa3d70a4ddf..1a12c7a6370 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -72,13 +72,23 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } @Override - public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { + return boundGenerator != null && boundGenerator.asTensorFunction().isPresent() + ? List.of(boundGenerator.asTensorFunction().get()) + : List.of(); + } @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if ( arguments.size() != 0) - throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); - return this; + if ( arguments.size() > 1) + throw new IllegalArgumentException("Generate must have 0 or 1 arguments, got " + arguments.size()); + if (arguments.isEmpty()) return this; + + if (arguments.get(0).asScalarFunction().isEmpty()) + throw new IllegalArgumentException("The argument to generate must be convertible to a tensor function, " + + "but got " + arguments.get(0)); + + return new Generate<>(type, null, arguments.get(0).asScalarFunction().get()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index ec579a90e4f..f8ab9dfa636 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; +import java.util.Optional; import java.util.function.Function; /** @@ -16,6 +17,9 @@ public interface ScalarFunction<NAMETYPE extends Name> extends Function<Evaluati @Override Double apply(EvaluationContext<NAMETYPE> context); + /** Returns this as a tensor function, or empty if it cannot be represented as a tensor function */ + default Optional<TensorFunction<NAMETYPE>> asTensorFunction() { return Optional.empty(); } + default String toString(ToStringContext context) { return toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index b4c5dedbf4e..5c0d0a99441 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; +import java.util.Optional; /** * A representation of a tensor function which is able to be translated to a set of primitive @@ -61,6 +62,9 @@ public abstract class TensorFunction<NAMETYPE extends Name> { */ public abstract String toString(ToStringContext context); + /** Returns this as a scalar function, or empty if it cannot be represented as a scalar function */ + public Optional<ScalarFunction<NAMETYPE>> asScalarFunction() { return Optional.empty(); } + @Override public String toString() { return toString(ToStringContext.empty()); } |