summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java81
1 files changed, 68 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index e81d22cefe9..2c177633590 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -9,9 +9,11 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.FeatureNames;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankingConstant;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -51,6 +53,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.logging.Logger;
+import java.util.stream.Collectors;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
@@ -85,10 +88,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
try {
ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(),
feature.getArguments());
- if (store.hasStoredModel())
- return transformFromStoredModel(store, context.rankProfile());
- else // not converted yet - access TensorFlow model files
+ if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files
return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles());
+ else
+ return transformFromStoredModel(store, context.rankProfile());
}
catch (IllegalArgumentException | UncheckedIOException e) {
throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e);
@@ -101,16 +104,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(),
k -> tensorFlowImporter.importModel(store.tensorFlowModelDir()));
+ // Add constants
+ Set<String> constantsReplacedByMacros = new HashSet<>();
+ model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
+ model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles,
+ constantsReplacedByMacros, k, v));
+
// Find the specified expression
Signature signature = chooseSignature(model, store.arguments().signature());
String output = chooseOutput(signature, store.arguments().output());
RankingExpression expression = model.expressions().get(output);
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles);
store.writeConverted(expression);
- model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
- model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v));
- model.macros().forEach((k, v) -> transformMacro(store, profile, k, v));
+ model.macros().forEach((k, v) -> transformMacro(store, profile, constantsReplacedByMacros, k, v));
return expression.getRoot();
}
@@ -189,17 +197,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
profile.addConstant(constantName, asValue(constantValue));
}
- private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) {
- Path constantPath = store.writeLargeConstant(constantName, constantValue);
+ private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles,
+ Set<String> constantsReplacedByMacros,
+ String constantName, Tensor constantValue) {
+ RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName);
+ if (macroOverridingConstant != null) {
+ TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles));
+ if ( ! macroType.equals(constantValue.type()))
+ throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " +
+ "The required type of this is " + constantValue.type() +
+ ", but the macro returns " + macroType);
+ constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later
+ }
+ else {
+
+ Path constantPath = store.writeLargeConstant(constantName, constantValue);
- if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) {
- log.info("Adding constant '" + constantName + "' of type " + constantValue.type());
- profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
- constantPath.toString()));
+ if (!profile.getSearch().getRankingConstants().containsKey(constantName)) {
+ log.info("Adding constant '" + constantName + "' of type " + constantValue.type());
+ profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(),
+ constantPath.toString()));
+ }
}
}
- private void transformMacro(ModelStore store, RankProfile profile, String macroName, RankingExpression expression) {
+ private void transformMacro(ModelStore store, RankProfile profile,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
+
+ expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
store.writeMacro(macroName, expression);
addMacroToProfile(profile, macroName, expression);
}
@@ -312,6 +338,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
return node;
}
+ /**
+ * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
+ * This method does that for the given expression and returns the result.
+ */
+ private RankingExpression replaceConstantsByMacros(RankingExpression expression,
+ Set<String> constantsReplacedByMacros) {
+ if (constantsReplacedByMacros.isEmpty()) return expression;
+ return new RankingExpression(expression.getName(),
+ replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros));
+ }
+
+ private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) {
+ if (node instanceof ReferenceNode) {
+ Reference reference = ((ReferenceNode)node).reference();
+ if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) {
+ String argument = reference.simpleArgument().get();
+ if (constantsReplacedByMacros.contains(argument))
+ return new ReferenceNode(argument);
+ }
+ }
+ if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above
+ CompositeNode composite = (CompositeNode)node;
+ return composite.setChildren(composite.children().stream()
+ .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros))
+ .collect(Collectors.toList()));
+ }
+ return node;
+ }
+
private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) {
return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions));
}