diff options
15 files changed, 133 insertions, 71 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java index 164cb7f808e..5ac1418c0c7 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankingConstants.java @@ -1,6 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition; +import com.yahoo.config.FileReference; +import com.yahoo.vespa.model.AbstractService; +import com.yahoo.vespa.model.utils.FileSender; + +import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -33,4 +38,14 @@ public class RankingConstants { return Collections.unmodifiableMap(constants); } + /** Initiate sending of these constants to some services over file distribution */ + public void sendTo(Collection<? extends AbstractService> services) { + for (RankingConstant constant : constants.values()) { + FileReference reference = (constant.getPathType() == RankingConstant.PathType.FILE) + ? FileSender.sendFileToServices(constant.getFileName(), services) + : FileSender.sendUriToServices(constant.getUri(), services); + constant.setFileReference(reference.value()); + } + } + } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java index e58b3da4f72..fcbfb47c597 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RankProfileList.java @@ -10,7 +10,9 @@ import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.Search; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.model.AbstractService; +import java.util.Collection; import java.util.Map; import java.util.logging.Logger; @@ -79,6 +81,10 @@ public class RankProfileList extends Derived implements RankProfilesConfig.Produ return rankProfiles.get(name); } + public void sendConstantsTo(Collection<? extends AbstractService> services) { + rankingConstants.sendTo(services); + } + @Override public String getDerivedName() { return "rank-profiles"; } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java index 0b434fd0c49..fbe86d26b02 100755 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ContainerCluster.java @@ -87,7 +87,6 @@ import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.Set; -import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -361,6 +360,7 @@ public final class ContainerCluster public void prepare() { addAndSendApplicationBundles(); + rankProfileList.sendConstantsTo(containers); sendUserConfiguredFiles(); setApplicationMetaData(); for (RestApi restApi : restApiGroup.getComponents()) diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java index 19d014e0a1d..ceb48732116 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/search/searchchain/FederationSearcher.java @@ -15,6 +15,7 @@ import java.util.*; /** * Config producer for the FederationSearcher. + * * @author Tony Vaagenes */ public class FederationSearcher extends Searcher<FederationSearcherModel> implements FederationConfig.Producer { diff --git a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java index 91d7fd436f3..a16c32d47ab 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ModelEvaluationTest.java @@ -70,31 +70,41 @@ public class ModelEvaluationTest { assertEquals(4, config.rankprofile().size()); Set<String> modelNames = config.rankprofile().stream().map(v -> v.name()).collect(Collectors.toSet()); assertTrue(modelNames.contains("xgboost_2_2")); + assertTrue(modelNames.contains("mnist_saved")); assertTrue(modelNames.contains("mnist_softmax")); assertTrue(modelNames.contains("mnist_softmax_saved")); ModelsEvaluator evaluator = new ModelsEvaluator(config, constantsConfig); assertEquals(4, evaluator.models().size()); + Model xgboost = evaluator.models().get("xgboost_2_2"); assertNotNull(xgboost); assertNotNull(xgboost.evaluatorOf()); assertNotNull(xgboost.evaluatorOf("xgboost_2_2")); - Model onnx = evaluator.models().get("mnist_softmax"); - assertNotNull(onnx); - assertNotNull(onnx.evaluatorOf()); - assertNotNull(onnx.evaluatorOf("default")); - assertNotNull(onnx.evaluatorOf("default", "add")); - assertNotNull(onnx.evaluatorOf("default.add")); + Model tensorflow_mnist = evaluator.models().get("mnist_saved"); + assertNotNull(tensorflow_mnist); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default")); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default", "y")); + assertNotNull(tensorflow_mnist.evaluatorOf("serving_default.y")); + assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default.y")); + assertNotNull(evaluator.evaluatorOf("mnist_saved", "serving_default", "y")); + + Model onnx_mnist_softmax = evaluator.models().get("mnist_softmax"); + assertNotNull(onnx_mnist_softmax); + assertNotNull(onnx_mnist_softmax.evaluatorOf()); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default", "add")); + assertNotNull(onnx_mnist_softmax.evaluatorOf("default.add")); assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default.add")); assertNotNull(evaluator.evaluatorOf("mnist_softmax", "default", "add")); - Model tensorflow = evaluator.models().get("mnist_softmax_saved"); - assertNotNull(tensorflow); - assertNotNull(tensorflow.evaluatorOf()); - assertNotNull(tensorflow.evaluatorOf("serving_default")); - assertNotNull(tensorflow.evaluatorOf("serving_default", "y")); + Model tensorflow_mnist_softmax = evaluator.models().get("mnist_softmax_saved"); + assertNotNull(tensorflow_mnist_softmax); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf()); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default")); + assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y")); } } diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java index 7423792693b..dbf68106e07 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ANDPolicy.java @@ -20,7 +20,7 @@ import java.util.List; public class ANDPolicy implements DocumentProtocolRoutingPolicy { // A list of hops that are to always be selected when select() is invoked. - private final List<Hop> hops = new ArrayList<Hop>(); + private final List<Hop> hops = new ArrayList<>(); /** * Constructs a new AND policy that requires all recipients to be ok for it to merge their replies to an ok reply. diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java index 82679e17990..a5b3accac68 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/MessageTypePolicy.java @@ -51,7 +51,7 @@ public class MessageTypePolicy implements DocumentProtocolRoutingPolicy, ConfigS @Override public void configure(MessagetyperouteselectorpolicyConfig cfg) { - Map<Integer, Route> h = new HashMap<Integer, Route>(); + Map<Integer, Route> h = new HashMap<>(); for (MessagetyperouteselectorpolicyConfig.Route selector : cfg.route()) { h.put(selector.messagetype(), Route.parse(selector.name())); } diff --git a/linguistics/src/main/java/com/yahoo/language/detect/Detection.java b/linguistics/src/main/java/com/yahoo/language/detect/Detection.java index c08bdc14cfb..4b816335154 100644 --- a/linguistics/src/main/java/com/yahoo/language/detect/Detection.java +++ b/linguistics/src/main/java/com/yahoo/language/detect/Detection.java @@ -7,7 +7,7 @@ import java.nio.charset.Charset; import java.nio.charset.UnsupportedCharsetException; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class Detection { diff --git a/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java b/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java index 12de309a2d3..7451a7f2c9c 100644 --- a/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java +++ b/linguistics/src/main/java/com/yahoo/language/opennlp/OpenNlpLinguistics.java @@ -4,8 +4,10 @@ import com.yahoo.language.process.Tokenizer; import com.yahoo.language.simple.SimpleLinguistics; public class OpenNlpLinguistics extends SimpleLinguistics { + @Override public Tokenizer getTokenizer() { return new OpenNlpTokenizer(getNormalizer(), getTransformer()); } + } diff --git a/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java b/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java index 2b31f95675b..f29ec691c60 100644 --- a/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java +++ b/linguistics/src/main/java/com/yahoo/language/simple/SimpleDetector.java @@ -37,6 +37,7 @@ import java.util.Locale; * @author bjorncs */ public class SimpleDetector implements Detector { + static private TextObjectFactory textObjectFactory; static private LanguageDetector languageDetector; diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java index 63514eca6dd..e21aeef1ee2 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/Hop.java @@ -20,15 +20,15 @@ public class Hop { private String cache = null; /** - * <p>Constructs an empty hop. You will need to add directives to the - * selector to make this usable.</p> + * Constructs an empty hop. You will need to add directives to the + * selector to make this usable. */ public Hop() { // empty } /** - * <p>Implements the copy constructor.</p> + * Implements the copy constructor. * * @param hop The hop to copy. */ @@ -38,8 +38,8 @@ public class Hop { } /** - * <p>Constructs a fully populated hop. This is package private and used by - * the {@link HopBlueprint#create()} method.</p> + * Constructs a fully populated hop. This is package private and used by + * the {@link HopBlueprint#create()} method. * * @param selector The selector to copy. * @param ignoreResult Whether or not to ignore the result of this hop. @@ -50,8 +50,8 @@ public class Hop { } /** - * <p>Parses the given string as a single hop. The {@link #toString()} - * method is compatible with this parser.</p> + * Parses the given string as a single hop. The {@link #toString()} + * method is compatible with this parser. * * @param str The string to parse. * @return A hop that corresponds to the string. @@ -65,8 +65,7 @@ public class Hop { } /** - * <p>Returns whether or not there are any directives contained in this - * hop.</p> + * Returns whether or not there are any directives contained in this hop. * * @return True if there is at least one directive. */ @@ -75,7 +74,7 @@ public class Hop { } /** - * <p>Returns the number of directives contained in this hop.</p> + * Returns the number of directives contained in this hop. * * @return The number of directives. */ @@ -84,7 +83,7 @@ public class Hop { } /** - * <p>Returns the directive at the given index.</p> + * Returns the directive at the given index. * * @param i The index of the directive to return. * @return The item. @@ -94,7 +93,7 @@ public class Hop { } /** - * <p>Adds a new directive to this hop.</p> + * Adds a new directive to this hop. * * @param directive The directive to add. * @return This, to allow chaining. @@ -106,7 +105,7 @@ public class Hop { } /** - * <p>Sets the directive at a given index.</p> + * Sets the directive at a given index. * * @param i The index at which to set the directive. * @param directive The directive to set. @@ -283,9 +282,10 @@ public class Hop { @Override public int hashCode() { - int result = selector != null ? selector.hashCode() : 0; + int result = selector.hashCode(); result = 31 * result + (ignoreResult ? 1 : 0); result = 31 * result + (cache != null ? cache.hashCode() : 0); return result; } + } diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java index 809b2da69c4..838b11e7a02 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/HopDirective.java @@ -14,13 +14,14 @@ public interface HopDirective { * @param dir The directive to compare this to. * @return True if this matches the argument. */ - public boolean matches(HopDirective dir); + boolean matches(HopDirective dir); /** * Returns a string representation of this that can be debugged but not parsed. * * @return The debug string. */ - public String toDebugString(); + String toDebugString(); + } diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java index a07c6e16100..9190b680ebf 100755 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/Route.java @@ -20,7 +20,7 @@ import java.util.List; */ public class Route { - private final List<Hop> hops = new ArrayList<Hop>(); + private final List<Hop> hops = new ArrayList<>(); private String cache = null; /** diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index b9e7a27c013..98c80ace047 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Logger; /** * Converts RankProfilesConfig instances to RankingExpressions for evaluation. @@ -35,19 +36,13 @@ import java.util.Optional; */ class RankProfilesConfigImporter { - /** - * Constants already imported in this while reading some expression. - * This is to avoid re-reading constants referenced - * multiple places, as that is potentially costly. - */ - private Map<String, Constant> globalImportedConstants = new HashMap<>(); + private static final Logger log = Logger.getLogger("CONSTANTS"); /** * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { - globalImportedConstants.clear(); try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { @@ -61,7 +56,8 @@ class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException { + private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) + throws ParseException { List<ExpressionFunction> functions = new ArrayList<>(); Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>(); ExpressionFunction firstPhase = null; @@ -79,7 +75,8 @@ class RankProfilesConfigImporter { functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); // // Make all functions, bound or not available under the name they are referenced by in expressions - referencedFunctions.put(reference.get(), new ExpressionFunction(reference.get().serialForm(), arguments, expression)); + referencedFunctions.put(reference.get(), + new ExpressionFunction(reference.get().serialForm(), arguments, expression)); } else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to macros firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(), @@ -112,24 +109,29 @@ class RankProfilesConfigImporter { private List<Constant> readConstants(RankingConstantsConfig constantsConfig) { List<Constant> constants = new ArrayList<>(); + for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) { constants.add(new Constant(constantConfig.name(), - readTensorFromFile(TensorType.fromSpec(constantConfig.type()), + readTensorFromFile(constantConfig.name(), + TensorType.fromSpec(constantConfig.type()), constantConfig.fileref().value()))); } return constants; } - private Tensor readTensorFromFile(TensorType type, String fileName) { + private Tensor readTensorFromFile(String name, TensorType type, String fileReference) { try { - if (fileName.endsWith(".tbf")) - return TypedBinaryFormat.decode(Optional.of(type), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileName)))); - // TODO: Support json and json.lz4 - - if (fileName.isEmpty()) // this is the case in unit tests + if (fileReference.isEmpty()) { // this may be the case in unit tests + log.warning("Got empty file reference for constant '" + name + "', using an empty tensor"); return Tensor.from(type, "{}"); - throw new IllegalArgumentException("Unknown tensor file format (determined by file ending): " + fileName); + } + if ( ! new File(fileReference).exists()) { // this may be the case in unit tests + log.warning("Got empty file reference for constant '" + name + "', using an empty tensor"); + return Tensor.from(type, "{}"); + } + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileReference)))); + // TODO: Support json and json.lz4 } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java index 84e01e58280..2cb9602dfa7 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java @@ -11,9 +11,11 @@ import org.junit.Test; import java.io.File; import java.util.Map; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * Tests instantiating models from rank-profiles configs. @@ -28,27 +30,49 @@ public class RankProfilesImporterTest { assertEquals(4, models.size()); - Model xgboost = models.get("xgboost_2_2"); - assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - - Model onnxMnistSoftmax = models.get("mnist_softmax"); - assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - assertEquals("tensor(d1[10],d2[784])", - onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); - - Model tfMnistSoftmax = models.get("mnist_softmax_saved"); - assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - - Model tfMnist = models.get("mnist_saved"); - assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); + // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that + { + Model xgboost = models.get("xgboost_2_2"); + assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + FunctionEvaluator evaluator = xgboost.evaluatorOf(); + assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + + Model onnxMnistSoftmax = models.get("mnist_softmax"); + assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); + FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + Model tfMnistSoftmax = models.get("mnist_softmax_saved"); + assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + Model tfMnist = models.get("mnist_saved"); + assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + // Macro: + assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument + assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } } @Test |