summaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-17 13:23:16 +0200
committerJon Bratseth <bratseth@oath.com>2018-09-17 13:23:16 +0200
commitda8b4b3b9b5d197f9f2e6b4da9f33b111a32f0c6 (patch)
treecb06cf65d8757d2f7fb58cffe343576e0c8dd277 /searchlib/src
parent9853b3f674bda90615e78649693ff9b2a3a6ec63 (diff)
Refactor: macro -> function
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java18
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java2
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java2
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java56
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java20
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java2
10 files changed, 59 insertions, 55 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index 17352aea01e..d25502fd149 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -193,10 +193,10 @@ public abstract class ModelImporter {
}
private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) {
- if (operation.macro().isPresent()) {
- TensorFunction function = operation.macro().get();
+ if (operation.rankingExpressionFunction().isPresent()) {
+ TensorFunction function = operation.rankingExpressionFunction().get();
try {
- model.function(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
+ model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString()));
}
catch (ParseException e) {
throw new RuntimeException("Tensorflow function " + function +
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index 5daf5c69548..34f5f1365a1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -38,7 +38,7 @@ public abstract class IntermediateOperation {
protected OrderedTensorType type;
protected TensorFunction function;
- protected TensorFunction macro = null;
+ protected TensorFunction rankingExpressionFunction = null;
private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
@@ -71,8 +71,8 @@ public abstract class IntermediateOperation {
ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
} else if (outputs.size() > 1) {
- macro = lazyGetFunction();
- function = new VariableTensor(macroName(), type.type());
+ rankingExpressionFunction = lazyGetFunction();
+ function = new VariableTensor(rankingExpressionFunctionName(), type.type());
} else {
function = lazyGetFunction();
}
@@ -89,8 +89,10 @@ public abstract class IntermediateOperation {
/** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a function. */
public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); }
- /** Returns a Vespa ranking expression that should be added as a macro */
- public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); }
+ /** Returns a function that should be added as a ranking expression function */
+ public Optional<TensorFunction> rankingExpressionFunction() {
+ return Optional.ofNullable(rankingExpressionFunction);
+ }
/** Add dimension name constraints for this operation */
public void addDimensionNameConstraints(DimensionRenamer renamer) { }
@@ -131,8 +133,10 @@ public abstract class IntermediateOperation {
public String vespaName() { return vespaName(name); }
public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
- /** Retrieve the valid Vespa name of this node if it is a macro */
- public String macroName() { return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; }
+ /** Retrieve the valid Vespa name of this node if it is a ranking expression function */
+ public String rankingExpressionFunctionName() {
+ return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null;
+ }
/** Retrieve the list of warnings produced during its lifetime */
public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
index 384c4b1c4a8..b335fd7e1c5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java
@@ -32,7 +32,7 @@ public class PlaceholderWithDefault extends IntermediateOperation {
}
@Override
- public Optional<TensorFunction> macro() {
+ public Optional<TensorFunction> rankingExpressionFunction() {
// For now, it is much more efficient to assume we always will return
// the default value, as we can prune away large parts of the expression
// tree by having it calculated as a constant. If a case arises where
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 2a5f1646391..eb8d2229a6d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -64,7 +64,7 @@ public final class ReferenceNode extends CompositeNode {
@Override
public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
- // A reference to a macro argument?
+ // A reference to a function argument?
if (reference.isIdentifier() && context.getBinding(getName()) != null) {
// a bound identifier: replace by the value it is bound to
return string.append(context.getBinding(getName()));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 7c929ae24b3..571e1f4d608 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -91,26 +91,26 @@ public class RankingExpressionTestCase {
@Test
public void testSelfRecursionSerialization() throws ParseException {
- List<ExpressionFunction> macros = new ArrayList<>();
- macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo")));
+ List<ExpressionFunction> functions = new ArrayList<>();
+ functions.add(new ExpressionFunction("foo", null, new RankingExpression("foo")));
RankingExpression exp = new RankingExpression("foo");
try {
- exp.getRankProperties(macros);
+ exp.getRankProperties(functions);
} catch (RuntimeException e) {
assertEquals("Cycle in ranking expression function: [foo[]]", e.getMessage());
}
}
@Test
- public void testMacroCycleSerialization() throws ParseException {
- List<ExpressionFunction> macros = new ArrayList<>();
- macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar")));
- macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo")));
+ public void testFunctionCycleSerialization() throws ParseException {
+ List<ExpressionFunction> funnctions = new ArrayList<>();
+ funnctions.add(new ExpressionFunction("foo", null, new RankingExpression("bar")));
+ funnctions.add(new ExpressionFunction("bar", null, new RankingExpression("foo")));
RankingExpression exp = new RankingExpression("foo");
try {
- exp.getRankProperties(macros);
+ exp.getRankProperties(funnctions);
} catch (RuntimeException e) {
assertEquals("Cycle in ranking expression function: [foo[], bar[]]", e.getMessage());
}
@@ -118,11 +118,11 @@ public class RankingExpressionTestCase {
@Test
public void testSerialization() throws ParseException {
- List<ExpressionFunction> macros = new ArrayList<>();
- macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))")));
- macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2")));
- macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)")));
- macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977")));
+ List<ExpressionFunction> functions = new ArrayList<>();
+ functions.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))")));
+ functions.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2")));
+ functions.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)")));
+ functions.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977")));
assertSerialization(Arrays.asList(
"rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)",
@@ -130,19 +130,19 @@ public class RankingExpressionTestCase {
"min(6,pow(7,2))",
"min(1,pow(2,2))",
"min(3,pow(4,2))",
- "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros);
+ "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", functions);
assertSerialization(Arrays.asList(
"rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)",
"min(1,pow(2,2))",
- "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros);
+ "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", functions);
assertSerialization(Arrays.asList(
"rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)",
"min(1,pow(2,2))",
"rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)",
- "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros);
+ "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", functions);
assertSerialization(Arrays.asList(
"rankingExpression(cox)",
- "10 + 08 * 1977"), "cox", macros
+ "10 + 08 * 1977"), "cox", functions
);
}
@@ -159,8 +159,8 @@ public class RankingExpressionTestCase {
@Test
public void testBug3464208() throws ParseException {
- List<ExpressionFunction> macros = new ArrayList<>();
- macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69")));
+ List<ExpressionFunction> functions = new ArrayList<>();
+ functions.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69")));
String lhs = "log10(0.01+attribute(user_followers_count)) * log10(socialratio) * " +
"log10(userage/(0.01+attribute(user_statuses_count)))";
@@ -172,8 +172,8 @@ public class RankingExpressionTestCase {
String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " +
"rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)";
- assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros);
- assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros);
+ assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, functions);
+ assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, functions);
}
@Test
@@ -295,12 +295,12 @@ public class RankingExpressionTestCase {
assertEquals(expected, new RankingExpression(expression).toString());
}
- /** Test serialization with no macros */
+ /** Test serialization with no functions */
private void assertSerialization(String expectedSerialization, String expressionString) {
String serializedExpression;
try {
RankingExpression expression = new RankingExpression(expressionString);
- // No macros -> expect one rank property
+ // No functions -> expect one rank property
serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next();
assertEquals(expectedSerialization, serializedExpression);
}
@@ -309,7 +309,7 @@ public class RankingExpressionTestCase {
}
try {
- // No macros -> output should be parseable to a ranking expression
+ // No functions -> output should be parseable to a ranking expression
// (but not the same one due to primitivization)
RankingExpression reparsedExpression = new RankingExpression(serializedExpression);
// Serializing the primitivized expression should yield the same expression again
@@ -323,17 +323,17 @@ public class RankingExpressionTestCase {
}
private void assertSerialization(List<String> expectedSerialization, String expressionString,
- List<ExpressionFunction> macros) {
- assertSerialization(expectedSerialization, expressionString, macros, false);
+ List<ExpressionFunction> functions) {
+ assertSerialization(expectedSerialization, expressionString, functions, false);
}
private void assertSerialization(List<String> expectedSerialization, String expressionString,
- List<ExpressionFunction> macros, boolean print) {
+ List<ExpressionFunction> functions, boolean print) {
try {
if (print)
System.out.println("Parsing expression '" + expressionString + "'.");
RankingExpression expression = new RankingExpression(expressionString);
- Map<String, String> rankProperties = expression.getRankProperties(macros);
+ Map<String, String> rankProperties = expression.getRankProperties(functions);
if (print) {
for (String key : rankProperties.keySet())
System.out.println("Property '" + key + "': " + rankProperties.get(key));
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index f3cb93e929b..a8f7542f3a4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -18,7 +18,7 @@ public class DropoutImportTestCase {
public void testDropoutImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
- // Check required macros
+ // Check required functions
assertEquals(1, model.get().requiredFunctions().size());
assertTrue(model.get().requiredFunctions().containsKey("X"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index 6856434e2d0..e20ac16a691 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -36,7 +36,7 @@ public class OnnxMnistSoftmaxImportTestCase {
constant1.type());
assertEquals(10, constant1.size());
- // Check required macros (inputs)
+ // Check required functions (inputs)
assertEquals(1, model.requiredFunctions().size());
assertTrue(model.requiredFunctions().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index 97f4e938a4f..ef28eb4678f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -34,10 +34,10 @@ public class TensorFlowMnistSoftmaxImportTestCase {
constant1.type());
assertEquals(10, constant1.size());
- // Check (provided) macros
+ // Check (provided) functions
assertEquals(0, model.get().functions().size());
- // Check required macros
+ // Check required functions
assertEquals(1, model.get().requiredFunctions().size());
assertTrue(model.get().requiredFunctions().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
index e688a983d27..5447e5240f7 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java
@@ -48,7 +48,7 @@ public class TestableTensorFlowModel {
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
- model.functions().forEach((k, v) -> evaluateMacro(context, model, k));
+ model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results",
@@ -62,7 +62,7 @@ public class TestableTensorFlowModel {
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
- model.functions().forEach((k, v) -> evaluateMacro(context, model, k));
+ model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
@@ -96,24 +96,24 @@ public class TestableTensorFlowModel {
return b.build();
}
- private void evaluateMacro(Context context, ImportedModel model, String macroName) {
- if (!context.names().contains(macroName)) {
- RankingExpression e = model.functions().get(macroName);
- evaluateMacroDependencies(context, model, e.getRoot());
- context.put(macroName, new TensorValue(e.evaluate(context).asTensor()));
+ private void evaluateFunction(Context context, ImportedModel model, String functionName) {
+ if (!context.names().contains(functionName)) {
+ RankingExpression e = model.functions().get(functionName);
+ evaluateFunctionDependencies(context, model, e.getRoot());
+ context.put(functionName, new TensorValue(e.evaluate(context).asTensor()));
}
}
- private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) {
+ private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) {
if (node instanceof ReferenceNode) {
String name = node.toString();
if (model.functions().containsKey(name)) {
- evaluateMacro(context, model, name);
+ evaluateFunction(context, model, name);
}
}
else if (node instanceof CompositeNode) {
for (ExpressionNode child : ((CompositeNode)node).children()) {
- evaluateMacroDependencies(context, model, child);
+ evaluateFunctionDependencies(context, model, child);
}
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
index 84e51835458..1f28f0b0129 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java
@@ -27,7 +27,7 @@ public class ConstantDereferencerTestCase {
TransformContext context = new TransformContext(constants);
assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString());
- assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)"), context).toString());
+ assertEquals("myFunction(1.0,2.0)", c.transform(new RankingExpression("myFunction(a, b)"), context).toString());
}
}