aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
blob: 6720269506d41dd31534d3292ec517c323e8938e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.evaluation;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import org.junit.Test;
import static org.junit.Assert.assertEquals;

/**
 * Tests evaluating neural nets expressed as tensors
 *
 * @author bratseth
 */
public class NeuralNetEvaluationTestCase {

    /** "XOR" neural network, separate expression per layer */
    @Test
    public void testPerLayerExpression() {
        String input = "{ {x:1}:0, {x:2}:1 }"; // tensor0
        String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; // tensor1
        String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; // tensor2
        String firstLayerInput = "sum(tensor0 * tensor1, x) + tensor2";
        String firstLayerOutput = "min(1.0, max(0.0, 0.5 + " + firstLayerInput + "))"; // non-linearity, "poor man's sigmoid"
        assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput, input, firstLayerWeights, firstLayerBias);
        String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; // tensor3
        String secondLayerBias = "{ {y:1}:-0.5 }"; // tensor4
        String secondLayerInput = "sum(" + firstLayerOutput + "* tensor3, h) + tensor4";
        String secondLayerOutput = "min(1.0, max(0.0, 0.5 + " + secondLayerInput + "))"; // non-linearity, "poor man's sigmoid"
        assertEvaluates("{ {y:1}:1 }", secondLayerOutput, input, firstLayerWeights, firstLayerBias, secondLayerWeights, secondLayerBias);
    }

    private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
        MapContext context = new MapContext();
        int argumentIndex = 0;
        for (String tensorArgument : tensorArguments)
            context.put("tensor" + (argumentIndex++), new TensorValue(Tensor.from(tensorArgument)));
        return assertEvaluates(new TensorValue(Tensor.from(expectedTensor)), expressionString, context);
    }

    private RankingExpression assertEvaluates(Value value, String expressionString, Context context) {
        try {
            RankingExpression expression = new RankingExpression(expressionString);
            assertEquals(expression.toString(), value, expression.evaluate(context));
            return expression;
        }
        catch (ParseException e) {
            throw new RuntimeException(e);
        }
    }

}