summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2020-06-29 11:07:22 +0200
committerJon Bratseth <bratseth@gmail.com>2020-06-29 11:07:22 +0200
commit4289be15756bd05e880f41b1dd3e81cf054950f8 (patch)
tree82cc456ea30cb67604c32519c36079f86ca3d940
parent7dc5390309ccd905aec92e68d222c0b1783abcc5 (diff)
Make tensor generate inspectable
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java11
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java4
5 files changed, 35 insertions, 4 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index befe2179dc1..86541343edb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -25,6 +25,7 @@ import java.util.Deque;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.stream.Collectors;
/**
@@ -157,6 +158,11 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
+ public Optional<TensorFunction<Reference>> asTensorFunction() {
+ return Optional.of(new ExpressionTensorFunction(expression));
+ }
+
+ @Override
public String toString() {
return toString(ExpressionToStringContext.empty);
}
@@ -230,6 +236,11 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
+ public Optional<ScalarFunction<Reference>> asScalarFunction() {
+ return Optional.of(new ExpressionScalarFunction(expression));
+ }
+
+ @Override
public Tensor evaluate(EvaluationContext<Reference> context) {
return expression.evaluate((Context)context).asTensor();
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
index a41fb02f784..bfaff0712ee 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
@@ -29,6 +29,8 @@ public class ConstantDereferencerTestCase {
assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString());
assertEquals("myFunction(1.0,2.0)", c.transform(new RankingExpression("myFunction(a, b)"), context).toString());
+ assertEquals("tensor(x[2],y[3])((x + y == 1.0))", c.transform(new RankingExpression("tensor(x[2],y[3])(x+y==a)"), context).toString());
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index fa3d70a4ddf..1a12c7a6370 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -72,13 +72,23 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
@Override
- public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() {
+ return boundGenerator != null && boundGenerator.asTensorFunction().isPresent()
+ ? List.of(boundGenerator.asTensorFunction().get())
+ : List.of();
+ }
@Override
public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if ( arguments.size() != 0)
- throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size());
- return this;
+ if ( arguments.size() > 1)
+ throw new IllegalArgumentException("Generate must have 0 or 1 arguments, got " + arguments.size());
+ if (arguments.isEmpty()) return this;
+
+ if (arguments.get(0).asScalarFunction().isEmpty())
+ throw new IllegalArgumentException("The argument to generate must be convertible to a tensor function, " +
+ "but got " + arguments.get(0));
+
+ return new Generate<>(type, null, arguments.get(0).asScalarFunction().get());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
index ec579a90e4f..f8ab9dfa636 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
+import java.util.Optional;
import java.util.function.Function;
/**
@@ -16,6 +17,9 @@ public interface ScalarFunction<NAMETYPE extends Name> extends Function<Evaluati
@Override
Double apply(EvaluationContext<NAMETYPE> context);
+ /** Returns this as a tensor function, or empty if it cannot be represented as a tensor function */
+ default Optional<TensorFunction<NAMETYPE>> asTensorFunction() { return Optional.empty(); }
+
default String toString(ToStringContext context) { return toString(); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index b4c5dedbf4e..5c0d0a99441 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
+import java.util.Optional;
/**
* A representation of a tensor function which is able to be translated to a set of primitive
@@ -61,6 +62,9 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
*/
public abstract String toString(ToStringContext context);
+ /** Returns this as a scalar function, or empty if it cannot be represented as a scalar function */
+ public Optional<ScalarFunction<NAMETYPE>> asScalarFunction() { return Optional.empty(); }
+
@Override
public String toString() { return toString(ToStringContext.empty()); }