diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java | 66 |
1 files changed, 35 insertions, 31 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 3abf4ec3596..24f6b1390fd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -130,7 +130,6 @@ public class RankProfile implements Cloneable { /** Global onnx models not tied to a schema */ private final OnnxModels onnxModels; - private final RankingConstants rankingConstants; private final ApplicationPackage applicationPackage; private final DeployLogger deployLogger; @@ -142,11 +141,10 @@ public class RankProfile implements Cloneable { * @param rankProfileRegistry the {@link com.yahoo.searchdefinition.RankProfileRegistry} to use for storing * and looking up rank profiles. */ - public RankProfile(String name, Schema schema, RankProfileRegistry rankProfileRegistry, RankingConstants rankingConstants) { + public RankProfile(String name, Schema schema, RankProfileRegistry rankProfileRegistry) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.schema = Objects.requireNonNull(schema, "schema cannot be null"); this.onnxModels = null; - this.rankingConstants = rankingConstants; this.rankProfileRegistry = rankProfileRegistry; this.applicationPackage = schema.applicationPackage(); this.deployLogger = schema.getDeployLogger(); @@ -158,11 +156,10 @@ public class RankProfile implements Cloneable { * @param name the name of the new profile */ public RankProfile(String name, ApplicationPackage applicationPackage, DeployLogger deployLogger, - RankProfileRegistry rankProfileRegistry, RankingConstants rankingConstants, OnnxModels onnxModels) { + RankProfileRegistry rankProfileRegistry, OnnxModels onnxModels) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.schema = null; this.rankProfileRegistry = rankProfileRegistry; - this.rankingConstants = rankingConstants; this.onnxModels = onnxModels; this.applicationPackage = applicationPackage; this.deployLogger = deployLogger; @@ -178,11 +175,6 @@ public class RankProfile implements Cloneable { return applicationPackage; } - /** Returns the ranking constants of the owner of this */ - public RankingConstants rankingConstants() { - return rankingConstants; - } - public Map<String, OnnxModel> onnxModels() { return schema != null ? schema.onnxModels().asMap() : onnxModels.asMap(); } @@ -415,24 +407,25 @@ public class RankProfile implements Cloneable { return finalSettings; } - public void addConstant(Reference name, Constant value) { - constants.put(name, value); + public void add(Constant constant) { + constants.put(constant.name(), constant); } /** Returns an unmodifiable view of the constants available in this */ public Map<Reference, Constant> getConstants() { - if (inherited().isEmpty()) return new HashMap<>(constants); - Map<Reference, Constant> allConstants = new HashMap<>(); for (var inheritedProfile : inherited()) { - for (var constant : inheritedProfile.getConstants().entrySet()) { - if (allConstants.containsKey(constant.getKey())) - throw new IllegalArgumentException("Constant '" + constant.getKey() + "' is present in " + + for (var constant : inheritedProfile.getConstants().values()) { + if (allConstants.containsKey(constant.name())) + throw new IllegalArgumentException(constant + "' is present in " + inheritedProfile + " inherited by " + this + ", but is also present in another profile inherited by it"); - allConstants.put(constant.getKey(), constant.getValue()); + allConstants.put(constant.name(), constant); } } + + if (schema != null) + allConstants.putAll(schema.constants()); allConstants.putAll(constants); return allConstants; } @@ -1046,9 +1039,7 @@ public class RankProfile implements Cloneable { Map<Reference, TensorType> featureTypes) { MapEvaluationTypeContext context = new MapEvaluationTypeContext(getExpressionFunctions(), featureTypes); - // Add small and large constants, respectively getConstants().forEach((k, v) -> context.setType(k, v.type())); - rankingConstants().asMap().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType())); // Add query features from all rank profile types for (QueryProfileType queryProfileType : queryProfiles.getTypeRegistry().allComponents()) { @@ -1427,23 +1418,32 @@ public class RankProfile implements Cloneable { private final Optional<Tensor> value; private final Optional<String> valuePath; + // Always set only if valuePath is set + private final Optional<DistributableResource.PathType> pathType; + public Constant(Reference name, Tensor value) { - this(name, value.type(), Optional.of(value), Optional.empty()); + this(name, value.type(), Optional.of(value), Optional.empty(), Optional.empty()); } public Constant(Reference name, TensorType type, String valuePath) { - this(name, type, Optional.empty(), Optional.of(valuePath)); + this(name, type, Optional.empty(), Optional.of(valuePath), Optional.of(DistributableResource.PathType.FILE)); + } + + public Constant(Reference name, TensorType type, String valuePath, DistributableResource.PathType pathType) { + this(name, type, Optional.empty(), Optional.of(valuePath), Optional.of(pathType)); } - private Constant(Reference name, TensorType type, Optional<Tensor> value, Optional<String> valuePath) { + private Constant(Reference name, TensorType type, Optional<Tensor> value, + Optional<String> valuePath, Optional<DistributableResource.PathType> pathType) { + this.name = Objects.requireNonNull(name); + this.type = Objects.requireNonNull(type); + this.value = Objects.requireNonNull(value); + this.valuePath = Objects.requireNonNull(valuePath); + this.pathType = Objects.requireNonNull(pathType); + if (type.dimensions().stream().anyMatch(d -> d.isIndexed() && d.size().isEmpty())) throw new IllegalArgumentException("Illegal type of constant " + name + " type " + type + ": Dense tensor dimensions must have a size"); - - this.name = name; - this.type = type; - this.value = value; - this.valuePath = valuePath; } public Reference name() { return name; } @@ -1455,6 +1455,9 @@ public class RankProfile implements Cloneable { /** Returns the path to the value of this, if its value is empty. */ public Optional<String> valuePath() { return valuePath; } + /** Returns the path type, if valuePath is set. */ + public Optional<DistributableResource.PathType> pathType() { return pathType; } + @Override public boolean equals(Object o) { if (o == this) return true; @@ -1464,18 +1467,19 @@ public class RankProfile implements Cloneable { if ( ! other.type().equals(this.type())) return false; if ( ! other.value().equals(this.value())) return false; if ( ! other.valuePath().equals(this.valuePath())) return false; + if ( ! other.pathType().equals(this.pathType())) return false; return true; } @Override public int hashCode() { - return Objects.hash(name, type, value, valuePath); + return Objects.hash(name, type, value, valuePath, pathType); } @Override public String toString() { - return "constant '" + name + "' " + type + - (value().isPresent() ? ":" + value.get().toAbbreviatedString() : "file:" + valuePath.get()); + return "constant '" + name + "' " + type + ":" + + (value().isPresent() ? value.get().toAbbreviatedString() : " file:" + valuePath.get()); } } |