diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-17 13:23:16 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-17 13:23:16 +0200 |
commit | da8b4b3b9b5d197f9f2e6b4da9f33b111a32f0c6 (patch) | |
tree | cb06cf65d8757d2f7fb58cffe343576e0c8dd277 /searchlib/src | |
parent | 9853b3f674bda90615e78649693ff9b2a3a6ec63 (diff) |
Refactor: macro -> function
Diffstat (limited to 'searchlib/src')
10 files changed, 59 insertions, 55 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 17352aea01e..d25502fd149 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -193,10 +193,10 @@ public abstract class ModelImporter { } private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) { - if (operation.macro().isPresent()) { - TensorFunction function = operation.macro().get(); + if (operation.rankingExpressionFunction().isPresent()) { + TensorFunction function = operation.rankingExpressionFunction().get(); try { - model.function(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); + model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString())); } catch (ParseException e) { throw new RuntimeException("Tensorflow function " + function + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java index 5daf5c69548..34f5f1365a1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java @@ -38,7 +38,7 @@ public abstract class IntermediateOperation { protected OrderedTensorType type; protected TensorFunction function; - protected TensorFunction macro = null; + protected TensorFunction rankingExpressionFunction = null; private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; @@ -71,8 +71,8 @@ public abstract class IntermediateOperation { ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); } else if (outputs.size() > 1) { - macro = lazyGetFunction(); - function = new VariableTensor(macroName(), type.type()); + rankingExpressionFunction = lazyGetFunction(); + function = new VariableTensor(rankingExpressionFunctionName(), type.type()); } else { function = lazyGetFunction(); } @@ -89,8 +89,10 @@ public abstract class IntermediateOperation { /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a function. */ public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } - /** Returns a Vespa ranking expression that should be added as a macro */ - public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); } + /** Returns a function that should be added as a ranking expression function */ + public Optional<TensorFunction> rankingExpressionFunction() { + return Optional.ofNullable(rankingExpressionFunction); + } /** Add dimension name constraints for this operation */ public void addDimensionNameConstraints(DimensionRenamer renamer) { } @@ -131,8 +133,10 @@ public abstract class IntermediateOperation { public String vespaName() { return vespaName(name); } public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } - /** Retrieve the valid Vespa name of this node if it is a macro */ - public String macroName() { return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; } + /** Retrieve the valid Vespa name of this node if it is a ranking expression function */ + public String rankingExpressionFunctionName() { + return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; + } /** Retrieve the list of warnings produced during its lifetime */ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java index 384c4b1c4a8..b335fd7e1c5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java @@ -32,7 +32,7 @@ public class PlaceholderWithDefault extends IntermediateOperation { } @Override - public Optional<TensorFunction> macro() { + public Optional<TensorFunction> rankingExpressionFunction() { // For now, it is much more efficient to assume we always will return // the default value, as we can prune away large parts of the expression // tree by having it calculated as a constant. If a case arises where diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 2a5f1646391..eb8d2229a6d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -64,7 +64,7 @@ public final class ReferenceNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { - // A reference to a macro argument? + // A reference to a function argument? if (reference.isIdentifier() && context.getBinding(getName()) != null) { // a bound identifier: replace by the value it is bound to return string.append(context.getBinding(getName())); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 7c929ae24b3..571e1f4d608 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -91,26 +91,26 @@ public class RankingExpressionTestCase { @Test public void testSelfRecursionSerialization() throws ParseException { - List<ExpressionFunction> macros = new ArrayList<>(); - macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo"))); + List<ExpressionFunction> functions = new ArrayList<>(); + functions.add(new ExpressionFunction("foo", null, new RankingExpression("foo"))); RankingExpression exp = new RankingExpression("foo"); try { - exp.getRankProperties(macros); + exp.getRankProperties(functions); } catch (RuntimeException e) { assertEquals("Cycle in ranking expression function: [foo[]]", e.getMessage()); } } @Test - public void testMacroCycleSerialization() throws ParseException { - List<ExpressionFunction> macros = new ArrayList<>(); - macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar"))); - macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo"))); + public void testFunctionCycleSerialization() throws ParseException { + List<ExpressionFunction> funnctions = new ArrayList<>(); + funnctions.add(new ExpressionFunction("foo", null, new RankingExpression("bar"))); + funnctions.add(new ExpressionFunction("bar", null, new RankingExpression("foo"))); RankingExpression exp = new RankingExpression("foo"); try { - exp.getRankProperties(macros); + exp.getRankProperties(funnctions); } catch (RuntimeException e) { assertEquals("Cycle in ranking expression function: [foo[], bar[]]", e.getMessage()); } @@ -118,11 +118,11 @@ public class RankingExpressionTestCase { @Test public void testSerialization() throws ParseException { - List<ExpressionFunction> macros = new ArrayList<>(); - macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))"))); - macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2"))); - macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)"))); - macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977"))); + List<ExpressionFunction> functions = new ArrayList<>(); + functions.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))"))); + functions.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2"))); + functions.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)"))); + functions.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977"))); assertSerialization(Arrays.asList( "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)", @@ -130,19 +130,19 @@ public class RankingExpressionTestCase { "min(6,pow(7,2))", "min(1,pow(2,2))", "min(3,pow(4,2))", - "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros); + "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", functions); assertSerialization(Arrays.asList( "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)", "min(1,pow(2,2))", - "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros); + "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", functions); assertSerialization(Arrays.asList( "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)", "min(1,pow(2,2))", "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)", - "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros); + "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", functions); assertSerialization(Arrays.asList( "rankingExpression(cox)", - "10 + 08 * 1977"), "cox", macros + "10 + 08 * 1977"), "cox", functions ); } @@ -159,8 +159,8 @@ public class RankingExpressionTestCase { @Test public void testBug3464208() throws ParseException { - List<ExpressionFunction> macros = new ArrayList<>(); - macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69"))); + List<ExpressionFunction> functions = new ArrayList<>(); + functions.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69"))); String lhs = "log10(0.01+attribute(user_followers_count)) * log10(socialratio) * " + "log10(userage/(0.01+attribute(user_statuses_count)))"; @@ -172,8 +172,8 @@ public class RankingExpressionTestCase { String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " + "rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)"; - assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros); - assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros); + assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, functions); + assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, functions); } @Test @@ -295,12 +295,12 @@ public class RankingExpressionTestCase { assertEquals(expected, new RankingExpression(expression).toString()); } - /** Test serialization with no macros */ + /** Test serialization with no functions */ private void assertSerialization(String expectedSerialization, String expressionString) { String serializedExpression; try { RankingExpression expression = new RankingExpression(expressionString); - // No macros -> expect one rank property + // No functions -> expect one rank property serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next(); assertEquals(expectedSerialization, serializedExpression); } @@ -309,7 +309,7 @@ public class RankingExpressionTestCase { } try { - // No macros -> output should be parseable to a ranking expression + // No functions -> output should be parseable to a ranking expression // (but not the same one due to primitivization) RankingExpression reparsedExpression = new RankingExpression(serializedExpression); // Serializing the primitivized expression should yield the same expression again @@ -323,17 +323,17 @@ public class RankingExpressionTestCase { } private void assertSerialization(List<String> expectedSerialization, String expressionString, - List<ExpressionFunction> macros) { - assertSerialization(expectedSerialization, expressionString, macros, false); + List<ExpressionFunction> functions) { + assertSerialization(expectedSerialization, expressionString, functions, false); } private void assertSerialization(List<String> expectedSerialization, String expressionString, - List<ExpressionFunction> macros, boolean print) { + List<ExpressionFunction> functions, boolean print) { try { if (print) System.out.println("Parsing expression '" + expressionString + "'."); RankingExpression expression = new RankingExpression(expressionString); - Map<String, String> rankProperties = expression.getRankProperties(macros); + Map<String, String> rankProperties = expression.getRankProperties(functions); if (print) { for (String key : rankProperties.keySet()) System.out.println("Property '" + key + "': " + rankProperties.get(key)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index f3cb93e929b..a8f7542f3a4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -18,7 +18,7 @@ public class DropoutImportTestCase { public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved"); - // Check required macros + // Check required functions assertEquals(1, model.get().requiredFunctions().size()); assertTrue(model.get().requiredFunctions().containsKey("X")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index 6856434e2d0..e20ac16a691 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -36,7 +36,7 @@ public class OnnxMnistSoftmaxImportTestCase { constant1.type()); assertEquals(10, constant1.size()); - // Check required macros (inputs) + // Check required functions (inputs) assertEquals(1, model.requiredFunctions().size()); assertTrue(model.requiredFunctions().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index 97f4e938a4f..ef28eb4678f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -34,10 +34,10 @@ public class TensorFlowMnistSoftmaxImportTestCase { constant1.type()); assertEquals(10, constant1.size()); - // Check (provided) macros + // Check (provided) functions assertEquals(0, model.get().functions().size()); - // Check required macros + // Check required functions assertEquals(1, model.get().requiredFunctions().size()); assertTrue(model.get().requiredFunctions().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index e688a983d27..5447e5240f7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -48,7 +48,7 @@ public class TestableTensorFlowModel { Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); - model.functions().forEach((k, v) -> evaluateMacro(context, model, k)); + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", @@ -62,7 +62,7 @@ public class TestableTensorFlowModel { Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); - model.functions().forEach((k, v) -> evaluateMacro(context, model, k)); + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); @@ -96,24 +96,24 @@ public class TestableTensorFlowModel { return b.build(); } - private void evaluateMacro(Context context, ImportedModel model, String macroName) { - if (!context.names().contains(macroName)) { - RankingExpression e = model.functions().get(macroName); - evaluateMacroDependencies(context, model, e.getRoot()); - context.put(macroName, new TensorValue(e.evaluate(context).asTensor())); + private void evaluateFunction(Context context, ImportedModel model, String functionName) { + if (!context.names().contains(functionName)) { + RankingExpression e = model.functions().get(functionName); + evaluateFunctionDependencies(context, model, e.getRoot()); + context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); } } - private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) { + private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) { if (node instanceof ReferenceNode) { String name = node.toString(); if (model.functions().containsKey(name)) { - evaluateMacro(context, model, name); + evaluateFunction(context, model, name); } } else if (node instanceof CompositeNode) { for (ExpressionNode child : ((CompositeNode)node).children()) { - evaluateMacroDependencies(context, model, child); + evaluateFunctionDependencies(context, model, child); } } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java index 84e51835458..1f28f0b0129 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java @@ -27,7 +27,7 @@ public class ConstantDereferencerTestCase { TransformContext context = new TransformContext(constants); assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString()); - assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)"), context).toString()); + assertEquals("myFunction(1.0,2.0)", c.transform(new RankingExpression("myFunction(a, b)"), context).toString()); } } |