diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 12 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 4 |
2 files changed, 8 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index aaa25a0b058..df78f3dfc3a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -4,6 +4,7 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.LinkedHashMap; @@ -410,7 +411,7 @@ public class TensorType { private final Value valueType; - /** Creates an empty builder with cells of type double*/ + /** Creates an empty builder with cells of type double */ public Builder() { this(Value.DOUBLE); } @@ -425,17 +426,16 @@ public class TensorType { * If the same dimension is indexed with different size restrictions the largest size will be used. * If it is size restricted in one argument but not the other it will not be size restricted. * If it is indexed in one and mapped in the other it will become mapped. + * + * The value type will be the largest of the value types of the input types */ public Builder(TensorType ... types) { - this(Value.DOUBLE, types); - } - public Builder(Value valueType, TensorType ... types) { - this.valueType = valueType; + this.valueType = TensorType.Value.largestOf(Arrays.stream(types).map(type -> type.valueType()).collect(Collectors.toList())); for (TensorType type : types) addDimensionsOf(type); } - /** Creates a builder from the given dimensions */ + /** Creates a builder from the given dimensions, having double as the value type */ public Builder(Iterable<Dimension> dimensions) { this(Value.DOUBLE, dimensions); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index a0a257bb909..a48ac19fbff 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -73,8 +73,8 @@ public class Concat extends PrimitiveTensorFunction { MutableLong concatSize = new MutableLong(0); a.sizeOfDimension(dimension).ifPresent(concatSize::add); b.sizeOfDimension(dimension).ifPresent(concatSize::add); - builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); - */ + builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); + */ } return builder.build(); } |