aboutsummaryrefslogtreecommitdiffstats
path: root/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/expressions/ArithmeticTestCase.java
blob: f6dc7f839ed93dcb0762c39853074dac0828331c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.DataType;
import com.yahoo.document.datatypes.FieldValue;
import com.yahoo.document.datatypes.IntegerFieldValue;
import com.yahoo.document.datatypes.LongFieldValue;
import com.yahoo.document.datatypes.StringFieldValue;
import com.yahoo.vespa.indexinglanguage.SimpleTestAdapter;
import org.junit.Test;

import static com.yahoo.vespa.indexinglanguage.expressions.ArithmeticExpression.Operator;
import static org.junit.Assert.*;

/**
 * @author Simon Thoresen Hult
 */
public class ArithmeticTestCase {

    @Test
    public void requireThatAccessorsWork() {
        ArithmeticExpression exp = newArithmetic(6, Operator.ADD, 9);
        assertEquals(newLong(6), exp.getLeftHandSide());
        assertEquals(Operator.ADD, exp.getOperator());
        assertEquals(newLong(9), exp.getRightHandSide());
    }

    @Test
    public void requireThatHashCodeAndEqualsAreImplemented() {
        Expression exp = newArithmetic(6, Operator.ADD, 9);
        assertFalse(exp.equals(new Object()));
        assertFalse(exp.equals(newArithmetic(1, Operator.DIV, 1)));
        assertFalse(exp.equals(newArithmetic(6, Operator.DIV, 1)));
        assertFalse(exp.equals(newArithmetic(6, Operator.ADD, 1)));
        assertEquals(exp, newArithmetic(6, Operator.ADD, 9));
        assertEquals(exp.hashCode(), newArithmetic(6, Operator.ADD, 9).hashCode());
    }

    @Test
    public void requireThatConstructorDoesNotAcceptNull() {
        try {
            newArithmetic(null, Operator.ADD, new SimpleExpression());
            fail();
        } catch (NullPointerException e) {

        }
        try {
            newArithmetic(new SimpleExpression(), null, new SimpleExpression());
            fail();
        } catch (NullPointerException e) {

        }
        try {
            newArithmetic(new SimpleExpression(), Operator.ADD, null);
            fail();
        } catch (NullPointerException e) {

        }
    }

    @Test
    public void requireThatVerifyCallsAreForwarded() {
        assertVerify(SimpleExpression.newOutput(DataType.INT), Operator.ADD,
                     SimpleExpression.newOutput(DataType.INT), null);
        assertVerifyThrows(SimpleExpression.newOutput(null), Operator.ADD,
                           SimpleExpression.newOutput(DataType.INT), null,
                           "Attempting to perform arithmetic on a null value");
        assertVerifyThrows(SimpleExpression.newOutput(DataType.INT), Operator.ADD,
                           SimpleExpression.newOutput(null), null,
                           "Attempting to perform arithmetic on a null value");
        assertVerifyThrows(SimpleExpression.newOutput(null), Operator.ADD,
                           SimpleExpression.newOutput(null), null,
                           "Attempting to perform arithmetic on a null value");
        assertVerifyThrows(SimpleExpression.newOutput(DataType.INT), Operator.ADD,
                           SimpleExpression.newOutput(DataType.STRING), null,
                           "Attempting to perform unsupported arithmetic: [int] + [string]");
        assertVerifyThrows(SimpleExpression.newOutput(DataType.STRING), Operator.ADD,
                           SimpleExpression.newOutput(DataType.STRING), null,
                           "Attempting to perform unsupported arithmetic: [string] + [string]");
    }

    @Test
    public void requireThatOperandInputCanBeNull() {
        SimpleExpression reqNull = new SimpleExpression();
        SimpleExpression reqInt = new SimpleExpression(DataType.INT);
        assertNull(newArithmetic(reqNull, Operator.ADD, reqNull).requiredInputType());
        assertEquals(DataType.INT, newArithmetic(reqInt, Operator.ADD, reqNull).requiredInputType());
        assertEquals(DataType.INT, newArithmetic(reqInt, Operator.ADD, reqInt).requiredInputType());
        assertEquals(DataType.INT, newArithmetic(reqNull, Operator.ADD, reqInt).requiredInputType());
    }

    @Test
    public void requireThatOperandsAreInputCompatible() {
        assertVerify(new SimpleExpression(DataType.INT), Operator.ADD,
                     new SimpleExpression(DataType.INT), DataType.INT);
        assertVerifyThrows(new SimpleExpression(DataType.INT), Operator.ADD,
                           new SimpleExpression(DataType.STRING), null,
                           "Operands require conflicting input types, int vs string");
    }

    @Test
    public void requireThatResultIsCalculated() {
        for (int i = 0; i < 50; ++i) {
            LongFieldValue lhs = new LongFieldValue(i);
            LongFieldValue rhs = new LongFieldValue(100 - i);
            assertResult(lhs, Operator.ADD, rhs, new LongFieldValue(lhs.getLong() + rhs.getLong()));
            assertResult(lhs, Operator.SUB, rhs, new LongFieldValue(lhs.getLong() - rhs.getLong()));
            assertResult(lhs, Operator.DIV, rhs, new LongFieldValue(lhs.getLong() / rhs.getLong()));
            assertResult(lhs, Operator.MOD, rhs, new LongFieldValue(lhs.getLong() % rhs.getLong()));
            assertResult(lhs, Operator.MUL, rhs, new LongFieldValue(lhs.getLong() * rhs.getLong()));
        }
    }

    @Test
    public void requireThatArithmeticWithNullEvaluatesToNull() {
        assertNull(newArithmetic(new SimpleExpression(), Operator.ADD,
                                 new ConstantExpression(new LongFieldValue(69))).execute());
        assertNull(newArithmetic(new ConstantExpression(new LongFieldValue(69)), Operator.ADD,
                                 new SimpleExpression()).execute());
    }

    @Test
    public void requireThatNonNumericOperandThrows() {
        try {
            newArithmetic(new ConstantExpression(new IntegerFieldValue(6)), Operator.ADD,
                          new ConstantExpression(new StringFieldValue("9"))).execute();
            fail();
        } catch (IllegalArgumentException e) {
            assertEquals("Unsupported operation: [int] + [string]", e.getMessage());
        }
        try {
            newArithmetic(new ConstantExpression(new StringFieldValue("6")), Operator.ADD,
                          new ConstantExpression(new IntegerFieldValue(9))).execute();
            fail();
        } catch (IllegalArgumentException e) {
            assertEquals("Unsupported operation: [string] + [int]", e.getMessage());
        }
    }

    @Test
    public void requireThatProperNumericalTypeIsUsed() {
        for (Operator op : Operator.values()) {
            assertType(DataType.INT, op, DataType.INT, DataType.INT);
            assertType(DataType.LONG, op, DataType.INT, DataType.LONG);
            assertType(DataType.LONG, op, DataType.LONG, DataType.LONG);
            assertType(DataType.INT, op, DataType.LONG, DataType.LONG);

            assertType(DataType.FLOAT, op, DataType.FLOAT, DataType.FLOAT);
            assertType(DataType.DOUBLE, op, DataType.FLOAT, DataType.DOUBLE);
            assertType(DataType.DOUBLE, op, DataType.DOUBLE, DataType.DOUBLE);
            assertType(DataType.FLOAT, op, DataType.DOUBLE, DataType.DOUBLE);

            assertType(DataType.INT, op, DataType.FLOAT, DataType.FLOAT);
            assertType(DataType.INT, op, DataType.DOUBLE, DataType.DOUBLE);
        }
    }

    private void assertResult(FieldValue lhs, Operator op, FieldValue rhs, FieldValue expected) {
        assertEquals(expected, evaluate(new ConstantExpression(lhs), op,
                                        new ConstantExpression(rhs)));
    }

    private void assertType(DataType lhs, Operator op, DataType rhs, DataType expected) {
        assertEquals(expected, newArithmetic(SimpleExpression.newOutput(lhs), op,
                                             SimpleExpression.newOutput(rhs)).verify());
        assertEquals(expected, newArithmetic(lhs.createFieldValue(6), op,
                                             rhs.createFieldValue(9)).execute().getDataType());
    }

    private static FieldValue evaluate(Expression lhs, Operator op, Expression rhs) {
        ExecutionContext ctx = new ExecutionContext(new SimpleTestAdapter());
        new ArithmeticExpression(lhs, op, rhs).execute(ctx);
        return ctx.getValue();
    }

    private static ArithmeticExpression newArithmetic(long lhs, Operator op, long rhs) {
        return newArithmetic(new LongFieldValue(lhs), op, new LongFieldValue(rhs));
    }

    private static ArithmeticExpression newArithmetic(FieldValue lhs, Operator op, FieldValue rhs) {
        return newArithmetic(new ConstantExpression(lhs), op, new ConstantExpression(rhs));
    }

    private static ArithmeticExpression newArithmetic(Expression lhs, Operator op, Expression rhs) {
        return new ArithmeticExpression(lhs, op, rhs);
    }

    private static ConstantExpression newLong(long val) {
        return new ConstantExpression(new LongFieldValue(val));
    }

    private static void assertVerify(Expression lhs, Operator op, Expression rhs, DataType val) {
        new ArithmeticExpression(lhs, op, rhs).verify(val);
    }

    private static void assertVerifyThrows(Expression lhs, Operator op, Expression rhs, DataType val,
                                           String expectedException) {
        try {
            new ArithmeticExpression(lhs, op, rhs).verify(val);
            fail();
        } catch (VerificationException e) {
            assertEquals(expectedException, e.getMessage());
        }
    }
}