1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.HashMap;
import java.util.Map;
import static org.junit.Assert.assertEquals;
/**
* @author bratseth
*/
public class EvaluationTester {
private MapContext defaultContext;
public EvaluationTester() {
Map<String, Value> bindings = new HashMap<>();
bindings.put("zero", DoubleValue.frozen(0d));
bindings.put("one", DoubleValue.frozen(1d));
bindings.put("one_half", DoubleValue.frozen(0.5d));
bindings.put("a_quarter", DoubleValue.frozen(0.25d));
bindings.put("foo", StringValue.frozen("foo"));
defaultContext = new MapContext(bindings);
}
public RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
assertEvaluates(expectedTensor, expressionString, false, tensorArguments);
return assertEvaluates(expectedTensor, expressionString, true, tensorArguments);
}
// TODO: Test both bound and unbound indexed
public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
String ... tensorArgumentStrings) {
MapContext context = defaultContext.thawedCopy();
int argumentIndex = 0;
for (String argumentString : tensorArgumentStrings) {
Tensor argument;
if (argumentString.startsWith("tensor")) // explicitly decided type
argument = Tensor.from(argumentString);
else // use mappedTensors+dimensions in tensor to decide type
argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString);
context.put("tensor" + (argumentIndex++), new TensorValue(argument));
}
return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
mappedTensors ? "Mapped tensors" : "Indexed tensors");
}
public RankingExpression assertEvaluates(Value value, String expressionString) {
return assertEvaluates(value, expressionString, defaultContext, "");
}
public RankingExpression assertEvaluates(double value, String expressionString) {
return assertEvaluates(value, expressionString, defaultContext);
}
public RankingExpression assertEvaluates(boolean value, String expressionString) {
return assertEvaluates(value, expressionString, defaultContext);
}
public RankingExpression assertEvaluates(double value, String expressionString, Context context) {
return assertEvaluates(new DoubleValue(value), expressionString, context, "");
}
public RankingExpression assertEvaluates(boolean value, String expressionString, Context context) {
return assertEvaluates(new BooleanValue(value), expressionString, context, "");
}
public RankingExpression assertEvaluates(Value value, String expressionString, Context context, String explanation) {
try {
RankingExpression expression = new RankingExpression(expressionString);
if ( ! explanation.isEmpty())
explanation = explanation + ": ";
assertEquals(explanation + expression, value, expression.evaluate(context));
return expression;
}
catch (ParseException e) {
throw new RuntimeException(e);
}
}
/** Create a tensor type from a tensor string which may or may not contain type info */
private TensorType typeFrom(String argument, boolean mappedTensors) {
Tensor tensor = Tensor.from(argument); // Create tensor just to get the dimensions
if (mappedTensors) {
return tensor.type(); // implicit type is mapped by default
}
else { // convert to indexed
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : tensor.type().dimensions())
builder.indexed(dimension.name());
return builder.build();
}
}
}
|