aboutsummaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
diff options
context:
space:
mode:
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java260
1 files changed, 130 insertions, 130 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index c3d6f457ce8..9f649bc820a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -386,138 +386,138 @@ public class ConvertedModel {
*/
private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model,
RankProfile profile, QueryProfileRegistry queryProfiles) {
- MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
-
- // Add any missing inputs for type resolution
- Set<String> functionNames = new HashSet<>();
- addFunctionNamesIn(expression.getRoot(), functionNames, model);
- for (String functionName : functionNames) {
- Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
- if (requiredType.isPresent()) {
- Reference ref = Reference.fromIdentifier(functionName);
- if (typeContext.getType(ref).equals(TensorType.empty)) {
- typeContext.setType(ref, requiredType.get());
- }
- }
- }
- typeContext.forgetResolvedTypes();
-
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated functions for inputs to reduce
- for (String functionName : functionNames) {
- if ( ! model.functions().containsKey(functionName)) continue;
-
- RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
- if (rankingExpressionFunction == null) {
- throw new IllegalArgumentException("Model refers to generated function '" + functionName +
- "but this function is not present in " + profile);
- }
- RankingExpression functionExpression = rankingExpressionFunction.function().getBody();
- functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
- MapEvaluationTypeContext typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- // Modify any renames in expression to disregard batch dimension
- else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
- TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function());
- TensorType childType = childFunction.type(typeContext);
- Rename rename = (Rename) tensorFunction;
- List<String> from = new ArrayList<>();
- List<String> to = new ArrayList<>();
- for (TensorType.Dimension dimension : childType.dimensions()) {
- int i = rename.fromDimensions().indexOf(dimension.name());
- if (i < 0) {
- throw new IllegalArgumentException("Rename does not contain dimension '" +
- dimension + "' in child expression type: " + childType);
- }
- from.add((String)rename.fromDimensions().get(i));
- to.add((String)rename.toDimensions().get(i));
- }
- return new TensorFunctionNode(new Rename<>(childFunction, from, to));
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- context.forgetResolvedTypes(); // We changed types
- }
- }
- return new TensorFunctionNode(result);
+// MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
+//
+// // Add any missing inputs for type resolution
+// Set<String> functionNames = new HashSet<>();
+// addFunctionNamesIn(expression.getRoot(), functionNames, model);
+// for (String functionName : functionNames) {
+// Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
+// if (requiredType.isPresent()) {
+// Reference ref = Reference.fromIdentifier(functionName);
+// if (typeContext.getType(ref).equals(TensorType.empty)) {
+// typeContext.setType(ref, requiredType.get());
+// }
+// }
+// }
+// typeContext.forgetResolvedTypes();
+//
+// TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+//
+// // Check generated functions for inputs to reduce
+// for (String functionName : functionNames) {
+// if ( ! model.functions().containsKey(functionName)) continue;
+//
+// RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
+// if (rankingExpressionFunction == null) {
+// throw new IllegalArgumentException("Model refers to generated function '" + functionName +
+// "but this function is not present in " + profile);
+// }
+// RankingExpression functionExpression = rankingExpressionFunction.function().getBody();
+// functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext));
+// }
+//
+// // Check expression for inputs to reduce
+// ExpressionNode root = expression.getRoot();
+// root = reduceBatchDimensionsAtInput(root, model, typeContext);
+// TensorType typeAfterReducing = root.type(typeContext);
+// root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+// expression.setRoot(root);
}
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- */
- // TODO: determine when this is not necessary!
- private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) return node;
-
- TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
+// private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
+// MapEvaluationTypeContext typeContext) {
+// if (node instanceof TensorFunctionNode) {
+// TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+// if (tensorFunction instanceof Rename) {
+// List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+// if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+// ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
+// return reduceBatchDimensionExpression(tensorFunction, typeContext);
+// }
+// }
+// // Modify any renames in expression to disregard batch dimension
+// else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
+// TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function());
+// TensorType childType = childFunction.type(typeContext);
+// Rename rename = (Rename) tensorFunction;
+// List<String> from = new ArrayList<>();
+// List<String> to = new ArrayList<>();
+// for (TensorType.Dimension dimension : childType.dimensions()) {
+// int i = rename.fromDimensions().indexOf(dimension.name());
+// if (i < 0) {
+// throw new IllegalArgumentException("Rename does not contain dimension '" +
+// dimension + "' in child expression type: " + childType);
+// }
+// from.add((String)rename.fromDimensions().get(i));
+// to.add((String)rename.toDimensions().get(i));
+// }
+// return new TensorFunctionNode(new Rename<>(childFunction, from, to));
+// }
+// }
+// }
+// if (node instanceof ReferenceNode) {
+// ReferenceNode referenceNode = (ReferenceNode) node;
+// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
+// return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext);
+// }
+// }
+// if (node instanceof CompositeNode) {
+// List<ExpressionNode> children = ((CompositeNode)node).children();
+// List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+// for (ExpressionNode child : children) {
+// transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+// }
+// return ((CompositeNode)node).setChildren(transformedChildren);
+// }
+// return node;
+// }
+//
+// private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) {
+// TensorFunction result = function;
+// TensorType type = function.type(context);
+// if (type.dimensions().size() > 1) {
+// List<String> reduceDimensions = new ArrayList<>();
+// for (TensorType.Dimension dimension : type.dimensions()) {
+// if (dimension.size().orElse(-1L) == 1) {
+// reduceDimensions.add(dimension.name());
+// }
+// }
+// if (reduceDimensions.size() > 0) {
+// result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+// context.forgetResolvedTypes(); // We changed types
+// }
+// }
+// return new TensorFunctionNode(result);
+// }
+//
+// /**
+// * If batch dimensions have been reduced away above, bring them back here
+// * for any following computation of the tensor.
+// */
+// // TODO: determine when this is not necessary!
+// private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+// if (after.equals(before)) return node;
+//
+// TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
+// for (TensorType.Dimension dimension : before.dimensions()) {
+// if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+// typeBuilder.indexed(dimension.name(), 1);
+// }
+// }
+// TensorType expandDimensionsType = typeBuilder.build();
+// if (expandDimensionsType.dimensions().size() > 0) {
+// ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+// Generate generatedFunction = new Generate(expandDimensionsType,
+// new GeneratorLambdaFunctionNode(expandDimensionsType,
+// generatedExpression)
+// .asLongListToDoubleOperator());
+// Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply());
+// return new TensorFunctionNode(expand);
+// }
+// return node;
+// }
/**
* If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions.