summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java26
1 files changed, 24 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index 6c149724aca..34beb465d4c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.functions.ToStringContext;
import java.util.Collections;
import java.util.List;
+import java.util.Optional;
/**
* A tensor variable name which resolves to a tensor in the context at evaluation time
@@ -20,9 +21,17 @@ import java.util.List;
public class VariableTensor extends PrimitiveTensorFunction {
private final String name;
+ private final Optional<TensorType> requiredType;
public VariableTensor(String name) {
this.name = name;
+ this.requiredType = Optional.empty();
+ }
+
+ /** A variable tensor which must be compatible with the given type */
+ public VariableTensor(String name, TensorType requiredType) {
+ this.name = name;
+ this.requiredType = Optional.of(requiredType);
}
@Override
@@ -35,11 +44,19 @@ public class VariableTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public TensorType type(EvaluationContext context) { return context.getTensorType(name); }
+ public TensorType type(TypeContext context) {
+ TensorType givenType = context.getType(name);
+ if (givenType == null) return null;
+ verifyType(givenType);
+ return givenType;
+ }
@Override
public Tensor evaluate(EvaluationContext context) {
- return context.getTensor(name);
+ Tensor tensor = context.getTensor(name);
+ if (tensor == null) return null;
+ verifyType(tensor.type());
+ return tensor;
}
@Override
@@ -47,4 +64,9 @@ public class VariableTensor extends PrimitiveTensorFunction {
return name;
}
+ private void verifyType(TensorType givenType) {
+ if (requiredType.isPresent() && ! givenType.isAssignableTo(requiredType.get()))
+ throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " +
+ requiredType.get() + " but was " + givenType);
+ }
}