diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-06 16:38:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-06 16:38:28 +0200 |
commit | ef211a55b4a343ad8bcd8ae34a202f3c61828a7a (patch) | |
tree | 4f62a9363a9b48bd8875e6868fc9f974d37b9b5b /model-evaluation | |
parent | c1fdecf3cb26f1a3aef2caf290916a4f533c6c58 (diff) |
Send global constants
Diffstat (limited to 'model-evaluation')
2 files changed, 65 insertions, 39 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index b9e7a27c013..98c80ace047 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -26,6 +26,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Logger; /** * Converts RankProfilesConfig instances to RankingExpressions for evaluation. @@ -35,19 +36,13 @@ import java.util.Optional; */ class RankProfilesConfigImporter { - /** - * Constants already imported in this while reading some expression. - * This is to avoid re-reading constants referenced - * multiple places, as that is potentially costly. - */ - private Map<String, Constant> globalImportedConstants = new HashMap<>(); + private static final Logger log = Logger.getLogger("CONSTANTS"); /** * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { - globalImportedConstants.clear(); try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { @@ -61,7 +56,8 @@ class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException { + private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) + throws ParseException { List<ExpressionFunction> functions = new ArrayList<>(); Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>(); ExpressionFunction firstPhase = null; @@ -79,7 +75,8 @@ class RankProfilesConfigImporter { functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); // // Make all functions, bound or not available under the name they are referenced by in expressions - referencedFunctions.put(reference.get(), new ExpressionFunction(reference.get().serialForm(), arguments, expression)); + referencedFunctions.put(reference.get(), + new ExpressionFunction(reference.get().serialForm(), arguments, expression)); } else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to macros firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(), @@ -112,24 +109,29 @@ class RankProfilesConfigImporter { private List<Constant> readConstants(RankingConstantsConfig constantsConfig) { List<Constant> constants = new ArrayList<>(); + for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) { constants.add(new Constant(constantConfig.name(), - readTensorFromFile(TensorType.fromSpec(constantConfig.type()), + readTensorFromFile(constantConfig.name(), + TensorType.fromSpec(constantConfig.type()), constantConfig.fileref().value()))); } return constants; } - private Tensor readTensorFromFile(TensorType type, String fileName) { + private Tensor readTensorFromFile(String name, TensorType type, String fileReference) { try { - if (fileName.endsWith(".tbf")) - return TypedBinaryFormat.decode(Optional.of(type), - GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileName)))); - // TODO: Support json and json.lz4 - - if (fileName.isEmpty()) // this is the case in unit tests + if (fileReference.isEmpty()) { // this may be the case in unit tests + log.warning("Got empty file reference for constant '" + name + "', using an empty tensor"); return Tensor.from(type, "{}"); - throw new IllegalArgumentException("Unknown tensor file format (determined by file ending): " + fileName); + } + if ( ! new File(fileReference).exists()) { // this may be the case in unit tests + log.warning("Got empty file reference for constant '" + name + "', using an empty tensor"); + return Tensor.from(type, "{}"); + } + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileReference)))); + // TODO: Support json and json.lz4 } catch (IOException e) { throw new UncheckedIOException(e); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java index 84e01e58280..2cb9602dfa7 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesImporterTest.java @@ -11,9 +11,11 @@ import org.junit.Test; import java.io.File; import java.util.Map; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; /** * Tests instantiating models from rank-profiles configs. @@ -28,27 +30,49 @@ public class RankProfilesImporterTest { assertEquals(4, models.size()); - Model xgboost = models.get("xgboost_2_2"); - assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); - - Model onnxMnistSoftmax = models.get("mnist_softmax"); - assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - assertEquals("tensor(d1[10],d2[784])", - onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); - - Model tfMnistSoftmax = models.get("mnist_softmax_saved"); - assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); - - Model tfMnist = models.get("mnist_saved"); - assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); + // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that + { + Model xgboost = models.get("xgboost_2_2"); + assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); + FunctionEvaluator evaluator = xgboost.evaluatorOf(); + assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + + Model onnxMnistSoftmax = models.get("mnist_softmax"); + assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); + FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + Model tfMnistSoftmax = models.get("mnist_softmax_saved"); + assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); + FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available + assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } + + { + Model tfMnist = models.get("mnist_saved"); + assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + // Macro: + assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", + "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", + tfMnist); + FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument + assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + } } @Test |