diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2018-10-01 14:22:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-01 14:22:55 +0200 |
commit | fbca8fc6115fbf924cc688d927c50d8e9d99a321 (patch) | |
tree | caf9b072fdaf5b7aff2c6dad5056402caed3a393 | |
parent | e317da1b538ced3dd49d7f582a1c942a4a00d772 (diff) | |
parent | da1a20ab27fff180baf3f574774c3bbb57488fee (diff) |
Merge pull request #7155 from vespa-engine/bratseth/expose-type-information
Bratseth/expose type information
21 files changed, 409 insertions, 121 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java index 4c8b5910b78..3d1ef48c9dd 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java @@ -6,6 +6,7 @@ import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankProfileRegistry; import com.yahoo.searchdefinition.Search; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; @@ -60,7 +61,7 @@ public class RankingExpressionTypeResolver extends Processor { private void resolveTypesIn(RankProfile profile, boolean validate) { TypeContext<Reference> context = profile.typeContext(queryProfiles); for (Map.Entry<String, RankProfile.RankingExpressionFunction> function : profile.getFunctions().entrySet()) { - if ( ! function.getValue().function().arguments().isEmpty()) continue; + if (hasUntypedArguments(function.getValue().function())) continue; TensorType type = resolveType(function.getValue().function().getBody(), "function '" + function.getKey() + "'", context); @@ -74,6 +75,10 @@ public class RankingExpressionTypeResolver extends Processor { } } + private boolean hasUntypedArguments(ExpressionFunction function) { + return function.arguments().size() > function.argumentTypes().size(); + } + private TensorType resolveType(RankingExpression expression, String expressionDescription, TypeContext context) { if (expression == null) return null; return resolveType(expression.getRoot(), expressionDescription, context); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java index 9ed82b9eef5..d2d5ecbd5aa 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java @@ -54,7 +54,7 @@ public class MlModelsTest { assertEquals("test", config.rankprofile(2).name()); RankProfilesConfig.Rankprofile.Fef test = config.rankprofile(2).fef(); - // Compare string content in a denser for that config: + // Compare profile content in a denser format than config: StringBuilder b = new StringBuilder(); for (RankProfilesConfig.Rankprofile.Fef.Property p : test.property()) b.append(p.name()).append(": ").append(p.value()).append("\n"); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index 22bba9dd079..10de10bcdfe 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -36,6 +36,27 @@ import static org.junit.Assert.assertTrue; */ public class ModelEvaluationTest { + /** Tests that we do not load models (which would waste memory) when not requested */ + @Test + public void testMl_serving_not_activated() { + Path appDir = Path.fromString("src/test/cfg/application/ml_serving_not_activated"); + try { + ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir); + VespaModel model = tester.createVespaModel(); + ContainerCluster cluster = model.getContainerClusters().get("container"); + assertNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName()))); + + RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); + cluster.getConfig(b); + RankProfilesConfig config = new RankProfilesConfig(b); + + assertEquals(0, config.rankprofile().size()); + } + finally { + IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + } + @Test public void testMl_serving() throws IOException { Path appDir = Path.fromString("src/test/cfg/application/ml_serving"); @@ -58,27 +79,6 @@ public class ModelEvaluationTest { } } - /** Tests that we do not load models (which will waste memory) when not requested */ - @Test - public void testMl_serving_not_activated() { - Path appDir = Path.fromString("src/test/cfg/application/ml_serving_not_activated"); - try { - ImportedModelTester tester = new ImportedModelTester("ml_serving", appDir); - VespaModel model = tester.createVespaModel(); - ContainerCluster cluster = model.getContainerClusters().get("container"); - assertNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName()))); - - RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); - cluster.getConfig(b); - RankProfilesConfig config = new RankProfilesConfig(b); - - assertEquals(0, config.rankprofile().size()); - } - finally { - IOUtils.recursiveDeleteDir(appDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); - } - } - private void assertHasMlModels(VespaModel model) { ContainerCluster cluster = model.getContainerClusters().get("container"); assertNotNull(cluster.getComponentsMap().get(new ComponentId(ModelsEvaluator.class.getName()))); @@ -90,6 +90,7 @@ public class ModelEvaluationTest { RankProfilesConfig.Builder b = new RankProfilesConfig.Builder(); cluster.getConfig(b); RankProfilesConfig config = new RankProfilesConfig(b); + System.out.println(config); RankingConstantsConfig.Builder cb = new RankingConstantsConfig.Builder(); cluster.getConfig(cb); @@ -102,6 +103,12 @@ public class ModelEvaluationTest { assertTrue(modelNames.contains("mnist_softmax")); assertTrue(modelNames.contains("mnist_softmax_saved")); + // Compare profile content in a denser format than config: + StringBuilder sb = new StringBuilder(); + for (RankProfilesConfig.Rankprofile.Fef.Property p : findProfile("mnist_saved", config).property()) + sb.append(p.name()).append(": ").append(p.value()).append("\n"); + assertEquals(mnistProfile, sb.toString()); + ModelsEvaluator evaluator = new ModelsEvaluator(new ToleratingMissingConstantFilesRankProfilesConfigImporter(MockFileAcquirer.returnFile(null)) .importFrom(config, constantsConfig)); @@ -136,6 +143,21 @@ public class ModelEvaluationTest { assertNotNull(tensorflow_mnist_softmax.evaluatorOf("serving_default", "y")); } + private final String mnistProfile = + "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" + + "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type: tensor(d3[300])\n" + + "rankingExpression(serving_default.y).rankingScript: join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" + + "rankingExpression(serving_default.y).input.type: tensor(d0[],d1[784])\n" + + "rankingExpression(serving_default.y).type: tensor(d1[10])\n"; + + private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) { + for (RankProfilesConfig.Rankprofile profile : config.rankprofile()) { + if (profile.name().equals(name)) + return profile.fef(); + } + throw new IllegalArgumentException("No profile named " + name); + } + // We don't have function file distribution so just return empty tensor constants private static class ToleratingMissingConstantFilesRankProfilesConfigImporter extends RankProfilesConfigImporter { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 1412936d4a0..8c728867f45 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -3,10 +3,14 @@ package ai.vespa.models.evaluation; import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.Map; +import java.util.stream.Collectors; + /** * An evaluator which can be used to evaluate a single function once. * @@ -34,7 +38,15 @@ public class FunctionEvaluator { */ public FunctionEvaluator bind(String name, Tensor value) { if (evaluated) - throw new IllegalStateException("You cannot bind a value in a used evaluator"); + throw new IllegalStateException("Cannot bind a new value in a used evaluator"); + TensorType requiredType = function.argumentTypes().get(name); + if (requiredType == null) + throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function + + ". Expected arguments: " + function.argumentTypes().entrySet().stream() + .map(e -> e.getKey() + ": " + e.getValue()) + .collect(Collectors.joining(", "))); + if ( ! value.type().isAssignableTo(requiredType)) + throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type()); context.put(name, new TensorValue(value)); return this; } @@ -52,10 +64,19 @@ public class FunctionEvaluator { } public Tensor evaluate() { + for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { + if (argument.getValue().rank() == 0) continue; // Scalar argumentds can be skipped (defaults to 0) + if (context.get(argument.getKey()) == LazyArrayContext.defaultContextValue) + throw new IllegalStateException("Missing argument '" + argument.getKey() + + "': Must be bound to a value of type " + argument.getValue()); + } evaluated = true; return function.getBody().evaluate(context).asTensor(); } + /** Returns the function evaluated by this */ + public ExpressionFunction function() { return function; } + public LazyArrayContext context() { return context; } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java index 00fcad94ce8..fa45920f3c8 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionReference.java @@ -1,6 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.collections.Pair; +import com.yahoo.tensor.TensorType; + import java.util.Objects; import java.util.Optional; import java.util.regex.Matcher; @@ -23,6 +26,10 @@ class FunctionReference { private static final Pattern referencePattern = Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)(\\.rankingScript)?"); + private static final Pattern argumentTypePattern = + Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.([a-zA-Z0-9_]+)\\.type?"); + private static final Pattern returnTypePattern = + Pattern.compile("rankingExpression\\(([a-zA-Z0-9_.]+)(@[a-f0-9]+\\.[a-f0-9]+)?\\)\\.type?"); /** The name of the function referenced */ private final String name; @@ -73,4 +80,35 @@ class FunctionReference { return Optional.of(new FunctionReference(name, instance)); } + /** + * Returns a function reference and argument name string from the given serial form, + * or empty if the string is not a valid function argument serial form + */ + static Optional<Pair<FunctionReference, String>> fromTypeArgumentSerial(String serialForm) { + Matcher expressionMatcher = argumentTypePattern.matcher(serialForm); + if ( ! expressionMatcher.matches()) return Optional.empty(); + + String name = expressionMatcher.group(1); + String instance = expressionMatcher.group(2); + String argument = expressionMatcher.group(3); + return Optional.of(new Pair<>(new FunctionReference(name, instance), argument)); + } + + /** + * Returns a function reference from the given return type serial form, + * or empty if the string is not a valid function return typoe serial form + */ + static Optional<FunctionReference> fromReturnTypeSerial(String serialForm) { + Matcher expressionMatcher = returnTypePattern.matcher(serialForm); + if ( ! expressionMatcher.matches()) return Optional.empty(); + + String name = expressionMatcher.group(1); + String instance = expressionMatcher.group(2); + return Optional.of(new FunctionReference(name, instance)); + } + + public static FunctionReference fromName(String name) { + return new FunctionReference(name, null); + } + } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index c7d0cbd8f30..78b30f0c873 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.Reference; @@ -28,6 +29,8 @@ import java.util.Set; */ public final class LazyArrayContext extends Context implements ContextIndex { + public final static Value defaultContextValue = DoubleValue.zero; + private final IndexedBindings indexedBindings; private LazyArrayContext(IndexedBindings indexedBindings) { @@ -110,6 +113,9 @@ public final class LazyArrayContext extends Context implements ContextIndex { @Override public Set<String> names() { return indexedBindings.names(); } + /** Returns the (immutable) subset of names in this which must be bound when invoking */ + public Set<String> arguments() { return indexedBindings.arguments(); } + private Integer requireIndexOf(String name) { Integer index = indexedBindings.indexOf(name); if (index == null) @@ -130,12 +136,18 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The mapping from variable name to index */ private final ImmutableMap<String, Integer> nameToIndex; + /** The names which neeeds to be bound externally when envoking this (i.e not constant or invocation */ + private final ImmutableSet<String> arguments; + /** The current values set, pre-converted to doubles */ private final Value[] values; - private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values) { + private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, + Value[] values, + ImmutableSet<String> arguments) { this.nameToIndex = nameToIndex; this.values = values; + this.arguments = arguments; } /** @@ -149,10 +161,12 @@ public final class LazyArrayContext extends Context implements ContextIndex { Model model) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); - extractBindTargets(expression.getRoot(), functions, bindTargets); + Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation + extractBindTargets(expression.getRoot(), functions, bindTargets, arguments); + this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; - Arrays.fill(values, DoubleValue.zero); + Arrays.fill(values, defaultContextValue); int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); @@ -178,23 +192,25 @@ public final class LazyArrayContext extends Context implements ContextIndex { private void extractBindTargets(ExpressionNode node, Map<FunctionReference, ExpressionFunction> functions, - Set<String> bindTargets) { + Set<String> bindTargets, + Set<String> arguments) { if (isFunctionReference(node)) { FunctionReference reference = FunctionReference.fromSerial(node.toString()).get(); bindTargets.add(reference.serialForm()); - extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets); + extractBindTargets(functions.get(reference).getBody().getRoot(), functions, bindTargets, arguments); } else if (isConstant(node)) { bindTargets.add(node.toString()); } else if (node instanceof ReferenceNode) { bindTargets.add(node.toString()); + arguments.add(node.toString()); } else if (node instanceof CompositeNode) { CompositeNode cNode = (CompositeNode)node; for (ExpressionNode child : cNode.children()) - extractBindTargets(child, functions, bindTargets); + extractBindTargets(child, functions, bindTargets, arguments); } } @@ -215,13 +231,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { Value get(int index) { return values[index]; } void set(int index, Value value) { values[index] = value; } Set<String> names() { return nameToIndex.keySet(); } + Set<String> arguments() { return arguments; } Integer indexOf(String name) { return nameToIndex.get(name); } IndexedBindings copy(Context context) { Value[] valueCopy = new Value[values.length]; for (int i = 0; i < values.length; i++) valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue)values[i]).copyFor(context) : values[i]; - return new IndexedBindings(nameToIndex, valueCopy); + return new IndexedBindings(nameToIndex, valueCopy, arguments); } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 3fb43d73187..fda1ae935ca 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -7,6 +7,7 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; @@ -38,29 +39,39 @@ public class Model { /** Programmatically create a model containing functions without constant of function references only */ public Model(String name, Collection<ExpressionFunction> functions) { - this(name, functions, Collections.emptyMap(), Collections.emptyList()); + this(name, + functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), + Collections.emptyMap(), + Collections.emptyList()); } Model(String name, - Collection<ExpressionFunction> functions, + Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants) { - // TODO: Optimize functions this.name = name; - this.functions = ImmutableList.copyOf(functions); + // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); - for (ExpressionFunction function : functions) { + for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - contextBuilder.put(function.getName(), - new LazyArrayContext(function.getBody(), referencedFunctions, constants, this)); + LazyArrayContext context = new LazyArrayContext(function.getValue().getBody(), referencedFunctions, constants, this); + contextBuilder.put(function.getValue().getName(), context); + for (String argument : context.arguments()) { + if (function.getValue().argumentTypes().get(argument) == null) + functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + } + if ( ! function.getValue().returnType().isPresent()) + functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); } catch (RuntimeException e) { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); } } this.contextPrototypes = contextBuilder.build(); + this.functions = ImmutableList.copyOf(functions.values()); + // Optimize functions ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) { ExpressionFunction optimizedFunction = optimize(function.getValue(), @@ -79,7 +90,10 @@ public class Model { public String name() { return name; } - /** Returns an immutable list of the free functions of this */ + /** + * Returns an immutable list of the free functions of this. + * The functions returned always specifies types of all arguments and the return value + */ public List<ExpressionFunction> functions() { return functions; } /** Returns the given function, or throws a IllegalArgumentException if it does not exist */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java index 7bea2d0825a..fb424439592 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/RankProfilesConfigImporter.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.collections.Pair; import com.yahoo.config.FileReference; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; import com.yahoo.io.GrowableByteBuffer; @@ -18,7 +19,10 @@ import java.io.File; import java.io.IOException; import java.io.UncheckedIOException; import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -60,26 +64,43 @@ public class RankProfilesConfigImporter { private Model importProfile(RankProfilesConfig.Rankprofile profile, RankingConstantsConfig constantsConfig) throws ParseException { - List<ExpressionFunction> functions = new ArrayList<>(); - Map<FunctionReference, ExpressionFunction> referencedFunctions = new HashMap<>(); - SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo(); - ExpressionFunction firstPhase = null; - ExpressionFunction secondPhase = null; List<Constant> constants = readLargeConstants(constantsConfig); + Map<FunctionReference, ExpressionFunction> functions = new LinkedHashMap<>(); + Map<FunctionReference, ExpressionFunction> referencedFunctions = new LinkedHashMap<>(); + SmallConstantsInfo smallConstantsInfo = new SmallConstantsInfo(); + ExpressionFunction firstPhase = null; + ExpressionFunction secondPhase = null; for (RankProfilesConfig.Rankprofile.Fef.Property property : profile.fef().property()) { Optional<FunctionReference> reference = FunctionReference.fromSerial(property.name()); + Optional<Pair<FunctionReference, String>> argumentType = FunctionReference.fromTypeArgumentSerial(property.name()); + Optional<FunctionReference> returnType = FunctionReference.fromReturnTypeSerial(property.name()); if ( reference.isPresent()) { - List<String> arguments = new ArrayList<>(); // TODO: Arguments? RankingExpression expression = new RankingExpression(reference.get().functionName(), property.value()); + ExpressionFunction function = new ExpressionFunction(reference.get().functionName(), + Collections.emptyList(), + expression); if (reference.get().isFree()) // make available in model under configured name - functions.add(new ExpressionFunction(reference.get().functionName(), arguments, expression)); // - - // Make all functions, bound or not available under the name they are referenced by in expressions - referencedFunctions.put(reference.get(), - new ExpressionFunction(reference.get().serialForm(), arguments, expression)); + functions.put(reference.get(), function); + // Make all functions, bound or not, available under the name they are referenced by in expressions + referencedFunctions.put(reference.get(), function); + } + else if (argumentType.isPresent()) { // Arguments always follows the function in properties + FunctionReference argReference = argumentType.get().getFirst(); + ExpressionFunction function = referencedFunctions.get(argReference); + function = function.withArgument(argumentType.get().getSecond(), TensorType.fromSpec(property.value())); + if (argReference.isFree()) + functions.put(argReference, function); + referencedFunctions.put(argReference, function); + } + else if (returnType.isPresent()) { // Return type always follows the function in properties + ExpressionFunction function = referencedFunctions.get(returnType.get()); + function = function.withReturnType(TensorType.fromSpec(property.value())); + if (returnType.get().isFree()) + functions.put(returnType.get(), function); + referencedFunctions.put(returnType.get(), function); } else if (property.name().equals("vespa.rank.firstphase")) { // Include in addition to functions firstPhase = new ExpressionFunction("firstphase", new ArrayList<>(), @@ -93,10 +114,10 @@ public class RankProfilesConfigImporter { smallConstantsInfo.addIfSmallConstantInfo(property.name(), property.value()); } } - if (functionByName("firstphase", functions) == null && firstPhase != null) // may be already included, depending on body - functions.add(firstPhase); - if (functionByName("secondphase", functions) == null && secondPhase != null) // may be already included, depending on body - functions.add(secondPhase); + if (functionByName("firstphase", functions.values()) == null && firstPhase != null) // may be already included, depending on body + functions.put(FunctionReference.fromName("firstphase"), firstPhase); + if (functionByName("secondphase", functions.values()) == null && secondPhase != null) // may be already included, depending on body + functions.put(FunctionReference.fromName("secondphase"), secondPhase); constants.addAll(smallConstantsInfo.asConstants()); @@ -108,7 +129,7 @@ public class RankProfilesConfigImporter { } } - private ExpressionFunction functionByName(String name, List<ExpressionFunction> functions) { + private ExpressionFunction functionByName(String name, Collection<ExpressionFunction> functions) { for (ExpressionFunction function : functions) if (function.getName().equals(name)) return function; diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 683a1f345d8..6edcd84272e 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -10,13 +10,16 @@ import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.slime.Cursor; import com.yahoo.slime.Slime; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.serialization.JsonFormat; +import com.yahoo.yolean.Exceptions; import java.io.IOException; import java.io.OutputStream; import java.net.URI; import java.nio.charset.Charset; import java.util.Arrays; +import java.util.Map; import java.util.Optional; import java.util.concurrent.Executor; @@ -60,14 +63,17 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { return listModelInformation(request, model, function); } catch (IllegalArgumentException e) { - return new ErrorResponse(404, e.getMessage()); + return new ErrorResponse(404, Exceptions.toMessageString(e)); + } catch (IllegalStateException e) { // On missing bindings + return new ErrorResponse(400, Exceptions.toMessageString(e)); } } private HttpResponse evaluateModel(HttpRequest request, Model model, String[] function) { FunctionEvaluator evaluator = model.evaluatorOf(function); - for (String bindingName : evaluator.context().names()) { - property(request, bindingName).ifPresent(s -> evaluator.bind(bindingName, Tensor.from(s))); + for (Map.Entry<String, TensorType> argument : evaluator.function().argumentTypes().entrySet()) { + property(request, argument.getKey()).ifPresent(value -> evaluator.bind(argument.getKey(), + Tensor.from(argument.getValue(), value))); } Tensor result = evaluator.evaluate(); return new Response(200, JsonFormat.encode(result)); diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java index 6e55c0c9a53..68c3b954675 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java @@ -1,8 +1,12 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import org.junit.Test; +import java.util.List; import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; @@ -25,9 +29,18 @@ public class MlModelsImportingTest { // TODO: When we get type information in Models, replace the evaluator.context().names() check below by that { Model xgboost = tester.models().get("xgboost_2_2"); + + // Function + assertEquals(1, xgboost.functions().size()); tester.assertFunction("xgboost_2_2", "(optimized sum of condition trees of size 192 bytes)", xgboost); + ExpressionFunction function = xgboost.functions().get(0); + assertEquals(TensorType.fromSpec("tensor()"), function.returnType().get()); + assertEquals("f109, f29, f56, f60", commaSeparated(function.arguments())); + function.arguments().forEach(arg -> assertEquals(TensorType.empty, function.argumentTypes().get(arg))); + + // Evaluator FunctionEvaluator evaluator = xgboost.evaluatorOf(); assertEquals("f109, f29, f56, f60", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); assertEquals(-8.17695, evaluator.evaluate().sum().asDouble(), delta); @@ -36,39 +49,80 @@ public class MlModelsImportingTest { { Model onnxMnistSoftmax = tester.models().get("mnist_softmax"); + + // Function + assertEquals(1, onnxMnistSoftmax.functions().size()); tester.assertFunction("default.add", "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))", onnxMnistSoftmax); + ExpressionFunction function = onnxMnistSoftmax.functions().get(0); + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + assertEquals(1, function.arguments().size()); + assertEquals("Placeholder", function.arguments().get(0)); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); + + // Evaluator assertEquals("tensor(d1[10],d2[784])", onnxMnistSoftmax.evaluatorOf("default.add").context().get("constant(mnist_softmax_Variable)").type().toString()); FunctionEvaluator evaluator = onnxMnistSoftmax.evaluatorOf(); // Verify exactly one output available - assertEquals("Placeholder, constant(mnist_softmax_Variable), constant(mnist_softmax_Variable_1)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + evaluator.bind("Placeholder", inputTensor()); assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); } { Model tfMnistSoftmax = tester.models().get("mnist_softmax_saved"); + + // Function + assertEquals(1, tfMnistSoftmax.functions().size()); tester.assertFunction("serving_default.y", "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))", tfMnistSoftmax); + ExpressionFunction function = tfMnistSoftmax.functions().get(0); + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + assertEquals(1, function.arguments().size()); + assertEquals("Placeholder", function.arguments().get(0)); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("Placeholder")); + + // Evaluator FunctionEvaluator evaluator = tfMnistSoftmax.evaluatorOf(); // Verify exactly one output available - assertEquals("Placeholder, constant(mnist_softmax_saved_layer_Variable_1_read), constant(mnist_softmax_saved_layer_Variable_read)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + evaluator.bind("Placeholder", inputTensor()); assertEquals(-1.6372650861740112E-6, evaluator.evaluate().sum().asDouble(), delta); } { Model tfMnist = tester.models().get("mnist_saved"); - tester.assertFunction("serving_default.y", - "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", - tfMnist); - // Macro: - tester.assertFunction("imported_ml_macro_mnist_saved_dnn_hidden1_add", + // Generated function + tester.assertFunction("imported_ml_function_mnist_saved_dnn_hidden1_add", "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))", tfMnist); - FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); // TODO: Macro is offered as an alternative output currently, so need to specify argument - assertEquals("constant(mnist_saved_dnn_hidden1_bias_read), constant(mnist_saved_dnn_hidden1_weights_read), constant(mnist_saved_dnn_hidden2_bias_read), constant(mnist_saved_dnn_hidden2_weights_read), constant(mnist_saved_dnn_outputs_bias_read), constant(mnist_saved_dnn_outputs_weights_read), input, rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add)", evaluator.context().names().stream().sorted().collect(Collectors.joining(", "))); + + // Function + assertEquals(2, tfMnist.functions().size()); // TODO: Filter out generated function + tester.assertFunction("serving_default.y", + "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))", + tfMnist); + ExpressionFunction function = tfMnist.functions().get(1); + assertEquals(TensorType.fromSpec("tensor(d1[10])"), function.returnType().get()); + assertEquals(1, function.arguments().size()); + assertEquals("input", function.arguments().get(0)); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), function.argumentTypes().get("input")); + + // Evaluator + FunctionEvaluator evaluator = tfMnist.evaluatorOf("serving_default"); + evaluator.bind("input", inputTensor()); assertEquals(-0.714629131972222, evaluator.evaluate().sum().asDouble(), delta); } } + private Tensor inputTensor() { + Tensor.Builder b = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[],d1[784])")); + for (int i = 0; i < 784; i++) + b.cell(0.0, 0, i); + return b.build(); + } + + private String commaSeparated(List<?> items) { + return items.stream().map(item -> item.toString()).sorted().collect(Collectors.joining(", ")); + } + } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java index 9a3e59aed80..50dd1d1d05f 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelTester.java @@ -60,7 +60,6 @@ public class ModelTester { public void assertBoundFunction(String name, String expression, Model model) { ExpressionFunction function = model.referencedFunctions().get(FunctionReference.fromSerial(name).get()); assertNotNull("Function '" + name + "' is present", function); - assertEquals(name, function.getName()); assertEquals(expression, function.getBody().getRoot().toString()); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index bd1ff6b8ed7..8824be05006 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -5,11 +5,19 @@ import com.yahoo.config.subscription.ConfigGetter; import com.yahoo.config.subscription.FileSource; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; import com.yahoo.path.Path; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; +import com.yahoo.yolean.Exceptions; import org.junit.Test; +import java.util.ArrayList; +import java.util.List; + import static org.junit.Assert.assertEquals; /** @@ -20,16 +28,7 @@ public class ModelsEvaluatorTest { private static final double delta = 0.00000000001; @Test - public void testTensorEvaluation() { - ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); - FunctionEvaluator function = models.evaluatorOf("macros", "fourtimessum"); - function.bind("var1", Tensor.from("{{x:0}:3,{x:1}:5}")); - function.bind("var2", Tensor.from("{{x:0}:7,{x:1}:11}")); - assertEquals(Tensor.from("{{x:0}:40.0,{x:1}:64.0}"), function.evaluate()); - } - - @Test - public void testEvaluationDependingOnMacroTakingArguments() { + public void testEvaluationDependingFunctionTakingArguments() { ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); function.bind("match", 3); @@ -37,10 +36,70 @@ public class ModelsEvaluatorTest { assertEquals(32.0, function.evaluate().asDouble(), delta); } + @Test + public void testBindingValidation() { + List<ExpressionFunction> functions = new ArrayList<>(); + ExpressionFunction function = new ExpressionFunction("test", RankingExpression.from("sum(arg1 * arg2)")); + function = function.withArgument("arg1", TensorType.fromSpec("tensor(d0[1])")); + function = function.withArgument("arg2", TensorType.fromSpec("tensor(d1{})")); + functions.add(function); + Model model = new Model("test-model", functions); + + try { // No bindings + FunctionEvaluator evaluator = model.evaluatorOf("test"); + evaluator.evaluate(); + } + catch (IllegalStateException e) { + assertEquals("Missing argument 'arg2': Must be bound to a value of type tensor(d1{})", + Exceptions.toMessageString(e)); + } + + try { // Just one binding + FunctionEvaluator evaluator = model.evaluatorOf("test"); + evaluator.bind("arg2", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}")); + evaluator.evaluate(); + } + catch (IllegalStateException e) { + assertEquals("Missing argument 'arg1': Must be bound to a value of type tensor(d0[1])", + Exceptions.toMessageString(e)); + } + + try { // Wrong binding argument + FunctionEvaluator evaluator = model.evaluatorOf("test"); + evaluator.bind("argNone", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}")); + evaluator.evaluate(); + } + catch (IllegalArgumentException e) { + assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])", + Exceptions.toMessageString(e)); + } + + try { // Wrong binding type + FunctionEvaluator evaluator = model.evaluatorOf("test"); + evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d3{})"), "{{d3:foo}:0.1}")); + evaluator.evaluate(); + } + catch (IllegalArgumentException e) { + assertEquals("'arg1' must be of type tensor(d0[1]), not tensor(d3{})", + Exceptions.toMessageString(e)); + } + + try { // Attempt to reuse evaluator + FunctionEvaluator evaluator = model.evaluatorOf("test"); + evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}")); + evaluator.bind("arg2", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}")); + evaluator.evaluate(); + evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}")); + } + catch (IllegalStateException e) { + assertEquals("Cannot bind a new value in a used evaluator", + Exceptions.toMessageString(e)); + } + + } + // TODO: Test argument-less function - // TODO: Test that binding nonexisting variable doesn't work - // TODO: Test that rebinding doesn't work - // TODO: Test with nested macros + // TODO: Test with nested functions private ModelsEvaluator createModels(String path) { Path configDir = Path.fromString(path); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 6726f117c05..b915ee72a79 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -9,6 +9,8 @@ import com.yahoo.container.jdisc.HttpRequest; import com.yahoo.container.jdisc.HttpResponse; import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; import com.yahoo.path.Path; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; import com.yahoo.vespa.config.search.core.RankingConstantsConfig; import org.junit.BeforeClass; @@ -94,39 +96,39 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; - assertResponse(url, 200, expected); + String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}"; + assertResponse(url, 400, expected); } @Test public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; - assertResponse(url, 200, expected); + String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}"; + assertResponse(url, 400, expected); } @Test public void testMnistSoftmaxEvaluateDefaultFunctionWithBindings() { Map<String, String> properties = new HashMap<>(); - properties.put("Placeholder", "{1.0}"); + properties.put("Placeholder", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; assertResponse(url, properties, 200, expected); } @Test public void testMnistSoftmaxEvaluateSpecificFunctionWithBindings() { Map<String, String> properties = new HashMap<>(); - properties.put("Placeholder", "{1.0}"); + properties.put("Placeholder", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":2.7147769462592217},{\"address\":{\"d1\":\"1\"},\"value\":-19.710327346521872},{\"address\":{\"d1\":\"2\"},\"value\":9.496512226053643},{\"address\":{\"d1\":\"3\"},\"value\":13.11241075176957},{\"address\":{\"d1\":\"4\"},\"value\":-12.355567088005559},{\"address\":{\"d1\":\"5\"},\"value\":10.39812446509341},{\"address\":{\"d1\":\"6\"},\"value\":-1.3739236534397499},{\"address\":{\"d1\":\"7\"},\"value\":-3.4260787871386995},{\"address\":{\"d1\":\"8\"},\"value\":6.471120687192041},{\"address\":{\"d1\":\"9\"},\"value\":-5.327024804970982}]}"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; assertResponse(url, properties, 200, expected); } @Test public void testMnistSavedDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; - String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_macro_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}"; + String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"imported_ml_function_mnist_saved_dnn_hidden1_add\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/imported_ml_function_mnist_saved_dnn_hidden1_add/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]},{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"bindings\":[{\"binding\":\"input\",\"type\":\"\"}]}]}"; assertResponse(url, 200, expected); } @@ -140,16 +142,16 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSavedEvaluateDefaultFunctionShouldFail() { String url = "http://localhost/model-evaluation/v1/mnist_saved/eval"; - String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_macro_mnist_saved_dnn_hidden1_add, serving_default.y\"}"; + String expected = "{\"error\":\"More than one function is available in model 'mnist_saved', but no name is given. Available functions: imported_ml_function_mnist_saved_dnn_hidden1_add, serving_default.y\"}"; assertResponse(url, 404, expected); } @Test public void testMnistSavedEvaluateSpecificFunction() { Map<String, String> properties = new HashMap<>(); - properties.put("input", "-1.0"); + properties.put("input", inputTensor()); String url = "http://localhost/model-evaluation/v1/mnist_saved/serving_default.y/eval"; - String expected = "{\"cells\":[{\"address\":{\"d1\":\"0\"},\"value\":-2.72208123403445},{\"address\":{\"d1\":\"1\"},\"value\":6.465137496457595},{\"address\":{\"d1\":\"2\"},\"value\":-7.078050386283122},{\"address\":{\"d1\":\"3\"},\"value\":-10.485296462655546},{\"address\":{\"d1\":\"4\"},\"value\":0.19508378636937004},{\"address\":{\"d1\":\"5\"},\"value\":6.348870746681019},{\"address\":{\"d1\":\"6\"},\"value\":10.756191852397258},{\"address\":{\"d1\":\"7\"},\"value\":1.476101533270058},{\"address\":{\"d1\":\"8\"},\"value\":-17.778398655804875},{\"address\":{\"d1\":\"9\"},\"value\":-2.0597690508530295}]}"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.6319251673007533},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":-7.577770600619843E-4},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":-0.010707969042025622},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.6344759233540788},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":-0.17529455385847528},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":0.7490809723192187},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.022790284182901716},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.26799240657608936},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-0.3152438845465862},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":0.05949304847735276}]}"; assertResponse(url, properties, 200, expected); } @@ -171,10 +173,10 @@ public class ModelsEvaluationHandlerTest { static private void assertResponse(HttpRequest request, int expectedCode, String expectedResult) { HttpResponse response = handler.handle(request); assertEquals("application/json", response.getContentType()); - assertEquals(expectedCode, response.getStatus()); if (expectedResult != null) { assertEquals(expectedResult, getContents(response)); } + assertEquals(expectedCode, response.getStatus()); } static private String getContents(HttpResponse response) { @@ -198,4 +200,11 @@ public class ModelsEvaluationHandlerTest { return new ModelsEvaluator(importer.importFrom(config, constantsConfig)); } + private String inputTensor() { + Tensor.Builder b = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[],d1[784])")); + for (int i = 0; i < 784; i++) + b.cell(0.0, 0, i); + return b.build().toString(); + } + } diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg index 1cc36f75158..c25c5ba555b 100644 --- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg +++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg @@ -1,14 +1,28 @@ -rankprofile[0].name "mnist_saved" -rankprofile[0].fef.property[0].name "rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add).rankingScript" -rankprofile[0].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))" -rankprofile[0].fef.property[1].name "rankingExpression(serving_default.y).rankingScript" -rankprofile[0].fef.property[1].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_macro_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" +rankprofile[0].name "mnist_softmax" +rankprofile[0].fef.property[0].name "rankingExpression(default.add).rankingScript" +rankprofile[0].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))" +rankprofile[0].fef.property[1].name "rankingExpression(default.add).Placeholder.type" +rankprofile[0].fef.property[1].value "tensor(d0[],d1[784])" +rankprofile[0].fef.property[2].name "rankingExpression(default.add).type" +rankprofile[0].fef.property[2].value "tensor(d1[10])" rankprofile[1].name "xgboost_2_2" rankprofile[1].fef.property[0].name "rankingExpression(xgboost_2_2).rankingScript" rankprofile[1].fef.property[0].value "if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)" rankprofile[2].name "mnist_softmax_saved" rankprofile[2].fef.property[0].name "rankingExpression(serving_default.y).rankingScript" rankprofile[2].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))" -rankprofile[3].name "mnist_softmax" -rankprofile[3].fef.property[0].name "rankingExpression(default.add).rankingScript" -rankprofile[3].fef.property[0].value "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(mnist_softmax_Variable), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_Variable_1), f(a,b)(a + b))" +rankprofile[2].fef.property[1].name "rankingExpression(serving_default.y).Placeholder.type" +rankprofile[2].fef.property[1].value "tensor(d0[],d1[784])" +rankprofile[2].fef.property[2].name "rankingExpression(serving_default.y).type" +rankprofile[2].fef.property[2].value "tensor(d1[10])" +rankprofile[3].name "mnist_saved" +rankprofile[3].fef.property[0].name "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript" +rankprofile[3].fef.property[0].value "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))" +rankprofile[3].fef.property[1].name "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).type" +rankprofile[3].fef.property[1].value "tensor(d3[300])" +rankprofile[3].fef.property[2].name "rankingExpression(serving_default.y).rankingScript" +rankprofile[3].fef.property[2].value "join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.050701 * if (a >= 0, a, 1.673263 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))" +rankprofile[3].fef.property[3].name "rankingExpression(serving_default.y).input.type" +rankprofile[3].fef.property[3].value "tensor(d0[],d1[784])" +rankprofile[3].fef.property[4].name "rankingExpression(serving_default.y).type" +rankprofile[3].fef.property[4].value "tensor(d1[10])" diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index f6502a9801d..787b857839d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -11,6 +11,7 @@ import com.yahoo.text.Utf8; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.HashMap; @@ -97,7 +98,12 @@ public class ExpressionFunction { return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType)); } - public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) { + /** Returns a copy of this with the given argument and argument type added */ + public ExpressionFunction withArgument(String argument, TensorType type) { + List<String> arguments = new ArrayList<>(this.arguments); + arguments.add(argument); + Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes); + argumentTypes.put(argument, type); return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 17157ab385f..8aa7446cae7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -9,7 +9,6 @@ import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; * In a boolean context doubles are true if they are different from 0.0 * * @author bratseth - * @since 5.1.5 */ public final class DoubleValue extends DoubleCompatibleValue { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 9ff391a5cfe..f26f2dea04f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -121,7 +121,7 @@ public class ImportedModel { if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs expressions.add(new Pair<>(signatureEntry.getKey(), new ExpressionFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().keySet()), + new ArrayList<>(signatureEntry.getValue().inputs().values()), expressions().get(signatureEntry.getKey()), signatureEntry.getValue().inputMap(), Optional.empty()))); @@ -182,8 +182,11 @@ public class ImportedModel { /** Returns the name and type of all inputs in this signature as an immutable map */ public Map<String, TensorType> inputMap() { ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); + // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* + // in the model, as these are the names which must actually be bound, if we are to avoid creating an + // "input mapping" to accomodate this complexity in for (Map.Entry<String, String> inputEntry : inputs().entrySet()) - inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue())); + inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue())); return inputs.build(); } @@ -207,7 +210,7 @@ public class ImportedModel { /** Returns the expression this output references */ public ExpressionFunction outputExpression(String outputName) { return new ExpressionFunction(outputName, - new ArrayList<>(inputs.keySet()), + new ArrayList<>(inputs.values()), owner().expressions().get(outputs.get(outputName)), inputMap(), Optional.empty()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index 593e7b54c10..e325c3d11b4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -15,17 +15,18 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", + "src/test/files/integration/tensorflow/batch_norm/saved"); ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName()); - model.assertEqualResult("X", output.getBody().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + ExpressionFunction function = signature.outputExpression("y"); + assertNotNull(function); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", function.getBody().getName()); + model.assertEqualResult("X", function.getBody().getName()); + assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 59712c0152f..8ca5a9a7888 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -30,13 +30,13 @@ public class DropoutImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("outputs/Maximum", output.getBody().getName()); + ExpressionFunction function = signature.outputExpression("y"); + assertNotNull(function); + assertEquals("outputs/Maximum", function.getBody().getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - output.getBody().getRoot().toString()); - model.assertEqualResult("X", output.getBody().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + function.getBody().getRoot().toString()); + model.assertEqualResult("X", function.getBody().getName()); + assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index 0a48ecfce21..feba40601e3 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -62,7 +62,7 @@ public class TensorFlowMnistSoftmaxImportTestCase { assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", output.getBody().getRoot().toString()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); // Test execution model.assertEqualResult("Placeholder", "MatMul"); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index a1334dc729c..000f33696f2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -27,7 +27,7 @@ class TensorParser { else { if (type.isPresent() && ! type.get().equals(TensorType.empty)) throw new IllegalArgumentException("Got zero-dimensional tensor '" + tensorString + - "but type is not empty but " + type.get()); + "' where type " + type.get() + " is required"); return Tensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build(); } } |