diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2018-09-07 13:57:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-09-07 13:57:43 +0200 |
commit | d81fcef66250f4ac0bae92564374058b0d8b4b7e (patch) | |
tree | 8617164d4ba7ab653b0b4bfd11cceb3903ae518a | |
parent | e2f4fe91f889153734cc7ac076a32cc531162784 (diff) | |
parent | 55b5c8c7c4b14303dfffb9ad017fd6bcea40e9b9 (diff) |
Merge pull request #6847 from vespa-engine/bratseth/read-constants
Bratseth/read constants
11 files changed, 230 insertions, 144 deletions
diff --git a/config-lib/src/main/java/com/yahoo/config/PathNode.java b/config-lib/src/main/java/com/yahoo/config/PathNode.java index b63dad4d1a7..9d73b5e23c2 100644 --- a/config-lib/src/main/java/com/yahoo/config/PathNode.java +++ b/config-lib/src/main/java/com/yahoo/config/PathNode.java @@ -14,7 +14,6 @@ import java.util.Map; * Represents a 'path' in a {@link ConfigInstance}, usually a filename. * * @author gjoranv - * @since 5.1.30 */ public class PathNode extends LeafNode<Path> { diff --git a/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java b/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java index c94dc30fd6f..ee12c7d4c9f 100644 --- a/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java +++ b/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java @@ -3,12 +3,13 @@ package com.yahoo.container.core; /** * @author gjoranv - * @since 5.46 */ public interface BundleLoaderProperties { + // TODO: This should be removed. The prefix is used to separate the bundles in BundlesConfig // into those that are transferred with filedistribution and those that are preinstalled // on disk. Instead, the model should have put them in two different configs. I.e. create a new // config 'preinstalled-bundles.def'. - public static final String DISK_BUNDLE_PREFIX = "file:"; + String DISK_BUNDLE_PREFIX = "file:"; + } diff --git a/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java b/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java index eceb41f9739..c0a68086212 100644 --- a/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java +++ b/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java @@ -95,7 +95,7 @@ public class BundleLoader { log.info("Installing bundle from disk with reference '" + reference.value() + "'"); File file = new File(referenceFileName); - if (!file.exists()) { + if ( ! file.exists()) { throw new IllegalArgumentException("Reference '" + reference.value() + "' not found on disk."); } diff --git a/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java b/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java index 80f1f003412..0f3f3938701 100644 --- a/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java +++ b/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java @@ -34,7 +34,7 @@ import java.util.concurrent.atomic.AtomicBoolean; import java.util.logging.Logger; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ public class MbusRequestContext implements RequestContext, ResponseHandler { diff --git a/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java b/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java index d0f5de54b4f..bc3d1edda7c 100644 --- a/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java +++ b/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java @@ -32,7 +32,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider private final ResourceReference sessionReference; @Inject - public MbusServer(final CurrentContainer container, final ServerSession session) { + public MbusServer(CurrentContainer container, ServerSession session) { this.container = container; this.session = session; uri = URI.create("mbus://localhost/" + session.name()); @@ -60,7 +60,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider } @Override - public void handleMessage(final Message msg) { + public void handleMessage(Message msg) { if (!running.get()) { dispatchErrorReply(msg, ErrorCode.SESSION_BUSY, "Session temporarily closed."); return; @@ -73,7 +73,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider try { request = new MbusRequest(container, uri, msg); content = request.connect(new ServerResponseHandler(msg)); - } catch (final RuntimeException e) { + } catch (RuntimeException e) { dispatchErrorReply(msg, ErrorCode.APP_FATAL_ERROR, e.toString()); } finally { if (request != null) { @@ -89,8 +89,8 @@ public final class MbusServer extends AbstractResource implements ServerProvider return session.connectionSpec(); } - private void dispatchErrorReply(final Message msg, final int errCode, final String errMsg) { - final Reply reply = new EmptyReply(); + private void dispatchErrorReply(Message msg, int errCode, String errMsg) { + Reply reply = new EmptyReply(); reply.swapState(msg); reply.addError(new Error(errCode, errMsg)); session.sendReply(reply); @@ -100,20 +100,20 @@ public final class MbusServer extends AbstractResource implements ServerProvider final Message msg; - ServerResponseHandler(final Message msg) { + ServerResponseHandler(Message msg) { this.msg = msg; } @Override - public ContentChannel handleResponse(final Response response) { - final Reply reply; + public ContentChannel handleResponse(Response response) { + Reply reply; if (response instanceof MbusResponse) { reply = ((MbusResponse)response).getReply(); } else { reply = new EmptyReply(); reply.swapState(msg); } - final Error err = StatusCodes.toMbusError(response.getStatus()); + Error err = StatusCodes.toMbusError(response.getStatus()); if (err != null) { if (err.isFatal()) { if (!reply.hasFatalErrors()) { diff --git a/messagebus/src/main/java/com/yahoo/messagebus/Message.java b/messagebus/src/main/java/com/yahoo/messagebus/Message.java index 22496487f61..43f5c8d2dfd 100644 --- a/messagebus/src/main/java/com/yahoo/messagebus/Message.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/Message.java @@ -5,9 +5,9 @@ import com.yahoo.concurrent.SystemTimer; import com.yahoo.messagebus.routing.Route; /** - * <p>A message is a child of Routable, it is not a reply, and it has a sequencing identifier. Furthermore, a message + * A message is a child of Routable, it is not a reply, and it has a sequencing identifier. Furthermore, a message * contains a retry counter that holds what retry the message is currently on. See the method comment {@link #getRetry} - * for more information.</p> + * for more information. * * @author Simon Thoresen Hult */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 98c80ace047..cd21a0a6813 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -15,6 +15,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.vespa.defaults.Defaults; import java.io.File; import java.io.IOException; @@ -119,18 +120,31 @@ class RankProfilesConfigImporter { return constants; } - private Tensor readTensorFromFile(String name, TensorType type, String fileReference) { + Tensor readTensorFromFile(String name, TensorType type, String fileReference) { try { + // TODO: Only allow these two fallbacks in testing mode if (fileReference.isEmpty()) { // this may be the case in unit tests log.warning("Got empty file reference for constant '" + name + "', using an empty tensor"); return Tensor.from(type, "{}"); } - if ( ! new File(fileReference).exists()) { // this may be the case in unit tests - log.warning("Got empty file reference for constant '" + name + "', using an empty tensor"); + File dir = new File(Defaults.getDefaults().underVespaHome("var/db/vespa/filedistribution"), fileReference); + if ( ! dir.exists()) { // this may be the case in unit tests + log.warning("Got reference to nonexisting file " + dir + "e for constant '" + name + + "', using an empty tensor"); return Tensor.from(type, "{}"); } - return TypedBinaryFormat.decode(Optional.of(type), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileReference)))); + + // TODO: Move these 2 lines to FileReference + + dir = new File(Defaults.getDefaults().underVespaHome("var/db/vespa/filedistribution"), fileReference); + File file = dir.listFiles()[0]; // directory contains one file having the original name + + if (file.getName().endsWith(".tbf")) + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(file))); + else + throw new IllegalArgumentException("Constant files on other formats than .tbf are not supported, got " + + file + " for constant " + name); // TODO: Support json and json.lz4 } catch (IOException e) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java new file mode 100644 index 00000000000..a823f16d727 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -0,0 +1,82 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import org.junit.Test; + +import java.io.File; +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Logger; +import java.util.stream.Collectors; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +/** + * Tests instantiating models from rank-profiles configs. + * + * @author bratseth + */ +public class MlModelsImportingTest { + + @Test + public void testImportingModels() { + ModelTester tester = new ModelTester("src/test/resources/config/models/"); + + assertEquals(4, tester.models().size()); + + // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that + { + Model xgboost = tester.models().get("xgboost_2_2"); + tester.assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + FunctionEvaluator evaluator = xgboost.evaluatorOf(); + assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + + Model onnxMnistSoftmax = tester.models().get("mnist_softmax"); + tester.assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); + FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved"); + tester.assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + Model tfMnist = tester.models().get("mnist_saved"); + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + // Macro: + tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument + assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java new file mode 100644 index 00000000000..63e17e37bde --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -0,0 +1,77 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import com.yahoo.config.subscription.ConfigGetter; +import com.yahoo.config.subscription.FileSource; +import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; + +import java.util.HashMap; +import java.util.Map; +import java.util.logging.Logger; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * Helper for testing model import and evaluation + * + * @author bratseth + */ +public class ModelTester { + + private final Map<String, Model> models; + + public ModelTester(String modelConfigDirectory) { + models = createModels(modelConfigDirectory); + } + + public Map<String, Model> models() { return models; } + + private static Map<String, Model> createModels(String path) { + Path configDir = Path.fromString(path); + RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), + RankProfilesConfig.class).getConfig(""); + RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), + RankingConstantsConfig.class).getConfig(""); + return new RankProfilesConfigImporterWithMockedConstants().importFrom(config, constantsConfig); + } + + public void assertFunction(String name, String expression, Model model) { + assertNotNull("Model is present in config", model); + ExpressionFunction function = model.function(name); + assertNotNull("Function '" + name + "' is in " + model, function); + assertEquals(name, function.getName()); + assertEquals(expression, function.getBody().getRoot().toString()); + } + + public void assertBoundFunction(String name, String expression, Model model) { + ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); + assertNotNull("Function '" + name + "' is present", function); + assertEquals(name, function.getName()); + assertEquals(expression, function.getBody().getRoot().toString()); + } + + /** Allows us to provide canned tensor constants during import since file distribution does not work in tests */ + private static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter { + + private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName()); + + Map<String, Tensor> constants = new HashMap<>(); + + @Override + Tensor readTensorFromFile(String name, TensorType type, String fileReference) { + if ( ! constants.containsKey(name)) { + log.warning("Missing a mocked tensor constant for '" + name + "': Returning an empty tensor"); + return Tensor.from(type, "{}"); + } + return constants.get(name); + } + + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java new file mode 100644 index 00000000000..210ffb823b2 --- /dev/null +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java @@ -0,0 +1,36 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.models.evaluation; + +import org.junit.Test; + +import java.util.Map; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class RankProfileImportingTest { + + @Test + public void testImportingRankExpressions() { + ModelTester tester = new ModelTester("src/test/resources/config/rankexpression/"); + + assertEquals(18, tester.models().size()); + + Model macros = tester.models().get("macros"); + assertEquals("macros", macros.name()); + assertEquals(4, macros.functions().size()); + tester.assertFunction("fourtimessum", "4 * (var1 + var2)", macros); + tester.assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros); + tester.assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros); + tester.assertFunction("myfeature", + "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " + + "30 * pow(0 - fieldMatch(description).earliness,2)", + macros); + assertEquals(4, macros.referencedFunctions().size()); + tester.assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", + "4 * (match + rankBoost)", macros); + } + +} diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java deleted file mode 100644 index 2cb9602dfa7..00000000000 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ /dev/null @@ -1,123 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.models.evaluation; - -import com.yahoo.config.subscription.ConfigGetter; -import com.yahoo.config.subscription.FileSource; -import com.yahoo.path.Path; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.vespa.config.search.RankProfilesConfig; -import com.yahoo.vespa.config.search.core.RankingConstantsConfig; -import org.junit.Test; - -import java.io.File; -import java.util.Map; -import java.util.stream.Collectors; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertTrue; - -/** - * Tests instantiating models from rank-profiles configs. - * - * @author bratseth - */ -public class RankProfilesImporterTest { - - @Test - public void testImportingModels() { - Map<String, Model> models = createModels("src/test/resources/config/models/"); - - assertEquals(4, models.size()); - - // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that - { - Model xgboost = models.get("xgboost_2_2"); - assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - FunctionEvaluator evaluator = xgboost.evaluatorOf(); - assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - } - - { - - Model onnxMnistSoftmax = models.get("mnist_softmax"); - assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - assertEquals("tensor(d1[10],d2[784])", - onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); - FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available - assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - } - - { - Model tfMnistSoftmax = models.get("mnist_softmax_saved"); - assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available - assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - } - - { - Model tfMnist = models.get("mnist_saved"); - assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - // Macro: - assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", - "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", - tfMnist); - FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument - assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); - } - } - - @Test - public void testImportingRankExpressions() { - Map<String, Model> models = createModels("src/test/resources/config/rankexpression/"); - - assertEquals(18, models.size()); - - Model macros = models.get("macros"); - assertEquals("macros", macros.name()); - assertEquals(4, macros.functions().size()); - assertFunction("fourtimessum", "4 * (var1 + var2)", macros); - assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros); - assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros); - assertFunction("myfeature", - "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " + - "30 * pow(0 - fieldMatch(description).earliness,2)", - macros); - assertEquals(4, macros.referencedFunctions().size()); - assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", - "4 * (match + rankBoost)", macros); - } - - private void assertFunction(String name, String expression, Model model) { - assertNotNull("Model is present in config", model); - ExpressionFunction function = model.function(name); - assertNotNull("Function '" + name + "' is in " + model, function); - assertEquals(name, function.getName()); - assertEquals(expression, function.getBody().getRoot().toString()); - } - - private void assertBoundFunction(String name, String expression, Model model) { - ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); - assertNotNull("Function '" + name + "' is present", function); - assertEquals(name, function.getName()); - assertEquals(expression, function.getBody().getRoot().toString()); - } - - private Map<String, Model> createModels(String path) { - Path configDir = Path.fromString(path); - RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()), - RankProfilesConfig.class).getConfig(""); - RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()), - RankingConstantsConfig.class).getConfig(""); - return new RankProfilesConfigImporter().importFrom(config, constantsConfig); - } - -} |