diff options
author | Arne Juul <arnej@yahoo-inc.com> | 2019-08-20 12:21:37 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahoo-inc.com> | 2019-08-20 12:21:37 +0000 |
commit | 7df067cfb84f0d6e00e87bf69276d7a353c9f972 (patch) | |
tree | c537a1291c9f91a47e7a660cc49de11f722783bb /vespajlib/src/main/java/com | |
parent | d88f2b235136691dcf08014cca60121ad2e3b62a (diff) |
use same rules for cell value type resolving as C++
* pick cell value type from tensors with dimensions only
* in Concat, use the expected combined cell value type for unit tensor
Diffstat (limited to 'vespajlib/src/main/java/com')
4 files changed, 18 insertions, 9 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 319947607d2..d64a62143f4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -87,6 +87,16 @@ public class TensorType { this.dimensions = ImmutableList.copyOf(dimensionList); } + static public Value combinedValueType(TensorType ... types) { + List<Value> valueTypes = new ArrayList<>(); + for (TensorType type : types) { + if (type.rank() > 0) { + valueTypes.add(type.valueType()); + } + } + return Value.largestOf(valueTypes); + } + /** * Returns a tensor type instance from a string on the format * <code>tensor(dimension1, dimension2, ...)</code> @@ -456,7 +466,7 @@ public class TensorType { * The value type will be the largest of the value types of the input types */ public Builder(TensorType ... types) { - this.valueType = TensorType.Value.largestOf(Arrays.stream(types).map(type -> type.valueType()).collect(Collectors.toList())); + this.valueType = TensorType.combinedValueType(types); for (TensorType type : types) addDimensionsOf(type); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index a48ac19fbff..42c6fe2f4aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -89,8 +89,9 @@ public class Concat extends PrimitiveTensorFunction { public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - a = ensureIndexedDimension(dimension, a); - b = ensureIndexedDimension(dimension, b); + TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); + a = ensureIndexedDimension(dimension, a, combinedValueType); + b = ensureIndexedDimension(dimension, b, combinedValueType); IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; @@ -128,7 +129,7 @@ public class Concat extends PrimitiveTensorFunction { } } - private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) { + private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) { Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName); if ( dimension.isPresent() ) { if ( ! dimension.get().isIndexed()) @@ -141,7 +142,7 @@ public class Concat extends PrimitiveTensorFunction { if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType()) + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType) .indexed(dimensionName, 1) .build()) .cell(1,0) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 062e0d92e80..2939b964f04 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -390,8 +390,7 @@ public class Join extends PrimitiveTensorFunction { private static TensorType commonDimensions(Tensor a, Tensor b) { TensorType aType = a.type(); TensorType bType = b.type(); - TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(), - bType.valueType())); + TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.combinedValueType(aType, bType)); for (int i = 0; i < aType.dimensions().size(); ++i) { TensorType.Dimension aDim = aType.dimensions().get(i); for (int j = 0; j < bType.dimensions().size(); ++j) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index db950e6c8b9..1134e8177ad 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -268,8 +268,7 @@ public class ReduceJoin extends CompositeTensorFunction { } private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) { - TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(), - b.type().valueType())); + TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(a.type(), b.type())); for (TensorType.Dimension aDim : a.type().dimensions()) { for (TensorType.Dimension bDim : b.type().dimensions()) { if (aDim.name().equals(bDim.name())) { |