aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-06-28 10:24:00 +0200
committerLester Solbakken <lesters@oath.com>2021-06-28 10:24:00 +0200
commitbe64d42b11f3c922e17a7c8ed3c627936a2e98cb (patch)
treeb1ab55ccf6329badf003f59942d3f87db1c29fae /model-evaluation
parentd713569989c88b541305e79ac531b0fc8a8bceaa (diff)
Remove onnx feature as argument for stateless evaluation
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java3
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java1
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java3
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java7
4 files changed, 3 insertions, 11 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 910aca8aa98..72fefe62e61 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -101,11 +101,11 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
- evaluateOnnxModels();
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
checkArgument(argument.getKey(), argument.getValue());
}
evaluated = true;
+ evaluateOnnxModels();
return function.getBody().evaluate(context).asTensor();
}
@@ -126,7 +126,6 @@ public class FunctionEvaluator {
if (context.get(onnxFeature).equals(context.defaultValue())) {
Map<String, Tensor> inputs = new HashMap<>();
for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) {
- checkArgument(input.getKey(), input.getValue());
inputs.put(input.getKey(), context.get(input.getKey()).asTensor());
}
Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index a5dcd2719c9..5754977bb46 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -274,7 +274,6 @@ public final class LazyArrayContext extends Context implements ContextIndex {
if (onnxModel.name().equals(modelName.get())) {
String onnxFeature = node.toString();
bindTargets.add(onnxFeature);
- arguments.add(onnxFeature);
// Load the model (if not already loaded) to extract inputs
onnxModel.load();
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 5590b9f0242..506a07be8d1 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -72,13 +72,10 @@ public class Model {
}
for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) {
- String onnxFeature = entry.getKey();
OnnxModel onnxModel = entry.getValue();
for(Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) {
functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue()));
}
- TensorType onnxOutputType = onnxModel.outputs().get(function.getKey().functionName());
- functions.put(function.getKey(), function.getValue().withArgument(onnxFeature, onnxOutputType));
}
for (String argument : context.arguments()) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
index 97692de56ef..2d419751db6 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
@@ -46,7 +46,6 @@ public class OnnxEvaluationHandlerTest {
"\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output1/eval\"," +
"\"arguments\":[" +
"{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," +
- "{\"name\":\"onnxModel(add_mul).output1\",\"type\":\"tensor<float>(d0[1])\"}," +
"{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" +
"]}," +
"{\"function\":\"output2\"," +
@@ -54,7 +53,6 @@ public class OnnxEvaluationHandlerTest {
"\"eval\":\"http://localhost/model-evaluation/v1/add_mul/output2/eval\"," +
"\"arguments\":[" +
"{\"name\":\"input1\",\"type\":\"tensor<float>(d0[1])\"}," +
- "{\"name\":\"onnxModel(add_mul).output2\",\"type\":\"tensor<float>(d0[1])\"}," +
"{\"name\":\"input2\",\"type\":\"tensor<float>(d0[1])\"}" +
"]}]}";
handler.assertResponse(url, 200, expected);
@@ -70,7 +68,7 @@ public class OnnxEvaluationHandlerTest {
@Test
public void testEvaluationWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval";
- String expected = "{\"error\":\"Argument 'input2' must be bound to a value of type tensor<float>(d0[1])\"}";
+ String expected = "{\"error\":\"Argument 'input1' must be bound to a value of type tensor<float>(d0[1])\"}";
handler.assertResponse(url, 400, expected);
}
@@ -102,8 +100,7 @@ public class OnnxEvaluationHandlerTest {
"\"info\":\"http://localhost/model-evaluation/v1/one_layer/output\"," +
"\"eval\":\"http://localhost/model-evaluation/v1/one_layer/output/eval\"," +
"\"arguments\":[" +
- "{\"name\":\"input\",\"type\":\"tensor<float>(d0[],d1[3])\"}," +
- "{\"name\":\"onnxModel(one_layer)\",\"type\":\"tensor<float>(d0[],d1[1])\"}" +
+ "{\"name\":\"input\",\"type\":\"tensor<float>(d0[],d1[3])\"}" +
"]}]}";
handler.assertResponse(url, 200, expected);
}