diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 12:50:08 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-10 12:50:08 +0100 |
commit | d479ea2ed063832297286e60e9ffb2b7f248be59 (patch) | |
tree | e1d27ccc652d2be336159fafb37ea94d14cb1d2e /config-model | |
parent | 9ef5fd6f9edf47d48c34cd6a8623ac38daa933f5 (diff) |
ImportResult -> TensorFlowModel
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 18 |
1 files changed, 9 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index b905b282973..d05027dda39 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -3,7 +3,7 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -28,7 +28,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter(); /** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */ - private final Map<String, ImportResult> importedModels = new HashMap<>(); + private final Map<String, TensorFlowModel> importedModels = new HashMap<>(); @Override public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) { @@ -49,11 +49,11 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil "the tensorflow model directory under [application]/models"); String modelPath = asString(feature.getArguments().expressions().get(0)); - ImportResult result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath)); + TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath)); // Find the specified expression - ImportResult.Signature signature = chooseSignature(result, - optionalArgument(1, feature.getArguments())); + TensorFlowModel.Signature signature = chooseSignature(result, + optionalArgument(1, feature.getArguments())); RankingExpression expression = chooseOutput(signature, optionalArgument(2, feature.getArguments())); @@ -71,7 +71,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil * Returns the specified, existing signature, or the only signature if none is specified. * Throws IllegalArgumentException in all other cases. */ - private ImportResult.Signature chooseSignature(ImportResult importResult, Optional<String> signatureName) { + private TensorFlowModel.Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) { if ( ! signatureName.isPresent()) { if (importResult.signatures().size() == 0) throw new IllegalArgumentException("No signatures are available"); @@ -83,7 +83,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return importResult.signatures().values().stream().findFirst().get(); } else { - ImportResult.Signature signature = importResult.signatures().get(signatureName.get()); + TensorFlowModel.Signature signature = importResult.signatures().get(signatureName.get()); if (signature == null) throw new IllegalArgumentException("Model does not have the specified signature '" + signatureName.get() + "'"); @@ -95,7 +95,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil * Returns the specified, existing output expression, or the only output expression if no output name is specified. * Throws IllegalArgumentException in all other cases. */ - private RankingExpression chooseOutput(ImportResult.Signature signature, Optional<String> outputName) { + private RankingExpression chooseOutput(TensorFlowModel.Signature signature, Optional<String> outputName) { if ( ! outputName.isPresent()) { if (signature.outputs().size() == 0) throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature)); @@ -120,7 +120,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private String skippedOutputsDescription(ImportResult.Signature signature) { + private String skippedOutputsDescription(TensorFlowModel.Signature signature) { if (signature.skippedOutputs().isEmpty()) return ""; StringBuilder b = new StringBuilder(": "); signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v)); |