summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-06-30 11:56:33 +0200
committerLester Solbakken <lesters@oath.com>2020-06-30 11:56:33 +0200
commit4418e0b8fedbf14e426e820a9bf500d8536e891c (patch)
tree925c5397b54fdcf7661afd61a43c339180df7b8b
parentbf1d34d22df581902935bbf5daa52baf6e7b88d5 (diff)
Remove code for batch dimension removal
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java141
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java5
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java16
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java6
4 files changed, 9 insertions, 159 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index e618326eff5..7a47c5bae34 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -261,7 +261,6 @@ public class ConvertedModel {
QueryProfileRegistry queryProfiles,
Map<String, ExpressionFunction> expressions) {
expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions));
- reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
@@ -380,146 +379,6 @@ public class ConvertedModel {
}
/**
- * Check if batch dimensions of inputs can be reduced out. If the input
- * function specifies that a single exemplar should be evaluated, we can
- * reduce the batch dimension out.
- */
- private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model,
- RankProfile profile, QueryProfileRegistry queryProfiles) {
- MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
-
- // Add any missing inputs for type resolution
- Set<String> functionNames = new HashSet<>();
- addFunctionNamesIn(expression.getRoot(), functionNames, model);
- for (String functionName : functionNames) {
- Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
- if (requiredType.isPresent()) {
- Reference ref = Reference.fromIdentifier(functionName);
- if (typeContext.getType(ref).equals(TensorType.empty)) {
- typeContext.setType(ref, requiredType.get());
- }
- }
- }
- typeContext.forgetResolvedTypes();
-
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated functions for inputs to reduce
- for (String functionName : functionNames) {
- if ( ! model.functions().containsKey(functionName)) continue;
-
- RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
- if (rankingExpressionFunction == null) {
- throw new IllegalArgumentException("Model refers to generated function '" + functionName +
- "but this function is not present in " + profile);
- }
- RankingExpression functionExpression = rankingExpressionFunction.function().getBody();
- functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
- MapEvaluationTypeContext typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- // Modify any renames in expression to disregard batch dimension
- else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
- TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function());
- TensorType childType = childFunction.type(typeContext);
- Rename rename = (Rename) tensorFunction;
- List<String> from = new ArrayList<>();
- List<String> to = new ArrayList<>();
- for (TensorType.Dimension dimension : childType.dimensions()) {
- int i = rename.fromDimensions().indexOf(dimension.name());
- if (i < 0) {
- throw new IllegalArgumentException("Rename does not contain dimension '" +
- dimension + "' in child expression type: " + childType);
- }
- from.add((String)rename.fromDimensions().get(i));
- to.add((String)rename.toDimensions().get(i));
- }
- return new TensorFunctionNode(new Rename<>(childFunction, from, to));
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- context.forgetResolvedTypes(); // We changed types
- }
- }
- return new TensorFunctionNode(result);
- }
-
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- */
- // TODO: determine when this is not necessary!
- private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) return node;
-
- TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
-
- /**
* If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions.
* This method does that for the given expression and returns the result.
*/
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 1fe1ebf2bb3..dffdc3b4a34 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -29,7 +29,6 @@ public class RankingExpressionWithOnnxTestCase {
private final static String name = "mnist_softmax";
private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
- private final static String vespaExpressionWithBatchReduce = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))";
@After
public void removeGeneratedModelFiles() {
@@ -97,7 +96,7 @@ public class RankingExpressionWithOnnxTestCase {
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
- search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -115,7 +114,7 @@ public class RankingExpressionWithOnnxTestCase {
"field mytensor type tensor<float>(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
- search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 126a41e14ad..610d92144ad 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -118,7 +118,7 @@ public class RankingExpressionWithTensorFlowTestCase {
"field mytensor type tensor(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
- search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
@@ -136,7 +136,7 @@ public class RankingExpressionWithTensorFlowTestCase {
"field mytensor type tensor(d0[1],d1[784]) { indexing: attribute }",
"Placeholder",
application);
- search.assertFirstPhaseExpression(vespaExpressionWithBatchReduce, "my_profile");
+ search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@Test
@@ -310,18 +310,10 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
- public void testTensorFlowReduceBatchDimension() {
- final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(" + name + "_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
- RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
- "tensorflow('mnist_softmax/saved')");
- search.assertFirstPhaseExpression(expression, "my_profile");
- }
-
- @Test
public void testFunctionGeneration() {
final String name = "mnist_saved";
final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))";
- final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))";
+ final String functionExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))";
final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
@@ -351,7 +343,7 @@ public class RankingExpressionWithTensorFlowTestCase {
" }";
final String expression = "join(reduce(join(join(join(reduce(constant(" + name + "_dnn_hidden2_Const), sum, d2), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden2_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(" + name + "_dnn_outputs_bias_read), f(a,b)(a + b))";
- final String functionExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))";
+ final String functionExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(" + name + "_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(" + name + "_dnn_hidden1_bias_read), f(a,b)(a + b))";
final String functionExpression2 = "join(reduce(join(join(join(0.009999999776482582, imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(a * b)), imported_ml_function_" + name + "_dnn_hidden1_add, f(a,b)(max(a,b))), constant(" + name + "_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(" + name + "_dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWithUncompiled(rankProfiles, new StoringApplicationPackage(applicationDir));
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
index 5b38e09537d..cad0ad7bae0 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
@@ -64,12 +64,12 @@ public class MlModelsTest {
private final String testProfile =
"rankingExpression(input).rankingScript: attribute(argument)\n" +
"rankingExpression(input).type: tensor<float>(d0[1],d1[784])\n" +
- "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(reduce(rename(rankingExpression(input), (d0, d1), (d0, d4)), sum, d0), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
+ "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(rankingExpression(input), (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
"rankingExpression(mnist_tensorflow).rankingScript: join(reduce(join(map(join(reduce(join(join(join(0.009999999776482582, rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.0507009873554805 * if (a >= 0, a, 1.6732632423543772 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" +
"rankingExpression(Placeholder).rankingScript: attribute(argument)\n" +
"rankingExpression(Placeholder).type: tensor<float>(d0[1],d1[784])\n" +
- "rankingExpression(mnist_softmax_tensorflow).rankingScript: join(join(reduce(join(reduce(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))\n" +
- "rankingExpression(mnist_softmax_onnx).rankingScript: join(join(reduce(join(reduce(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), sum, d0), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b)), tensor<float>(d0[1])(1.0), f(a,b)(a * b))\n" +
+ "rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" +
+ "rankingExpression(mnist_softmax_onnx).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))\n" +
"rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (!(f56 >= -0.242398), 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (!(f60 >= -0.482947), if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" +
"rankingExpression(my_lightgbm).rankingScript: if (!(numerical_2 >= 0.46643291586559305), 2.1594397038037663, if (categorical_2 in [\"k\", \"l\", \"m\"], 2.235297305276056, 2.1792953471546546)) + if (categorical_1 in [\"d\", \"e\"], 0.03070842919354316, if (!(numerical_1 >= 0.5102250691730842), -0.04439151147520909, 0.005117411709368601)) + if (!(numerical_2 >= 0.668665477622446), if (!(numerical_2 >= 0.008118820676863816), -0.15361238490967524, -0.01192330846157292), 0.03499044894987518) + if (!(numerical_1 >= 0.5201391072644542), -0.02141000620783247, if (categorical_1 in [\"a\", \"b\"], -0.004121485787596721, 0.04534090904886873)) + if (categorical_2 in [\"k\", \"l\", \"m\"], if (!(numerical_2 >= 0.27283279016959255), -0.01924803254356527, 0.03643772842347651), -0.02701711918923075)\n" +
"vespa.rank.firstphase: rankingExpression(firstphase)\n" +