diff options
Diffstat (limited to 'searchlib')
13 files changed, 108 insertions, 91 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index c6d8f70fde8..da34ab8822d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -3,13 +3,16 @@ package com.yahoo.searchlib.rankingexpression; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.text.Utf8; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; -import java.util.*; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.List; +import java.util.Map; /** * A function defined by a ranking expression @@ -24,6 +27,16 @@ public class ExpressionFunction { private final RankingExpression body; /** + * Constructs a new function with no arguments + * + * @param name the name of this function + * @param body the ranking expression that defines this function + */ + public ExpressionFunction(String name, RankingExpression body) { + this(name, Collections.emptyList(), body); + } + + /** * Constructs a new function * * @param name the name of this function diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index ed82ba20fbe..722520fea08 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -250,12 +250,12 @@ public class RankingExpression implements Serializable { /** * Creates the necessary rank properties required to implement this expression. * - * @param macros the expression macros to expand. - * @return a list of named rank properties required to implement this expression. + * @param functions the expression functions to expand + * @return a list of named rank properties required to implement this expression */ - public Map<String, String> getRankProperties(List<ExpressionFunction> macros) { + public Map<String, String> getRankProperties(List<ExpressionFunction> functions) { Deque<String> path = new LinkedList<>(); - SerializationContext context = new SerializationContext(macros); + SerializationContext context = new SerializationContext(functions); String serializedRoot = root.toString(new StringBuilder(), context, path, null).toString(); Map<String, String> serializedExpressions = context.serializedFunctions(); serializedExpressions.put(propertyName(name), serializedRoot); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index ac5eefcc5b2..282a4c5e0a9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -30,8 +30,8 @@ public class ImportedModel { private final Map<String, Tensor> smallConstants = new HashMap<>(); private final Map<String, Tensor> largeConstants = new HashMap<>(); private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final Map<String, RankingExpression> macros = new HashMap<>(); - private final Map<String, TensorType> requiredMacros = new HashMap<>(); + private final Map<String, RankingExpression> functions = new HashMap<>(); + private final Map<String, TensorType> requiredFunctions = new HashMap<>(); /** * Creates a new imported model. @@ -77,13 +77,13 @@ public class ImportedModel { public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } /** - * Returns an immutable map of macros that are part of this model. - * Note that the macros themselves are *not* copies and *not* immutable - they must be copied before modification. + * Returns an immutable map of the functions that are part of this model. + * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. */ - public Map<String, RankingExpression> macros() { return Collections.unmodifiableMap(macros); } + public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } - /** Returns an immutable map of the macros that must be provided by the environment running this model */ - public Map<String, TensorType> requiredMacros() { return Collections.unmodifiableMap(requiredMacros); } + /** Returns an immutable map of the functions that must be provided by the environment running this model */ + public Map<String, TensorType> requiredFunctions() { return Collections.unmodifiableMap(requiredFunctions); } /** Returns an immutable map of the signatures of this */ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } @@ -100,8 +100,8 @@ public class ImportedModel { void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - void macro(String name, RankingExpression expression) { macros.put(name, expression); } - void requiredMacro(String name, TensorType type) { requiredMacros.put(name, type); } + void function(String name, RankingExpression expression) { functions.put(name, expression); } + void requiredFunction(String name, TensorType type) { requiredFunctions.put(name, type); } /** * Returns all the output expressions of this indexed by name. The names consist of one or two parts diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index 2ae107a5770..d25502fd149 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -24,7 +24,7 @@ import java.util.logging.Logger; * ranking expressions. The general mechanism for import is for the * specific ML platform import implementations to create an * IntermediateGraph. This class offers common code to convert the - * IntermediateGraph to Vespa ranking expressions and macros. + * IntermediateGraph to Vespa ranking expressions and functions. * * @author lesters */ @@ -122,7 +122,7 @@ public abstract class ModelImporter { importExpressionInputs(operation, model); importRankingExpression(operation, model); importArgumentExpression(operation, model); - importMacroExpression(operation, model); + importFunctionExpression(operation, model); return operation.function(); } @@ -188,15 +188,15 @@ public abstract class ModelImporter { // All inputs must have dimensions with standard naming convention: d0, d1, ... OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); model.argument(operation.vespaName(), standardNamingConvention.type()); - model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); + model.requiredFunction(operation.vespaName(), standardNamingConvention.type()); } } - private static void importMacroExpression(IntermediateOperation operation, ImportedModel model) { - if (operation.macro().isPresent()) { - TensorFunction function = operation.macro().get(); + private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.rankingExpressionFunction().isPresent()) { + TensorFunction function = operation.rankingExpressionFunction().get(); try { - model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); + model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString())); } catch (ParseException e) { throw new RuntimeException("Tensorflow function " + function + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java index 43de29cedd5..34f5f1365a1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java @@ -29,7 +29,7 @@ import java.util.function.Function; */ public abstract class IntermediateOperation { - private final static String MACRO_PREFIX = "imported_ml_macro_"; + private final static String FUNCTION_PREFIX = "imported_ml_function_"; protected final String name; protected final String modelName; @@ -38,7 +38,7 @@ public abstract class IntermediateOperation { protected OrderedTensorType type; protected TensorFunction function; - protected TensorFunction macro = null; + protected TensorFunction rankingExpressionFunction = null; private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; @@ -71,8 +71,8 @@ public abstract class IntermediateOperation { ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); } else if (outputs.size() > 1) { - macro = lazyGetFunction(); - function = new VariableTensor(macroName(), type.type()); + rankingExpressionFunction = lazyGetFunction(); + function = new VariableTensor(rankingExpressionFunctionName(), type.type()); } else { function = lazyGetFunction(); } @@ -86,11 +86,13 @@ public abstract class IntermediateOperation { /** Return unmodifiable list of inputs */ public List<IntermediateOperation> inputs() { return inputs; } - /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ + /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a function. */ public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } - /** Returns a Vespa ranking expression that should be added as a macro */ - public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); } + /** Returns a function that should be added as a ranking expression function */ + public Optional<TensorFunction> rankingExpressionFunction() { + return Optional.ofNullable(rankingExpressionFunction); + } /** Add dimension name constraints for this operation */ public void addDimensionNameConstraints(DimensionRenamer renamer) { } @@ -131,8 +133,10 @@ public abstract class IntermediateOperation { public String vespaName() { return vespaName(name); } public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } - /** Retrieve the valid Vespa name of this node if it is a macro */ - public String macroName() { return vespaName() != null ? MACRO_PREFIX + modelName + "_" + vespaName() : null; } + /** Retrieve the valid Vespa name of this node if it is a ranking expression function */ + public String rankingExpressionFunctionName() { + return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; + } /** Retrieve the list of warnings produced during its lifetime */ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java index 9299ae9be12..b335fd7e1c5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/PlaceholderWithDefault.java @@ -26,13 +26,13 @@ public class PlaceholderWithDefault extends IntermediateOperation { if (!allInputFunctionsPresent(1)) { return null; } - // This should be a call to the macro we add below, but for now + // This should be a call to the function we add below, but for now // we treat this as as identity function and just pass the constant. return inputs.get(0).function().orElse(null); } @Override - public Optional<TensorFunction> macro() { + public Optional<TensorFunction> rankingExpressionFunction() { // For now, it is much more efficient to assume we always will return // the default value, as we can prune away large parts of the expression // tree by having it calculated as a constant. If a case arises where @@ -42,7 +42,7 @@ public class PlaceholderWithDefault extends IntermediateOperation { @Override public boolean isConstant() { - return true; // not true if we add to macro + return true; // not true if we add to function } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index e2dc170c168..eb8d2229a6d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -15,7 +15,7 @@ import java.util.Deque; import java.util.List; /** - * A node referring either to a value in the context or to a named ranking expression (function aka macro). + * A node referring either to a value in the context or to a named ranking expression function. * * @author bratseth */ @@ -64,7 +64,7 @@ public final class ReferenceNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { - // A reference to a macro argument? + // A reference to a function argument? if (reference.isIdentifier() && context.getBinding(getName()) != null) { // a bound identifier: replace by the value it is bound to return string.append(context.getBinding(getName())); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 7c929ae24b3..571e1f4d608 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -91,26 +91,26 @@ public class RankingExpressionTestCase { @Test public void testSelfRecursionSerialization() throws ParseException { - List<ExpressionFunction> macros = new ArrayList<>(); - macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo"))); + List<ExpressionFunction> functions = new ArrayList<>(); + functions.add(new ExpressionFunction("foo", null, new RankingExpression("foo"))); RankingExpression exp = new RankingExpression("foo"); try { - exp.getRankProperties(macros); + exp.getRankProperties(functions); } catch (RuntimeException e) { assertEquals("Cycle in ranking expression function: [foo[]]", e.getMessage()); } } @Test - public void testMacroCycleSerialization() throws ParseException { - List<ExpressionFunction> macros = new ArrayList<>(); - macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar"))); - macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo"))); + public void testFunctionCycleSerialization() throws ParseException { + List<ExpressionFunction> funnctions = new ArrayList<>(); + funnctions.add(new ExpressionFunction("foo", null, new RankingExpression("bar"))); + funnctions.add(new ExpressionFunction("bar", null, new RankingExpression("foo"))); RankingExpression exp = new RankingExpression("foo"); try { - exp.getRankProperties(macros); + exp.getRankProperties(funnctions); } catch (RuntimeException e) { assertEquals("Cycle in ranking expression function: [foo[], bar[]]", e.getMessage()); } @@ -118,11 +118,11 @@ public class RankingExpressionTestCase { @Test public void testSerialization() throws ParseException { - List<ExpressionFunction> macros = new ArrayList<>(); - macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))"))); - macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2"))); - macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)"))); - macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977"))); + List<ExpressionFunction> functions = new ArrayList<>(); + functions.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))"))); + functions.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2"))); + functions.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)"))); + functions.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977"))); assertSerialization(Arrays.asList( "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)", @@ -130,19 +130,19 @@ public class RankingExpressionTestCase { "min(6,pow(7,2))", "min(1,pow(2,2))", "min(3,pow(4,2))", - "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros); + "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", functions); assertSerialization(Arrays.asList( "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)", "min(1,pow(2,2))", - "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros); + "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", functions); assertSerialization(Arrays.asList( "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)", "min(1,pow(2,2))", "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)", - "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros); + "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", functions); assertSerialization(Arrays.asList( "rankingExpression(cox)", - "10 + 08 * 1977"), "cox", macros + "10 + 08 * 1977"), "cox", functions ); } @@ -159,8 +159,8 @@ public class RankingExpressionTestCase { @Test public void testBug3464208() throws ParseException { - List<ExpressionFunction> macros = new ArrayList<>(); - macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69"))); + List<ExpressionFunction> functions = new ArrayList<>(); + functions.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69"))); String lhs = "log10(0.01+attribute(user_followers_count)) * log10(socialratio) * " + "log10(userage/(0.01+attribute(user_statuses_count)))"; @@ -172,8 +172,8 @@ public class RankingExpressionTestCase { String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " + "rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)"; - assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros); - assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros); + assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, functions); + assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, functions); } @Test @@ -295,12 +295,12 @@ public class RankingExpressionTestCase { assertEquals(expected, new RankingExpression(expression).toString()); } - /** Test serialization with no macros */ + /** Test serialization with no functions */ private void assertSerialization(String expectedSerialization, String expressionString) { String serializedExpression; try { RankingExpression expression = new RankingExpression(expressionString); - // No macros -> expect one rank property + // No functions -> expect one rank property serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next(); assertEquals(expectedSerialization, serializedExpression); } @@ -309,7 +309,7 @@ public class RankingExpressionTestCase { } try { - // No macros -> output should be parseable to a ranking expression + // No functions -> output should be parseable to a ranking expression // (but not the same one due to primitivization) RankingExpression reparsedExpression = new RankingExpression(serializedExpression); // Serializing the primitivized expression should yield the same expression again @@ -323,17 +323,17 @@ public class RankingExpressionTestCase { } private void assertSerialization(List<String> expectedSerialization, String expressionString, - List<ExpressionFunction> macros) { - assertSerialization(expectedSerialization, expressionString, macros, false); + List<ExpressionFunction> functions) { + assertSerialization(expectedSerialization, expressionString, functions, false); } private void assertSerialization(List<String> expectedSerialization, String expressionString, - List<ExpressionFunction> macros, boolean print) { + List<ExpressionFunction> functions, boolean print) { try { if (print) System.out.println("Parsing expression '" + expressionString + "'."); RankingExpression expression = new RankingExpression(expressionString); - Map<String, String> rankProperties = expression.getRankProperties(macros); + Map<String, String> rankProperties = expression.getRankProperties(functions); if (print) { for (String key : rankProperties.keySet()) System.out.println("Property '" + key + "': " + rankProperties.get(key)); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index a63c7346335..a8f7542f3a4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -18,11 +18,11 @@ public class DropoutImportTestCase { public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved"); - // Check required macros - assertEquals(1, model.get().requiredMacros().size()); - assertTrue(model.get().requiredMacros().containsKey("X")); + // Check required functions + assertEquals(1, model.get().requiredFunctions().size()); + assertTrue(model.get().requiredFunctions().containsKey("X")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().requiredMacros().get("X")); + model.get().requiredFunctions().get("X")); ImportedModel.Signature signature = model.get().signature("serving_default"); @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/Maximum", output.getName()); - assertEquals("join(join(imported_ml_macro_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_macro_test_outputs_BiasAdd, f(a,b)(max(a,b)))", + assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index bcfc6ce0a04..e20ac16a691 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -36,11 +36,11 @@ public class OnnxMnistSoftmaxImportTestCase { constant1.type()); assertEquals(10, constant1.size()); - // Check required macros (inputs) - assertEquals(1, model.requiredMacros().size()); - assertTrue(model.requiredMacros().containsKey("Placeholder")); + // Check required functions (inputs) + assertEquals(1, model.requiredFunctions().size()); + assertTrue(model.requiredFunctions().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.requiredMacros().get("Placeholder")); + model.requiredFunctions().get("Placeholder")); // Check outputs RankingExpression output = model.defaultSignature().outputExpression("add"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index dd6c8095e3c..ef28eb4678f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -34,14 +34,14 @@ public class TensorFlowMnistSoftmaxImportTestCase { constant1.type()); assertEquals(10, constant1.size()); - // Check (provided) macros - assertEquals(0, model.get().macros().size()); + // Check (provided) functions + assertEquals(0, model.get().functions().size()); - // Check required macros - assertEquals(1, model.get().requiredMacros().size()); - assertTrue(model.get().requiredMacros().containsKey("Placeholder")); + // Check required functions + assertEquals(1, model.get().requiredFunctions().size()); + assertTrue(model.get().requiredFunctions().containsKey("Placeholder")); assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), - model.get().requiredMacros().get("Placeholder")); + model.get().requiredFunctions().get("Placeholder")); // Check signatures assertEquals(1, model.get().signatures().size()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java index 4de3aa5d635..5447e5240f7 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TestableTensorFlowModel.java @@ -48,7 +48,7 @@ public class TestableTensorFlowModel { Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); - model.macros().forEach((k,v) -> evaluateMacro(context, model, k)); + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", @@ -62,7 +62,7 @@ public class TestableTensorFlowModel { Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); - model.macros().forEach((k,v) -> evaluateMacro(context, model, k)); + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); @@ -96,24 +96,24 @@ public class TestableTensorFlowModel { return b.build(); } - private void evaluateMacro(Context context, ImportedModel model, String macroName) { - if (!context.names().contains(macroName)) { - RankingExpression e = model.macros().get(macroName); - evaluateMacroDependencies(context, model, e.getRoot()); - context.put(macroName, new TensorValue(e.evaluate(context).asTensor())); + private void evaluateFunction(Context context, ImportedModel model, String functionName) { + if (!context.names().contains(functionName)) { + RankingExpression e = model.functions().get(functionName); + evaluateFunctionDependencies(context, model, e.getRoot()); + context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); } } - private void evaluateMacroDependencies(Context context, ImportedModel model, ExpressionNode node) { + private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) { if (node instanceof ReferenceNode) { String name = node.toString(); - if (model.macros().containsKey(name)) { - evaluateMacro(context, model, name); + if (model.functions().containsKey(name)) { + evaluateFunction(context, model, name); } } else if (node instanceof CompositeNode) { for (ExpressionNode child : ((CompositeNode)node).children()) { - evaluateMacroDependencies(context, model, child); + evaluateFunctionDependencies(context, model, child); } } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java index 84e51835458..1f28f0b0129 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencerTestCase.java @@ -27,7 +27,7 @@ public class ConstantDereferencerTestCase { TransformContext context = new TransformContext(constants); assertEquals("1.0 + 2.0 + 3.5", c.transform(new RankingExpression("a + b + c"), context).toString()); - assertEquals("myMacro(1.0,2.0)", c.transform(new RankingExpression("myMacro(a, b)"), context).toString()); + assertEquals("myFunction(1.0,2.0)", c.transform(new RankingExpression("myFunction(a, b)"), context).toString()); } } |