aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-06 16:38:28 +0200
committerJon Bratseth <bratseth@oath.com>2018-09-06 16:38:28 +0200
commitef211a55b4a343ad8bcd8ae34a202f3c61828a7a (patch)
tree4f62a9363a9b48bd8875e6868fc9f974d37b9b5b /model-evaluation
parentc1fdecf3cb26f1a3aef2caf290916a4f533c6c58 (diff)
Send global constants
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java38
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java66
2 files changed, 65 insertions, 39 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index b9e7a27c013..98c80ace047 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -26,6 +26,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Logger;
/**
* Converts RankProfilesConfig instances to RankingExpressions for evaluation.
@@ -35,19 +36,13 @@ import java.util.Optional;
*/
class RankProfilesConfigImporter {
- /**
- * Constants already imported in this while reading some expression.
- * This is to avoid re-reading constants referenced
- * multiple places, as that is potentially costly.
- */
- private Map<String, Constant> globalImportedConstants = new HashMap<>();
+ private static final Logger log = Logger.getLogger("CONSTANTS");
/**
* Returns a map of the models contained in this config, indexed on name.
* The map is modifiable and owned by the caller.
*/
Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) {
- globalImportedConstants.clear();
try {
Map<String, Model> models = new HashMap<>();
for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
@@ -61,7 +56,8 @@ class RankProfilesConfigImporter {
}
}
- private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException {
+ private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig)
+ throws ParseException {
List<ExpressionFunction> functions = new ArrayList<>();
Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>();
ExpressionFunction firstPhase = null;
@@ -79,7 +75,8 @@ class RankProfilesConfigImporter {
functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); //
// Make all functions, bound or not available under the name they are referenced by in expressions
- referencedFunctions.put(reference.get(), new ExpressionFunction(reference.get().serialForm(), arguments, expression));
+ referencedFunctions.put(reference.get(),
+ new ExpressionFunction(reference.get().serialForm(), arguments, expression));
}
else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to macros
firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
@@ -112,24 +109,29 @@ class RankProfilesConfigImporter {
private List<Constant> readConstants(RankingConstantsConfig constantsConfig) {
List<Constant> constants = new ArrayList<>();
+
for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) {
constants.add(new Constant(constantConfig.name(),
- readTensorFromFile(TensorType.fromSpec(constantConfig.type()),
+ readTensorFromFile(constantConfig.name(),
+ TensorType.fromSpec(constantConfig.type()),
constantConfig.fileref().value())));
}
return constants;
}
- private Tensor readTensorFromFile(TensorType type, String fileName) {
+ private Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
try {
- if (fileName.endsWith(".tbf"))
- return TypedBinaryFormat.decode(Optional.of(type),
- GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileName))));
- // TODO: Support json and json.lz4
-
- if (fileName.isEmpty()) // this is the case in unit tests
+ if (fileReference.isEmpty()) { // this may be the case in unit tests
+ log.warning("Got empty file reference for constant '" + name + "', using an empty tensor");
return Tensor.from(type, "{}");
- throw new IllegalArgumentException("Unknown tensor file format (determined by file ending): " + fileName);
+ }
+ if ( ! new File(fileReference).exists()) { // this may be the case in unit tests
+ log.warning("Got empty file reference for constant '" + name + "', using an empty tensor");
+ return Tensor.from(type, "{}");
+ }
+ return TypedBinaryFormat.decode(Optional.of(type),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileReference))));
+ // TODO: Support json and json.lz4
}
catch (IOException e) {
throw new UncheckedIOException(e);
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
index 84e01e58280..2cb9602dfa7 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java
@@ -11,9 +11,11 @@ import org.junit.Test;
import java.io.File;
import java.util.Map;
+import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
/**
* Tests instantiating models from rank-profiles configs.
@@ -28,27 +30,49 @@ public class RankProfilesImporterTest {
assertEquals(4, models.size());
- Model xgboost = models.get("xgboost_2_2");
- assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
-
- Model onnxMnistSoftmax = models.get("mnist_softmax");
- assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- assertEquals("tensor(d1[10],d2[784])",
- onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
-
- Model tfMnistSoftmax = models.get("mnist_softmax_saved");
- assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
-
- Model tfMnist = models.get("mnist_saved");
- assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
+ // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
+ {
+ Model xgboost = models.get("xgboost_2_2");
+ assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+ FunctionEvaluator evaluator = xgboost.evaluatorOf();
+ assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+
+ Model onnxMnistSoftmax = models.get("mnist_softmax");
+ assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+ assertEquals("tensor(d1[10],d2[784])",
+ onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
+ FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
+ assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+ Model tfMnistSoftmax = models.get("mnist_softmax_saved");
+ assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+ FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
+ assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
+
+ {
+ Model tfMnist = models.get("mnist_saved");
+ assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ // Macro:
+ assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add",
+ "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
+ assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ }
}
@Test