aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-10-11 18:44:05 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-10-11 18:44:05 +0200
commite9c7a4fcd3e8902eace31366bcda26bf49e458a2 (patch)
treea94397d5c23d28e868edb05ce0b3dd2267ac7a6b /model-evaluation
parent77259b9db997e703a450f5989b6e294aa841dca6 (diff)
Default value may be of any type
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java10
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java46
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java4
4 files changed, 27 insertions, 37 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 9db26d7ecd8..caa2db13ff2 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -1,9 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -16,7 +14,6 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-@Beta
// This wraps all access to the context and the ranking expression to avoid incorrect usage
public class FunctionEvaluator {
@@ -65,10 +62,15 @@ public class FunctionEvaluator {
public Tensor evaluate() {
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- if (argument.getValue().rank() == 0) continue; // Scalar arguments can be skipped (defaults to 0)
+ System.out.println("Checking " + argument.getKey() + " default " + context.defaultValue() + " is assignable to " + argument.getValue() +
+ "? " + context.defaultValue().type().isAssignableTo(argument.getValue()));
if (context.isMissing(argument.getKey()))
throw new IllegalStateException("Missing argument '" + argument.getKey() +
"': Must be bound to a value of type " + argument.getValue());
+ if (! context.get(argument.getKey()).type().isAssignableTo(argument.getValue()))
+ throw new IllegalStateException("Argument '" + argument.getKey() +
+ "' must be bound to a value of type " + argument.getValue());
+
}
evaluated = true;
return function.getBody().evaluate(context).asTensor();
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 9045e335167..51daf278a4a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -16,7 +16,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
-import java.util.BitSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
@@ -42,9 +41,9 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Map<FunctionReference, ExpressionFunction> referencedFunctions,
List<Constant> constants,
Model model,
- Value missingValue) {
+ Value defaultFeatureValue) {
this.function = function;
- this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model, missingValue);
+ this.indexedBindings = new IndexedBindings(function, referencedFunctions, constants, this, model, defaultFeatureValue);
}
/**
@@ -121,8 +120,12 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
boolean isMissing(String name) {
- Integer index = indexedBindings.indexOf(name);
- return index == null || indexedBindings.isMissing(index);
+ return indexedBindings.indexOf(name) == null;
+ }
+
+ /** Returns the value which should be used when no value is set */
+ public Value defaultValue() {
+ return indexedBindings.defaultValue;
}
/**
@@ -138,28 +141,23 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The mapping from variable name to index */
private final ImmutableMap<String, Integer> nameToIndex;
- /** The names which needs to be bound externally when envoking this (i.e not constant or invocation */
+ /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */
private final ImmutableSet<String> arguments;
/** The current values set, pre-converted to doubles */
private final Value[] values;
- /** The values that actually have been set */
- private final BitSet setValues;
-
/** The value to return if not set */
- private final Value missingValue;
+ private final Value defaultValue;
private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
Value[] values,
ImmutableSet<String> arguments,
- BitSet setValues,
- Value missingValue) {
+ Value defaultValue) {
this.nameToIndex = nameToIndex;
this.values = values;
this.arguments = arguments;
- this.setValues = setValues;
- this.missingValue = missingValue.freeze();
+ this.defaultValue = defaultValue.freeze();
}
/**
@@ -171,17 +169,16 @@ public final class LazyArrayContext extends Context implements ContextIndex {
List<Constant> constants,
LazyArrayContext owner,
Model model,
- Value missingValue) {
+ Value defaultFeatureValue) {
// 1. Determine and prepare bind targets
Set<String> bindTargets = new LinkedHashSet<>();
Set<String> arguments = new LinkedHashSet<>(); // Arguments: Bind targets which need to be bound before invocation
extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments);
this.arguments = ImmutableSet.copyOf(arguments);
- this.missingValue = missingValue.freeze();
+ this.defaultValue = defaultFeatureValue.freeze();
values = new Value[bindTargets.size()];
- Arrays.fill(values, this.missingValue);
- setValues = new BitSet(bindTargets.size());
+ Arrays.fill(values, this.defaultValue);
int i = 0;
ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
@@ -195,7 +192,6 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Integer index = nameToIndex.get(constantReference);
if (index != null) {
values[index] = new TensorValue(constant.value());
- setValues.set(index);
}
}
@@ -203,7 +199,6 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Integer index = nameToIndex.get(referencedFunction.getKey().serialForm());
if (index != null) { // Referenced in this, so bind it
values[index] = new LazyValue(referencedFunction.getKey(), owner, model);
- setValues.set(index);
}
}
}
@@ -249,24 +244,17 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Value get(int index) { return values[index]; }
void set(int index, Value value) {
values[index] = value;
- setValues.set(index);
}
Set<String> names() { return nameToIndex.keySet(); }
Set<String> arguments() { return arguments; }
Integer indexOf(String name) { return nameToIndex.get(name); }
- boolean isMissing(int index) { return ! setValues.get(index); }
IndexedBindings copy(Context context) {
Value[] valueCopy = new Value[values.length];
- BitSet setValuesCopy = new BitSet(values.length);
- for (int i = 0; i < values.length; i++) {
+ for (int i = 0; i < values.length; i++)
valueCopy[i] = values[i] instanceof LazyValue ? ((LazyValue) values[i]).copyFor(context) : values[i];
- if (setValues.get(i)) {
- setValuesCopy.set(i);
- }
- }
- return new IndexedBindings(nameToIndex, valueCopy, arguments, setValuesCopy, missingValue);
+ return new IndexedBindings(nameToIndex, valueCopy, arguments, defaultValue);
}
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index 8824be05006..e8620670dd6 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -50,7 +50,7 @@ public class ModelsEvaluatorTest {
evaluator.evaluate();
}
catch (IllegalStateException e) {
- assertEquals("Missing argument 'arg2': Must be bound to a value of type tensor(d1{})",
+ assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})",
Exceptions.toMessageString(e));
}
@@ -60,7 +60,7 @@ public class ModelsEvaluatorTest {
evaluator.evaluate();
}
catch (IllegalStateException e) {
- assertEquals("Missing argument 'arg1': Must be bound to a value of type tensor(d0[1])",
+ assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])",
Exceptions.toMessageString(e));
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 95f9888024a..5d30fc93a4c 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -97,14 +97,14 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testMnistSoftmaxEvaluateDefaultFunctionWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/eval";
- String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}";
+ String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}";
assertResponse(url, 400, expected);
}
@Test
public void testMnistSoftmaxEvaluateSpecificFunctionWithoutBindings() {
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"error\":\"Missing argument 'Placeholder': Must be bound to a value of type tensor(d0[],d1[784])\"}";
+ String expected = "{\"error\":\"Argument 'Placeholder' must be bound to a value of type tensor(d0[],d1[784])\"}";
assertResponse(url, 400, expected);
}