diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-07 15:49:00 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-03-07 15:49:00 +0100 |
commit | e3f5b2812b51896d4a4d5304e6e8c7060e60f68a (patch) | |
tree | 50f86020fe0b4c295bffd14bb877251293593f28 /searchlib | |
parent | 584dcf0b4a54b5e5a70696c15ee0c2bfe63ab656 (diff) |
Allow macros to replace TenorFlow variables
Also, remove quoting of constant arguments generated
in TensorFlow as that is unnecessary now and is
interpreted as a string constant argument to a macro.
Diffstat (limited to 'searchlib')
9 files changed, 52 insertions, 31 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java index 0fe73fad8ce..ee358f45b22 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java @@ -34,9 +34,7 @@ public class OperationMapper { public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) { switch (node.getOp().toLowerCase()) { - /* - * array ops - */ + // array ops case "const": return new Const(node, inputs, port); case "expanddims": return new ExpandDims(node, inputs, port); case "identity": return new Identity(node, inputs, port); @@ -46,15 +44,11 @@ public class OperationMapper { case "shape": return new Shape(node, inputs, port); case "squeeze": return new Squeeze(node, inputs, port); - /* - * control flow - */ + // control flow case "merge": return new Merge(node, inputs, port); case "switch": return new Switch(node, inputs, port); - /* - * math ops - */ + // math ops case "add": return new Join(node, inputs, port, ScalarFunctions.add()); case "add_n": return new Join(node, inputs, port, ScalarFunctions.add()); case "acos": return new Map(node, inputs, port, ScalarFunctions.acos()); @@ -75,27 +69,17 @@ public class OperationMapper { case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract()); case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract()); - /* - * nn ops - */ + // nn ops case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add()); case "elu": return new Map(node, inputs, port, ScalarFunctions.elu()); case "relu": return new Map(node, inputs, port, ScalarFunctions.relu()); case "selu": return new Map(node, inputs, port, ScalarFunctions.selu()); - /* - * random ops - */ - - /* - * state ops - */ + // state ops case "variable": return new Variable(node, inputs, port); case "variablev2": return new Variable(node, inputs, port); - /* - * evaluation no-ops - */ + // evaluation no-ops case "stopgradient":return new Identity(node, inputs, port); case "noop": return new NoOp(node, inputs, port); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java index 7decef51ab7..d06d7b48def 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -46,7 +47,7 @@ public class Const extends TensorFlowOperation { if (type.type().rank() == 0 && getConstantValue().isPresent()) { expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue()); } else { - expressionNode = new ReferenceNode("constant(\"" + vespaName() + "\")"); + expressionNode = new ReferenceNode(Reference.simple("constant", vespaName())); } return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); } @@ -72,7 +73,7 @@ public class Const extends TensorFlowOperation { private Value value() { if (!node.getAttrMap().containsKey("value")) { throw new IllegalArgumentException("Node '" + node.getName() + "' of type " + - "const has missing 'value' attribute"); + "const has missing 'value' attribute"); } AttrValue attrValue = node.getAttrMap().get("value"); if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java index 9e8f6df3e2c..5d711aac100 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; @@ -67,7 +68,7 @@ public abstract class TensorFlowOperation { public Optional<TensorFunction> function() { if (function == null) { if (isConstant()) { - ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")"); + ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); } else if (outputs.size() > 1) { macro = lazyGetFunction(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 822d6055815..639c5d22d9e 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; @@ -105,7 +106,8 @@ public final class ReferenceNode extends CompositeNode { // TODO: Context should accept a Reference instead. if (reference.isIdentifier()) return context.get(reference.name()); - return context.get(getName(), getArguments(), getOutput()); + else + return context.get(getName(), getArguments(), getOutput()); } @Override diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java new file mode 100644 index 00000000000..f275f95ca8e --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/ReferenceTestCase.java @@ -0,0 +1,33 @@ +package com.yahoo.searchlib.rankingexpression; + +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.NameNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import org.junit.Test; + +import static junit.framework.TestCase.assertTrue; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +/** + * @author bratseth + */ +public class ReferenceTestCase { + + @Test + public void testSimple() { + assertTrue(new Reference("foo", new Arguments(new ReferenceNode("arg")), null).isSimple()); + assertTrue(new Reference("foo", new Arguments(new ReferenceNode("arg")), "out").isSimple()); + assertTrue(new Reference("foo", new Arguments(new NameNode("arg")), "out").isSimple()); + assertFalse(new Reference("foo", new Arguments(), null).isSimple()); + } + + @Test + public void testToString() { + assertEquals("foo(arg_1)", new Reference("foo", new Arguments(new ReferenceNode("arg_1")), null).toString()); + assertEquals("foo(arg_1).out", new Reference("foo", new Arguments(new ReferenceNode("arg_1")), "out").toString()); + assertEquals("foo(arg_1).out", new Reference("foo", new Arguments(new NameNode("arg_1")), "out").toString()); + assertEquals("foo", new Reference("foo", new Arguments(), null).toString()); + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index c09b1f2b606..a13ff3147c8 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/BiasAdd", output.getName()); - assertEquals("join(reduce(join(tf_macro_X, constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", + assertEquals("join(reduce(join(tf_macro_X, constant(outputs_kernel_read), f(a,b)(a * b)), sum, d2), constant(outputs_bias_read), f(a,b)(a + b))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index 60dd3865aa1..0deac3f8216 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -59,7 +59,7 @@ public class MnistSoftmaxImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(Variable_read), f(a,b)(a * b)), sum, d2), constant(Variable_1_read), f(a,b)(a + b))", output.getRoot().toString()); // Test execution diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 9f372d8d6f5..daacd014b63 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -68,8 +68,8 @@ public class TestableTensorFlowModel { private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); - result.largeConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); - result.smallConstants().forEach((name, tensor) -> context.put("constant(\"" + name + "\")", new TensorValue(tensor))); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); return context; } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java index 75b8e1122c1..135cc95a209 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNodeTestCase.java @@ -9,7 +9,7 @@ import java.util.List; import static org.junit.Assert.assertEquals; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public class ReferenceNodeTestCase { |