diff options
author | Lester Solbakken <lesters@oath.com> | 2018-02-08 09:15:56 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-02-08 09:15:56 +0100 |
commit | ff500ab6c72887f64cfbf0e6b40748c7c6e9dd08 (patch) | |
tree | 08941eb8a4b630e447e4209b519f3aa713f94a73 /searchlib | |
parent | 74b3ef7b54e8ac8b0473c016185f1476a3fd3db4 (diff) |
Inline small tensor constants imported from tensorflow
Diffstat (limited to 'searchlib')
4 files changed, 14 insertions, 6 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index 59d2d95b879..e5a9e6a5ef1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -33,7 +33,7 @@ public abstract class Value { /** Returns this as a tensor value */ public abstract Tensor asTensor(); - /** A utility method for wrapping a sdouble in a rank 0 tensor */ + /** A utility method for wrapping a double in a rank 0 tensor */ protected Tensor doubleAsTensor(double value) { return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 55782c36d18..ef82045e771 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -156,6 +156,12 @@ class OperationMapper { private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) { Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value"); + if (value.type().rank() == 0) { + TypedTensorFunction output = new TypedTensorFunction(value.type(), + new TensorFunctionNode.TensorFunctionExpressionNode( + new ConstantNode(new DoubleValue(value.asDouble())))); + return Optional.of(output); + } return createConstant(params, value); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index e4c381972e9..927bb4c0ea2 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -117,11 +117,12 @@ public class TensorFunctionNode extends CompositeNode { @Override public Tensor evaluate(EvaluationContext context) { - Value result = expression.evaluate((Context)context); - if ( ! ( result instanceof TensorValue)) - throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + - "but this returns " + result + ", not a tensor"); - return result.asTensor(); +// Value result = expression.evaluate((Context)context); +// if ( ! ( result instanceof TensorValue)) +// throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " + +// "but this returns " + result + ", not a tensor"); +// return result.asTensor(); + return expression.evaluate((Context)context).asTensor(); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 6c7643b37b3..e9030cf5852 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -294,6 +294,7 @@ public class EvaluationTestCase { "tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }", "tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }"); + tester.assertEvaluates("{ {x:0}:0.1 }", "join(tensor0, 0.1, f(x,y) (x*y))", "{ {x:0}:1 }"); // TODO // argmax |