diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 10 |
1 files changed, 7 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 91ab4f9d046..a48ac19fbff 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -73,8 +73,8 @@ public class Concat extends PrimitiveTensorFunction { MutableLong concatSize = new MutableLong(0); a.sizeOfDimension(dimension).ifPresent(concatSize::add); b.sizeOfDimension(dimension).ifPresent(concatSize::add); - builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); - */ + builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); + */ } return builder.build(); } @@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction { if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType()) + .indexed(dimensionName, 1) + .build()) + .cell(1,0) + .build(); return tensor.multiply(unitTensor); } |