aboutsummaryrefslogtreecommitdiffstats
path: root/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/ArithmeticExpression.java
blob: 5a12ce76d2b675cebad02f74a34b4cea8bd76d02 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.indexinglanguage.expressions;

import com.yahoo.document.DataType;
import com.yahoo.document.NumericDataType;
import com.yahoo.document.datatypes.*;
import com.yahoo.vespa.indexinglanguage.ExpressionConverter;
import com.yahoo.vespa.objects.ObjectOperation;
import com.yahoo.vespa.objects.ObjectPredicate;

import java.math.BigDecimal;
import java.math.MathContext;

/**
 * @author Simon Thoresen Hult
 */
public final class ArithmeticExpression extends CompositeExpression {

    public enum Operator {

        ADD(1, "+"),
        SUB(1, "-"),
        MUL(0, "*"),
        DIV(0, "/"),
        MOD(0, "%");

        private final int precedence;
        private final String img;

        Operator(int precedence, String img) {
            this.precedence = precedence;
            this.img = img;
        }

        public boolean precedes(Operator op) {
            return precedence <= op.precedence;
        }

        @Override
        public String toString() {
            return img;
        }
    }

    private final Expression left;
    private final Operator op;
    private final Expression right;

    public ArithmeticExpression(Expression left, Operator op, Expression right) {
        super(requiredInputType(left, right));
        left.getClass(); // throws NullPointerException
        op.getClass();
        right.getClass();
        this.left = left;
        this.op = op;
        this.right = right;
    }

    @Override
    public ArithmeticExpression convertChildren(ExpressionConverter converter) {
        // TODO: branch()?
        return new ArithmeticExpression(converter.convert(left), op, converter.convert(right));
    }

    public Expression getLeftHandSide() {
        return left;
    }

    public Operator getOperator() {
        return op;
    }

    public Expression getRightHandSide() {
        return right;
    }

    @Override
    protected void doExecute(ExecutionContext context) {
        FieldValue input = context.getValue();
        context.setValue(evaluate(context.setValue(input).execute(left).getValue(),
                                  context.setValue(input).execute(right).getValue()));
    }

    @Override
    protected void doVerify(VerificationContext context) {
        DataType input = context.getValueType();
        context.setValueType(evaluate(context.setValueType(input).execute(left).getValueType(),
                                      context.setValueType(input).execute(right).getValueType()));
    }

    private static DataType requiredInputType(Expression lhs, Expression rhs) {
        DataType lhsType = lhs.requiredInputType();
        DataType rhsType = rhs.requiredInputType();
        if (lhsType == null) {
            return rhsType;
        }
        if (rhsType == null) {
            return lhsType;
        }
        if (!lhsType.equals(rhsType)) {
            throw new VerificationException(ArithmeticExpression.class, "Operands require conflicting input types, " +
                                                                        lhsType.getName() + " vs " + rhsType.getName());
        }
        return lhsType;
    }

    @Override
    public DataType createdOutputType() {
        return UnresolvedDataType.INSTANCE;
    }

    @Override
    public String toString() {
        return left + " " + op + " " + right;
    }

    @Override
    public boolean equals(Object obj) {
        if (!(obj instanceof ArithmeticExpression)) {
            return false;
        }
        ArithmeticExpression exp = (ArithmeticExpression)obj;
        if (!left.equals(exp.left)) {
            return false;
        }
        if (!op.equals(exp.op)) {
            return false;
        }
        if (!right.equals(exp.right)) {
            return false;
        }
        return true;
    }

    @Override
    public int hashCode() {
        return getClass().hashCode() + left.hashCode() + op.hashCode() + right.hashCode();
    }

    private DataType evaluate(DataType lhs, DataType rhs) {
        if (lhs == null || rhs == null) {
            throw new VerificationException(this, "Attempting to perform arithmetic on a null value");
        }
        if (!(lhs instanceof NumericDataType) ||
            !(rhs instanceof NumericDataType))
        {
            throw new VerificationException(this, "Attempting to perform unsupported arithmetic: [" +
                                                  lhs.getName() + "] " + op + " [" + rhs.getName() + "]");
        }
        if (lhs == DataType.FLOAT || lhs == DataType.DOUBLE ||
            rhs == DataType.FLOAT || rhs == DataType.DOUBLE)
        {
            if (lhs == DataType.DOUBLE || rhs == DataType.DOUBLE) {
                return DataType.DOUBLE;
            }
            return DataType.FLOAT;
        }
        if (lhs == DataType.LONG || rhs == DataType.LONG) {
            return DataType.LONG;
        }
        return DataType.INT;
    }

    private FieldValue evaluate(FieldValue lhs, FieldValue rhs) {
        if (lhs == null || rhs == null) {
            return null;
        }
        if (!(lhs instanceof NumericFieldValue) ||
            !(rhs instanceof NumericFieldValue))
        {
            throw new IllegalArgumentException("Unsupported operation: [" + lhs.getDataType().getName() + "] " +
                                               op + " [" + rhs.getDataType().getName() + "]");
        }
        BigDecimal lhsVal = asBigDecimal((NumericFieldValue)lhs);
        BigDecimal rhsVal = asBigDecimal((NumericFieldValue)rhs);
        return switch (op) {
            case ADD -> createFieldValue(lhs, rhs, lhsVal.add(rhsVal));
            case SUB -> createFieldValue(lhs, rhs, lhsVal.subtract(rhsVal));
            case MUL -> createFieldValue(lhs, rhs, lhsVal.multiply(rhsVal));
            case DIV -> createFieldValue(lhs, rhs, lhsVal.divide(rhsVal, MathContext.DECIMAL64));
            case MOD -> createFieldValue(lhs, rhs, lhsVal.remainder(rhsVal));
        };
    }

    private FieldValue createFieldValue(FieldValue lhs, FieldValue rhs, BigDecimal val) {
        if (lhs instanceof FloatFieldValue || lhs instanceof DoubleFieldValue ||
            rhs instanceof FloatFieldValue || rhs instanceof DoubleFieldValue)
        {
            if (lhs instanceof DoubleFieldValue || rhs instanceof DoubleFieldValue) {
                return new DoubleFieldValue(val.doubleValue());
            }
            return new FloatFieldValue(val.floatValue());
        }
        if (lhs instanceof LongFieldValue || rhs instanceof LongFieldValue) {
            return new LongFieldValue(val.longValue());
        }
        return new IntegerFieldValue(val.intValue());
    }

    public static BigDecimal asBigDecimal(NumericFieldValue value) {
        if (value instanceof ByteFieldValue) {
            return BigDecimal.valueOf(((ByteFieldValue)value).getByte());
        } else if (value instanceof DoubleFieldValue) {
            return BigDecimal.valueOf(((DoubleFieldValue)value).getDouble());
        } else if (value instanceof FloatFieldValue) {
            return BigDecimal.valueOf(((FloatFieldValue)value).getFloat());
        } else if (value instanceof IntegerFieldValue) {
            return BigDecimal.valueOf(((IntegerFieldValue)value).getInteger());
        } else if (value instanceof LongFieldValue) {
            return BigDecimal.valueOf(((LongFieldValue)value).getLong());
        }
        throw new IllegalArgumentException("Unsupported numeric field value type '" +
                                           value.getClass().getName() + "'");
    }

    @Override
    public void selectMembers(ObjectPredicate predicate, ObjectOperation operation) {
        left.select(predicate, operation);
        right.select(predicate, operation);
    }
}