diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-30 11:56:33 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-30 11:56:33 +0200 |
commit | 4418e0b8fedbf14e426e820a9bf500d8536e891c (patch) | |
tree | 925c5397b54fdcf7661afd61a43c339180df7b8b | |
parent | bf1d34d22df581902935bbf5daa52baf6e7b88d5 (diff) |
Remove code for batch dimension removal
4 files changed, 9 insertions, 159 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index e618326eff5..7a47c5bae34 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -261,7 +261,6 @@ public class ConvertedModel { QueryProfileRegistry queryProfiles, Map<String, ExpressionFunction> expressions) { expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions)); - reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } @@ -380,146 +379,6 @@ public class ConvertedModel { } /** - * Check if batch dimensions of inputs can be reduced out. If the input - * function specifies that a single exemplar should be evaluated, we can - * reduce the batch dimension out. - */ - private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model, - RankProfile profile, QueryProfileRegistry queryProfiles) { - MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); - - // Add any missing inputs for type resolution - Set<String> functionNames = new HashSet<>(); - addFunctionNamesIn(expression.getRoot(), functionNames, model); - for (String functionName : functionNames) { - Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); - if (requiredType.isPresent()) { - Reference ref = Reference.fromIdentifier(functionName); - if (typeContext.getType(ref).equals(TensorType.empty)) { - typeContext.setType(ref, requiredType.get()); - } - } - } - typeContext.forgetResolvedTypes(); - - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated functions for inputs to reduce - for (String functionName : functionNames) { - if ( ! model.functions().containsKey(functionName)) continue; - - RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); - if (rankingExpressionFunction == null) { - throw new IllegalArgumentException("Model refers to generated function '" + functionName + - "but this function is not present in " + profile); - } - RankingExpression functionExpression = rankingExpressionFunction.function().getBody(); - functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, - MapEvaluationTypeContext typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List<ExpressionNode> children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - // Modify any renames in expression to disregard batch dimension - else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { - TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function()); - TensorType childType = childFunction.type(typeContext); - Rename rename = (Rename) tensorFunction; - List<String> from = new ArrayList<>(); - List<String> to = new ArrayList<>(); - for (TensorType.Dimension dimension : childType.dimensions()) { - int i = rename.fromDimensions().indexOf(dimension.name()); - if (i < 0) { - throw new IllegalArgumentException("Rename does not contain dimension '" + - dimension + "' in child expression type: " + childType); - } - from.add((String)rename.fromDimensions().get(i)); - to.add((String)rename.toDimensions().get(i)); - } - return new TensorFunctionNode(new Rename<>(childFunction, from, to)); - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List<ExpressionNode> children = ((CompositeNode)node).children(); - List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - context.forgetResolvedTypes(); // We changed types - } - } - return new TensorFunctionNode(result); - } - - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - */ - // TODO: determine when this is not necessary! - private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) return node; - - TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } - - /** * If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions. * This method does that for the given expression and returns the result. */ diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 1fe1ebf2bb3..dffdc3b4a34 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -29,7 +29,6 @@ public class RankingExpressionWithOnnxTestCase { private final static String name = "mnist_softmax"; private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; - private final static String vespaExpressionWithBatchReduce = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))"; @After public void removeGeneratedModelFiles() { @@ -97,7 +96,7 @@ public class RankingExpressionWithOnnxTestCase { "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); - search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -115,7 +114,7 @@ public class RankingExpressionWithOnnxTestCase { "field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); - search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 126a41e14ad..610d92144ad 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -118,7 +118,7 @@ public class RankingExpressionWithTensorFlowTestCase { "field mytensor type tensor(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); - search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test @@ -136,7 +136,7 @@ public class RankingExpressionWithTensorFlowTestCase { "field mytensor type tensor(d0[1],d1[784]) { indexing: attribute }", "Placeholder", application); - search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile"); + search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @Test @@ -310,18 +310,10 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test - public void testTensorFlowReduceBatchDimension() { - final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; - RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist_softmax/saved')"); - search.assertFirstPhaseExpression(expression, "my_profile"); - } - - @Test public void testFunctionGeneration() { final String name = "mnist_saved"; final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))"; - final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; + final String functionExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", @@ -351,7 +343,7 @@ public class RankingExpressionWithTensorFlowTestCase { " }"; final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))"; - final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; + final String functionExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))"; final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir)); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java index 5b38e09537d..cad0ad7bae0 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java @@ -64,12 +64,12 @@ public class MlModelsTest { private final String testProfile = "rankingExpression(input).rankingScript: attribute(argument)\n" + "rankingExpression(input).type: tensor<float>(d0[1],d1[784])\n" + - "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(reduce(rename(rankingExpression(input), (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" + + "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(rankingExpression(input), (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" + "rankingExpression(mnist_tensorflow).rankingScript: join(reduce(join(map(join(reduce(join(join(join(0.009999999776482582, rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.0507009873554805 * if (a >= 0, a, 1.6732632423543772 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" + "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" + "rankingExpression(Placeholder).type: tensor<float>(d0[1],d1[784])\n" + - "rankingExpression(mnist_softmax_tensorflow).rankingScript: join(join(reduce(join(reduce(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))\n" + - "rankingExpression(mnist_softmax_onnx).rankingScript: join(join(reduce(join(reduce(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))\n" + + "rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" + + "rankingExpression(mnist_softmax_onnx).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))\n" + "rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (!(f56 >= -0.242398), 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (!(f60 >= -0.482947), if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" + "rankingExpression(my_lightgbm).rankingScript: if (!(numerical_2 >= 0.46643291586559305), 2.1594397038037663, if (categorical_2 in [\"k\", \"l\", \"m\"], 2.235297305276056, 2.1792953471546546)) + if (categorical_1 in [\"d\", \"e\"], 0.03070842919354316, if (!(numerical_1 >= 0.5102250691730842), -0.04439151147520909, 0.005117411709368601)) + if (!(numerical_2 >= 0.668665477622446), if (!(numerical_2 >= 0.008118820676863816), -0.15361238490967524, -0.01192330846157292), 0.03499044894987518) + if (!(numerical_1 >= 0.5201391072644542), -0.02141000620783247, if (categorical_1 in [\"a\", \"b\"], -0.004121485787596721, 0.04534090904886873)) + if (categorical_2 in [\"k\", \"l\", \"m\"], if (!(numerical_2 >= 0.27283279016959255), -0.01924803254356527, 0.03643772842347651), -0.02701711918923075)\n" + "vespa.rank.firstphase: rankingExpression(firstphase)\n" + |