diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-23 16:28:51 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-10-23 16:28:51 +0200 |
commit | 68249a1c7541c9c5e1b1e43afacf6118f2d22689 (patch) | |
tree | b0627fab0557be8020d64f6e8fb7f5cfe843c6b6 /model-evaluation | |
parent | 51d8e8546649327cef3b892090c1631a603a4949 (diff) |
Add dynamically settable default
Diffstat (limited to 'model-evaluation')
3 files changed, 74 insertions, 12 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index caa2db13ff2..c0f96dfb161 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -60,10 +60,31 @@ public class FunctionEvaluator { return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build()); } + /** + * Sets the default value to use for variables which are not bound + * + * @param value the default value + * @return this for chaining + */ + public FunctionEvaluator setUnboundValue(Tensor value) { + if (evaluated) + throw new IllegalStateException("Cannot change the unbound value in a used evaluator"); + context.setUnboundValue(value); + return this; + } + + /** + * Sets the default value to use for variables which are not bound + * + * @param value the default value + * @return this for chaining + */ + public FunctionEvaluator setUnboundValue(double value) { + return setUnboundValue(Tensor.Builder.of(TensorType.empty).cell(value).build()); + } + public Tensor evaluate() { for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - System.out.println("Checking " + argument.getKey() + " default " + context.defaultValue() + " is assignable to " + argument.getValue() + - "? " + context.defaultValue().type().isAssignableTo(argument.getValue())); if (context.isMissing(argument.getKey())) throw new IllegalStateException("Missing argument '" + argument.getKey() + "': Must be bound to a value of type " + argument.getValue()); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 51daf278a4a..84606146f79 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -13,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -47,6 +48,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { } /** + * Sets the value to use for lookups to existing values which are not set in this context. + * The default value that will be returned is NaN + */ + public void setUnboundValue(Tensor value) { + indexedBindings.setUnboundValue(value); + } + + /** * Puts a value by name. * The value will be frozen if it isn't already. * @@ -90,10 +99,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { @Override public double getDouble(int index) { - double value = get(index).asDouble(); - if (value == Double.NaN) - throw new UnsupportedOperationException("Value at " + index + " has no double representation"); - return value; + return get(index).asDouble(); } @Override @@ -144,11 +150,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */ private final ImmutableSet<String> arguments; - /** The current values set, pre-converted to doubles */ + /** The current values set */ private final Value[] values; - /** The value to return if not set */ - private final Value defaultValue; + /** The object instance which encodes "no value is set". The actual value of this is never used. */ + private static final Value missing = new DoubleValue(Double.NaN).freeze(); + + /** The value to return for lookups where no value is set */ + private Value defaultValue; private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values, @@ -178,7 +187,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { this.arguments = ImmutableSet.copyOf(arguments); this.defaultValue = defaultFeatureValue.freeze(); values = new Value[bindTargets.size()]; - Arrays.fill(values, this.defaultValue); + Arrays.fill(values, missing); int i = 0; ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); @@ -203,6 +212,10 @@ public final class LazyArrayContext extends Context implements ContextIndex { } } + private void setUnboundValue(Tensor value) { + defaultValue = new TensorValue(value).freeze(); + } + private void extractBindTargets(ExpressionNode node, Map<FunctionReference, ExpressionFunction> functions, Set<String> bindTargets, @@ -241,7 +254,11 @@ public final class LazyArrayContext extends Context implements ContextIndex { return reference.getName().equals("constant") && reference.getArguments().size() == 1; } - Value get(int index) { return values[index]; } + Value get(int index) { + Value value = values[index]; + return value == missing ? defaultValue : value; + } + void set(int index, Value value) { values[index] = value; } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index e8620670dd6..4e31dd89e6a 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -7,7 +7,6 @@ import com.yahoo.filedistribution.fileacquirer.MockFileAcquirer; import com.yahoo.path.Path; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.vespa.config.search.RankProfilesConfig; @@ -19,6 +18,7 @@ import java.util.ArrayList; import java.util.List; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; /** * @author bratseth @@ -36,6 +36,30 @@ public class ModelsEvaluatorTest { assertEquals(32.0, function.evaluate().asDouble(), delta); } + /** Tests a function defined as 4 * (var1 + var2) */ + @Test + public void testSettingDefaultVariableValue() { + ModelsEvaluator models = createModels("src/test/resources/config/rankexpression/"); + + { + FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); + assertTrue(Double.isNaN(function.evaluate().asDouble())); + } + + { + FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); + function.setUnboundValue(5); + assertEquals(40.0, function.evaluate().asDouble(), delta); + } + + { + FunctionEvaluator function = models.evaluatorOf("macros", "secondphase"); + function.setUnboundValue(5); + function.bind("match", 3); + assertEquals(32.0, function.evaluate().asDouble(), delta); + } + } + @Test public void testBindingValidation() { List<ExpressionFunction> functions = new ArrayList<>(); |