diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java | 26 |
1 files changed, 24 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 6c149724aca..34beb465d4c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.functions.ToStringContext; import java.util.Collections; import java.util.List; +import java.util.Optional; /** * A tensor variable name which resolves to a tensor in the context at evaluation time @@ -20,9 +21,17 @@ import java.util.List; public class VariableTensor extends PrimitiveTensorFunction { private final String name; + private final Optional<TensorType> requiredType; public VariableTensor(String name) { this.name = name; + this.requiredType = Optional.empty(); + } + + /** A variable tensor which must be compatible with the given type */ + public VariableTensor(String name, TensorType requiredType) { + this.name = name; + this.requiredType = Optional.of(requiredType); } @Override @@ -35,11 +44,19 @@ public class VariableTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(EvaluationContext context) { return context.getTensorType(name); } + public TensorType type(TypeContext context) { + TensorType givenType = context.getType(name); + if (givenType == null) return null; + verifyType(givenType); + return givenType; + } @Override public Tensor evaluate(EvaluationContext context) { - return context.getTensor(name); + Tensor tensor = context.getTensor(name); + if (tensor == null) return null; + verifyType(tensor.type()); + return tensor; } @Override @@ -47,4 +64,9 @@ public class VariableTensor extends PrimitiveTensorFunction { return name; } + private void verifyType(TensorType givenType) { + if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get())) + throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " + + requiredType.get() + " but was " + givenType); + } } |