diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-10-01 12:52:30 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-10-01 12:52:30 +0200 |
commit | d410a107a048070d07afc199a12dff85bdea139e (patch) | |
tree | e26f33065ebfb612a8c83d840de5661523590d1d /model-evaluation | |
parent | 50bc3b3c198d29374448cc3eac73fbb26e42cab0 (diff) |
Validate all bindings
Diffstat (limited to 'model-evaluation')
6 files changed, 65 insertions, 29 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 1412936d4a0..8ce44ef5ed2 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -3,10 +3,14 @@ package ai.vespa.models.evaluation; import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.Map; +import java.util.stream.Collectors; + /** * An evaluator which can be used to evaluate a single function once. * @@ -35,6 +39,14 @@ public class FunctionEvaluator { public FunctionEvaluator bind(String name, Tensor value) { if (evaluated) throw new IllegalStateException("You cannot bind a value in a used evaluator"); + TensorType requiredType = function.argumentTypes().get(name); + if (requiredType == null) + throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function + + ". Expected arguments: " + function.argumentTypes().entrySet().stream() + .map(e -> e.getKey() + ": " + e.getValue()) + .collect(Collectors.joining(", "))); + if ( ! value.type().isAssignableTo(requiredType)) + throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); context.put(name, new TensorValue(value)); return this; } @@ -52,10 +64,19 @@ public class FunctionEvaluator { } public Tensor evaluate() { + for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { + if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0) + if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue) + throw new IllegalStateException("Missing argument '" + argument.getKey() + + "': Must be bound to a value of type " + argument.getValue()); + } evaluated = true; return function.getBody().evaluate(context).asTensor(); } + /** Returns the function evaluated by this */ + public ExpressionFunction function() { return function; } + public LazyArrayContext context() { return context; } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index d144411127e..093d487c31f 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -32,6 +32,8 @@ import java.util.Set; */ public final class LazyArrayContext extends Context implements ContextIndex { + public final static Value defaultContextValue = DoubleValue.zero; + private final IndexedBindings indexedBindings; private LazyArrayContext(IndexedBindings indexedBindings) { @@ -167,7 +169,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; - Arrays.fill(values, DoubleValue.zero); + Arrays.fill(values, defaultContextValue); int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 683a1f345d8..6edcd84272e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -10,13 +10,16 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.JsonFormat; +import com.yahoo.yolean.Exceptions; import java.io.IOException; import java.io.OutputStream; import java.net.URI; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Map; import java.util.Optional; import java.util.concurrent.Executor; @@ -60,14 +63,17 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { return listModelInformation(request, model, function); } catch (IllegalArgumentException e) { - return new ErrorResponse(404, e.getMessage()); + return new ErrorResponse(404, Exceptions.toMessageString(e)); + } catch (IllegalStateException e) { // On missing bindings + return new ErrorResponse(400, Exceptions.toMessageString(e)); } } private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) { FunctionEvaluator evaluator = model.evaluatorOf(function); - for (String bindingName : evaluator.context().names()) { - property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s))); + for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) { + property(request, argument.getKey()).ifPresent(value -> evaluator.bind(argument.getKey(), + Tensor.from(argument.getValue(), value))); } Tensor result = evaluator.evaluate(); return new Response(200, JsonFormat.encode(result)); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index c4b163e89c0..68c3b954675 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -65,7 +65,7 @@ public class MlModelsImportingTest { assertEquals("tensor(d1[10],d2[784])", onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available - assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + evaluator.bind("Placeholder", inputTensor()); assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); } @@ -85,7 +85,7 @@ public class MlModelsImportingTest { // Evaluator FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available - assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + evaluator.bind("Placeholder", inputTensor()); assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); } @@ -109,11 +109,18 @@ public class MlModelsImportingTest { // Evaluator FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); - assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + evaluator.bind("input", inputTensor()); assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta); } } + private Tensor inputTensor() { + Tensor.Builder b = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[],d1[784])")); + for (int i = 0; i < 784; i++) + b.cell(0.0, 0, i); + return b.build(); + } + private String commaSeparated(List<?> items) { return items.stream().map(item -> item.toString()).sorted().collect(Collectors.joining(", ")); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index bd1ff6b8ed7..57a415e8894 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -20,16 +20,7 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; @Test - public void testTensorEvaluation() { - ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); - FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); - function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); - function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); - assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), function.evaluate()); - } - - @Test - public void testEvaluationDependingOnMacroTakingArguments() { + public void testEvaluationDependingFunctionTakingArguments() { ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); function.bind("match", 3); @@ -40,7 +31,7 @@ public class ModelsEvaluatorTest { // TODO: Test argument-less function // TODO: Test that binding nonexisting variable doesn't work // TODO: Test that rebinding doesn't work - // TODO: Test with nested macros + // TODO: Test with nested functions private ModelsEvaluator createModels(String path) { Path configDir = Path.fromString(path); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index b92e13b640f..b915ee72a79 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -9,6 +9,8 @@ import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; import com.yahoo.path.Path; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.BeforeClass; @@ -94,32 +96,32 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; - assertResponse(url, 200, expected); + String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}"; + assertResponse(url, 400, expected); } @Test public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; - assertResponse(url, 200, expected); + String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}"; + assertResponse(url, 400, expected); } @Test public void testMnistSoftmaxEvaluateDefaultFunctionWithBindings() { Map<String, String> properties = new HashMap<>(); - properties.put("Placeholder", "{1.0}"); + properties.put("Placeholder", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; assertResponse(url, properties, 200, expected); } @Test public void testMnistSoftmaxEvaluateSpecificFunctionWithBindings() { Map<String, String> properties = new HashMap<>(); - properties.put("Placeholder", "{1.0}"); + properties.put("Placeholder", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; assertResponse(url, properties, 200, expected); } @@ -147,9 +149,9 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSavedEvaluateSpecificFunction() { Map<String, String> properties = new HashMap<>(); - properties.put("input", "-1.0"); + properties.put("input", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-2.72208123403445},{\"address\":{\"d1\":\"1\"},\"value\":6.465137496457595},{\"address\":{\"d1\":\"2\"},\"value\":-7.078050386283122},{\"address\":{\"d1\":\"3\"},\"value\":-10.485296462655546},{\"address\":{\"d1\":\"4\"},\"value\":0.19508378636937004},{\"address\":{\"d1\":\"5\"},\"value\":6.348870746681019},{\"address\":{\"d1\":\"6\"},\"value\":10.756191852397258},{\"address\":{\"d1\":\"7\"},\"value\":1.476101533270058},{\"address\":{\"d1\":\"8\"},\"value\":-17.778398655804875},{\"address\":{\"d1\":\"9\"},\"value\":-2.0597690508530295}]}"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.6319251673007533},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":-7.577770600619843E-4},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":-0.010707969042025622},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.6344759233540788},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":-0.17529455385847528},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":0.7490809723192187},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.022790284182901716},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.26799240657608936},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-0.3152438845465862},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":0.05949304847735276}]}"; assertResponse(url, properties, 200, expected); } @@ -171,10 +173,10 @@ public class ModelsEvaluationHandlerTest { static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) { HttpResponse response = handler.handle(request); assertEquals("application/json", response.getContentType()); - assertEquals(expectedCode, response.getStatus()); if (expectedResult != null) { assertEquals(expectedResult, getContents(response)); } + assertEquals(expectedCode, response.getStatus()); } static private String getContents(HttpResponse response) { @@ -198,4 +200,11 @@ public class ModelsEvaluationHandlerTest { return new ModelsEvaluator(importer.importFrom(config, constantsConfig)); } + private String inputTensor() { + Tensor.Builder b = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[],d1[784])")); + for (int i = 0; i < 784; i++) + b.cell(0.0, 0, i); + return b.build().toString(); + } + } |