diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 26 |
1 files changed, 23 insertions, 3 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index d05027dda39..a36384ce6f2 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -1,17 +1,22 @@ package com.yahoo.searchdefinition.expressiontransforms; import com.google.common.base.Joiner; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowModel; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import java.io.File; import java.util.HashMap; import java.util.Map; import java.util.Optional; @@ -48,7 +53,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " + "the tensorflow model directory under [application]/models"); - String modelPath = asString(feature.getArguments().expressions().get(0)); + String modelPath = ApplicationPackage.MODELS_DIR + "/" + asString(feature.getArguments().expressions().get(0)); TensorFlowModel result = importedModels.computeIfAbsent(modelPath, k -> tensorFlowImporter.importModel(modelPath)); // Find the specified expression @@ -58,7 +63,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil optionalArgument(2, feature.getArguments())); // Add all constants (after finding outputs to fail faster when the output is not found) - result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); + if (1==1) + result.constants().forEach((k, v) -> context.rankProfile().addConstantTensor(k, new TensorValue(v))); + else // correct way, disabled for now + result.constants().forEach((k, v) -> transformConstant(modelPath, context.rankProfile(), k, v)); return expression.getRoot(); } @@ -120,6 +128,18 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } + private void transformConstant(String modelPath, RankProfile profile, String constantName, Tensor constantValue) { + File constantFilePath = new File(modelPath, "converted_variables"); + if ( ! constantFilePath.exists() ) { + if ( ! constantFilePath.mkdir() ) + throw new IllegalStateException("Could not create directory " + constantFilePath); + } + + File constantFile = new File(constantFilePath, constantName + ".json"); + // writeAsVespaTensor(constantValue, constantFile); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), constantFilePath.getPath())); + } + private String skippedOutputsDescription(TensorFlowModel.Signature signature) { if (signature.skippedOutputs().isEmpty()) return ""; StringBuilder b = new StringBuilder(": "); |