diff options
Diffstat (limited to 'model-evaluation')
5 files changed, 70 insertions, 39 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index c7d0cbd8f30..d144411127e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -1,7 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -15,7 +17,9 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -110,6 +114,9 @@ public final class LazyArrayContext extends Context implements ContextIndex { @Override public Set<String> names() { return indexedBindings.names(); } + /** Returns the (immutable) subset of names in this which must be bound when invoking */ + public Set<String> arguments() { return indexedBindings.arguments(); } + private Integer requireIndexOf(String name) { Integer index = indexedBindings.indexOf(name); if (index == null) @@ -130,12 +137,18 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The mapping from variable name to index */ private final ImmutableMap<String, Integer> nameToIndex; + /** The names which neeeds to be bound externally when envoking this (i.e not constant or invocation */ + private final ImmutableSet<String> arguments; + /** The current values set, pre-converted to doubles */ private final Value[] values; - private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values) { + private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, + Value[] values, + ImmutableSet<String> arguments) { this.nameToIndex = nameToIndex; this.values = values; + this.arguments = arguments; } /** @@ -149,8 +162,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { Model model) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); - extractBindTargets(expression.getRoot(), functions, bindTargets); + Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation + extractBindTargets(expression.getRoot(), functions, bindTargets, arguments); + this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; Arrays.fill(values, DoubleValue.zero); @@ -178,23 +193,25 @@ public final class LazyArrayContext extends Context implements ContextIndex { private void extractBindTargets(ExpressionNode node, Map<FunctionReference, ExpressionFunction> functions, - Set<String> bindTargets) { + Set<String> bindTargets, + Set<String> arguments) { if (isFunctionReference(node)) { FunctionReference reference = FunctionReference.fromSerial(node.toString()).get(); bindTargets.add(reference.serialForm()); - extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets); + extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments); } else if (isConstant(node)) { bindTargets.add(node.toString()); } else if (node instanceof ReferenceNode) { bindTargets.add(node.toString()); + arguments.add(node.toString()); } else if (node instanceof CompositeNode) { CompositeNode cNode = (CompositeNode)node; for (ExpressionNode child : cNode.children()) - extractBindTargets(child, functions, bindTargets); + extractBindTargets(child, functions, bindTargets, arguments); } } @@ -215,13 +232,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { Value get(int index) { return values[index]; } void set(int index, Value value) { values[index] = value; } Set<String> names() { return nameToIndex.keySet(); } + Set<String> arguments() { return arguments; } Integer indexOf(String name) { return nameToIndex.get(name); } IndexedBindings copy(Context context) { Value[] valueCopy = new Value[values.length]; for (int i = 0; i < values.length; i++) valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i]; - return new IndexedBindings(nameToIndex, valueCopy); + return new IndexedBindings(nameToIndex, valueCopy, arguments); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index ac8f28677a4..fda1ae935ca 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; @@ -38,28 +39,39 @@ public class Model { /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { - this(name, functions, Collections.emptyMap(), Collections.emptyList()); + this(name, + functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), + Collections.emptyMap(), + Collections.emptyList()); } Model(String name, - Collection<ExpressionFunction> functions, + Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants) { this.name = name; - this.functions = ImmutableList.copyOf(functions); + // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); - for (ExpressionFunction function : functions) { + for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - contextBuilder.put(function.getName(), - new LazyArrayContext(function.getBody(), referencedFunctions, constants, this)); + LazyArrayContext context = new LazyArrayContext(function.getValue().getBody(), referencedFunctions, constants, this); + contextBuilder.put(function.getValue().getName(), context); + for (String argument : context.arguments()) { + if (function.getValue().argumentTypes().get(argument) == null) + functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + } + if ( ! function.getValue().returnType().isPresent()) + functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); } } this.contextPrototypes = contextBuilder.build(); + this.functions = ImmutableList.copyOf(functions.values()); + // Optimize functions ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) { ExpressionFunction optimizedFunction = optimize(function.getValue(), diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 648c6d931a9..fb424439592 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -78,7 +78,9 @@ public class RankProfilesConfigImporter { Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); if ( reference.isPresent()) { RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); - ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), expression); + ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), + Collections.emptyList(), + expression); if (reference.get().isFree()) // make available in model under configured name functions.put(reference.get(), function); @@ -120,14 +122,13 @@ public class RankProfilesConfigImporter { constants.addAll(smallConstantsInfo.asConstants()); try { - return new Model(profile.name(), functions.values(), referencedFunctions, constants); + return new Model(profile.name(), functions, referencedFunctions, constants); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e); } } - // TODO: Replace by lookup in map private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) { for (ExpressionFunction function : functions) if (function.getName().equals(name)) diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 287a2387b34..c4b163e89c0 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -29,15 +29,16 @@ public class MlModelsImportingTest { // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that { Model xgboost = tester.models().get("xgboost_2_2"); - tester.assertFunction("xgboost_2_2", - "(optimized sum of condition trees of size 192 bytes)", - xgboost); // Function assertEquals(1, xgboost.functions().size()); + tester.assertFunction("xgboost_2_2", + "(optimized sum of condition trees of size 192 bytes)", + xgboost); ExpressionFunction function = xgboost.functions().get(0); - assertEquals("xgboost_2_2", function.getName()); - // assertEquals("f109, f29, f56, f60", commaSeparated(xgboost.functions().get(0).arguments())); TODO + assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); + assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments())); + function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); // Evaluator FunctionEvaluator evaluator = xgboost.evaluatorOf(); @@ -48,14 +49,12 @@ public class MlModelsImportingTest { { Model onnxMnistSoftmax = tester.models().get("mnist_softmax"); - tester.assertFunction("default.add", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", - onnxMnistSoftmax); - assertEquals("tensor(d1[10],d2[784])", - onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); // Function assertEquals(1, onnxMnistSoftmax.functions().size()); + tester.assertFunction("default.add", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", + onnxMnistSoftmax); ExpressionFunction function = onnxMnistSoftmax.functions().get(0); assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); @@ -63,6 +62,8 @@ public class MlModelsImportingTest { assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); // Evaluator + assertEquals("tensor(d1[10],d2[784])", + onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); @@ -70,17 +71,17 @@ public class MlModelsImportingTest { { Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved"); - tester.assertFunction("serving_default.y", - "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", - tfMnistSoftmax); // Function assertEquals(1, tfMnistSoftmax.functions().size()); + tester.assertFunction("serving_default.y", + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", + tfMnistSoftmax); ExpressionFunction function = tfMnistSoftmax.functions().get(0); assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); - assertEquals("x", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x")); + assertEquals("Placeholder", function.arguments().get(0)); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); // Evaluator FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available @@ -90,10 +91,6 @@ public class MlModelsImportingTest { { Model tfMnist = tester.models().get("mnist_saved"); - tester.assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - // Generated function tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add", "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", @@ -101,11 +98,14 @@ public class MlModelsImportingTest { // Function assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); ExpressionFunction function = tfMnist.functions().get(1); assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); assertEquals(1, function.arguments().size()); - assertEquals("x", function.arguments().get(0)); - assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x")); + assertEquals("input", function.arguments().get(0)); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input")); // Evaluator FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg index 9175b60315b..c25c5ba555b 100644 --- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg +++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg @@ -11,7 +11,7 @@ rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, rankprofile[2].name "mnist_softmax_saved" rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript" rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))" -rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).x.type" +rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).Placeholder.type" rankprofile[2].fef.property[1].value "tensor(d0[],d1[784])" rankprofile[2].fef.property[2].name "rankingExpression(serving_default.y).type" rankprofile[2].fef.property[2].value "tensor(d1[10])" @@ -22,7 +22,7 @@ rankprofile[3].fef.property[1].name "rankingExpression(imported_ml_function_mnis rankprofile[3].fef.property[1].value "tensor(d3[300])" rankprofile[3].fef.property[2].name "rankingExpression(serving_default.y).rankingScript" rankprofile[3].fef.property[2].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" -rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).x.type" +rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).input.type" rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])" rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type" -rankprofile[3].fef.property[4].value "tensor(d1[10])"
\ No newline at end of file +rankprofile[3].fef.property[4].value "tensor(d1[10])" |