aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTester.java
blob: cc278b3d73b766bd68af3b93db87cfe51f2a5dce (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;

import java.util.HashMap;
import java.util.Map;

import static org.junit.Assert.assertEquals;

/**
 * @author bratseth
 */
public class EvaluationTester {

    private MapContext defaultContext;

    public EvaluationTester() {
        Map<String, Value> bindings = new HashMap<>();
        bindings.put("zero", DoubleValue.frozen(0d));
        bindings.put("one", DoubleValue.frozen(1d));
        bindings.put("one_half", DoubleValue.frozen(0.5d));
        bindings.put("a_quarter", DoubleValue.frozen(0.25d));
        bindings.put("foo", StringValue.frozen("foo"));
        defaultContext = new MapContext(bindings);
    }

    public RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
        assertEvaluates(expectedTensor, expressionString, false, tensorArguments);
        return assertEvaluates(expectedTensor, expressionString, true, tensorArguments);
    }

    // TODO: Test both bound and unbound indexed
    public RankingExpression assertEvaluates(String expectedTensor, String expressionString, boolean mappedTensors,
                                             String ... tensorArgumentStrings) {
        MapContext context = defaultContext.thawedCopy();
        int argumentIndex = 0;
        for (String argumentString : tensorArgumentStrings) {
            Tensor argument;
            if (argumentString.startsWith("tensor")) // explicitly decided type
                argument = Tensor.from(argumentString);
            else // use mappedTensors+dimensions in tensor to decide type
                argument = Tensor.from(typeFrom(argumentString, mappedTensors), argumentString);
            context.put("tensor" + (argumentIndex++), new TensorValue(argument));
        }
        return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context,
                               mappedTensors ? "Mapped tensors" : "Indexed tensors");
    }

    public RankingExpression assertEvaluates(Value value, String expressionString) {
        return assertEvaluates(value, expressionString, defaultContext, "");
    }

    public RankingExpression assertEvaluates(double value, String expressionString) {
        return assertEvaluates(value, expressionString, defaultContext);
    }

    public RankingExpression assertEvaluates(boolean value, String expressionString) {
        return assertEvaluates(value, expressionString, defaultContext);
    }

    public RankingExpression assertEvaluates(double value, String expressionString, Context context) {
        return assertEvaluates(new DoubleValue(value), expressionString, context, "");
    }

    public RankingExpression assertEvaluates(boolean value, String expressionString, Context context) {
        return assertEvaluates(new BooleanValue(value), expressionString, context, "");
    }

    public RankingExpression assertEvaluates(Value value, String expressionString, Context context, String explanation) {
        try {
            RankingExpression expression = new RankingExpression(expressionString);
            if ( ! explanation.isEmpty())
                explanation = explanation + ": ";
            assertEquals(explanation + expression, value, expression.evaluate(context));
            return expression;
        }
        catch (ParseException e) {
            throw new RuntimeException(e);
        }
    }

    /** Create a tensor type from a tensor string which may or may not contain type info */
    private TensorType typeFrom(String argument, boolean mappedTensors) {
        Tensor tensor = Tensor.from(argument); // Create tensor just to get the dimensions
        if (mappedTensors) {
            return tensor.type(); // implicit type is mapped by default
        }
        else { // convert to indexed
            TensorType.Builder builder = new TensorType.Builder();
            for (TensorType.Dimension dimension : tensor.type().dimensions())
                builder.indexed(dimension.name());
            return builder.build();
        }
    }

}