aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/test/java/com/yahoo/vespa/model/ml/ImportedModelTester.java
blob: 97c222e75d3f2282ea849b02a7aec35bb91dcd94 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.model.ml;

import com.yahoo.config.FileReference;
import com.yahoo.config.model.ApplicationPackageTester;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import com.yahoo.config.model.deploy.DeployState;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
import ai.vespa.rankingexpression.importer.vespa.VespaImporter;
import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.model.VespaModel;
import org.xml.sax.SAXException;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.List;
import java.util.Optional;

import static org.junit.jupiter.api.Assertions.*;

/**
 * Helper for testing of imported models.
 * More duplicated functionality across tests on imported models should be moved here
 *
 * @author bratseth
 */
public class ImportedModelTester {

    private final List<MlModelImporter> importers = List.of(new TensorFlowImporter(),
                                                            new OnnxImporter(),
                                                            new LightGBMImporter(),
                                                            new XGBoostImporter(),
                                                            new VespaImporter());

    private final String modelName;
    private final Path applicationDir;
    private final DeployState deployState;

    public ImportedModelTester(String modelName, Path applicationDir) {
        this(modelName, applicationDir, new DeployState.Builder());
    }

    public ImportedModelTester(String modelName, Path applicationDir, DeployState.Builder deployStateBuilder) {
        this.modelName = modelName;
        this.applicationDir = applicationDir;
        deployState = deployStateBuilder.applicationPackage(ApplicationPackageTester.create(applicationDir.toString()).app())
                                        .modelImporters(importers)
                                        .build();
    }

    public VespaModel createVespaModel() {
        try {
            return new VespaModel(deployState);
        }
        catch (SAXException | IOException e) {
            throw new RuntimeException(e);
        }
    }

    /**
     * Verifies that the constant with the given name exists, and - only if an expected size is given -
     * that the content of the constant is available and has the expected size.
     */
    public void assertLargeConstant(String constantName, VespaModel model, Optional<Long> expectedSize) {
        try {
            Path constantApplicationPackagePath = Path.fromString("models.generated/" + modelName + "/constants").append(constantName + ".tbf");
            var constant = model.rankProfileList().constants().asMap().get(constantName);
            assertNotNull(constant);
            assertEquals(constantName, constant.getName());
            assertTrue(constant.getFileName().endsWith(constantApplicationPackagePath.toString()));

            assertTrue(model.fileReferences().contains(new FileReference(constant.getFileName())));

            if (expectedSize.isPresent()) {
                Path constantPath = applicationDir.append(constantApplicationPackagePath);
                assertTrue(constantPath.toFile().exists(),
                           "Constant file '" + constantPath + "' has been written");
                Tensor deserializedConstant = TypedBinaryFormat.decode(Optional.empty(),
                                                                       GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantPath.toFile())));
                assertEquals(expectedSize.get().longValue(), deserializedConstant.size());
            }
        }
        catch (IOException e) {
            throw new UncheckedIOException(e);
        }
    }

}