diff options
Diffstat (limited to 'config-model')
5 files changed, 18 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 4cd8c6ac92b..8634d51c418 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -44,7 +44,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans // TODO: Put modelPath in FeatureArguments instead Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); + convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 72cfde0a566..5139d041f00 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -42,7 +42,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); + convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, false, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index 8591bf16d07..f21248b6d74 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -43,7 +43,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr try { Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context)); + convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context)); return convertedModel.expression(asFeatureArguments(feature.getArguments()), context); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 1f27b9843cd..e2236feb336 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -84,10 +84,13 @@ public class ConvertedModel { /** * Create and store a converted model for a rank profile given from either an imported model, * or (if unavailable) from stored application package data. + * + * @param modelPath the path to the model + * @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory */ - public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) { + public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) { File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath); - ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath); + ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile); if (sourceModel.exists()) return fromSource(modelName, modelPath.toString(), diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java index 5e22fefd093..2c7dc6b337d 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java @@ -19,8 +19,9 @@ public class ModelName { this(null, name); } - public ModelName(String namespace, Path modelPath) { - this(namespace, modelPath.toString().replace("/", "_")); + public ModelName(String namespace, Path modelPath, boolean pathIsFile) { + this(namespace, + stripFileEndingIfFile(modelPath, pathIsFile).toString().replace("/", "_")); } private ModelName(String namespace, String name) { @@ -29,6 +30,13 @@ public class ModelName { this.fullName = (namespace != null ? namespace + "." : "") + name; } + private static Path stripFileEndingIfFile(Path path, boolean pathIsFile) { + if ( ! pathIsFile) return path; + int dotIndex = path.last().lastIndexOf("."); + if (dotIndex <= 0) return path; + return path.withLast(path.last().substring(0, dotIndex)); + } + /** Returns true if the local name of this is not in a namespace */ public boolean isGlobal() { return namespace == null; } |