diff options
Diffstat (limited to 'config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 260 |
1 files changed, 130 insertions, 130 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index c3d6f457ce8..9f649bc820a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -386,138 +386,138 @@ public class ConvertedModel { */ private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { - MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); - - // Add any missing inputs for type resolution - Set<String> functionNames = new HashSet<>(); - addFunctionNamesIn(expression.getRoot(), functionNames, model); - for (String functionName : functionNames) { - Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); - if (requiredType.isPresent()) { - Reference ref = Reference.fromIdentifier(functionName); - if (typeContext.getType(ref).equals(TensorType.empty)) { - typeContext.setType(ref, requiredType.get()); - } - } - } - typeContext.forgetResolvedTypes(); - - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated functions for inputs to reduce - for (String functionName : functionNames) { - if ( ! model.functions().containsKey(functionName)) continue; - - RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); - if (rankingExpressionFunction == null) { - throw new IllegalArgumentException("Model refers to generated function '" + functionName + - "but this function is not present in " + profile); - } - RankingExpression functionExpression = rankingExpressionFunction.function().getBody(); - functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, - MapEvaluationTypeContext typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - // Modify any renames in expression to disregard batch dimension - else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { - TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function()); - TensorType childType = childFunction.type(typeContext); - Rename rename = (Rename) tensorFunction; - List<String> from = new ArrayList<>(); - List<String> to = new ArrayList<>(); - for (TensorType.Dimension dimension : childType.dimensions()) { - int i = rename.fromDimensions().indexOf(dimension.name()); - if (i < 0) { - throw new IllegalArgumentException("Rename does not contain dimension '" + - dimension + "' in child expression type: " + childType); - } - from.add((String)rename.fromDimensions().get(i)); - to.add((String)rename.toDimensions().get(i)); - } - return new TensorFunctionNode(new Rename<>(childFunction, from, to)); - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - context.forgetResolvedTypes(); // We changed types - } - } - return new TensorFunctionNode(result); +// MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); +// +// // Add any missing inputs for type resolution +// Set<String> functionNames = new HashSet<>(); +// addFunctionNamesIn(expression.getRoot(), functionNames, model); +// for (String functionName : functionNames) { +// Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); +// if (requiredType.isPresent()) { +// Reference ref = Reference.fromIdentifier(functionName); +// if (typeContext.getType(ref).equals(TensorType.empty)) { +// typeContext.setType(ref, requiredType.get()); +// } +// } +// } +// typeContext.forgetResolvedTypes(); +// +// TensorType typeBeforeReducing = expression.getRoot().type(typeContext); +// +// // Check generated functions for inputs to reduce +// for (String functionName : functionNames) { +// if ( ! model.functions().containsKey(functionName)) continue; +// +// RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); +// if (rankingExpressionFunction == null) { +// throw new IllegalArgumentException("Model refers to generated function '" + functionName + +// "but this function is not present in " + profile); +// } +// RankingExpression functionExpression = rankingExpressionFunction.function().getBody(); +// functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext)); +// } +// +// // Check expression for inputs to reduce +// ExpressionNode root = expression.getRoot(); +// root = reduceBatchDimensionsAtInput(root, model, typeContext); +// TensorType typeAfterReducing = root.type(typeContext); +// root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); +// expression.setRoot(root); } - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - */ - // TODO: determine when this is not necessary! - private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) return node; - - TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } +// private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, +// MapEvaluationTypeContext typeContext) { +// if (node instanceof TensorFunctionNode) { +// TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); +// if (tensorFunction instanceof Rename) { +// List<ExpressionNode> children = ((TensorFunctionNode)node).children(); +// if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { +// ReferenceNode referenceNode = (ReferenceNode) children.get(0); +// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { +// return reduceBatchDimensionExpression(tensorFunction, typeContext); +// } +// } +// // Modify any renames in expression to disregard batch dimension +// else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { +// TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function()); +// TensorType childType = childFunction.type(typeContext); +// Rename rename = (Rename) tensorFunction; +// List<String> from = new ArrayList<>(); +// List<String> to = new ArrayList<>(); +// for (TensorType.Dimension dimension : childType.dimensions()) { +// int i = rename.fromDimensions().indexOf(dimension.name()); +// if (i < 0) { +// throw new IllegalArgumentException("Rename does not contain dimension '" + +// dimension + "' in child expression type: " + childType); +// } +// from.add((String)rename.fromDimensions().get(i)); +// to.add((String)rename.toDimensions().get(i)); +// } +// return new TensorFunctionNode(new Rename<>(childFunction, from, to)); +// } +// } +// } +// if (node instanceof ReferenceNode) { +// ReferenceNode referenceNode = (ReferenceNode) node; +// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { +// return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext); +// } +// } +// if (node instanceof CompositeNode) { +// List<ExpressionNode> children = ((CompositeNode)node).children(); +// List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); +// for (ExpressionNode child : children) { +// transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); +// } +// return ((CompositeNode)node).setChildren(transformedChildren); +// } +// return node; +// } +// +// private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) { +// TensorFunction result = function; +// TensorType type = function.type(context); +// if (type.dimensions().size() > 1) { +// List<String> reduceDimensions = new ArrayList<>(); +// for (TensorType.Dimension dimension : type.dimensions()) { +// if (dimension.size().orElse(-1L) == 1) { +// reduceDimensions.add(dimension.name()); +// } +// } +// if (reduceDimensions.size() > 0) { +// result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); +// context.forgetResolvedTypes(); // We changed types +// } +// } +// return new TensorFunctionNode(result); +// } +// +// /** +// * If batch dimensions have been reduced away above, bring them back here +// * for any following computation of the tensor. +// */ +// // TODO: determine when this is not necessary! +// private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { +// if (after.equals(before)) return node; +// +// TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); +// for (TensorType.Dimension dimension : before.dimensions()) { +// if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { +// typeBuilder.indexed(dimension.name(), 1); +// } +// } +// TensorType expandDimensionsType = typeBuilder.build(); +// if (expandDimensionsType.dimensions().size() > 0) { +// ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); +// Generate generatedFunction = new Generate(expandDimensionsType, +// new GeneratorLambdaFunctionNode(expandDimensionsType, +// generatedExpression) +// .asLongListToDoubleOperator()); +// Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply()); +// return new TensorFunctionNode(expand); +// } +// return node; +// } /** * If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions. |