diff options
author | Lester Solbakken <lesters@oath.com> | 2021-04-21 11:50:08 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-04-21 11:50:08 +0200 |
commit | 2076b7950d83a860688d923d577e97c20b5470f6 (patch) | |
tree | 6e4e74479632ab45a1142fa4b734ca97ed02e6e4 | |
parent | c02603afee45b09cc1fa6d8b5448aa346a0984a8 (diff) |
Wire inn tensor cell type resolving for join in Java
3 files changed, 10 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java index 37a4bf375d0..f9bc9072cfa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java @@ -153,6 +153,10 @@ public class TypeResolver { map.put(dim.name(), dim); } else if (firstIsBoundSecond(other, dim)) { map.put(dim.name(), other); + } else if (dim.isMapped() && other.isIndexed()) { + map.put(dim.name(), dim); // {} and [] -> {}. Note: this is not allowed in C++ + } else if (dim.isIndexed() && other.isMapped()) { + map.put(dim.name(), other); // {} and [] -> {}. Note: this is not allowed in C++ } else { throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 5419d04a4fb..d43b7889982 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -49,7 +50,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP /** Returns the type resulting from applying Join to the two given types */ public static TensorType outputType(TensorType a, TensorType b) { try { - return new TensorType.Builder(false, a, b).build(); + return TypeResolver.join(a, b); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Can not join " + a + " and " + b, e); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java index 8e4205c8c27..8ed9a57c195 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java @@ -77,10 +77,12 @@ public class TypeResolverTestCase { checkJoin("tensor(x{})", "tensor<bfloat16>(y{})", "tensor(x{},y{})"); checkJoin("tensor(x{})", "tensor<float>(y{})", "tensor(x{},y{})"); checkJoin("tensor(x{})", "tensor<int8>(y{})", "tensor(x{},y{})"); + // specific for Java + checkJoin("tensor(x[])", "tensor(x{})", "tensor(x{})"); + checkJoin("tensor(x{})", "tensor(x[])", "tensor(x{})"); // dimension mismatch should fail: checkJoinFails("tensor(x[3])", "tensor(x[5])"); checkJoinFails("tensor(x[5])", "tensor(x[3])"); - checkJoinFails("tensor(x{})", "tensor(x[5])"); } @Test @@ -156,6 +158,7 @@ public class TypeResolverTestCase { checkMerge("tensor(x{},y{})", "tensor<float>(x{},y{})", "tensor(x{},y{})"); checkMerge("tensor(x{},y{})", "tensor<int8>(x{},y{})", "tensor(x{},y{})"); checkMerge("tensor(y{})", "tensor(y{})", "tensor(y{})"); + checkMerge("tensor(x{})", "tensor(x[5])", "tensor(x{})"); checkMergeFails("tensor(a[10])", "tensor()"); checkMergeFails("tensor(a[10])", "tensor(x{},y{},z{})"); checkMergeFails("tensor<bfloat16>(x[5])", "tensor()"); @@ -168,7 +171,6 @@ public class TypeResolverTestCase { checkMergeFails("tensor(x[3])", "tensor(x[5])"); checkMergeFails("tensor(x[5])", "tensor(x[3])"); checkMergeFails("tensor(x{})", "tensor()"); - checkMergeFails("tensor(x{})", "tensor(x[5])"); checkMergeFails("tensor(x{},y{})", "tensor(x{},z{})"); checkMergeFails("tensor(y{})", "tensor()"); } |