summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java7
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java12
5 files changed, 18 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 4cd8c6ac92b..8634d51c418 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -44,7 +44,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
// TODO: Put modelPath in FeatureArguments instead
Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context));
return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
}
catch (IllegalArgumentException | UncheckedIOException e) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 72cfde0a566..5139d041f00 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -42,7 +42,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, false, context));
return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
}
catch (IllegalArgumentException | UncheckedIOException e) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
index 8591bf16d07..f21248b6d74 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java
@@ -43,7 +43,7 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr
try {
Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0)));
ConvertedModel convertedModel =
- convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, context));
+ convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context));
return convertedModel.expression(asFeatureArguments(feature.getArguments()), context);
} catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e);
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 1f27b9843cd..e2236feb336 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -84,10 +84,13 @@ public class ConvertedModel {
/**
* Create and store a converted model for a rank profile given from either an imported model,
* or (if unavailable) from stored application package data.
+ *
+ * @param modelPath the path to the model
+ * @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory
*/
- public static ConvertedModel fromSourceOrStore(Path modelPath, RankProfileTransformContext context) {
+ public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) {
File sourceModel = sourceModelFile(context.rankProfile().applicationPackage(), modelPath);
- ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath);
+ ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile);
if (sourceModel.exists())
return fromSource(modelName,
modelPath.toString(),
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java
index 5e22fefd093..2c7dc6b337d 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ModelName.java
@@ -19,8 +19,9 @@ public class ModelName {
this(null, name);
}
- public ModelName(String namespace, Path modelPath) {
- this(namespace, modelPath.toString().replace("/", "_"));
+ public ModelName(String namespace, Path modelPath, boolean pathIsFile) {
+ this(namespace,
+ stripFileEndingIfFile(modelPath, pathIsFile).toString().replace("/", "_"));
}
private ModelName(String namespace, String name) {
@@ -29,6 +30,13 @@ public class ModelName {
this.fullName = (namespace != null ? namespace + "." : "") + name;
}
+ private static Path stripFileEndingIfFile(Path path, boolean pathIsFile) {
+ if ( ! pathIsFile) return path;
+ int dotIndex = path.last().lastIndexOf(".");
+ if (dotIndex <= 0) return path;
+ return path.withLast(path.last().substring(0, dotIndex));
+ }
+
/** Returns true if the local name of this is not in a namespace */
public boolean isGlobal() { return namespace == null; }