blob: 1beb441be2dc2d7bdcceb16d2ddc7c6a0d41a473 (
plain) (
blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.IOException;
import java.util.Optional;
import java.util.logging.Logger;
/** Allows us to provide canned tensor constants during import since file distribution does not work in tests */
public class RankProfilesConfigImporterWithMockedConstants extends RankProfilesConfigImporter {
private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName());
private final Path constantsPath;
public RankProfilesConfigImporterWithMockedConstants(Path constantsPath, FileAcquirer fileAcquirer) {
super(fileAcquirer, new OnnxRuntime());
this.constantsPath = constantsPath;
}
@Override
protected Tensor readTensorFromFile(String name, TensorType type, FileReference fileReference) {
try {
var path = constantsPath.append(fileReference.value());
var file = path.toFile();
return TypedBinaryFormat.decode(Optional.of(type),
GrowableByteBuffer.wrap(IOUtils.readFileBytes(file)));
}
catch (IOException e) {
log.warning("Missing a mocked tensor constant for '" + name + "': " + e.getMessage() +
". Returning an empty tensor");
return Tensor.from(type, "{}");
}
}
@Override
protected RankingExpression readExpressionFromFile(String name, FileReference fileReference) throws ParseException {
try {
return new RankingExpression(name, readExpressionFromFile(constantsPath.append(fileReference.value()).toFile()));
} catch (IOException e) {
throw new IllegalArgumentException("Missing expression file '" + fileReference.value() + "' for expression '" + name + "'.", e);
}
}
}
|