aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-09-07 13:57:43 +0200
committerGitHub <noreply@github.com>2018-09-07 13:57:43 +0200
commitd81fcef66250f4ac0bae92564374058b0d8b4b7e (patch)
tree8617164d4ba7ab653b0b4bfd11cceb3903ae518a
parente2f4fe91f889153734cc7ac076a32cc531162784 (diff)
parent55b5c8c7c4b14303dfffb9ad017fd6bcea40e9b9 (diff)
Merge pull request #6847 from vespa-engine/bratseth/read-constants
Bratseth/read constants
-rw-r--r--config-lib/src/main/java/com/yahoo/config/PathNode.java1
-rw-r--r--container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java5
-rw-r--r--container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java2
-rw-r--r--docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java2
-rw-r--r--jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java18
-rw-r--r--messagebus/src/main/java/com/yahoo/messagebus/Message.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java24
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java82
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java77
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java36
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java123
11 files changed, 230 insertions, 144 deletions
diff --git a/config-lib/src/main/java/com/yahoo/config/PathNode.java b/config-lib/src/main/java/com/yahoo/config/PathNode.java
index b63dad4d1a7..9d73b5e23c2 100644
--- a/config-lib/src/main/java/com/yahoo/config/PathNode.java
+++ b/config-lib/src/main/java/com/yahoo/config/PathNode.java
@@ -14,7 +14,6 @@ import java.util.Map;
* Represents a 'path' in a {@link ConfigInstance}, usually a filename.
*
* @author gjoranv
- * @since 5.1.30
*/
public class PathNode extends LeafNode<Path> {
diff --git a/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java b/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java
index c94dc30fd6f..ee12c7d4c9f 100644
--- a/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java
+++ b/container-core/src/main/java/com/yahoo/container/core/BundleLoaderProperties.java
@@ -3,12 +3,13 @@ package com.yahoo.container.core;
/**
* @author gjoranv
- * @since 5.46
*/
public interface BundleLoaderProperties {
+
// TODO: This should be removed. The prefix is used to separate the bundles in BundlesConfig
// into those that are transferred with filedistribution and those that are preinstalled
// on disk. Instead, the model should have put them in two different configs. I.e. create a new
// config 'preinstalled-bundles.def'.
- public static final String DISK_BUNDLE_PREFIX = "file:";
+ String DISK_BUNDLE_PREFIX = "file:";
+
}
diff --git a/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java b/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java
index eceb41f9739..c0a68086212 100644
--- a/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java
+++ b/container-core/src/main/java/com/yahoo/container/core/config/BundleLoader.java
@@ -95,7 +95,7 @@ public class BundleLoader {
log.info("Installing bundle from disk with reference '" + reference.value() + "'");
File file = new File(referenceFileName);
- if (!file.exists()) {
+ if ( ! file.exists()) {
throw new IllegalArgumentException("Reference '" + reference.value() + "' not found on disk.");
}
diff --git a/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java b/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java
index 80f1f003412..0f3f3938701 100644
--- a/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java
+++ b/docproc/src/main/java/com/yahoo/docproc/jdisc/messagebus/MbusRequestContext.java
@@ -34,7 +34,7 @@ import java.util.concurrent.atomic.AtomicBoolean;
import java.util.logging.Logger;
/**
- * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
+ * @author Einar M R Rosenvinge
*/
public class MbusRequestContext implements RequestContext, ResponseHandler {
diff --git a/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java b/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java
index d0f5de54b4f..bc3d1edda7c 100644
--- a/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java
+++ b/jdisc_messagebus_service/src/main/java/com/yahoo/messagebus/jdisc/MbusServer.java
@@ -32,7 +32,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider
private final ResourceReference sessionReference;
@Inject
- public MbusServer(final CurrentContainer container, final ServerSession session) {
+ public MbusServer(CurrentContainer container, ServerSession session) {
this.container = container;
this.session = session;
uri = URI.create("mbus://localhost/" + session.name());
@@ -60,7 +60,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider
}
@Override
- public void handleMessage(final Message msg) {
+ public void handleMessage(Message msg) {
if (!running.get()) {
dispatchErrorReply(msg, ErrorCode.SESSION_BUSY, "Session temporarily closed.");
return;
@@ -73,7 +73,7 @@ public final class MbusServer extends AbstractResource implements ServerProvider
try {
request = new MbusRequest(container, uri, msg);
content = request.connect(new ServerResponseHandler(msg));
- } catch (final RuntimeException e) {
+ } catch (RuntimeException e) {
dispatchErrorReply(msg, ErrorCode.APP_FATAL_ERROR, e.toString());
} finally {
if (request != null) {
@@ -89,8 +89,8 @@ public final class MbusServer extends AbstractResource implements ServerProvider
return session.connectionSpec();
}
- private void dispatchErrorReply(final Message msg, final int errCode, final String errMsg) {
- final Reply reply = new EmptyReply();
+ private void dispatchErrorReply(Message msg, int errCode, String errMsg) {
+ Reply reply = new EmptyReply();
reply.swapState(msg);
reply.addError(new Error(errCode, errMsg));
session.sendReply(reply);
@@ -100,20 +100,20 @@ public final class MbusServer extends AbstractResource implements ServerProvider
final Message msg;
- ServerResponseHandler(final Message msg) {
+ ServerResponseHandler(Message msg) {
this.msg = msg;
}
@Override
- public ContentChannel handleResponse(final Response response) {
- final Reply reply;
+ public ContentChannel handleResponse(Response response) {
+ Reply reply;
if (response instanceof MbusResponse) {
reply = ((MbusResponse)response).getReply();
} else {
reply = new EmptyReply();
reply.swapState(msg);
}
- final Error err = StatusCodes.toMbusError(response.getStatus());
+ Error err = StatusCodes.toMbusError(response.getStatus());
if (err != null) {
if (err.isFatal()) {
if (!reply.hasFatalErrors()) {
diff --git a/messagebus/src/main/java/com/yahoo/messagebus/Message.java b/messagebus/src/main/java/com/yahoo/messagebus/Message.java
index 22496487f61..43f5c8d2dfd 100644
--- a/messagebus/src/main/java/com/yahoo/messagebus/Message.java
+++ b/messagebus/src/main/java/com/yahoo/messagebus/Message.java
@@ -5,9 +5,9 @@ import com.yahoo.concurrent.SystemTimer;
import com.yahoo.messagebus.routing.Route;
/**
- * <p>A message is a child of Routable, it is not a reply, and it has a sequencing identifier. Furthermore, a message
+ * A message is a child of Routable, it is not a reply, and it has a sequencing identifier. Furthermore, a message
* contains a retry counter that holds what retry the message is currently on. See the method comment {@link #getRetry}
- * for more information.</p>
+ * for more information.
*
* @author Simon Thoresen Hult
*/
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 98c80ace047..cd21a0a6813 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -15,6 +15,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.vespa.defaults.Defaults;
import java.io.File;
import java.io.IOException;
@@ -119,18 +120,31 @@ class RankProfilesConfigImporter {
return constants;
}
- private Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
+ Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
try {
+ // TODO: Only allow these two fallbacks in testing mode
if (fileReference.isEmpty()) { // this may be the case in unit tests
log.warning("Got empty file reference for constant '" + name + "', using an empty tensor");
return Tensor.from(type, "{}");
}
- if ( ! new File(fileReference).exists()) { // this may be the case in unit tests
- log.warning("Got empty file reference for constant '" + name + "', using an empty tensor");
+ File dir = new File(Defaults.getDefaults().underVespaHome("var/db/vespa/filedistribution"), fileReference);
+ if ( ! dir.exists()) { // this may be the case in unit tests
+ log.warning("Got reference to nonexisting file " + dir + "e for constant '" + name +
+ "', using an empty tensor");
return Tensor.from(type, "{}");
}
- return TypedBinaryFormat.decode(Optional.of(type),
- GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileReference))));
+
+ // TODO: Move these 2 lines to FileReference
+
+ dir = new File(Defaults.getDefaults().underVespaHome("var/db/vespa/filedistribution"), fileReference);
+ File file = dir.listFiles()[0]; // directory contains one file having the original name
+
+ if (file.getName().endsWith(".tbf"))
+ return TypedBinaryFormat.decode(Optional.of(type),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(file)));
+ else
+ throw new IllegalArgumentException("Constant files on other formats than .tbf are not supported, got " +
+ file + " for constant " + name);
// TODO: Support json and json.lz4
}
catch (IOException e) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
new file mode 100644
index 00000000000..a823f16d727
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -0,0 +1,82 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import org.junit.Test;
+
+import java.io.File;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.logging.Logger;
+import java.util.stream.Collectors;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+/**
+ * Tests instantiating models from rank-profiles configs.
+ *
+ * @author bratseth
+ */
+public class MlModelsImportingTest {
+
+ @Test
+ public void testImportingModels() {
+ ModelTester tester = new ModelTester("src/test/resources/config/models/");
+
+ assertEquals(4, tester.models().size());
+
+ // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
+ {
+ Model xgboost = tester.models().get("xgboost_2_2");
+ tester.assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+ FunctionEvaluator evaluator = xgboost.evaluatorOf();
+ assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+
+ Model onnxMnistSoftmax = tester.models().get("mnist_softmax");
+ tester.assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+ assertEquals("tensor(d1[10],d2[784])",
+ onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
+ FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
+ assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+ Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved");
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+ FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
+ assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+ Model tfMnist = tester.models().get("mnist_saved");
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ // Macro:
+ tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add",
+ "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
+ assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
new file mode 100644
index 00000000000..63e17e37bde
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -0,0 +1,77 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import com.yahoo.config.subscription.ConfigGetter;
+import com.yahoo.config.subscription.FileSource;
+import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.vespa.config.search.RankProfilesConfig;
+import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+
+import java.util.HashMap;
+import java.util.Map;
+import java.util.logging.Logger;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * Helper for testing model import and evaluation
+ *
+ * @author bratseth
+ */
+public class ModelTester {
+
+ private final Map<String, Model> models;
+
+ public ModelTester(String modelConfigDirectory) {
+ models = createModels(modelConfigDirectory);
+ }
+
+ public Map<String, Model> models() { return models; }
+
+ private static Map<String, Model> createModels(String path) {
+ Path configDir = Path.fromString(path);
+ RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
+ RankProfilesConfig.class).getConfig("");
+ RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
+ RankingConstantsConfig.class).getConfig("");
+ return new RankProfilesConfigImporterWithMockedConstants().importFrom(config, constantsConfig);
+ }
+
+ public void assertFunction(String name, String expression, Model model) {
+ assertNotNull("Model is present in config", model);
+ ExpressionFunction function = model.function(name);
+ assertNotNull("Function '" + name + "' is in " + model, function);
+ assertEquals(name, function.getName());
+ assertEquals(expression, function.getBody().getRoot().toString());
+ }
+
+ public void assertBoundFunction(String name, String expression, Model model) {
+ ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get());
+ assertNotNull("Function '" + name + "' is present", function);
+ assertEquals(name, function.getName());
+ assertEquals(expression, function.getBody().getRoot().toString());
+ }
+
+ /** Allows us to provide canned tensor constants during import since file distribution does not work in tests */
+ private static class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter {
+
+ private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName());
+
+ Map<String, Tensor> constants = new HashMap<>();
+
+ @Override
+ Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
+ if ( ! constants.containsKey(name)) {
+ log.warning("Missing a mocked tensor constant for '" + name + "': Returning an empty tensor");
+ return Tensor.from(type, "{}");
+ }
+ return constants.get(name);
+ }
+
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
new file mode 100644
index 00000000000..210ffb823b2
--- /dev/null
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
@@ -0,0 +1,36 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.models.evaluation;
+
+import org.junit.Test;
+
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class RankProfileImportingTest {
+
+ @Test
+ public void testImportingRankExpressions() {
+ ModelTester tester = new ModelTester("src/test/resources/config/rankexpression/");
+
+ assertEquals(18, tester.models().size());
+
+ Model macros = tester.models().get("macros");
+ assertEquals("macros", macros.name());
+ assertEquals(4, macros.functions().size());
+ tester.assertFunction("fourtimessum", "4 * (var1 + var2)", macros);
+ tester.assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros);
+ tester.assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros);
+ tester.assertFunction("myfeature",
+ "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " +
+ "30 * pow(0 - fieldMatch(description).earliness,2)",
+ macros);
+ assertEquals(4, macros.referencedFunctions().size());
+ tester.assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)",
+ "4 * (match + rankBoost)", macros);
+ }
+
+}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
deleted file mode 100644
index 2cb9602dfa7..00000000000
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
+++ /dev/null
@@ -1,123 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.models.evaluation;
-
-import com.yahoo.config.subscription.ConfigGetter;
-import com.yahoo.config.subscription.FileSource;
-import com.yahoo.path.Path;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.vespa.config.search.RankProfilesConfig;
-import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
-import org.junit.Test;
-
-import java.io.File;
-import java.util.Map;
-import java.util.stream.Collectors;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
-
-/**
- * Tests instantiating models from rank-profiles configs.
- *
- * @author bratseth
- */
-public class RankProfilesImporterTest {
-
- @Test
- public void testImportingModels() {
- Map<String, Model> models = createModels("src/test/resources/config/models/");
-
- assertEquals(4, models.size());
-
- // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
- {
- Model xgboost = models.get("xgboost_2_2");
- assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
- FunctionEvaluator evaluator = xgboost.evaluatorOf();
- assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- }
-
- {
-
- Model onnxMnistSoftmax = models.get("mnist_softmax");
- assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- assertEquals("tensor(d1[10],d2[784])",
- onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
- FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
- assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- }
-
- {
- Model tfMnistSoftmax = models.get("mnist_softmax_saved");
- assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
- FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
- assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- }
-
- {
- Model tfMnist = models.get("mnist_saved");
- assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
- // Macro:
- assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add",
- "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
- tfMnist);
- FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
- assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
- }
- }
-
- @Test
- public void testImportingRankExpressions() {
- Map<String, Model> models = createModels("src/test/resources/config/rankexpression/");
-
- assertEquals(18, models.size());
-
- Model macros = models.get("macros");
- assertEquals("macros", macros.name());
- assertEquals(4, macros.functions().size());
- assertFunction("fourtimessum", "4 * (var1 + var2)", macros);
- assertFunction("firstphase", "match + fieldMatch(title) + rankingExpression(myfeature)", macros);
- assertFunction("secondphase", "rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)", macros);
- assertFunction("myfeature",
- "70 * fieldMatch(title).completeness * pow(0 - fieldMatch(title).earliness,2) + " +
- "30 * pow(0 - fieldMatch(description).earliness,2)",
- macros);
- assertEquals(4, macros.referencedFunctions().size());
- assertBoundFunction("rankingExpression(fourtimessum@5cf279212355b980.67f1e87166cfef86)",
- "4 * (match + rankBoost)", macros);
- }
-
- private void assertFunction(String name, String expression, Model model) {
- assertNotNull("Model is present in config", model);
- ExpressionFunction function = model.function(name);
- assertNotNull("Function '" + name + "' is in " + model, function);
- assertEquals(name, function.getName());
- assertEquals(expression, function.getBody().getRoot().toString());
- }
-
- private void assertBoundFunction(String name, String expression, Model model) {
- ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get());
- assertNotNull("Function '" + name + "' is present", function);
- assertEquals(name, function.getName());
- assertEquals(expression, function.getBody().getRoot().toString());
- }
-
- private Map<String, Model> createModels(String path) {
- Path configDir = Path.fromString(path);
- RankProfilesConfig config = new ConfigGetter<>(new FileSource(configDir.append("rank-profiles.cfg").toFile()),
- RankProfilesConfig.class).getConfig("");
- RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
- RankingConstantsConfig.class).getConfig("");
- return new RankProfilesConfigImporter().importFrom(config, constantsConfig);
- }
-
-}