summaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 15:44:09 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-22 15:44:09 +0100
commit07b29b192fa5e373a90fe0c7e6661f9e8024577e (patch)
tree4b6068b80c549ca8fa05a9d7884a444b5a224247 /vespajlib/src
parent1887446f4eb928d4208e9e33d18cbb0e2c164e13 (diff)
Concat skeleton
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java58
2 files changed, 63 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 13ddf3c2e20..5645ba6eb8e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -108,6 +108,11 @@ public class TensorType {
public Set<String> dimensionNames() {
return dimensions.stream().map(Dimension::name).collect(Collectors.toSet());
}
+
+ /** Returns the dimension with this name, or empty if not present */
+ public Optional<Dimension> dimension(String name) {
+ return indexOfDimension(name).map(i -> dimensions.get(i));
+ }
/** Returns the 0-base index of this dimension, or empty if it is not present */
public Optional<Integer> indexOfDimension(String dimension) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
new file mode 100644
index 00000000000..a39f46e5a73
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -0,0 +1,58 @@
+package com.yahoo.tensor.functions;
+
+import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+
+import java.util.List;
+import java.util.Optional;
+
+/**
+ * Concatenation of two tensors along an (indexed) dimension
+ *
+ * @author bratseth
+ */
+@Beta
+public class Concat extends PrimitiveTensorFunction {
+
+ private final TensorFunction argumentA, argumentB;
+ private final String dimension;
+
+ public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) {
+ this.argumentA = argumentA;
+ this.argumentB = argumentB;
+ this.dimension = dimension;
+ }
+
+ @Override
+ public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if (arguments.size() != 2)
+ throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
+ return new Concat(arguments.get(0), arguments.get(1), dimension);
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() {
+ return new Concat(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
+ }
+
+ @Override
+ public Tensor evaluate(EvaluationContext context) {
+ Tensor a = argumentA.evaluate(context);
+ Tensor b = argumentB.evaluate(context);
+ Optional<TensorType.Dimension> aDimension = a.type().dimension(dimension);
+ Optional<TensorType.Dimension> bDimension = a.type().dimension(dimension);
+ throw new UnsupportedOperationException("Not implemented"); // TODO
+ }
+
+}