aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfilesConfigImporterWithMockedConstants.java
blob: 1beb441be2dc2d7bdcceb16d2ddc7c6a0d41a473 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;

import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;

import java.io.IOException;
import java.util.Optional;
import java.util.logging.Logger;

/** Allows us to provide canned tensor constants during import since file distribution does not work in tests */
public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter {

    private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName());

    private final Path constantsPath;

    public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) {
        super(fileAcquirer, new OnnxRuntime());
        this.constantsPath = constantsPath;
    }

    @Override
    protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) {
        try {
            var path = constantsPath.append(fileReference.value());
            var file = path.toFile();
            return TypedBinaryFormat.decode(Optional.of(type),
                                            GrowableByteBuffer.wrap(IOUtils.readFileBytes(file)));
        }
        catch (IOException e) {
            log.warning("Missing a mocked tensor constant for '" + name + "': " + e.getMessage() +
                    ". Returning an empty tensor");
            return Tensor.from(type, "{}");
        }
    }

    @Override
    protected RankingExpression readExpressionFromFile(String name, FileReference fileReference) throws ParseException {
        try {
            return new RankingExpression(name, readExpressionFromFile(constantsPath.append(fileReference.value()).toFile()));
        } catch (IOException e) {
            throw new IllegalArgumentException("Missing expression file '" + fileReference.value() + "' for expression '" + name + "'.", e);
        }
    }
}