summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-10-01 14:22:55 +0200
committerGitHub <noreply@github.com>2018-10-01 14:22:55 +0200
commitfbca8fc6115fbf924cc688d927c50d8e9d99a321 (patch)
treecaf9b072fdaf5b7aff2c6dad5056402caed3a393
parente317da1b538ced3dd49d7f582a1c942a4a00d772 (diff)
parentda1a20ab27fff180baf3f574774c3bbb57488fee (diff)
Merge pull request #7155 from vespa-engine/bratseth/expose-type-information
Bratseth/expose type information
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java7
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java64
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java23
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java38
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java31
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java30
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java53
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java12
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java72
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java1
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java85
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java35
-rw-r--r--model-evaluation/src/test/resources/config/models/rank-profiles.cfg30
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java9
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java12
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
21 files changed, 409 insertions, 121 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
index 4c8b5910b78..3d1ef48c9dd 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
@@ -6,6 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchdefinition.RankProfileRegistry;
import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
@@ -60,7 +61,7 @@ public class RankingExpressionTypeResolver extends Processor {
private void resolveTypesIn(RankProfile profile, boolean validate) {
TypeContext<Reference> context = profile.typeContext(queryProfiles);
for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) {
- if ( ! function.getValue().function().arguments().isEmpty()) continue;
+ if (hasUntypedArguments(function.getValue().function())) continue;
TensorType type = resolveType(function.getValue().function().getBody(),
"function '" + function.getKey() + "'",
context);
@@ -74,6 +75,10 @@ public class RankingExpressionTypeResolver extends Processor {
}
}
+ private boolean hasUntypedArguments(ExpressionFunction function) {
+ return function.arguments().size() > function.argumentTypes().size();
+ }
+
private TensorType resolveType(RankingExpression expression, String expressionDescription, TypeContext context) {
if (expression == null) return null;
return resolveType(expression.getRoot(), expressionDescription, context);
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
index 9ed82b9eef5..d2d5ecbd5aa 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
@@ -54,7 +54,7 @@ public class MlModelsTest {
assertEquals("test", config.rankprofile(2).name());
RankProfilesConfig.Rankprofile.Fef test = config.rankprofile(2).fef();
- // Compare string content in a denser for that config:
+ // Compare profile content in a denser format than config:
StringBuilder b = new StringBuilder();
for (RankProfilesConfig.Rankprofile.Fef.Property p : test.property())
b.append(p.name()).append(": ").append(p.value()).append("\n");
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index 22bba9dd079..10de10bcdfe 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -36,6 +36,27 @@ import static org.junit.Assert.assertTrue;
*/
public class ModelEvaluationTest {
+ /** Tests that we do not load models (which would waste memory) when not requested */
+ @Test
+ public void testMl_serving_not_activated() {
+ Path appDir = Path.fromString("src/test/cfg/application/ml_serving_not_activated");
+ try {
+ ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir);
+ VespaModel model = tester.createVespaModel();
+ ContainerCluster cluster = model.getContainerClusters().get("container");
+ assertNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
+
+ RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
+ cluster.getConfig(b);
+ RankProfilesConfig config = new RankProfilesConfig(b);
+
+ assertEquals(0, config.rankprofile().size());
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+ }
+
@Test
public void testMl_serving() throws IOException {
Path appDir = Path.fromString("src/test/cfg/application/ml_serving");
@@ -58,27 +79,6 @@ public class ModelEvaluationTest {
}
}
- /** Tests that we do not load models (which will waste memory) when not requested */
- @Test
- public void testMl_serving_not_activated() {
- Path appDir = Path.fromString("src/test/cfg/application/ml_serving_not_activated");
- try {
- ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir);
- VespaModel model = tester.createVespaModel();
- ContainerCluster cluster = model.getContainerClusters().get("container");
- assertNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
-
- RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
- cluster.getConfig(b);
- RankProfilesConfig config = new RankProfilesConfig(b);
-
- assertEquals(0, config.rankprofile().size());
- }
- finally {
- IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
- }
- }
-
private void assertHasMlModels(VespaModel model) {
ContainerCluster cluster = model.getContainerClusters().get("container");
assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName())));
@@ -90,6 +90,7 @@ public class ModelEvaluationTest {
RankProfilesConfig.Builder b = new RankProfilesConfig.Builder();
cluster.getConfig(b);
RankProfilesConfig config = new RankProfilesConfig(b);
+ System.out.println(config);
RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder();
cluster.getConfig(cb);
@@ -102,6 +103,12 @@ public class ModelEvaluationTest {
assertTrue(modelNames.contains("mnist_softmax"));
assertTrue(modelNames.contains("mnist_softmax_saved"));
+ // Compare profile content in a denser format than config:
+ StringBuilder sb = new StringBuilder();
+ for (RankProfilesConfig.Rankprofile.Fef.Property p : findProfile("mnist_saved", config).property())
+ sb.append(p.name()).append(": ").append(p.value()).append("\n");
+ assertEquals(mnistProfile, sb.toString());
+
ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null))
.importFrom(config, constantsConfig));
@@ -136,6 +143,21 @@ public class ModelEvaluationTest {
assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y"));
}
+ private final String mnistProfile =
+ "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
+ "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type: tensor(d3[300])\n" +
+ "rankingExpression(serving_default.y).rankingScript: join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" +
+ "rankingExpression(serving_default.y).input.type: tensor(d0[],d1[784])\n" +
+ "rankingExpression(serving_default.y).type: tensor(d1[10])\n";
+
+ private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) {
+ for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) {
+ if (profile.name().equals(name))
+ return profile.fef();
+ }
+ throw new IllegalArgumentException("No profile named " + name);
+ }
+
// We don't have function file distribution so just return empty tensor constants
private static class ToleratingMissingConstantFilesRankProfilesConfigImporter extends RankProfilesConfigImporter {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 1412936d4a0..8c728867f45 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -3,10 +3,14 @@ package ai.vespa.models.evaluation;
import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.Map;
+import java.util.stream.Collectors;
+
/**
* An evaluator which can be used to evaluate a single function once.
*
@@ -34,7 +38,15 @@ public class FunctionEvaluator {
*/
public FunctionEvaluator bind(String name, Tensor value) {
if (evaluated)
- throw new IllegalStateException("You cannot bind a value in a used evaluator");
+ throw new IllegalStateException("Cannot bind a new value in a used evaluator");
+ TensorType requiredType = function.argumentTypes().get(name);
+ if (requiredType == null)
+ throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function +
+ ". Expected arguments: " + function.argumentTypes().entrySet().stream()
+ .map(e -> e.getKey() + ": " + e.getValue())
+ .collect(Collectors.joining(", ")));
+ if ( ! value.type().isAssignableTo(requiredType))
+ throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
context.put(name, new TensorValue(value));
return this;
}
@@ -52,10 +64,19 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
+ for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
+ if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0)
+ if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue)
+ throw new IllegalStateException("Missing argument '" + argument.getKey() +
+ "': Must be bound to a value of type " + argument.getValue());
+ }
evaluated = true;
return function.getBody().evaluate(context).asTensor();
}
+ /** Returns the function evaluated by this */
+ public ExpressionFunction function() { return function; }
+
public LazyArrayContext context() { return context; }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
index 00fcad94ce8..fa45920f3c8 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java
@@ -1,6 +1,9 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.collections.Pair;
+import com.yahoo.tensor.TensorType;
+
import java.util.Objects;
import java.util.Optional;
import java.util.regex.Matcher;
@@ -23,6 +26,10 @@ class FunctionReference {
private static final Pattern referencePattern =
Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?");
+ private static final Pattern argumentTypePattern =
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?");
+ private static final Pattern returnTypePattern =
+ Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.type?");
/** The name of the function referenced */
private final String name;
@@ -73,4 +80,35 @@ class FunctionReference {
return Optional.of(new FunctionReference(name, instance));
}
+ /**
+ * Returns a function reference and argument name string from the given serial form,
+ * or empty if the string is not a valid function argument serial form
+ */
+ static Optional<Pair<FunctionReference, String>> fromTypeArgumentSerial(String serialForm) {
+ Matcher expressionMatcher = argumentTypePattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+
+ String name = expressionMatcher.group(1);
+ String instance = expressionMatcher.group(2);
+ String argument = expressionMatcher.group(3);
+ return Optional.of(new Pair<>(new FunctionReference(name, instance), argument));
+ }
+
+ /**
+ * Returns a function reference from the given return type serial form,
+ * or empty if the string is not a valid function return typoe serial form
+ */
+ static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) {
+ Matcher expressionMatcher = returnTypePattern.matcher(serialForm);
+ if ( ! expressionMatcher.matches()) return Optional.empty();
+
+ String name = expressionMatcher.group(1);
+ String instance = expressionMatcher.group(2);
+ return Optional.of(new FunctionReference(name, instance));
+ }
+
+ public static FunctionReference fromName(String name) {
+ return new FunctionReference(name, null);
+ }
+
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index c7d0cbd8f30..78b30f0c873 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -2,6 +2,7 @@
package ai.vespa.models.evaluation;
import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
@@ -28,6 +29,8 @@ import java.util.Set;
*/
public final class LazyArrayContext extends Context implements ContextIndex {
+ public final static Value defaultContextValue = DoubleValue.zero;
+
private final IndexedBindings indexedBindings;
private LazyArrayContext(IndexedBindings indexedBindings) {
@@ -110,6 +113,9 @@ public final class LazyArrayContext extends Context implements ContextIndex {
@Override
public Set<String> names() { return indexedBindings.names(); }
+ /** Returns the (immutable) subset of names in this which must be bound when invoking */
+ public Set<String> arguments() { return indexedBindings.arguments(); }
+
private Integer requireIndexOf(String name) {
Integer index = indexedBindings.indexOf(name);
if (index == null)
@@ -130,12 +136,18 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The mapping from variable name to index */
private final ImmutableMap<String, Integer> nameToIndex;
+ /** The names which neeeds to be bound externally when envoking this (i.e not constant or invocation */
+ private final ImmutableSet<String> arguments;
+
/** The current values set, pre-converted to doubles */
private final Value[] values;
- private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values) {
+ private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
+ Value[] values,
+ ImmutableSet<String> arguments) {
this.nameToIndex = nameToIndex;
this.values = values;
+ this.arguments = arguments;
}
/**
@@ -149,10 +161,12 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Model model) {
// 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
- extractBindTargets(expression.getRoot(), functions, bindTargets);
+ Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
+ extractBindTargets(expression.getRoot(), functions, bindTargets, arguments);
+ this.arguments = ImmutableSet.copyOf(arguments);
values = new Value[bindTargets.size()];
- Arrays.fill(values, DoubleValue.zero);
+ Arrays.fill(values, defaultContextValue);
int i = 0;
ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
@@ -178,23 +192,25 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private void extractBindTargets(ExpressionNode node,
Map<FunctionReference, ExpressionFunction> functions,
- Set<String> bindTargets) {
+ Set<String> bindTargets,
+ Set<String> arguments) {
if (isFunctionReference(node)) {
FunctionReference reference = FunctionReference.fromSerial(node.toString()).get();
bindTargets.add(reference.serialForm());
- extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets);
+ extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments);
}
else if (isConstant(node)) {
bindTargets.add(node.toString());
}
else if (node instanceof ReferenceNode) {
bindTargets.add(node.toString());
+ arguments.add(node.toString());
}
else if (node instanceof CompositeNode) {
CompositeNode cNode = (CompositeNode)node;
for (ExpressionNode child : cNode.children())
- extractBindTargets(child, functions, bindTargets);
+ extractBindTargets(child, functions, bindTargets, arguments);
}
}
@@ -215,13 +231,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Value get(int index) { return values[index]; }
void set(int index, Value value) { values[index] = value; }
Set<String> names() { return nameToIndex.keySet(); }
+ Set<String> arguments() { return arguments; }
Integer indexOf(String name) { return nameToIndex.get(name); }
IndexedBindings copy(Context context) {
Value[] valueCopy = new Value[values.length];
for (int i = 0; i < values.length; i++)
valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i];
- return new IndexedBindings(nameToIndex, valueCopy);
+ return new IndexedBindings(nameToIndex, valueCopy, arguments);
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 3fb43d73187..fda1ae935ca 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
@@ -38,29 +39,39 @@ public class Model {
/** Programmatically create a model containing functions without constant of function references only */
public Model(String name, Collection<ExpressionFunction> functions) {
- this(name, functions, Collections.emptyMap(), Collections.emptyList());
+ this(name,
+ functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
+ Collections.emptyMap(),
+ Collections.emptyList());
}
Model(String name,
- Collection<ExpressionFunction> functions,
+ Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants) {
- // TODO: Optimize functions
this.name = name;
- this.functions = ImmutableList.copyOf(functions);
+ // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
- for (ExpressionFunction function : functions) {
+ for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- contextBuilder.put(function.getName(),
- new LazyArrayContext(function.getBody(), referencedFunctions, constants, this));
+ LazyArrayContext context = new LazyArrayContext(function.getValue().getBody(), referencedFunctions, constants, this);
+ contextBuilder.put(function.getValue().getName(), context);
+ for (String argument : context.arguments()) {
+ if (function.getValue().argumentTypes().get(argument) == null)
+ functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ }
+ if ( ! function.getValue().returnType().isPresent())
+ functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
}
catch (RuntimeException e) {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
}
}
this.contextPrototypes = contextBuilder.build();
+ this.functions = ImmutableList.copyOf(functions.values());
+ // Optimize functions
ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
ExpressionFunction optimizedFunction = optimize(function.getValue(),
@@ -79,7 +90,10 @@ public class Model {
public String name() { return name; }
- /** Returns an immutable list of the free functions of this */
+ /**
+ * Returns an immutable list of the free functions of this.
+ * The functions returned always specifies types of all arguments and the return value
+ */
public List<ExpressionFunction> functions() { return functions; }
/** Returns the given function, or throws a IllegalArgumentException if it does not exist */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
index 7bea2d0825a..fb424439592 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.collections.Pair;
import com.yahoo.config.FileReference;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.io.GrowableByteBuffer;
@@ -18,7 +19,10 @@ import java.io.File;
import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -60,26 +64,43 @@ public class RankProfilesConfigImporter {
private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig)
throws ParseException {
- List<ExpressionFunction> functions = new ArrayList<>();
- Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>();
- SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
- ExpressionFunction firstPhase = null;
- ExpressionFunction secondPhase = null;
List<Constant> constants = readLargeConstants(constantsConfig);
+ Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>();
+ Map<FunctionReference, ExpressionFunction> referencedFunctions = new LinkedHashMap<>();
+ SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo();
+ ExpressionFunction firstPhase = null;
+ ExpressionFunction secondPhase = null;
for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) {
Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name());
+ Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name());
+ Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name());
if ( reference.isPresent()) {
- List<String> arguments = new ArrayList<>(); // TODO: Arguments?
RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value());
+ ExpressionFunction function = new ExpressionFunction(reference.get().functionName(),
+ Collections.emptyList(),
+ expression);
if (reference.get().isFree()) // make available in model under configured name
- functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); //
-
- // Make all functions, bound or not available under the name they are referenced by in expressions
- referencedFunctions.put(reference.get(),
- new ExpressionFunction(reference.get().serialForm(), arguments, expression));
+ functions.put(reference.get(), function);
+ // Make all functions, bound or not, available under the name they are referenced by in expressions
+ referencedFunctions.put(reference.get(), function);
+ }
+ else if (argumentType.isPresent()) { // Arguments always follows the function in properties
+ FunctionReference argReference = argumentType.get().getFirst();
+ ExpressionFunction function = referencedFunctions.get(argReference);
+ function = function.withArgument(argumentType.get().getSecond(), TensorType.fromSpec(property.value()));
+ if (argReference.isFree())
+ functions.put(argReference, function);
+ referencedFunctions.put(argReference, function);
+ }
+ else if (returnType.isPresent()) { // Return type always follows the function in properties
+ ExpressionFunction function = referencedFunctions.get(returnType.get());
+ function = function.withReturnType(TensorType.fromSpec(property.value()));
+ if (returnType.get().isFree())
+ functions.put(returnType.get(), function);
+ referencedFunctions.put(returnType.get(), function);
}
else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions
firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(),
@@ -93,10 +114,10 @@ public class RankProfilesConfigImporter {
smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value());
}
}
- if (functionByName("firstphase", functions) == null && firstPhase != null) // may be already included, depending on body
- functions.add(firstPhase);
- if (functionByName("secondphase", functions) == null && secondPhase != null) // may be already included, depending on body
- functions.add(secondPhase);
+ if (functionByName("firstphase", functions.values()) == null && firstPhase != null) // may be already included, depending on body
+ functions.put(FunctionReference.fromName("firstphase"), firstPhase);
+ if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body
+ functions.put(FunctionReference.fromName("secondphase"), secondPhase);
constants.addAll(smallConstantsInfo.asConstants());
@@ -108,7 +129,7 @@ public class RankProfilesConfigImporter {
}
}
- private ExpressionFunction functionByName(String name, List<ExpressionFunction> functions) {
+ private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) {
for (ExpressionFunction function : functions)
if (function.getName().equals(name))
return function;
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index 683a1f345d8..6edcd84272e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -10,13 +10,16 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.slime.Cursor;
import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.JsonFormat;
+import com.yahoo.yolean.Exceptions;
import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Arrays;
+import java.util.Map;
import java.util.Optional;
import java.util.concurrent.Executor;
@@ -60,14 +63,17 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
return listModelInformation(request, model, function);
} catch (IllegalArgumentException e) {
- return new ErrorResponse(404, e.getMessage());
+ return new ErrorResponse(404, Exceptions.toMessageString(e));
+ } catch (IllegalStateException e) { // On missing bindings
+ return new ErrorResponse(400, Exceptions.toMessageString(e));
}
}
private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) {
FunctionEvaluator evaluator = model.evaluatorOf(function);
- for (String bindingName : evaluator.context().names()) {
- property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s)));
+ for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) {
+ property(request, argument.getKey()).ifPresent(value -> evaluator.bind(argument.getKey(),
+ Tensor.from(argument.getValue(), value)));
}
Tensor result = evaluator.evaluate();
return new Response(200, JsonFormat.encode(result));
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 6e55c0c9a53..68c3b954675 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -1,8 +1,12 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import java.util.List;
import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
@@ -25,9 +29,18 @@ public class MlModelsImportingTest {
// TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
{
Model xgboost = tester.models().get("xgboost_2_2");
+
+ // Function
+ assertEquals(1, xgboost.functions().size());
tester.assertFunction("xgboost_2_2",
"(optimized sum of condition trees of size 192 bytes)",
xgboost);
+ ExpressionFunction function = xgboost.functions().get(0);
+ assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get());
+ assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments()));
+ function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg)));
+
+ // Evaluator
FunctionEvaluator evaluator = xgboost.evaluatorOf();
assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta);
@@ -36,39 +49,80 @@ public class MlModelsImportingTest {
{
Model onnxMnistSoftmax = tester.models().get("mnist_softmax");
+
+ // Function
+ assertEquals(1, onnxMnistSoftmax.functions().size());
tester.assertFunction("default.add",
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))",
onnxMnistSoftmax);
+ ExpressionFunction function = onnxMnistSoftmax.functions().get(0);
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ assertEquals(1, function.arguments().size());
+ assertEquals("Placeholder", function.arguments().get(0));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+
+ // Evaluator
assertEquals("tensor(d1[10],d2[784])",
onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString());
FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available
- assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ evaluator.bind("Placeholder", inputTensor());
assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta);
}
{
Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved");
+
+ // Function
+ assertEquals(1, tfMnistSoftmax.functions().size());
tester.assertFunction("serving_default.y",
"join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))",
tfMnistSoftmax);
+ ExpressionFunction function = tfMnistSoftmax.functions().get(0);
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ assertEquals(1, function.arguments().size());
+ assertEquals("Placeholder", function.arguments().get(0));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder"));
+
+ // Evaluator
FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available
- assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+ evaluator.bind("Placeholder", inputTensor());
assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta);
}
{
Model tfMnist = tester.models().get("mnist_saved");
- tester.assertFunction("serving_default.y",
- "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
- tfMnist);
- // Macro:
- tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add",
+ // Generated function
+ tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add",
"join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))",
tfMnist);
- FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument
- assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", ")));
+
+ // Function
+ assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function
+ tester.assertFunction("serving_default.y",
+ "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))",
+ tfMnist);
+ ExpressionFunction function = tfMnist.functions().get(1);
+ assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get());
+ assertEquals(1, function.arguments().size());
+ assertEquals("input", function.arguments().get(0));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input"));
+
+ // Evaluator
+ FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default");
+ evaluator.bind("input", inputTensor());
assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta);
}
}
+ private Tensor inputTensor() {
+ Tensor.Builder b = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[],d1[784])"));
+ for (int i = 0; i < 784; i++)
+ b.cell(0.0, 0, i);
+ return b.build();
+ }
+
+ private String commaSeparated(List<?> items) {
+ return items.stream().map(item -> item.toString()).sorted().collect(Collectors.joining(", "));
+ }
+
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
index 9a3e59aed80..50dd1d1d05f 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java
@@ -60,7 +60,6 @@ public class ModelTester {
public void assertBoundFunction(String name, String expression, Model model) {
ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get());
assertNotNull("Function '" + name + "' is present", function);
- assertEquals(name, function.getName());
assertEquals(expression, function.getBody().getRoot().toString());
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index bd1ff6b8ed7..8824be05006 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -5,11 +5,19 @@ import com.yahoo.config.subscription.ConfigGetter;
import com.yahoo.config.subscription.FileSource;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.path.Path;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
+import com.yahoo.yolean.Exceptions;
import org.junit.Test;
+import java.util.ArrayList;
+import java.util.List;
+
import static org.junit.Assert.assertEquals;
/**
@@ -20,16 +28,7 @@ public class ModelsEvaluatorTest {
private static final double delta = 0.00000000001;
@Test
- public void testTensorEvaluation() {
- ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/");
- FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum");
- function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}"));
- function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}"));
- assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), function.evaluate());
- }
-
- @Test
- public void testEvaluationDependingOnMacroTakingArguments() {
+ public void testEvaluationDependingFunctionTakingArguments() {
ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/");
FunctionEvaluator function = models.evaluatorOf("macros", "secondphase");
function.bind("match", 3);
@@ -37,10 +36,70 @@ public class ModelsEvaluatorTest {
assertEquals(32.0, function.evaluate().asDouble(), delta);
}
+ @Test
+ public void testBindingValidation() {
+ List<ExpressionFunction> functions = new ArrayList<>();
+ ExpressionFunction function = new ExpressionFunction("test", RankingExpression.from("sum(arg1 * arg2)"));
+ function = function.withArgument("arg1", TensorType.fromSpec("tensor(d0[1])"));
+ function = function.withArgument("arg2", TensorType.fromSpec("tensor(d1{})"));
+ functions.add(function);
+ Model model = new Model("test-model", functions);
+
+ try { // No bindings
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.evaluate();
+ }
+ catch (IllegalStateException e) {
+ assertEquals("Missing argument 'arg2': Must be bound to a value of type tensor(d1{})",
+ Exceptions.toMessageString(e));
+ }
+
+ try { // Just one binding
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("arg2", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}"));
+ evaluator.evaluate();
+ }
+ catch (IllegalStateException e) {
+ assertEquals("Missing argument 'arg1': Must be bound to a value of type tensor(d0[1])",
+ Exceptions.toMessageString(e));
+ }
+
+ try { // Wrong binding argument
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("argNone", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}"));
+ evaluator.evaluate();
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])",
+ Exceptions.toMessageString(e));
+ }
+
+ try { // Wrong binding type
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d3{})"), "{{d3:foo}:0.1}"));
+ evaluator.evaluate();
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("'arg1' must be of type tensor(d0[1]), not tensor(d3{})",
+ Exceptions.toMessageString(e));
+ }
+
+ try { // Attempt to reuse evaluator
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}"));
+ evaluator.bind("arg2", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}"));
+ evaluator.evaluate();
+ evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}"));
+ }
+ catch (IllegalStateException e) {
+ assertEquals("Cannot bind a new value in a used evaluator",
+ Exceptions.toMessageString(e));
+ }
+
+ }
+
// TODO: Test argument-less function
- // TODO: Test that binding nonexisting variable doesn't work
- // TODO: Test that rebinding doesn't work
- // TODO: Test with nested macros
+ // TODO: Test with nested functions
private ModelsEvaluator createModels(String path) {
Path configDir = Path.fromString(path);
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 6726f117c05..b915ee72a79 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -9,6 +9,8 @@ import com.yahoo.container.jdisc.HttpRequest;
import com.yahoo.container.jdisc.HttpResponse;
import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.path.Path;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
import com.yahoo.vespa.config.search.core.RankingConstantsConfig;
import org.junit.BeforeClass;
@@ -94,39 +96,39 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
- String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
- assertResponse(url, 200, expected);
+ String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}";
+ assertResponse(url, 400, expected);
}
@Test
public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
- assertResponse(url, 200, expected);
+ String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}";
+ assertResponse(url, 400, expected);
}
@Test
public void testMnistSoftmaxEvaluateDefaultFunctionWithBindings() {
Map<String, String> properties = new HashMap<>();
- properties.put("Placeholder", "{1.0}");
+ properties.put("Placeholder", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
- String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}";
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
assertResponse(url, properties, 200, expected);
}
@Test
public void testMnistSoftmaxEvaluateSpecificFunctionWithBindings() {
Map<String, String> properties = new HashMap<>();
- properties.put("Placeholder", "{1.0}");
+ properties.put("Placeholder", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}";
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
assertResponse(url, properties, 200, expected);
}
@Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
- String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}";
+ String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_function_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}";
assertResponse(url, 200, expected);
}
@@ -140,16 +142,16 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSavedEvaluateDefaultFunctionShouldFail() {
String url = "http://localhost/model-evaluation/v1/mnist_saved/eval";
- String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_macro_mnist_saved_dnn_hidden1_add, serving_default.y\"}";
+ String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}";
assertResponse(url, 404, expected);
}
@Test
public void testMnistSavedEvaluateSpecificFunction() {
Map<String, String> properties = new HashMap<>();
- properties.put("input", "-1.0");
+ properties.put("input", inputTensor());
String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval";
- String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-2.72208123403445},{\"address\":{\"d1\":\"1\"},\"value\":6.465137496457595},{\"address\":{\"d1\":\"2\"},\"value\":-7.078050386283122},{\"address\":{\"d1\":\"3\"},\"value\":-10.485296462655546},{\"address\":{\"d1\":\"4\"},\"value\":0.19508378636937004},{\"address\":{\"d1\":\"5\"},\"value\":6.348870746681019},{\"address\":{\"d1\":\"6\"},\"value\":10.756191852397258},{\"address\":{\"d1\":\"7\"},\"value\":1.476101533270058},{\"address\":{\"d1\":\"8\"},\"value\":-17.778398655804875},{\"address\":{\"d1\":\"9\"},\"value\":-2.0597690508530295}]}";
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.6319251673007533},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":-7.577770600619843E-4},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":-0.010707969042025622},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.6344759233540788},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":-0.17529455385847528},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":0.7490809723192187},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.022790284182901716},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.26799240657608936},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-0.3152438845465862},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":0.05949304847735276}]}";
assertResponse(url, properties, 200, expected);
}
@@ -171,10 +173,10 @@ public class ModelsEvaluationHandlerTest {
static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) {
HttpResponse response = handler.handle(request);
assertEquals("application/json", response.getContentType());
- assertEquals(expectedCode, response.getStatus());
if (expectedResult != null) {
assertEquals(expectedResult, getContents(response));
}
+ assertEquals(expectedCode, response.getStatus());
}
static private String getContents(HttpResponse response) {
@@ -198,4 +200,11 @@ public class ModelsEvaluationHandlerTest {
return new ModelsEvaluator(importer.importFrom(config, constantsConfig));
}
+ private String inputTensor() {
+ Tensor.Builder b = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[],d1[784])"));
+ for (int i = 0; i < 784; i++)
+ b.cell(0.0, 0, i);
+ return b.build().toString();
+ }
+
}
diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
index 1cc36f75158..c25c5ba555b 100644
--- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
+++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
@@ -1,14 +1,28 @@
-rankprofile[0].name "mnist_saved"
-rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript"
-rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"
-rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript"
-rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))"
+rankprofile[0].name "mnist_softmax"
+rankprofile[0].fef.property[0].name "rankingExpression(default.add).rankingScript"
+rankprofile[0].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))"
+rankprofile[0].fef.property[1].name "rankingExpression(default.add).Placeholder.type"
+rankprofile[0].fef.property[1].value "tensor(d0[],d1[784])"
+rankprofile[0].fef.property[2].name "rankingExpression(default.add).type"
+rankprofile[0].fef.property[2].value "tensor(d1[10])"
rankprofile[1].name "xgboost_2_2"
rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript"
rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)"
rankprofile[2].name "mnist_softmax_saved"
rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript"
rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))"
-rankprofile[3].name "mnist_softmax"
-rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript"
-rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))"
+rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).Placeholder.type"
+rankprofile[2].fef.property[1].value "tensor(d0[],d1[784])"
+rankprofile[2].fef.property[2].name "rankingExpression(serving_default.y).type"
+rankprofile[2].fef.property[2].value "tensor(d1[10])"
+rankprofile[3].name "mnist_saved"
+rankprofile[3].fef.property[0].name "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript"
+rankprofile[3].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))"
+rankprofile[3].fef.property[1].name "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type"
+rankprofile[3].fef.property[1].value "tensor(d3[300])"
+rankprofile[3].fef.property[2].name "rankingExpression(serving_default.y).rankingScript"
+rankprofile[3].fef.property[2].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))"
+rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).input.type"
+rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])"
+rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type"
+rankprofile[3].fef.property[4].value "tensor(d1[10])"
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index f6502a9801d..787b857839d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -11,6 +11,7 @@ import com.yahoo.text.Utf8;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
@@ -97,7 +98,12 @@ public class ExpressionFunction {
return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType));
}
- public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) {
+ /** Returns a copy of this with the given argument and argument type added */
+ public ExpressionFunction withArgument(String argument, TensorType type) {
+ List<String> arguments = new ArrayList<>(this.arguments);
+ arguments.add(argument);
+ Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes);
+ argumentTypes.put(argument, type);
return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 17157ab385f..8aa7446cae7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -9,7 +9,6 @@ import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
* In a boolean context doubles are true if they are different from 0.0
*
* @author bratseth
- * @since 5.1.5
*/
public final class DoubleValue extends DoubleCompatibleValue {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 9ff391a5cfe..f26f2dea04f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -121,7 +121,7 @@ public class ImportedModel {
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
new ExpressionFunction(signatureEntry.getKey(),
- new ArrayList<>(signatureEntry.getValue().inputs().keySet()),
+ new ArrayList<>(signatureEntry.getValue().inputs().values()),
expressions().get(signatureEntry.getKey()),
signatureEntry.getValue().inputMap(),
Optional.empty())));
@@ -182,8 +182,11 @@ public class ImportedModel {
/** Returns the name and type of all inputs in this signature as an immutable map */
public Map<String, TensorType> inputMap() {
ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
+ // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
+ // in the model, as these are the names which must actually be bound, if we are to avoid creating an
+ // "input mapping" to accomodate this complexity in
for (Map.Entry<String, String> inputEntry : inputs().entrySet())
- inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue()));
+ inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
return inputs.build();
}
@@ -207,7 +210,7 @@ public class ImportedModel {
/** Returns the expression this output references */
public ExpressionFunction outputExpression(String outputName) {
return new ExpressionFunction(outputName,
- new ArrayList<>(inputs.keySet()),
+ new ArrayList<>(inputs.values()),
owner().expressions().get(outputs.get(outputName)),
inputMap(),
Optional.empty());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index 593e7b54c10..e325c3d11b4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -15,17 +15,18 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test",
+ "src/test/files/integration/tensorflow/batch_norm/saved");
ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName());
- model.assertEqualResult("X", output.getBody().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ ExpressionFunction function = signature.outputExpression("y");
+ assertNotNull(function);
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", function.getBody().getName());
+ model.assertEqualResult("X", function.getBody().getName());
+ assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 59712c0152f..8ca5a9a7888 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -30,13 +30,13 @@ public class DropoutImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("outputs/Maximum", output.getBody().getName());
+ ExpressionFunction function = signature.outputExpression("y");
+ assertNotNull(function);
+ assertEquals("outputs/Maximum", function.getBody().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- output.getBody().getRoot().toString());
- model.assertEqualResult("X", output.getBody().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ function.getBody().getRoot().toString());
+ model.assertEqualResult("X", function.getBody().getName());
+ assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index 0a48ecfce21..feba40601e3 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -62,7 +62,7 @@ public class TensorFlowMnistSoftmaxImportTestCase {
assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
output.getBody().getRoot().toString());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index a1334dc729c..000f33696f2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -27,7 +27,7 @@ class TensorParser {
else {
if (type.isPresent() && ! type.get().equals(TensorType.empty))
throw new IllegalArgumentException("Got zero-dimensional tensor '" + tensorString +
- "but type is not empty but " + type.get());
+ "' where type " + type.get() + " is required");
return Tensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build();
}
}