diff options
Diffstat (limited to 'config-model')
4 files changed, 22 insertions, 10 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index d928f17e345..ab143f77b6a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -42,10 +42,9 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans try { FeatureArguments arguments = asFeatureArguments(feature.getArguments()); - // TODO: Put modelPath in FeatureArguments instead - Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); ConvertedModel convertedModel = - convertedOnnxModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context)); + convertedOnnxModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, true, context)); return convertedModel.expression(arguments, context); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index c468d27457c..4a315420b0a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -41,9 +41,9 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { FeatureArguments arguments = asFeatureArguments(feature.getArguments()); - Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); // TODO: Put in FeatureArguments ConvertedModel convertedModel = - convertedTensorFlowModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, false, context)); + convertedTensorFlowModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, false, context)); return convertedModel.expression(arguments, context); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java index fcb05f7e5b6..663c5afbed6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/XgboostFeatureConverter.java @@ -42,9 +42,9 @@ public class XgboostFeatureConverter extends ExpressionTransformer<RankProfileTr try { FeatureArguments arguments = asFeatureArguments(feature.getArguments()); - Path modelPath = Path.fromString(FeatureArguments.asString(feature.getArguments().expressions().get(0))); // TODO: Keep in FeatureArguments ConvertedModel convertedModel = - convertedXGBoostModels.computeIfAbsent(modelPath, __ -> ConvertedModel.fromSourceOrStore(modelPath, true, context)); + convertedXGBoostModels.computeIfAbsent(arguments.path(), + path -> ConvertedModel.fromSourceOrStore(path, true, context)); return convertedModel.expression(arguments, context); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use XGBoost model from " + feature, e); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java index fda49af6178..4a02dc97d19 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/FeatureArguments.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.ml; +import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -14,19 +15,24 @@ import java.util.Optional; */ public class FeatureArguments { + private final Path path; + /** Optional arguments */ private final Optional<String> signature, output; public FeatureArguments(Arguments arguments) { - this(optionalArgument(1, arguments), + this(Path.fromString(argument(0, arguments)), + optionalArgument(1, arguments), optionalArgument(2, arguments)); } - public FeatureArguments(Optional<String> signature, Optional<String> output) { + private FeatureArguments(Path path, Optional<String> signature, Optional<String> output) { + this.path = path; this.signature = signature; this.output = output; } + public Path path() { return path; } public Optional<String> signature() { return signature; } public Optional<String> output() { return output; } @@ -35,13 +41,20 @@ public class FeatureArguments { (output.isPresent() ? "." + output.get() : ""); } + private static String argument(int argumentIndex, Arguments arguments) { + if (argumentIndex >= arguments.expressions().size()) + throw new IllegalArgumentException("Requires at least " + argumentIndex + + " arguments, but got just " + arguments.size()); + return asString(arguments.expressions().get(argumentIndex)); + } + private static Optional<String> optionalArgument(int argumentIndex, Arguments arguments) { if (argumentIndex >= arguments.expressions().size()) return Optional.empty(); return Optional.of(asString(arguments.expressions().get(argumentIndex))); } - public static String asString(ExpressionNode node) { + private static String asString(ExpressionNode node) { if ( ! (node instanceof ConstantNode)) throw new IllegalArgumentException("Expected a constant string as argument, but got '" + node); return stripQuotes(((ConstantNode)node).sourceString()); |