diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-22 15:44:09 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-22 15:44:09 +0100 |
commit | 07b29b192fa5e373a90fe0c7e6661f9e8024577e (patch) | |
tree | 4b6068b80c549ca8fa05a9d7884a444b5a224247 /vespajlib/src/main/java/com | |
parent | 1887446f4eb928d4208e9e33d18cbb0e2c164e13 (diff) |
Concat skeleton
Diffstat (limited to 'vespajlib/src/main/java/com')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 5 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 58 |
2 files changed, 63 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 13ddf3c2e20..5645ba6eb8e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -108,6 +108,11 @@ public class TensorType { public Set<String> dimensionNames() { return dimensions.stream().map(Dimension::name).collect(Collectors.toSet()); } + + /** Returns the dimension with this name, or empty if not present */ + public Optional<Dimension> dimension(String name) { + return indexOfDimension(name).map(i -> dimensions.get(i)); + } /** Returns the 0-base index of this dimension, or empty if it is not present */ public Optional<Integer> indexOfDimension(String dimension) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java new file mode 100644 index 00000000000..a39f46e5a73 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -0,0 +1,58 @@ +package com.yahoo.tensor.functions; + +import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; + +import java.util.List; +import java.util.Optional; + +/** + * Concatenation of two tensors along an (indexed) dimension + * + * @author bratseth + */ +@Beta +public class Concat extends PrimitiveTensorFunction { + + private final TensorFunction argumentA, argumentB; + private final String dimension; + + public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) { + this.argumentA = argumentA; + this.argumentB = argumentB; + this.dimension = dimension; + } + + @Override + public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if (arguments.size() != 2) + throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); + return new Concat(arguments.get(0), arguments.get(1), dimension); + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Concat(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); + } + + @Override + public String toString(ToStringContext context) { + return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; + } + + @Override + public Tensor evaluate(EvaluationContext context) { + Tensor a = argumentA.evaluate(context); + Tensor b = argumentB.evaluate(context); + Optional<TensorType.Dimension> aDimension = a.type().dimension(dimension); + Optional<TensorType.Dimension> bDimension = a.type().dimension(dimension); + throw new UnsupportedOperationException("Not implemented"); // TODO + } + +} |