summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java6
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java1
4 files changed, 9 insertions, 6 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 59d2d95b879..e5a9e6a5ef1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -33,7 +33,7 @@ public abstract class Value {
/** Returns this as a tensor value */
public abstract Tensor asTensor();
- /** A utility method for wrapping a sdouble in a rank 0 tensor */
+ /** A utility method for wrapping a double in a rank 0 tensor */
protected Tensor doubleAsTensor(double value) {
return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 55782c36d18..ef82045e771 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -156,6 +156,12 @@ class OperationMapper {
private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) {
Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value");
+ if (value.type().rank() == 0) {
+ TypedTensorFunction output = new TypedTensorFunction(value.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(
+ new ConstantNode(new DoubleValue(value.asDouble()))));
+ return Optional.of(output);
+ }
return createConstant(params, value);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index e4c381972e9..ec6af4bb413 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -117,11 +117,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public Tensor evaluate(EvaluationContext context) {
- Value result = expression.evaluate((Context)context);
- if ( ! ( result instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
- "but this returns " + result + ", not a tensor");
- return result.asTensor();
+ return expression.evaluate((Context)context).asTensor();
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 6c7643b37b3..e9030cf5852 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -294,6 +294,7 @@ public class EvaluationTestCase {
"tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }",
"tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }");
+ tester.assertEvaluates("{ {x:0}:0.1 }", "join(tensor0, 0.1, f(x,y) (x*y))", "{ {x:0}:1 }");
// TODO
// argmax