summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java
blob: bc10ecc3abdbb4077f6d20ca6011f231c97bfaae (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;

import static org.junit.Assert.assertEquals;

/**
 * @author lesters
 */
public class CellCastTestCase {

    @Test
    public void testCellCasting() {
        Tensor tensor;

        tensor = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]");
        assertEquals(TensorType.Value.DOUBLE, tensor.type().valueType());
        assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
        assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
        assertEquals(tensor, tensor.cellCast(TensorType.Value.FLOAT));

        tensor = Tensor.from("tensor<double>(x{}):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}");
        assertEquals(TensorType.Value.DOUBLE, tensor.type().valueType());
        assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
        assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
        assertEquals(tensor, tensor.cellCast(TensorType.Value.FLOAT));

        tensor = Tensor.from("tensor<float>(x[3],y{}):{a:[1.0, 2.0, 3.0],b:[4.0,5.0,6.0]}");
        assertEquals(TensorType.Value.FLOAT, tensor.type().valueType());
        assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
        assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
        assertEquals(tensor, tensor.cellCast(TensorType.Value.DOUBLE));
    }

}