diff options
author | Lester Solbakken <lesters@oath.com> | 2021-06-28 10:24:00 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-06-28 10:24:00 +0200 |
commit | be64d42b11f3c922e17a7c8ed3c627936a2e98cb (patch) | |
tree | b1ab55ccf6329badf003f59942d3f87db1c29fae /model-evaluation | |
parent | d713569989c88b541305e79ac531b0fc8a8bceaa (diff) |
Remove onnx feature as argument for stateless evaluation
Diffstat (limited to 'model-evaluation')
4 files changed, 3 insertions, 11 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 910aca8aa98..72fefe62e61 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -101,11 +101,11 @@ public class FunctionEvaluator { } public Tensor evaluate() { - evaluateOnnxModels(); for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { checkArgument(argument.getKey(), argument.getValue()); } evaluated = true; + evaluateOnnxModels(); return function.getBody().evaluate(context).asTensor(); } @@ -126,7 +126,6 @@ public class FunctionEvaluator { if (context.get(onnxFeature).equals(context.defaultValue())) { Map<String, Tensor> inputs = new HashMap<>(); for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) { - checkArgument(input.getKey(), input.getValue()); inputs.put(input.getKey(), context.get(input.getKey()).asTensor()); } Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index a5dcd2719c9..5754977bb46 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -274,7 +274,6 @@ public final class LazyArrayContext extends Context implements ContextIndex { if (onnxModel.name().equals(modelName.get())) { String onnxFeature = node.toString(); bindTargets.add(onnxFeature); - arguments.add(onnxFeature); // Load the model (if not already loaded) to extract inputs onnxModel.load(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 5590b9f0242..506a07be8d1 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -72,13 +72,10 @@ public class Model { } for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) { - String onnxFeature = entry.getKey(); OnnxModel onnxModel = entry.getValue(); for(Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) { functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue())); } - TensorType onnxOutputType = onnxModel.outputs().get(function.getKey().functionName()); - functions.put(function.getKey(), function.getValue().withArgument(onnxFeature, onnxOutputType)); } for (String argument : context.arguments()) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java index 97692de56ef..2d419751db6 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java @@ -46,7 +46,6 @@ public class OnnxEvaluationHandlerTest { "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output1/eval\"," + "\"arguments\":[" + "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," + - "{\"name\":\"onnxModel(add_mul).output1\",\"type\":\"tensor<float>(d0[1])\"}," + "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" + "]}," + "{\"function\":\"output2\"," + @@ -54,7 +53,6 @@ public class OnnxEvaluationHandlerTest { "\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output2/eval\"," + "\"arguments\":[" + "{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," + - "{\"name\":\"onnxModel(add_mul).output2\",\"type\":\"tensor<float>(d0[1])\"}," + "{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" + "]}]}"; handler.assertResponse(url, 200, expected); @@ -70,7 +68,7 @@ public class OnnxEvaluationHandlerTest { @Test public void testEvaluationWithoutBindings() { String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval"; - String expected = "{\"error\":\"Argument 'input2' must be bound to a value of type tensor<float>(d0[1])\"}"; + String expected = "{\"error\":\"Argument 'input1' must be bound to a value of type tensor<float>(d0[1])\"}"; handler.assertResponse(url, 400, expected); } @@ -102,8 +100,7 @@ public class OnnxEvaluationHandlerTest { "\"info\":\"http://localhost/model-evaluation/v1/one_layer/output\"," + "\"eval\":\"http://localhost/model-evaluation/v1/one_layer/output/eval\"," + "\"arguments\":[" + - "{\"name\":\"input\",\"type\":\"tensor<float>(d0[],d1[3])\"}," + - "{\"name\":\"onnxModel(one_layer)\",\"type\":\"tensor<float>(d0[],d1[1])\"}" + + "{\"name\":\"input\",\"type\":\"tensor<float>(d0[],d1[3])\"}" + "]}]}"; handler.assertResponse(url, 200, expected); } |