aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-10-23 16:28:51 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-10-23 16:28:51 +0200
commit68249a1c7541c9c5e1b1e43afacf6118f2d22689 (patch)
treeb0627fab0557be8020d64f6e8fb7f5cfe843c6b6 /model-evaluation
parent51d8e8546649327cef3b892090c1631a603a4949 (diff)
Add dynamically settable default
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java25
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java35
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java26
3 files changed, 74 insertions, 12 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index caa2db13ff2..c0f96dfb161 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -60,10 +60,31 @@ public class FunctionEvaluator {
return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build());
}
+ /**
+ * Sets the default value to use for variables which are not bound
+ *
+ * @param value the default value
+ * @return this for chaining
+ */
+ public FunctionEvaluator setUnboundValue(Tensor value) {
+ if (evaluated)
+ throw new IllegalStateException("Cannot change the unbound value in a used evaluator");
+ context.setUnboundValue(value);
+ return this;
+ }
+
+ /**
+ * Sets the default value to use for variables which are not bound
+ *
+ * @param value the default value
+ * @return this for chaining
+ */
+ public FunctionEvaluator setUnboundValue(double value) {
+ return setUnboundValue(Tensor.Builder.of(TensorType.empty).cell(value).build());
+ }
+
public Tensor evaluate() {
for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- System.out.println("Checking " + argument.getKey() + " default " + context.defaultValue() + " is assignable to " + argument.getValue() +
- "? " + context.defaultValue().type().isAssignableTo(argument.getValue()));
if (context.isMissing(argument.getKey()))
throw new IllegalStateException("Missing argument '" + argument.getKey() +
"': Must be bound to a value of type " + argument.getValue());
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 51daf278a4a..84606146f79 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
@@ -47,6 +48,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
/**
+ * Sets the value to use for lookups to existing values which are not set in this context.
+ * The default value that will be returned is NaN
+ */
+ public void setUnboundValue(Tensor value) {
+ indexedBindings.setUnboundValue(value);
+ }
+
+ /**
* Puts a value by name.
* The value will be frozen if it isn't already.
*
@@ -90,10 +99,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
@Override
public double getDouble(int index) {
- double value = get(index).asDouble();
- if (value == Double.NaN)
- throw new UnsupportedOperationException("Value at " + index + " has no double representation");
- return value;
+ return get(index).asDouble();
}
@Override
@@ -144,11 +150,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The names which needs to be bound externally when invoking this (i.e not constant or invocation */
private final ImmutableSet<String> arguments;
- /** The current values set, pre-converted to doubles */
+ /** The current values set */
private final Value[] values;
- /** The value to return if not set */
- private final Value defaultValue;
+ /** The object instance which encodes "no value is set". The actual value of this is never used. */
+ private static final Value missing = new DoubleValue(Double.NaN).freeze();
+
+ /** The value to return for lookups where no value is set */
+ private Value defaultValue;
private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
Value[] values,
@@ -178,7 +187,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
this.arguments = ImmutableSet.copyOf(arguments);
this.defaultValue = defaultFeatureValue.freeze();
values = new Value[bindTargets.size()];
- Arrays.fill(values, this.defaultValue);
+ Arrays.fill(values, missing);
int i = 0;
ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
@@ -203,6 +212,10 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
}
+ private void setUnboundValue(Tensor value) {
+ defaultValue = new TensorValue(value).freeze();
+ }
+
private void extractBindTargets(ExpressionNode node,
Map<FunctionReference, ExpressionFunction> functions,
Set<String> bindTargets,
@@ -241,7 +254,11 @@ public final class LazyArrayContext extends Context implements ContextIndex {
return reference.getName().equals("constant") && reference.getArguments().size() == 1;
}
- Value get(int index) { return values[index]; }
+ Value get(int index) {
+ Value value = values[index];
+ return value == missing ? defaultValue : value;
+ }
+
void set(int index, Value value) {
values[index] = value;
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index e8620670dd6..4e31dd89e6a 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -7,7 +7,6 @@ import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer;
import com.yahoo.path.Path;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.vespa.config.search.RankProfilesConfig;
@@ -19,6 +18,7 @@ import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
/**
* @author bratseth
@@ -36,6 +36,30 @@ public class ModelsEvaluatorTest {
assertEquals(32.0, function.evaluate().asDouble(), delta);
}
+ /** Tests a function defined as 4 * (var1 + var2) */
+ @Test
+ public void testSettingDefaultVariableValue() {
+ ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/");
+
+ {
+ FunctionEvaluator function = models.evaluatorOf("macros", "secondphase");
+ assertTrue(Double.isNaN(function.evaluate().asDouble()));
+ }
+
+ {
+ FunctionEvaluator function = models.evaluatorOf("macros", "secondphase");
+ function.setUnboundValue(5);
+ assertEquals(40.0, function.evaluate().asDouble(), delta);
+ }
+
+ {
+ FunctionEvaluator function = models.evaluatorOf("macros", "secondphase");
+ function.setUnboundValue(5);
+ function.bind("match", 3);
+ assertEquals(32.0, function.evaluate().asDouble(), delta);
+ }
+ }
+
@Test
public void testBindingValidation() {
List<ExpressionFunction> functions = new ArrayList<>();