summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java')
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java26
1 files changed, 20 insertions, 6 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index 63e17e37bde..d2729ab7b6e 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -3,15 +3,20 @@ package ai.vespa.models.evaluation;
import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
+import com.yahoo.io.GrowableByteBuffer;
+import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import java.io.IOException;
import java.util.HashMap;
import java.util.Map;
+import java.util.Optional;
import java.util.logging.Logger;
import static org.junit.Assert.assertEquals;
@@ -38,7 +43,8 @@ public class ModelTester {
RankProfilesConfig.class).getConfig("");
RankingConstantsConfig constantsConfig = new ConfigGetter<>(new FileSource(configDir.append("ranking-constants.cfg").toFile()),
RankingConstantsConfig.class).getConfig("");
- return new RankProfilesConfigImporterWithMockedConstants().importFrom(config, constantsConfig);
+ return new RankProfilesConfigImporterWithMockedConstants(Path.fromString(path).append("constants"))
+ .importFrom(config, constantsConfig);
}
public void assertFunction(String name, String expression, Model model) {
@@ -61,15 +67,23 @@ public class ModelTester {
private static final Logger log = Logger.getLogger(RankProfilesConfigImporterWithMockedConstants.class.getName());
- Map<String, Tensor> constants = new HashMap<>();
+ private final Path constantsPath;
+
+ public RankProfilesConfigImporterWithMockedConstants(Path constantsPath) {
+ this.constantsPath = constantsPath;
+ }
@Override
- Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
- if ( ! constants.containsKey(name)) {
- log.warning("Missing a mocked tensor constant for '" + name + "': Returning an empty tensor");
+ protected Tensor readTensorFromFile(String name, TensorType type, String fileReference) {
+ try {
+ return TypedBinaryFormat.decode(Optional.of(type),
+ GrowableByteBuffer.wrap(IOUtils.readFileBytes(constantsPath.append(name).toFile())));
+ }
+ catch (IOException e) {
+ log.warning("Missing a mocked tensor constant for '" + name + "': " + e.getMessage() +
+ ". Returning an empty tensor");
return Tensor.from(type, "{}");
}
- return constants.get(name);
}
}