summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-21 11:50:08 +0200
committerLester Solbakken <lesters@oath.com>2021-04-21 11:50:08 +0200
commit2076b7950d83a860688d923d577e97c20b5470f6 (patch)
tree6e4e74479632ab45a1142fa4b734ca97ed02e6e4 /vespajlib
parentc02603afee45b09cc1fa6d8b5448aa346a0984a8 (diff)
Wire inn tensor cell type resolving for join in Java
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java3
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java6
3 files changed, 10 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
index 37a4bf375d0..f9bc9072cfa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
@@ -153,6 +153,10 @@ public class TypeResolver {
map.put(dim.name(), dim);
} else if (firstIsBoundSecond(other, dim)) {
map.put(dim.name(), other);
+ } else if (dim.isMapped() && other.isIndexed()) {
+ map.put(dim.name(), dim); // {} and [] -> {}. Note: this is not allowed in C++
+ } else if (dim.isIndexed() && other.isMapped()) {
+ map.put(dim.name(), other); // {} and [] -> {}. Note: this is not allowed in C++
} else {
throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 5419d04a4fb..d43b7889982 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -49,7 +50,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
/** Returns the type resulting from applying Join to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
try {
- return new TensorType.Builder(false, a, b).build();
+ return TypeResolver.join(a, b);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
index 8e4205c8c27..8ed9a57c195 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
@@ -77,10 +77,12 @@ public class TypeResolverTestCase {
checkJoin("tensor(x{})", "tensor<bfloat16>(y{})", "tensor(x{},y{})");
checkJoin("tensor(x{})", "tensor<float>(y{})", "tensor(x{},y{})");
checkJoin("tensor(x{})", "tensor<int8>(y{})", "tensor(x{},y{})");
+ // specific for Java
+ checkJoin("tensor(x[])", "tensor(x{})", "tensor(x{})");
+ checkJoin("tensor(x{})", "tensor(x[])", "tensor(x{})");
// dimension mismatch should fail:
checkJoinFails("tensor(x[3])", "tensor(x[5])");
checkJoinFails("tensor(x[5])", "tensor(x[3])");
- checkJoinFails("tensor(x{})", "tensor(x[5])");
}
@Test
@@ -156,6 +158,7 @@ public class TypeResolverTestCase {
checkMerge("tensor(x{},y{})", "tensor<float>(x{},y{})", "tensor(x{},y{})");
checkMerge("tensor(x{},y{})", "tensor<int8>(x{},y{})", "tensor(x{},y{})");
checkMerge("tensor(y{})", "tensor(y{})", "tensor(y{})");
+ checkMerge("tensor(x{})", "tensor(x[5])", "tensor(x{})");
checkMergeFails("tensor(a[10])", "tensor()");
checkMergeFails("tensor(a[10])", "tensor(x{},y{},z{})");
checkMergeFails("tensor<bfloat16>(x[5])", "tensor()");
@@ -168,7 +171,6 @@ public class TypeResolverTestCase {
checkMergeFails("tensor(x[3])", "tensor(x[5])");
checkMergeFails("tensor(x[5])", "tensor(x[3])");
checkMergeFails("tensor(x{})", "tensor()");
- checkMergeFails("tensor(x{})", "tensor(x[5])");
checkMergeFails("tensor(x{},y{})", "tensor(x{},z{})");
checkMergeFails("tensor(y{})", "tensor()");
}