aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test/java/ai/vespa/models/handler/OnnxEvaluationHandlerTest.java
blob: d4909e981b32b5e21b9c056075f401b0d848da92 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.handler;

import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import ai.vespa.models.evaluation.ModelsEvaluator;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.tensor.Tensor;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.OnnxModelsConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import com.yahoo.vespa.config.search.core.RankingExpressionsConfig;
import org.junit.BeforeClass;
import org.junit.Test;

import java.io.File;
import java.util.HashMap;
import java.util.Map;

import static org.junit.Assume.assumeTrue;

public class OnnxEvaluationHandlerTest {

    private static HandlerTester handler;
    private static final String CONFIG_DIR = "src/test/resources/config/onnx/";

    @BeforeClass
    static public void setUp() {
        assumeTrue(OnnxRuntime.isRuntimeAvailable());
        handler = new HandlerTester(createModels());
    }

    @Test
    public void testListModels() {
        String url = "http://localhost/model-evaluation/v1";
        String expected = "{\"one_layer\":\"http://localhost/model-evaluation/v1/one_layer\"," +
                           "\"add_mul\":\"http://localhost/model-evaluation/v1/add_mul\"," +
                           "\"no_model\":\"http://localhost/model-evaluation/v1/no_model\"}";
        handler.checkResponse(url, 200, HandlerTester.matchJson(expected));
    }

    @Test
    public void testModelInfo() {
        String url = "http://localhost/model-evaluation/v1/add_mul";
        var check = HandlerTester.matchJson(
                "{'model':'add_mul','functions':[",
                " {'function':'output1',",
                "  'info':'http://localhost/model-evaluation/v1/add_mul/output1',",
                "  'eval':'http://localhost/model-evaluation/v1/add_mul/output1/eval',",
                "  'arguments':[",
                "   {'name':'input1','type':'tensor<float>(d0[1])'},",
                "   {'name':'input2','type':'tensor<float>(d0[1])'}",
                "  ]},",
                " {'function':'output2',",
                "  'info':'http://localhost/model-evaluation/v1/add_mul/output2',",
                "  'eval':'http://localhost/model-evaluation/v1/add_mul/output2/eval',",
                "  'arguments':[",
                "   {'name':'input1','type':'tensor<float>(d0[1])'},",
                "   {'name':'input2','type':'tensor<float>(d0[1])'}",
                "  ]}]}");
        handler.checkResponse(url, 200, check);
    }

    @Test
    public void testEvaluationWithoutSpecifyingOutput() {
        String url = "http://localhost/model-evaluation/v1/add_mul/eval";
        String expected = "{\"error\":\"More than one function is available in model 'add_mul', but no name is given. Available functions: output1, output2\"}";
        handler.assertResponse(url, 404, expected);
    }

    @Test
    public void testEvaluationWithoutBindings() {
        String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval";
        String expected = "{\"error\":\"Argument 'input1' must be bound to a value of type tensor<float>(d0[1])\"}";
        handler.assertResponse(url, 400, expected);
    }

    @Test
    public void testEvaluationOutput1() {
        Map<String, String> properties = new HashMap<>();
        properties.put("input1", "tensor<float>(d0[1]):[2]");
        properties.put("input2", "tensor<float>(d0[1]):[3]");
        properties.put("format.tensors", "long");
        String url = "http://localhost/model-evaluation/v1/add_mul/output1/eval";
        String expected = "{\"type\":\"tensor<float>(d0[1])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":6.0}]}";  // output1 is a mul
        handler.assertResponse(url, properties, 200, expected);
    }

    @Test
    public void testEvaluationOutput2() {
        Map<String, String> properties = new HashMap<>();
        properties.put("input1", "tensor<float>(d0[1]):[2]");
        properties.put("input2", "tensor<float>(d0[1]):[3]");
        properties.put("format.tensors", "long");
        String url = "http://localhost/model-evaluation/v1/add_mul/output2/eval";
        String expected = "{\"type\":\"tensor<float>(d0[1])\",\"cells\":[{\"address\":{\"d0\":\"0\"},\"value\":5.0}]}";  // output2 is an add
        handler.assertResponse(url, properties, 200, expected);
    }

    @Test
    public void testBatchDimensionModelInfo() {
        String url = "http://localhost/model-evaluation/v1/one_layer";
        String expected = "{\"model\":\"one_layer\",\"functions\":[" +
                "{\"function\":\"output\"," +
                "\"info\":\"http://localhost/model-evaluation/v1/one_layer/output\"," +
                "\"eval\":\"http://localhost/model-evaluation/v1/one_layer/output/eval\"," +
                "\"arguments\":[" +
                "{\"name\":\"input\",\"type\":\"tensor<float>(d0[],d1[3])\"}" +
                "]}]}";
        handler.assertResponse(url, 200, expected);
    }

    @Test
    public void testBatchDimensionEvaluation() {
        Map<String, String> properties = new HashMap<>();
        properties.put("input", "tensor<float>(d0[],d1[3]):{{d0:0,d1:0}:0.1,{d0:0,d1:1}:0.2,{d0:0,d1:2}:0.3,{d0:1,d1:0}:0.4,{d0:1,d1:1}:0.5,{d0:1,d1:2}:0.6}");
        properties.put("format.tensors", "long");
        String url = "http://localhost/model-evaluation/v1/one_layer/eval";  // output not specified
        Tensor expected = Tensor.from("tensor<float>(d0[2],d1[1]):[0.6393113,0.67574286]");
        handler.assertResponse(url, properties, 200, expected);
    }

    @SuppressWarnings("deprecation")
    static private ModelsEvaluator createModels() {
        RankProfilesConfig config = ConfigGetter.getConfig(RankProfilesConfig.class, fileConfigId("rank-profiles.cfg"));
        RankingConstantsConfig constantsConfig = ConfigGetter.getConfig(RankingConstantsConfig.class, fileConfigId("ranking-constants.cfg"));
        RankingExpressionsConfig expressionsConfig = ConfigGetter.getConfig(RankingExpressionsConfig.class, fileConfigId("ranking-expressions.cfg"));
        OnnxModelsConfig onnxModelsConfig = ConfigGetter.getConfig(OnnxModelsConfig.class, fileConfigId("onnx-models.cfg"));

        Map<String, File> fileMap = new HashMap<>();
        for (OnnxModelsConfig.Model onnxModel : onnxModelsConfig.model()) {
            fileMap.put(onnxModel.fileref().value(), new File(CONFIG_DIR + onnxModel.fileref().value()));
        }
        FileAcquirer fileAcquirer = MockFileAcquirer.returnFiles(fileMap);

        return new ModelsEvaluator(config, constantsConfig, expressionsConfig, onnxModelsConfig, fileAcquirer);
    }

    private static String fileConfigId(String filename) {
        return "file:" + CONFIG_DIR + filename;
    }

}