diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java | 63 |
1 files changed, 58 insertions, 5 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index bfd6342218a..b9e7a27c013 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,33 +1,57 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.filedistribution.fileacquirer.FileAcquirer; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.serialization.TypedBinaryFormat; import com.yahoo.vespa.config.search.RankProfilesConfig; +import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import java.io.File; +import java.io.IOException; +import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; /** - * Converts RankProfilesConfig instances to RankingExpressions for evaluation + * Converts RankProfilesConfig instances to RankingExpressions for evaluation. + * This class can be used by a single thread only. * * @author bratseth */ class RankProfilesConfigImporter { /** + * Constants already imported in this while reading some expression. + * This is to avoid re-reading constants referenced + * multiple places, as that is potentially costly. + */ + private Map<String, Constant> globalImportedConstants = new HashMap<>(); + + /** * Returns a map of the models contained in this config, indexed on name. * The map is modifiable and owned by the caller. */ - Map<String, Model> importFrom(RankProfilesConfig config) { + Map<String, Model> importFrom(RankProfilesConfig config, RankingConstantsConfig constantsConfig) { + globalImportedConstants.clear(); try { Map<String, Model> models = new HashMap<>(); for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { - Model model = importProfile(profile); + Model model = importProfile(profile, constantsConfig); models.put(model.name(), model); } return models; @@ -37,11 +61,14 @@ class RankProfilesConfigImporter { } } - private Model importProfile(RankProfilesConfig.Rankprofile profile) throws ParseException { + private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException { List<ExpressionFunction> functions = new ArrayList<>(); Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>(); ExpressionFunction firstPhase = null; ExpressionFunction secondPhase = null; + + List<Constant> constants = readConstants(constantsConfig); + for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); if ( reference.isPresent()) { @@ -69,7 +96,7 @@ class RankProfilesConfigImporter { functions.add(secondPhase); try { - return new Model(profile.name(), functions, referencedFunctions); + return new Model(profile.name(), functions, referencedFunctions, constants); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); @@ -83,4 +110,30 @@ class RankProfilesConfigImporter { return null; } + private List<Constant> readConstants(RankingConstantsConfig constantsConfig) { + List<Constant> constants = new ArrayList<>(); + for (RankingConstantsConfig.Constant constantConfig : constantsConfig.constant()) { + constants.add(new Constant(constantConfig.name(), + readTensorFromFile(TensorType.fromSpec(constantConfig.type()), + constantConfig.fileref().value()))); + } + return constants; + } + + private Tensor readTensorFromFile(TensorType type, String fileName) { + try { + if (fileName.endsWith(".tbf")) + return TypedBinaryFormat.decode(Optional.of(type), + GrowableByteBuffer.wrap(IOUtils.readFileBytes(new File(fileName)))); + // TODO: Support json and json.lz4 + + if (fileName.isEmpty()) // this is the case in unit tests + return Tensor.from(type, "{}"); + throw new IllegalArgumentException("Unknown tensor file format (determined by file ending): " + fileName); + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + } |