diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-11 18:44:05 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-11 18:44:05 +0200 |
commit | e9c7a4fcd3e8902eace31366bcda26bf49e458a2 (patch) | |
tree | a94397d5c23d28e868edb05ce0b3dd2267ac7a6b /model-evaluation/src | |
parent | 77259b9db997e703a450f5989b6e294aa841dca6 (diff) |
Default value may be of any type
Diffstat (limited to 'model-evaluation/src')
4 files changed, 27 insertions, 37 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 9db26d7ecd8..caa2db13ff2 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -1,9 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import com.google.common.annotations.Beta; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -16,7 +14,6 @@ import java.util.stream.Collectors; * * @author bratseth */ -@Beta // This wraps all access to the context and the ranking expression to avoid incorrect usage public class FunctionEvaluator { @@ -65,10 +62,15 @@ public class FunctionEvaluator { public Tensor evaluate() { for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - if (argument.getValue().rank() == 0) continue; // Scalar arguments can be skipped (defaults to 0) + System.out.println("Checking " + argument.getKey() + " default " + context.defaultValue() + " is assignable to " + argument.getValue() + + "? " + context.defaultValue().type().isAssignableTo(argument.getValue())); if (context.isMissing(argument.getKey())) throw new IllegalStateException("Missing argument '" + argument.getKey() + "': Must be bound to a value of type " + argument.getValue()); + if (! context.get(argument.getKey()).type().isAssignableTo(argument.getValue())) + throw new IllegalStateException("Argument '" + argument.getKey() + + "' must be bound to a value of type " + argument.getValue()); + } evaluated = true; return function.getBody().evaluate(context).asTensor(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 9045e335167..51daf278a4a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -16,7 +16,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; import java.util.Arrays; -import java.util.BitSet; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; @@ -42,9 +41,9 @@ public final class LazyArrayContext extends Context implements ContextIndex { Map<FunctionReference, ExpressionFunction> referencedFunctions, List<Constant> constants, Model model, - Value missingValue) { + Value defaultFeatureValue) { this.function = function; - this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model, missingValue); + this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model, defaultFeatureValue); } /** @@ -121,8 +120,12 @@ public final class LazyArrayContext extends Context implements ContextIndex { } boolean isMissing(String name) { - Integer index = indexedBindings.indexOf(name); - return index == null || indexedBindings.isMissing(index); + return indexedBindings.indexOf(name) == null; + } + + /** Returns the value which should be used when no value is set */ + public Value defaultValue() { + return indexedBindings.defaultValue; } /** @@ -138,28 +141,23 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The mapping from variable name to index */ private final ImmutableMap<String, Integer> nameToIndex; - /** The names which needs to be bound externally when envoking this (i.e not constant or invocation */ + /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */ private final ImmutableSet<String> arguments; /** The current values set, pre-converted to doubles */ private final Value[] values; - /** The values that actually have been set */ - private final BitSet setValues; - /** The value to return if not set */ - private final Value missingValue; + private final Value defaultValue; private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values, ImmutableSet<String> arguments, - BitSet setValues, - Value missingValue) { + Value defaultValue) { this.nameToIndex = nameToIndex; this.values = values; this.arguments = arguments; - this.setValues = setValues; - this.missingValue = missingValue.freeze(); + this.defaultValue = defaultValue.freeze(); } /** @@ -171,17 +169,16 @@ public final class LazyArrayContext extends Context implements ContextIndex { List<Constant> constants, LazyArrayContext owner, Model model, - Value missingValue) { + Value defaultFeatureValue) { // 1. Determine and prepare bind targets Set<String> bindTargets = new LinkedHashSet<>(); Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments); this.arguments = ImmutableSet.copyOf(arguments); - this.missingValue = missingValue.freeze(); + this.defaultValue = defaultFeatureValue.freeze(); values = new Value[bindTargets.size()]; - Arrays.fill(values, this.missingValue); - setValues = new BitSet(bindTargets.size()); + Arrays.fill(values, this.defaultValue); int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); @@ -195,7 +192,6 @@ public final class LazyArrayContext extends Context implements ContextIndex { Integer index = nameToIndex.get(constantReference); if (index != null) { values[index] = new TensorValue(constant.value()); - setValues.set(index); } } @@ -203,7 +199,6 @@ public final class LazyArrayContext extends Context implements ContextIndex { Integer index = nameToIndex.get(referencedFunction.getKey().serialForm()); if (index != null) { // Referenced in this, so bind it values[index] = new LazyValue(referencedFunction.getKey(), owner, model); - setValues.set(index); } } } @@ -249,24 +244,17 @@ public final class LazyArrayContext extends Context implements ContextIndex { Value get(int index) { return values[index]; } void set(int index, Value value) { values[index] = value; - setValues.set(index); } Set<String> names() { return nameToIndex.keySet(); } Set<String> arguments() { return arguments; } Integer indexOf(String name) { return nameToIndex.get(name); } - boolean isMissing(int index) { return ! setValues.get(index); } IndexedBindings copy(Context context) { Value[] valueCopy = new Value[values.length]; - BitSet setValuesCopy = new BitSet(values.length); - for (int i = 0; i < values.length; i++) { + for (int i = 0; i < values.length; i++) valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i]; - if (setValues.get(i)) { - setValuesCopy.set(i); - } - } - return new IndexedBindings(nameToIndex, valueCopy, arguments, setValuesCopy, missingValue); + return new IndexedBindings(nameToIndex, valueCopy, arguments, defaultValue); } } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index 8824be05006..e8620670dd6 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -50,7 +50,7 @@ public class ModelsEvaluatorTest { evaluator.evaluate(); } catch (IllegalStateException e) { - assertEquals("Missing argument 'arg2': Must be bound to a value of type tensor(d1{})", + assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})", Exceptions.toMessageString(e)); } @@ -60,7 +60,7 @@ public class ModelsEvaluatorTest { evaluator.evaluate(); } catch (IllegalStateException e) { - assertEquals("Missing argument 'arg1': Must be bound to a value of type tensor(d0[1])", + assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])", Exceptions.toMessageString(e)); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 95f9888024a..5d30fc93a4c 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -97,14 +97,14 @@ public class ModelsEvaluationHandlerTest { @Test public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval"; - String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}"; + String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}"; assertResponse(url, 400, expected); } @Test public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() { String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; - String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}"; + String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}"; assertResponse(url, 400, expected); } |