aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
blob: ce21e13298021ea20ce07d523d8ef1990e8411b6 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.rule;

import com.google.common.annotations.Beta;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;

import java.util.Collections;
import java.util.Deque;
import java.util.List;
import java.util.stream.Collectors;

/**
 * A node which performs a tensor function
 *
 * @author bratseth
 */
 @Beta
public class TensorFunctionNode extends CompositeNode {

    private final TensorFunction function;
    
    public TensorFunctionNode(TensorFunction function) {
        this.function = function;
    }

    @Override
    public List<ExpressionNode> children() {
        return function.functionArguments().stream()
                                           .map(f -> ((TensorFunctionExpressionNode)f).expression)
                                           .collect(Collectors.toList());
    }

    @Override
    public CompositeNode setChildren(List<ExpressionNode> children) {
        List<TensorFunction> wrappedChildren = children.stream()
                                                        .map(TensorFunctionExpressionNode::new)
                                                        .collect(Collectors.toList());
        return new TensorFunctionNode(function.replaceArguments(wrappedChildren));
    }

    @Override
    public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) {
        // Serialize as primitive
        return function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this));
    }
    
    @Override
    public Value evaluate(Context context) {
        return new TensorValue(function.evaluate(context));
    }

    public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
        return new TensorFunctionExpressionNode(node);
    }
    
    /** 
     * A tensor function implemented by an expression.
     * This allows us to pass expressions as tensor function arguments.
     */
    public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction {

        /** An expression which produces a tensor */
        private final ExpressionNode expression;
        
        public TensorFunctionExpressionNode(ExpressionNode expression) {
            this.expression = expression;
        }
        
        @Override
        public List<TensorFunction> functionArguments() { 
            if (expression instanceof CompositeNode)
                return ((CompositeNode)expression).children().stream()
                                                             .map(TensorFunctionExpressionNode::new)
                                                             .collect(Collectors.toList());
            else
                return Collections.emptyList();
        }

        @Override
        public TensorFunction replaceArguments(List<TensorFunction> arguments) {
            if (arguments.size() == 0) return this;
            List<ExpressionNode> unwrappedChildren = arguments.stream()
                                                              .map(arg -> ((TensorFunctionExpressionNode)arg).expression)
                                                              .collect(Collectors.toList());
            return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren));
        }

        @Override
        public PrimitiveTensorFunction toPrimitive() { return this; }

        @Override
        public Tensor evaluate(EvaluationContext context) {
            Value result = expression.evaluate((Context)context);
            if ( ! ( result instanceof TensorValue))
                throw new IllegalArgumentException("Attempted to evaluate tensor function '" + expression + "', " +
                                                   "but this returns " + result + ", not a tensor");
            return ((TensorValue)result).asTensor();
        }

        @Override
        public String toString() {
            return toString(ExpressionNodeToStringContext.empty);
        }
        
        @Override
        public String toString(ToStringContext c) {
            if (c instanceof ExpressionNodeToStringContext) {
                ExpressionNodeToStringContext context = (ExpressionNodeToStringContext) c;
                return expression.toString(context.context, context.path, context.parent);
            }
            else {
                return expression.toString();
            }
        }

    }
    
    /** Allows passing serialization context arguments through TensorFunctions */
    private static class ExpressionNodeToStringContext implements ToStringContext {
        
        final SerializationContext context;
        final Deque<String> path;
        final CompositeNode parent;
        
        public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(null, null, null);

        public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
            this.context = context;
            this.path = path;
            this.parent = parent;
        }

    }

}