summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-10-01 10:42:16 +0200
committerJon Bratseth <bratseth@oath.com>2018-10-01 10:42:16 +0200
commit50bc3b3c198d29374448cc3eac73fbb26e42cab0 (patch)
tree668c2fdcf18b25fda38e1faa10bd479b76e1ecb6
parent0ff988ecf9704faac33f6201cb59349e48846457 (diff)
Fill in missing types
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java4
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java30
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java24
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java7
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java42
-rw-r--r--model-evaluation/src/test/resources/config/models/rank-profiles.cfg6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java9
7 files changed, 78 insertions, 44 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index daae2dbc496..10de10bcdfe 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -90,7 +90,7 @@ public class ModelEvaluationTest {
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
cluster.getConfig(b);
RankProfilesConfig config = new RankProfilesConfig(b);
- // System.out.println(config);
+ System.out.println(config);
RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder();
cluster.getConfig(cb);
@@ -147,7 +147,7 @@ public class ModelEvaluationTest {
"rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
"rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type: tensor(d3[300])\n" +
"rankingExpression(serving_default.y).rankingScript: join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" +
- "rankingExpression(serving_default.y).x.type: tensor(d0[],d1[784])\n" +
+ "rankingExpression(serving_default.y).input.type: tensor(d0[],d1[784])\n" +
"rankingExpression(serving_default.y).type: tensor(d1[10])\n";
private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index c7d0cbd8f30..d144411127e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -1,7 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -15,7 +17,9 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@@ -110,6 +114,9 @@ public final class LazyArrayContext extends Context implements ContextIndex {
@Override
public Set<String> names() { return indexedBindings.names(); }
+ /** Returns the (immutable) subset of names in this which must be bound when invoking */
+ public Set<String> arguments() { return indexedBindings.arguments(); }
+
private Integer requireIndexOf(String name) {
Integer index = indexedBindings.indexOf(name);
if (index == null)
@@ -130,12 +137,18 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The mapping from variable name to index */
private final ImmutableMap<String, Integer> nameToIndex;
+ /** The names which neeeds to be bound externally when envoking this (i.e not constant or invocation */
+ private final ImmutableSet<String> arguments;
+
/** The current values set, pre-converted to doubles */
private final Value[] values;
- private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values) {
+ private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
+ Value[] values,
+ ImmutableSet<String> arguments) {
this.nameToIndex = nameToIndex;
this.values = values;
+ this.arguments = arguments;
}
/**
@@ -149,8 +162,10 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Model model) {
// 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
- extractBindTargets(expression.getRoot(), functions, bindTargets);
+ Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
+ extractBindTargets(expression.getRoot(), functions, bindTargets, arguments);
+ this.arguments = ImmutableSet.copyOf(arguments);
values = new Value[bindTargets.size()];
Arrays.fill(values, DoubleValue.zero);
@@ -178,23 +193,25 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private void extractBindTargets(ExpressionNode node,
Map<FunctionReference, ExpressionFunction> functions,
- Set<String> bindTargets) {
+ Set<String> bindTargets,
+ Set<String> arguments) {
if (isFunctionReference(node)) {
FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
bindTargets.add(reference.serialForm());
- extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets);
+ extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments);
}
else if (isConstant(node)) {
bindTargets.add(node.toString());
}
else if (node instanceof ReferenceNode) {
bindTargets.add(node.toString());
+ arguments.add(node.toString());
}
else if (node instanceof CompositeNode) {
CompositeNode cNode = (CompositeNode)node;
for (ExpressionNode child : cNode.children())
- extractBindTargets(child, functions, bindTargets);
+ extractBindTargets(child, functions, bindTargets, arguments);
}
}
@@ -215,13 +232,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Value get(int index) { return values[index]; }
void set(int index, Value value) { values[index] = value; }
Set<String> names() { return nameToIndex.keySet(); }
+ Set<String> arguments() { return arguments; }
Integer indexOf(String name) { return nameToIndex.get(name); }
IndexedBindings copy(Context context) {
Value[] valueCopy = new Value[values.length];
for (int i = 0; i < values.length; i++)
valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i];
- return new IndexedBindings(nameToIndex, valueCopy);
+ return new IndexedBindings(nameToIndex, valueCopy, arguments);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index ac8f28677a4..fda1ae935ca 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
@@ -38,28 +39,39 @@ public class Model {
/** Programmatically create a model containing functions without constant of function references only */
public Model(String name, Collection<ExpressionFunction> functions) {
- this(name, functions, Collections.emptyMap(), Collections.emptyList());
+ this(name,
+ functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
+ Collections.emptyMap(),
+ Collections.emptyList());
}
Model(String name,
- Collection<ExpressionFunction> functions,
+ Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants) {
this.name = name;
- this.functions = ImmutableList.copyOf(functions);
+ // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
- for (ExpressionFunction function : functions) {
+ for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- contextBuilder.put(function.getName(),
- new LazyArrayContext(function.getBody(), referencedFunctions, constants, this));
+ LazyArrayContext context = new LazyArrayContext(function.getValue().getBody(), referencedFunctions, constants, this);
+ contextBuilder.put(function.getValue().getName(), context);
+ for (String argument : context.arguments()) {
+ if (function.getValue().argumentTypes().get(argument) == null)
+ functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ }
+ if ( ! function.getValue().returnType().isPresent())
+ functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
}
}
this.contextPrototypes = contextBuilder.build();
+ this.functions = ImmutableList.copyOf(functions.values());
+ // Optimize functions
ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
ExpressionFunction optimizedFunction = optimize(function.getValue(),
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 648c6d931a9..fb424439592 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -78,7 +78,9 @@ public class RankProfilesConfigImporter {
Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
if ( reference.isPresent()) {
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
- ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), Collections.emptyList(), expression);
+ ExpressionFunction function = new ExpressionFunction(reference.get().functionName(),
+ Collections.emptyList(),
+ expression);
if (reference.get().isFree()) // make available in model under configured name
functions.put(reference.get(), function);
@@ -120,14 +122,13 @@ public class RankProfilesConfigImporter {
constants.addAll(smallConstantsInfo.asConstants());
try {
- return new Model(profile.name(), functions.values(), referencedFunctions, constants);
+ return new Model(profile.name(), functions, referencedFunctions, constants);
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not load model '" + profile.name() + "'", e);
}
}
- // TODO: Replace by lookup in map
private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) {
for (ExpressionFunction function : functions)
if (function.getName().equals(name))
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 287a2387b34..c4b163e89c0 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -29,15 +29,16 @@ public class MlModelsImportingTest {
// TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
{
Model xgboost = tester.models().get("xgboost_2_2");
- tester.assertFunction("xgboost_2_2",
- "(optimized sum of condition trees of size 192 bytes)",
- xgboost);
// Function
assertEquals(1, xgboost.functions().size());
+ tester.assertFunction("xgboost_2_2",
+ "(optimized sum of condition trees of size 192 bytes)",
+ xgboost);
ExpressionFunction function = xgboost.functions().get(0);
- assertEquals("xgboost_2_2", function.getName());
- // assertEquals("f109, f29, f56, f60", commaSeparated(xgboost.functions().get(0).arguments())); TODO
+ assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get());
+ assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments()));
+ function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
// Evaluator
FunctionEvaluator evaluator = xgboost.evaluatorOf();
@@ -48,14 +49,12 @@ public class MlModelsImportingTest {
{
Model onnxMnistSoftmax = tester.models().get("mnist_softmax");
- tester.assertFunction("default.add",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
- onnxMnistSoftmax);
- assertEquals("tensor(d1[10],d2[784])",
- onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
// Function
assertEquals(1, onnxMnistSoftmax.functions().size());
+ tester.assertFunction("default.add",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
+ onnxMnistSoftmax);
ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
@@ -63,6 +62,8 @@ public class MlModelsImportingTest {
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
// Evaluator
+ assertEquals("tensor(d1[10],d2[784])",
+ onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta);
@@ -70,17 +71,17 @@ public class MlModelsImportingTest {
{
Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved");
- tester.assertFunction("serving_default.y",
- "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
- tfMnistSoftmax);
// Function
assertEquals(1, tfMnistSoftmax.functions().size());
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
+ tfMnistSoftmax);
ExpressionFunction function = tfMnistSoftmax.functions().get(0);
assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
- assertEquals("x", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x"));
+ assertEquals("Placeholder", function.arguments().get(0));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
// Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
@@ -90,10 +91,6 @@ public class MlModelsImportingTest {
{
Model tfMnist = tester.models().get("mnist_saved");
- tester.assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
-
// Generated function
tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
"join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
@@ -101,11 +98,14 @@ public class MlModelsImportingTest {
// Function
assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
ExpressionFunction function = tfMnist.functions().get(1);
assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
assertEquals(1, function.arguments().size());
- assertEquals("x", function.arguments().get(0));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("x"));
+ assertEquals("input", function.arguments().get(0));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input"));
// Evaluator
FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");
diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
index 9175b60315b..c25c5ba555b 100644
--- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
+++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
@@ -11,7 +11,7 @@ rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398,
rankprofile[2].name "mnist_softmax_saved"
rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript"
rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"
-rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).x.type"
+rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).Placeholder.type"
rankprofile[2].fef.property[1].value "tensor(d0[],d1[784])"
rankprofile[2].fef.property[2].name "rankingExpression(serving_default.y).type"
rankprofile[2].fef.property[2].value "tensor(d1[10])"
@@ -22,7 +22,7 @@ rankprofile[3].fef.property[1].name "rankingExpression(imported_ml_function_mnis
rankprofile[3].fef.property[1].value "tensor(d3[300])"
rankprofile[3].fef.property[2].name "rankingExpression(serving_default.y).rankingScript"
rankprofile[3].fef.property[2].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))"
-rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).x.type"
+rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).input.type"
rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])"
rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type"
-rankprofile[3].fef.property[4].value "tensor(d1[10])" \ No newline at end of file
+rankprofile[3].fef.property[4].value "tensor(d1[10])"
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 9ff391a5cfe..f26f2dea04f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -121,7 +121,7 @@ public class ImportedModel {
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
new ExpressionFunction(signatureEntry.getKey(),
- new ArrayList<>(signatureEntry.getValue().inputs().keySet()),
+ new ArrayList<>(signatureEntry.getValue().inputs().values()),
expressions().get(signatureEntry.getKey()),
signatureEntry.getValue().inputMap(),
Optional.empty())));
@@ -182,8 +182,11 @@ public class ImportedModel {
/** Returns the name and type of all inputs in this signature as an immutable map */
public Map<String, TensorType> inputMap() {
ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
+ // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
+ // in the model, as these are the names which must actually be bound, if we are to avoid creating an
+ // "input mapping" to accomodate this complexity in
for (Map.Entry<String, String> inputEntry : inputs().entrySet())
- inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue()));
+ inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
return inputs.build();
}
@@ -207,7 +210,7 @@ public class ImportedModel {
/** Returns the expression this output references */
public ExpressionFunction outputExpression(String outputName) {
return new ExpressionFunction(outputName,
- new ArrayList<>(inputs.keySet()),
+ new ArrayList<>(inputs.values()),
owner().expressions().get(outputs.get(outputName)),
inputMap(),
Optional.empty());