summaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
blob: f2164a1b177192327047775be5dc42ef664283c2 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import org.junit.Test;

import java.util.List;

import static org.junit.Assert.assertEquals;

/**
 * @author bratseth
 */
public class TensorFlowImporterTestCase {

    @Test
    public void testModel1() {
        List<RankingExpression> expressions = 
                new TensorFlowImporter().importModel("src/test/files/integration/tensorflow/model1/");
        assertEquals(1, expressions.size());
        assertEquals("scores", expressions.get(0).getName());
        assertEquals("" +
                     "softmax(join(rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " +
                                  "rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " +
                                  "f(a,b)(a + b)), " +
                             "d1)",
                     toNonPrimitiveString(expressions.get(0)));
    }

    private String toNonPrimitiveString(RankingExpression expression) {
        // toString on the wrapping expression will map to primitives, which is harder to read
        return ((TensorFunctionNode)expression.getRoot()).function().toString();
    }

}