aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
blob: 8a9a85d343c12b9a0b3f2c261752a10a11a6d204 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.evaluation;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;

import java.util.Collections;
import java.util.List;
import java.util.Objects;
import java.util.Optional;

/**
 * A tensor variable name which resolves to a tensor in the context at evaluation time
 *
 * @author bratseth
 */
public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {

    private final String name;
    private final Optional<TensorType> requiredType;

    public VariableTensor(String name) {
        this.name = name;
        this.requiredType = Optional.empty();
    }

    /** A variable tensor which must be compatible with the given type */
    public VariableTensor(String name, TensorType requiredType) {
        this.name = name;
        this.requiredType = Optional.of(requiredType);
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { return this; }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        TensorType givenType = context.getType(name);
        if (givenType == null) return null;
        verifyType(givenType);
        return givenType;
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor tensor = context.getTensor(name);
        if (tensor == null) return null;
        verifyType(tensor.type());
        return tensor;
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return name;
    }

    @Override
    public int hashCode() { return Objects.hash("variableTensor", name, requiredType); }

    private void verifyType(TensorType givenType) {
        if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get()))
            throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " +
                                               requiredType.get() + " but was " + givenType);
    }

}