diff options
author | Lester Solbakken <lesters@oath.com> | 2021-04-22 09:21:05 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-04-22 09:21:05 +0200 |
commit | bd531088df9eb9f7e083fd03d91e5d3f19a1664b (patch) | |
tree | 4c891ef1e0f41244fc150fabe955f163425e2bcc /vespajlib/src/main | |
parent | 1f7363ff53144c40fd27c4332b1cb3619b1525d6 (diff) |
Wire in tensor cell type resolving for concat in Java
Diffstat (limited to 'vespajlib/src/main')
3 files changed, 20 insertions, 30 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java index f9bc9072cfa..651bec6a1aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java @@ -126,6 +126,14 @@ public class TypeResolver { first.name().equals(second.name())); } + private static boolean firstIsSmaller(Dimension first, Dimension second) { + return (first.type() == Dimension.Type.indexedBound && + second.type() == Dimension.Type.indexedBound && + first.name().equals(second.name()) && + first.size().isPresent() && second.size().isPresent() && + first.size().get() < second.size().get()); + } + static public TensorType join(TensorType lhs, TensorType rhs) { Value cellType = Value.DOUBLE; if (lhs.rank() > 0 && rhs.rank() > 0) { @@ -219,9 +227,13 @@ public class TypeResolver { Dimension other = map.get(dim.name()); if (! other.equals(dim)) { if (firstIsBoundSecond(dim, other)) { - map.put(dim.name(), dim); + map.put(dim.name(), other); // [N] and [] -> [] } else if (firstIsBoundSecond(other, dim)) { - map.put(dim.name(), other); + map.put(dim.name(), dim); // [N] and [] -> [] + } else if (firstIsSmaller(dim, other)) { + map.put(dim.name(), dim); // [N] and [M] -> [ min(N,M] ]. + } else if (firstIsSmaller(other, dim)) { + map.put(dim.name(), other); // [N] and [M] -> [ min(N,M] ]. } else { throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index c6f8171bd18..fe8b2f417aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -47,7 +48,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM @Override public TensorType type(TypeContext<NAMETYPE> context) { - return new TensorType(valueType, argument.type(context).dimensions()); + return TypeResolver.cell_cast(argument.type(context), valueType); } @Override @@ -56,12 +57,11 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM if (tensor.type().valueType() == valueType) { return tensor; } - TensorType type = new TensorType(valueType, tensor.type().dimensions()); + TensorType type = TypeResolver.cell_cast(tensor.type(), valueType); return cast(tensor, type); } private Tensor cast(Tensor tensor, TensorType type) { - Tensor.Builder builder = Tensor.Builder.of(type); TensorType.Value fromValueType = tensor.type().valueType(); switch (fromValueType) { case DOUBLE: diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index fff2ddaf320..e71d1c717d3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -60,30 +61,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET @Override public TensorType type(TypeContext<NAMETYPE> context) { - return type(argumentA.type(context), argumentB.type(context)); - } - - /** Returns the type resulting from concatenating a and b */ - private TensorType type(TensorType a, TensorType b) { - // TODO: Fail if concat dimension is present but not indexed in a or b - TensorType.Builder builder = new TensorType.Builder(a, b); - if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) { - builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) + - b.sizeOfDimension(dimension).orElse(1L))); - /* - MutableLong concatSize = new MutableLong(0); - a.sizeOfDimension(dimension).ifPresent(concatSize::add); - b.sizeOfDimension(dimension).ifPresent(concatSize::add); - builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); - */ - } - return builder.build(); - } - - /** Returns true if this dimension is present and unbound */ - private boolean unboundIn(TensorType type, String dimensionName) { - Optional<TensorType.Dimension> dimension = type.dimension(dimensionName); - return dimension.isPresent() && ! dimension.get().size().isPresent(); + return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension); } @Override @@ -97,7 +75,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; - TensorType concatType = type(a.type(), b.type()); + TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); |