summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 12:50:08 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-10 12:50:08 +0100
commitd479ea2ed063832297286e60e9ffb2b7f248be59 (patch)
treee1d27ccc652d2be336159fafb37ea94d14cb1d2e /config-model
parent9ef5fd6f9edf47d48c34cd6a8623ac38daa933f5 (diff)
ImportResult -> TensorFlowModel
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java18
1 files changed, 9 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index b905b282973..d05027dda39 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -3,7 +3,7 @@ package com.yahoo.searchdefinition.expressiontransforms;
import com.google.common.base.Joiner;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -28,7 +28,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
/** A cache of imported models indexed by model path. This avoids importing the same model multiple times. */
- private final Map<String, ImportResult> importedModels = new HashMap<>();
+ private final Map<String, TensorFlowModel> importedModels = new HashMap<>();
@Override
public ExpressionNode transform(ExpressionNode node, RankProfileTransformContext context) {
@@ -49,11 +49,11 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
"the tensorflow model directory under [application]/models");
String modelPath = asString(feature.getArguments().expressions().get(0));
- ImportResult result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath));
+ TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath));
// Find the specified expression
- ImportResult.Signature signature = chooseSignature(result,
- optionalArgument(1, feature.getArguments()));
+ TensorFlowModel.Signature signature = chooseSignature(result,
+ optionalArgument(1, feature.getArguments()));
RankingExpression expression = chooseOutput(signature,
optionalArgument(2, feature.getArguments()));
@@ -71,7 +71,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
* Returns the specified, existing signature, or the only signature if none is specified.
* Throws IllegalArgumentException in all other cases.
*/
- private ImportResult.Signature chooseSignature(ImportResult importResult, Optional<String> signatureName) {
+ private TensorFlowModel.Signature chooseSignature(TensorFlowModel importResult, Optional<String> signatureName) {
if ( ! signatureName.isPresent()) {
if (importResult.signatures().size() == 0)
throw new IllegalArgumentException("No signatures are available");
@@ -83,7 +83,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
return importResult.signatures().values().stream().findFirst().get();
}
else {
- ImportResult.Signature signature = importResult.signatures().get(signatureName.get());
+ TensorFlowModel.Signature signature = importResult.signatures().get(signatureName.get());
if (signature == null)
throw new IllegalArgumentException("Model does not have the specified signature '" +
signatureName.get() + "'");
@@ -95,7 +95,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
* Returns the specified, existing output expression, or the only output expression if no output name is specified.
* Throws IllegalArgumentException in all other cases.
*/
- private RankingExpression chooseOutput(ImportResult.Signature signature, Optional<String> outputName) {
+ private RankingExpression chooseOutput(TensorFlowModel.Signature signature, Optional<String> outputName) {
if ( ! outputName.isPresent()) {
if (signature.outputs().size() == 0)
throw new IllegalArgumentException("No outputs are available" + skippedOutputsDescription(signature));
@@ -120,7 +120,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
- private String skippedOutputsDescription(ImportResult.Signature signature) {
+ private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
if (signature.skippedOutputs().isEmpty()) return "";
StringBuilder b = new StringBuilder(": ");
signature.skippedOutputs().forEach((k, v) -> b.append("Skipping output '").append(k).append("': ").append(v));