summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
blob: 1928971820c1495526144a28e0ad5c7ecf14ad5b (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;

import org.junit.Test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.fail;

public class TensorParserTestCase {

    @Test
    public void testSparseParsing() {
        assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(),
                     Tensor.from("{}"));
        assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell(1.0, 0).build(),
                     Tensor.from("{{x:0}:1.0}"));
        assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(x{})")).cell().label("x", "l0").value(1.0).build(),
                     Tensor.from("{{x:l0}:1.0}"));
        assertEquals("If the type is specified, a dense tensor can be created from the sparse text form",
                     Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(),
                     Tensor.from("tensor(x[1]):{{x:0}:1.0}"));
    }

    @Test
    public void testDenseParsing() {
        assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(),
                    "tensor():{0.0}");
        assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).cell(1.3).build(),
                    "tensor():{1.3}");
        assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[])")).cell(1.0, 0).build(),
                    "tensor(x[]):{{x:0}:1.0}");
        assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(),
                    "tensor(x[1]):[1.0]");
        assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[2])")).cell(1.0, 0).cell(2.0, 1).build(),
                    "tensor(x[2]):[1.0, 2.0]");
        assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[2],y[3])"))
                                   .cell(1.0, 0, 0)
                                   .cell(2.0, 0, 1)
                                   .cell(3.0, 0, 2)
                                   .cell(4.0, 1, 0)
                                   .cell(5.0, 1, 1)
                                   .cell(6.0, 1, 2).build(),
                    "tensor(x[2],y[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]");
        assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1],y[2],z[3])"))
                                   .cell(1.0, 0, 0, 0)
                                   .cell(2.0, 0, 0, 1)
                                   .cell(3.0, 0, 0, 2)
                                   .cell(4.0, 0, 1, 0)
                                   .cell(5.0, 0, 1, 1)
                                   .cell(6.0, 0, 1, 2).build(),
                    "tensor(x[1],y[2],z[3]):[[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]]");
        assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
                                   .cell(1.0, 0, 0, 0)
                                   .cell(2.0, 0, 1, 0)
                                   .cell(3.0, 1, 0, 0)
                                   .cell(4.0, 1, 1, 0)
                                   .cell(5.0, 2, 0, 0)
                                   .cell(6.0, 2, 1, 0).build(),
                    "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [6.0]]]");
        assertEquals("Messy input",
                     Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
                                   .cell( 1.0, 0, 0, 0)
                                   .cell( 2.0, 0, 1, 0)
                                   .cell( 3.0, 1, 0, 0)
                                   .cell( 4.0, 1, 1, 0)
                                   .cell( 5.0, 2, 0, 0)
                                   .cell(-6.0, 2, 1, 0).build(),
                     Tensor.from("tensor( x[3],y[2],z[1]) : [  [ [1.0, 2.0, 3.0] , [4.0, 5,-6.0] ]  ]"));
        assertEquals("Skipping syntactic sugar",
                     Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
                                   .cell( 1.0, 0, 0, 0)
                                   .cell( 2.0, 0, 1, 0)
                                   .cell( 3.0, 1, 0, 0)
                                   .cell( 4.0, 1, 1, 0)
                                   .cell( 5.0, 2, 0, 0)
                                   .cell(-6.0, 2, 1, 0).build(),
                     Tensor.from("tensor( x[3],y[2],z[1]) : [1.0, 2.0, 3.0 , 4.0, 5, -6.0]"));
    }

    private void assertDense(Tensor expectedTensor, String denseFormat) {
        assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat));
        assertEquals(denseFormat, expectedTensor.toString());
    }

    @Test
    public void testIllegalStrings() {
        assertIllegal("label must be an identifier or integer, not '\"l0\"'",
                      "{{x:\"l0\"}:1.0}");
        assertIllegal("dimension must be an identifier or integer, not ''x''",
                      "{{'x':\"l0\"}:1.0}");
        assertIllegal("dimension must be an identifier or integer, not '\"x\"'",
                      "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}");
        assertIllegal("At {x:0}: '1-.0' is not a valid double",
                      "{{x:0}:1-.0}");
        assertIllegal("At index 0: '1-.0' is not a valid double",
                      "tensor(x[1]):[1-.0]");
    }

    private void assertIllegal(String message, String tensor) {
        try {
            Tensor.from(tensor);
            fail("Expected an IllegalArgumentException when parsing " + tensor);
        }
        catch (IllegalArgumentException e) {
            assertEquals(message, e.getMessage());
        }
    }

}