diff options
5 files changed, 35 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index befe2179dc1..86541343edb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -25,6 +25,7 @@ import java.util.Deque; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -157,6 +158,11 @@ public class TensorFunctionNode extends CompositeNode { } @Override + public Optional<TensorFunction<Reference>> asTensorFunction() { + return Optional.of(new ExpressionTensorFunction(expression)); + } + + @Override public String toString() { return toString(ExpressionToStringContext.empty); } @@ -230,6 +236,11 @@ public class TensorFunctionNode extends CompositeNode { } @Override + public Optional<ScalarFunction<Reference>> asScalarFunction() { + return Optional.of(new ExpressionScalarFunction(expression)); + } + + @Override public Tensor evaluate(EvaluationContext<Reference> context) { return expression.evaluate((Context)context).asTensor(); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java index a41fb02f784..bfaff0712ee 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java @@ -29,6 +29,8 @@ public class ConstantDereferencerTestCase { assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString()); assertEquals("myFunction(1.0,2.0)", c.transform(new RankingExpression("myFunction(a, b)"), context).toString()); + assertEquals("tensor(x[2],y[3])((x + y == 1.0))", c.transform(new RankingExpression("tensor(x[2],y[3])(x+y==a)"), context).toString()); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index fa3d70a4ddf..1a12c7a6370 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -72,13 +72,23 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } @Override - public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { + return boundGenerator != null && boundGenerator.asTensorFunction().isPresent() + ? List.of(boundGenerator.asTensorFunction().get()) + : List.of(); + } @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if ( arguments.size() != 0) - throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); - return this; + if ( arguments.size() > 1) + throw new IllegalArgumentException("Generate must have 0 or 1 arguments, got " + arguments.size()); + if (arguments.isEmpty()) return this; + + if (arguments.get(0).asScalarFunction().isEmpty()) + throw new IllegalArgumentException("The argument to generate must be convertible to a tensor function, " + + "but got " + arguments.get(0)); + + return new Generate<>(type, null, arguments.get(0).asScalarFunction().get()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index ec579a90e4f..f8ab9dfa636 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; +import java.util.Optional; import java.util.function.Function; /** @@ -16,6 +17,9 @@ public interface ScalarFunction<NAMETYPE extends Name> extends Function<Evaluati @Override Double apply(EvaluationContext<NAMETYPE> context); + /** Returns this as a tensor function, or empty if it cannot be represented as a tensor function */ + default Optional<TensorFunction<NAMETYPE>> asTensorFunction() { return Optional.empty(); } + default String toString(ToStringContext context) { return toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index b4c5dedbf4e..5c0d0a99441 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; +import java.util.Optional; /** * A representation of a tensor function which is able to be translated to a set of primitive @@ -61,6 +62,9 @@ public abstract class TensorFunction<NAMETYPE extends Name> { */ public abstract String toString(ToStringContext context); + /** Returns this as a scalar function, or empty if it cannot be represented as a scalar function */ + public Optional<ScalarFunction<NAMETYPE>> asScalarFunction() { return Optional.empty(); } + @Override public String toString() { return toString(ToStringContext.empty()); } |