aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java
blob: 255dcc19974b4ac8f7e3335fd946937bef1d935f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;

import org.junit.Test;

import static org.junit.Assert.assertTrue;

public class DimensionRenamerTest {

    @Test
    public void testMnistRenaming() {
        DimensionRenamer renamer = new DimensionRenamer(new IntermediateGraph("test"));

        renamer.addDimension("first_dimension_of_x");
        renamer.addDimension("second_dimension_of_x");
        renamer.addDimension("first_dimension_of_w");
        renamer.addDimension("second_dimension_of_w");
        renamer.addDimension("first_dimension_of_b");

        // which dimension to join on matmul
        renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer.Constraint.equal(false), null);

        // other dimensions in matmul can't be equal
        renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer.Constraint.lessThan(false), null);

        // for efficiency, put dimension to join on innermost
        renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer.Constraint.lessThan(true), null);
        renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer.Constraint.greaterThan(true), null);

        // bias
        renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer.Constraint.equal(false), null);

        renamer.solve();

        String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
        String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
        String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
        String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
        String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();

        assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
        assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
        assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
        assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
        assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
    }

}