summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java10
1 files changed, 7 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 91ab4f9d046..a48ac19fbff 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -73,8 +73,8 @@ public class Concat extends PrimitiveTensorFunction {
MutableLong concatSize = new MutableLong(0);
a.sizeOfDimension(dimension).ifPresent(concatSize::add);
b.sizeOfDimension(dimension).ifPresent(concatSize::add);
- builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
- */
+ builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
+ */
}
return builder.build();
}
@@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction {
if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType())
+ .indexed(dimensionName, 1)
+ .build())
+ .cell(1,0)
+ .build();
return tensor.multiply(unitTensor);
}