aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-02-08 09:15:56 +0100
committerLester Solbakken <lesters@oath.com>2018-02-08 09:15:56 +0100
commitff500ab6c72887f64cfbf0e6b40748c7c6e9dd08 (patch)
tree08941eb8a4b630e447e4209b519f3aa713f94a73
parent74b3ef7b54e8ac8b0473c016185f1476a3fd3db4 (diff)
Inline small tensor constants imported from tensorflow
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java40
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java11
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java1
6 files changed, 15 insertions, 47 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 83cc3ae418a..3e11eb72a30 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -250,46 +250,6 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
- @Test
- public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(rename(reduce(join(map(join(rename(reduce(join(join(join(constant(\"dnn_hidden1_mul_x\"), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))";
- StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "tensorflow('mnist/saved')",
- null,
- null,
- "input",
- application);
- search.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search);
-
- // At this point the expression is stored - copy application to another location which do not have a models dir
- Path storedApplicationDirectory = applicationDir.getParentPath().append("copy");
- try {
- storedApplicationDirectory.toFile().mkdirs();
- IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
- storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
- "tensorflow('mnist/saved')",
- null,
- null,
- "input",
- storedApplication);
- searchFromStored.assertFirstPhaseExpression(expression, "my_profile");
- assertSmallConstant("dnn_hidden1_mul_x", TensorType.empty, search);
- }
- finally {
- IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
- }
- }
-
- private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
- Value value = search.rankProfile("my_profile").getConstants().get(name);
- assertNotNull(value);
- assertEquals(type, value.type());
- }
-
/**
* Verifies that the constant with the given name exists, and - only if an expected size is given -
* that the content of the constant is available and has the expected size.
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
index c18cfcfe1aa..b001db69768 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java
@@ -64,7 +64,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase {
assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)");
assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)");
assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...)
- assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)");
+ assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)");
}
@Test
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
index 59d2d95b879..e5a9e6a5ef1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java
@@ -33,7 +33,7 @@ public abstract class Value {
/** Returns this as a tensor value */
public abstract Tensor asTensor();
- /** A utility method for wrapping a sdouble in a rank 0 tensor */
+ /** A utility method for wrapping a double in a rank 0 tensor */
protected Tensor doubleAsTensor(double value) {
return Tensor.Builder.of(TensorType.empty).cell(TensorAddress.of(), value).build();
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 55782c36d18..ef82045e771 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -156,6 +156,12 @@ class OperationMapper {
private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) {
Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value");
+ if (value.type().rank() == 0) {
+ TypedTensorFunction output = new TypedTensorFunction(value.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(
+ new ConstantNode(new DoubleValue(value.asDouble()))));
+ return Optional.of(output);
+ }
return createConstant(params, value);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index e4c381972e9..927bb4c0ea2 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -117,11 +117,12 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public Tensor evaluate(EvaluationContext context) {
- Value result = expression.evaluate((Context)context);
- if ( ! ( result instanceof TensorValue))
- throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
- "but this returns " + result + ", not a tensor");
- return result.asTensor();
+// Value result = expression.evaluate((Context)context);
+// if ( ! ( result instanceof TensorValue))
+// throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
+// "but this returns " + result + ", not a tensor");
+// return result.asTensor();
+ return expression.evaluate((Context)context).asTensor();
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 6c7643b37b3..e9030cf5852 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -294,6 +294,7 @@ public class EvaluationTestCase {
"tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }");
tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }",
"tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }");
+ tester.assertEvaluates("{ {x:0}:0.1 }", "join(tensor0, 0.1, f(x,y) (x*y))", "{ {x:0}:1 }");
// TODO
// argmax