aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
blob: b128f052d53eee6368af1e3796937d4cf4e62aef (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.vespa;

import ai.vespa.rankingexpression.importer.ImportedModel;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import org.junit.Test;

import java.util.Map;
import java.util.stream.Collectors;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
 * @author bratseth
 */
public class VespaImportTestCase {

    @Test
    public void testExample() {
        ImportedModel model = importModel("example");
        assertModel(model);
    }

    @Test
    public void testLegacySyntax() {
        ImportedModel model = importModel("legacy_syntax");
        assertModel(model);
    }

    private void assertModel(ImportedModel model) {
        assertEquals(2, model.inputs().size());
        assertEquals("tensor(name{},x[3])", model.inputs().get("input1").toString());
        assertEquals("tensor(x[3])", model.inputs().get("input2").toString());

        assertEquals(2, model.smallConstantTensors().size());
        assertEquals("tensor(x[3]):[0.5, 1.5, 2.5]", model.smallConstantTensors().get("constant1").toString());
        assertEquals("tensor():{3.0}", model.smallConstantTensors().get("constant2").toString());

        assertEquals(1, model.largeConstantTensors().size());
        assertEquals("tensor(x[3]):[0.5, 1.5, 2.5]", model.largeConstantTensors().get("constant1asLarge").toString());

        assertEquals(2, model.expressions().size());
        assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2",
                     model.expressions().get("foo1").getRoot().toString());
        assertEquals("reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * constant2",
                     model.expressions().get("foo2").getRoot().toString());

        Map<String, ImportedMlFunction> byName = model.outputExpressions().stream()
                .collect(Collectors.toUnmodifiableMap(ImportedMlFunction::name, f -> f));
        assertEquals(2, byName.size());
        assertTrue(byName.containsKey("foo1"));
        assertTrue(byName.containsKey("foo2"));
        ImportedMlFunction foo1Function = byName.get("foo1");
        assertEquals("foo1", foo1Function.name());
        assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2", foo1Function.expression());
        assertEquals("tensor():{202.5}", evaluate(foo1Function, "{{name:a, x:0}: 1, {name:a, x:1}: 2, {name:a, x:2}: 3}").toString());
        assertEquals(2, foo1Function.arguments().size());
        assertTrue(foo1Function.arguments().contains("input1"));
        assertTrue(foo1Function.arguments().contains("input2"));
        assertEquals(2, foo1Function.argumentTypes().size());
        assertEquals("tensor(name{},x[3])", foo1Function.argumentTypes().get("input1"));
        assertEquals("tensor(x[3])", foo1Function.argumentTypes().get("input2"));
    }

    @Test
    public void testEmpty() {
        ImportedModel model = importModel("empty");
        assertTrue(model.expressions().isEmpty());
        assertTrue(model.functions().isEmpty());
        assertTrue(model.inputs().isEmpty());
        assertTrue(model.largeConstantTensors().isEmpty());
        assertTrue(model.smallConstantTensors().isEmpty());
    }

    @Test
    public void testWrongName() {
        try {
            importModel("misnamed");
            fail("Expected exception");
        }
        catch (IllegalArgumentException e) {
            assertEquals("Unexpected model name 'misnamed': " +
                         "Model 'expectedname' must be saved in a file named 'expectedname.model'", e.getMessage());
        }
    }

    private ImportedModel importModel(String name) {
        String modelPath = "src/test/models/vespa/" + name + ".model";

        VespaImporter importer = new VespaImporter();
        assertTrue(importer.canImport(modelPath));
        ImportedModel model = new VespaImporter().importModel(name, modelPath);
        assertEquals(name, model.name());
        assertEquals(modelPath, model.source());
        return model;
    }

    private Tensor evaluate(ImportedMlFunction function, String input1Argument) {
        try {
            MapContext context = new MapContext();
            context.put("input1", new TensorValue(Tensor.from(function.argumentTypes().get("input1"), input1Argument)));
            context.put("input2", new TensorValue(Tensor.from(function.argumentTypes().get("input2"), "{{x:0}:3, {x:1}:6, {x:2}:9}")));
            context.put("constant1", new TensorValue(Tensor.from("tensor(x[3]):{{x:0}:0.5, {x:1}:1.5, {x:2}:2.5}")));
            context.put("constant2", new TensorValue(Tensor.from("tensor():{{}:3}")));
            return new RankingExpression(function.expression()).evaluate(context).asTensor();
        }
        catch (ParseException e) {
            throw new IllegalArgumentException(e);
        }
    }
}