summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-02 10:01:36 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-02 10:01:36 +0200
commit8909fd9728591d8e00e7babc601c600b26d5acf4 (patch)
tree53231c4abb7857b8345c5125bb8539519f0d776e
parent55236fc050998712ad6dc136e2b5e45c9d41538f (diff)
Be truthful about generated functions
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java18
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java57
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java3
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java16
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java4
6 files changed, 65 insertions, 35 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index 10de10bcdfe..adf1770284b 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -90,7 +90,7 @@ public class ModelEvaluationTest {
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
cluster.getConfig(b);
RankProfilesConfig config = new RankProfilesConfig(b);
- System.out.println(config);
+ // System.out.println(config);
RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder();
cluster.getConfig(cb);
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 5c8a53c9e83..e001204f650 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -61,12 +61,22 @@ public class Model {
try {
LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this);
contextBuilder.put(function.getValue().getName(), context);
+ if ( ! function.getValue().returnType().isPresent()) {
+ functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
+ }
+
for (String argument : context.arguments()) {
- if (function.getValue().argumentTypes().get(argument) == null)
- functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ if (function.getValue().getName().startsWith(IntermediateOperation.FUNCTION_PREFIX)) {
+ // Internal (generated) functions do not have type info - add arguments
+ if (!function.getValue().arguments().contains(argument))
+ functions.put(function.getKey(), function.getValue().withArgument(argument));
+ }
+ else {
+ // External functions have type info (when not scalar) - add argument types
+ if (function.getValue().argumentTypes().get(argument) == null)
+ functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ }
}
- if ( ! function.getValue().returnType().isPresent())
- functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index cf7d208ed25..db892dce593 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -10,6 +10,7 @@ import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
/**
* Tests instantiating models from rank-profiles configs.
@@ -32,11 +33,10 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, xgboost.functions().size());
- tester.assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
- ExpressionFunction function = xgboost.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get());
+ ExpressionFunction function = tester.assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
+ assertEquals("tensor()", function.returnType().get().toString());
assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments()));
function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
@@ -52,14 +52,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, onnxMnistSoftmax.functions().size());
- tester.assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
// Evaluator
assertEquals("tensor(d1[10],d2[784])",
@@ -74,14 +74,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(1, tfMnistSoftmax.functions().size());
- tester.assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
- ExpressionFunction function = tfMnistSoftmax.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("Placeholder", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("Placeholder").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
@@ -92,20 +92,25 @@ public class MlModelsImportingTest {
{
Model tfMnist = tester.models().get("mnist_saved");
// Generated function
- tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
- "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
- tfMnist);
+ ExpressionFunction generatedFunction =
+ tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
+ "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ assertEquals("tensor(d3[300])", generatedFunction.returnType().get().toString());
+ assertEquals(1, generatedFunction.arguments().size());
+ assertEquals("input", generatedFunction.arguments().get(0));
+ assertNull(null, generatedFunction.argumentTypes().get("input")); // TODO: Not available until we resolve all argument types
// Function
assertEquals(1, tfMnist.functions().size());
- tester.assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
- ExpressionFunction function = tfMnist.functions().get(0);
- assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ ExpressionFunction function =
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ assertEquals("tensor(d1[10])", function.returnType().get().toString());
assertEquals(1, function.arguments().size());
assertEquals("input", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input"));
+ assertEquals("tensor(d0[],d1[784])", function.argumentTypes().get("input").toString());
// Evaluator
FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index 50dd1d1d05f..bacdb52a201 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -49,12 +49,13 @@ public class ModelTester {
.importFrom(config, constantsConfig);
}
- public void assertFunction(String name, String expression, Model model) {
+ public ExpressionFunction assertFunction(String name, String expression, Model model) {
assertNotNull("Model is present in config", model);
ExpressionFunction function = model.function(name);
assertNotNull("Function '" + name + "' is in " + model, function);
assertEquals(name, function.getName());
assertEquals(expression, function.getBody().getRoot().toString());
+ return function;
}
public void assertBoundFunction(String name, String expression, Model model) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 787b857839d..674571ff73e 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -98,12 +98,24 @@ public class ExpressionFunction {
return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType));
}
- /** Returns a copy of this with the given argument and argument type added */
- public ExpressionFunction withArgument(String argument, TensorType type) {
+ /** Returns a copy of this with the given argument added (if not already present) */
+ public ExpressionFunction withArgument(String argument) {
+ if (arguments.contains(argument)) return this;
+
List<String> arguments = new ArrayList<>(this.arguments);
arguments.add(argument);
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
+ }
+
+ /** Returns a copy of this with the given argument (if not present) and argument type added */
+ public ExpressionFunction withArgument(String argument, TensorType type) {
+ List<String> arguments = new ArrayList<>(this.arguments);
+ if ( ! arguments.contains(argument))
+ arguments.add(argument);
+
Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes);
argumentTypes.put(argument, type);
+
return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index 6235756d4e1..481b7f9397a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -196,7 +196,9 @@ public abstract class ModelImporter {
if (operation.rankingExpressionFunction().isPresent()) {
TensorFunction function = operation.rankingExpressionFunction().get();
try {
- model.function(operation.rankingExpressionFunctionName(), new RankingExpression(operation.rankingExpressionFunctionName(), function.toString()));
+ model.function(operation.rankingExpressionFunctionName(),
+ new RankingExpression(operation.rankingExpressionFunctionName(),
+ function.toString()));
}
catch (ParseException e) {
throw new RuntimeException("Tensorflow function " + function +