summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-04-12 11:24:28 +0000
committerArne Juul <arnej@verizonmedia.com>2021-04-12 11:24:41 +0000
commitc307bf170a02a39602bce19ee516a39b173a7e6b (patch)
tree5d234fdcf779908137a921ab9b535793dde47692 /vespajlib
parent791c4b163669d5ef8ea671be1efacb89655d3935 (diff)
add TypeResolver
* with unit tests mostly cribbed from C++ version
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json19
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java251
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java455
3 files changed, 725 insertions, 0 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 6d99eeac816..9e9d32a5a6e 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1459,6 +1459,25 @@
],
"fields": []
},
+ "com.yahoo.tensor.TypeResolver": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>()",
+ "public static com.yahoo.tensor.TensorType map(com.yahoo.tensor.TensorType)",
+ "public static com.yahoo.tensor.TensorType reduce(com.yahoo.tensor.TensorType, java.util.List)",
+ "public static com.yahoo.tensor.TensorType peek(com.yahoo.tensor.TensorType, java.util.List)",
+ "public static com.yahoo.tensor.TensorType rename(com.yahoo.tensor.TensorType, java.util.List, java.util.List)",
+ "public static com.yahoo.tensor.TensorType cell_cast(com.yahoo.tensor.TensorType, com.yahoo.tensor.TensorType$Value)",
+ "public static com.yahoo.tensor.TensorType join(com.yahoo.tensor.TensorType, com.yahoo.tensor.TensorType)",
+ "public static com.yahoo.tensor.TensorType merge(com.yahoo.tensor.TensorType, com.yahoo.tensor.TensorType)",
+ "public static com.yahoo.tensor.TensorType concat(com.yahoo.tensor.TensorType, com.yahoo.tensor.TensorType, java.lang.String)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.evaluation.EvaluationContext": {
"superClass": "java.lang.Object",
"interfaces": [
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
new file mode 100644
index 00000000000..3fe5e01295a
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
@@ -0,0 +1,251 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import static com.yahoo.tensor.TensorType.Dimension;
+import static com.yahoo.tensor.TensorType.Value;
+
+/**
+ * Common type resolving for basic tensor operations.
+ *
+ * @author arnej
+ */
+public class TypeResolver {
+
+ static private TensorType scalar() {
+ return TensorType.empty;
+ }
+
+ static public TensorType map(TensorType inputType) {
+ Value orig = inputType.valueType();
+ Value cellType = Value.largestOf(orig, Value.FLOAT);
+ if (cellType == orig) {
+ return inputType;
+ }
+ return new TensorType(cellType, inputType.dimensions());
+ }
+
+ static public TensorType reduce(TensorType inputType, List<String> reduceDimensions) {
+ if (reduceDimensions.isEmpty()) {
+ return scalar();
+ }
+ Map<String, Dimension> map = new HashMap<>();
+ for (Dimension dim : inputType.dimensions()) {
+ map.put(dim.name(), dim);
+ }
+ for (String name : reduceDimensions) {
+ if (map.containsKey(name)) {
+ map.remove(name);
+ } else {
+ throw new IllegalArgumentException("reducing non-existing dimension "+name+" in type "+inputType);
+ }
+ }
+ if (map.isEmpty()) {
+ return scalar();
+ }
+ Value cellType = Value.largestOf(inputType.valueType(), Value.FLOAT);
+ return new TensorType(cellType, map.values());
+ }
+
+ static public TensorType peek(TensorType inputType, List<String> peekDimensions) {
+ if (peekDimensions.isEmpty()) {
+ throw new IllegalArgumentException("peeking no dimensions makes no sense");
+ }
+ Map<String, Dimension> map = new HashMap<>();
+ for (Dimension dim : inputType.dimensions()) {
+ map.put(dim.name(), dim);
+ }
+ for (String name : peekDimensions) {
+ if (map.containsKey(name)) {
+ map.remove(name);
+ } else {
+ throw new IllegalArgumentException("peeking non-existing dimension "+name+" in type "+inputType);
+ }
+ }
+ if (map.isEmpty()) {
+ return scalar();
+ }
+ Value cellType = inputType.valueType();
+ return new TensorType(cellType, map.values());
+ }
+
+ static public TensorType rename(TensorType inputType, List<String> from, List<String> to) {
+ if (from.isEmpty()) {
+ throw new IllegalArgumentException("renaming no dimensions");
+ }
+ if (from.size() != to.size()) {
+ throw new IllegalArgumentException("bad rename, from size "+from.size()+" != to.size "+to.size());
+ }
+ Map<String,Dimension> oldDims = new HashMap<>();
+ for (Dimension dim : inputType.dimensions()) {
+ oldDims.put(dim.name(), dim);
+ }
+ Map<String,Dimension> newDims = new HashMap<>();
+ for (int i = 0; i < from.size(); ++i) {
+ String oldName = from.get(i);
+ String newName = to.get(i);
+ if (oldDims.containsKey(oldName)) {
+ var dim = oldDims.remove(oldName);
+ newDims.put(newName, dim.withName(newName));
+ } else {
+ throw new IllegalArgumentException("bad rename, dimension "+oldName+" not found");
+ }
+ }
+ for (var keep : oldDims.values()) {
+ newDims.put(keep.name(), keep);
+ }
+ if (inputType.dimensions().size() == newDims.size()) {
+ return new TensorType(inputType.valueType(), newDims.values());
+ } else {
+ throw new IllegalArgumentException("bad rename, lost some dimenions");
+ }
+ }
+
+ static public TensorType cell_cast(TensorType inputType, Value toCellType) {
+ if (toCellType != Value.DOUBLE && inputType.dimensions().isEmpty()) {
+ throw new IllegalArgumentException("cannot cast "+inputType+" to valueType"+toCellType);
+ }
+ return new TensorType(toCellType, inputType.dimensions());
+ }
+
+ private static boolean firstIsBoundSecond(Dimension first, Dimension second) {
+ return (first.type() == Dimension.Type.indexedBound &&
+ second.type() == Dimension.Type.indexedUnbound &&
+ first.name().equals(second.name()));
+ }
+
+ static public TensorType join(TensorType lhs, TensorType rhs) {
+ Value cellType = Value.DOUBLE;
+ if (lhs.rank() > 0 && rhs.rank() > 0) {
+ // both types decide the new cell type
+ cellType = Value.largestOf(lhs.valueType(), rhs.valueType());
+ } else if (lhs.rank() > 0) {
+ // only the tensor decide the new cell type
+ cellType = lhs.valueType();
+ } else if (rhs.rank() > 0) {
+ // only the tensor decide the new cell type
+ cellType = rhs.valueType();
+ }
+ // result of computation must be at least float
+ cellType = Value.largestOf(cellType, Value.FLOAT);
+
+ Map<String, Dimension> map = new HashMap<>();
+ for (Dimension dim : lhs.dimensions()) {
+ map.put(dim.name(), dim);
+ }
+ for (Dimension dim : rhs.dimensions()) {
+ if (map.containsKey(dim.name())) {
+ Dimension other = map.get(dim.name());
+ if (! other.equals(dim)) {
+ if (firstIsBoundSecond(dim, other)) {
+ map.put(dim.name(), dim);
+ } else if (firstIsBoundSecond(other, dim)) {
+ map.put(dim.name(), other);
+ } else {
+ throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs);
+ }
+ }
+ } else {
+ map.put(dim.name(), dim);
+ }
+ }
+ return new TensorType(cellType, map.values());
+ }
+
+ static public TensorType merge(TensorType lhs, TensorType rhs) {
+ int sz = lhs.dimensions().size();
+ boolean allOk = (rhs.dimensions().size() == sz);
+ if (allOk) {
+ for (int i = 0; i < sz; i++) {
+ String lName = lhs.dimensions().get(i).name();
+ String rName = rhs.dimensions().get(i).name();
+ if (! lName.equals(rName)) {
+ allOk = false;
+ }
+ }
+ }
+ if (allOk) {
+ return join(lhs, rhs);
+ } else {
+ throw new IllegalArgumentException("types in merge() dimensions mismatch: "+lhs+" != "+rhs);
+ }
+ }
+
+ static public TensorType concat(TensorType lhs, TensorType rhs, String concatDimension) {
+ Value cellType = Value.DOUBLE;
+ if (lhs.rank() > 0 && rhs.rank() > 0) {
+ if (lhs.valueType() == rhs.valueType()) {
+ cellType = lhs.valueType();
+ } else {
+ cellType = Value.largestOf(lhs.valueType(), rhs.valueType());
+ // when changing cell type, make it at least float
+ cellType = Value.largestOf(cellType, Value.FLOAT);
+ }
+ } else if (lhs.rank() > 0) {
+ cellType = lhs.valueType();
+ } else if (rhs.rank() > 0) {
+ cellType = rhs.valueType();
+ }
+ Optional<Dimension> first = Optional.empty();
+ Optional<Dimension> second = Optional.empty();
+ Map<String, Dimension> map = new HashMap<>();
+ for (Dimension dim : lhs.dimensions()) {
+ if (dim.name().equals(concatDimension)) {
+ first = Optional.of(dim);
+ } else {
+ map.put(dim.name(), dim);
+ }
+ }
+ for (Dimension dim : rhs.dimensions()) {
+ if (dim.name().equals(concatDimension)) {
+ second = Optional.of(dim);
+ } else if (map.containsKey(dim.name())) {
+ Dimension other = map.get(dim.name());
+ if (! other.equals(dim)) {
+ if (firstIsBoundSecond(dim, other)) {
+ map.put(dim.name(), dim);
+ } else if (firstIsBoundSecond(other, dim)) {
+ map.put(dim.name(), other);
+ } else {
+ throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs);
+ }
+ }
+ } else {
+ map.put(dim.name(), dim);
+ }
+ }
+ if (first.isPresent() && first.get().type() == Dimension.Type.mapped) {
+ throw new IllegalArgumentException("Bad concat dimension "+concatDimension+" in lhs: "+lhs);
+ }
+ if (second.isPresent() && second.get().type() == Dimension.Type.mapped) {
+ throw new IllegalArgumentException("Bad concat dimension "+concatDimension+" in rhs: "+rhs);
+ }
+ if (first.isPresent() && first.get().type() == Dimension.Type.indexedUnbound) {
+ map.put(concatDimension, first.get());
+ } else if (second.isPresent() && second.get().type() == Dimension.Type.indexedUnbound) {
+ map.put(concatDimension, second.get());
+ } else {
+ long concatSize = 0;
+ if (first.isPresent()) {
+ concatSize += first.get().size().get();
+ } else {
+ concatSize += 1;
+ }
+ if (second.isPresent()) {
+ concatSize += second.get().size().get();
+ } else {
+ concatSize += 1;
+ }
+ map.put(concatDimension, Dimension.indexed(concatDimension, concatSize));
+ }
+ return new TensorType(cellType, map.values());
+ }
+
+}
+
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
new file mode 100644
index 00000000000..5271ed6f42c
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
@@ -0,0 +1,455 @@
+// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertThat;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/**
+ * @author arnej
+ */
+public class TypeResolverTestCase {
+
+ private static List<String> mkl(String ...values) {
+ return Arrays.asList(values);
+ }
+
+ @Test
+ public void verifyMap() {
+ checkMap("tensor()", "tensor()");
+ checkMap("tensor(x[10])", "tensor(x[10])");
+ checkMap("tensor(a[10],b[20],c[30])", "tensor(a[10],b[20],c[30])");
+ checkMap("tensor(y{})", "tensor(y{})");
+ checkMap("tensor(x[10],y{})", "tensor(x[10],y{})");
+ checkMap("tensor<float>(x[10])", "tensor<float>(x[10])");
+ checkMap("tensor<float>(y{})", "tensor<float>(y{})");
+ checkMap("tensor<bfloat16>(x[10])", "tensor<float>(x[10])");
+ checkMap("tensor<bfloat16>(y{})", "tensor<float>(y{})");
+ checkMap("tensor<int8>(x[10])", "tensor<float>(x[10])");
+ checkMap("tensor<int8>(y{})", "tensor<float>(y{})");
+ }
+
+ @Test
+ public void verifyJoin() {
+ checkJoin("tensor()", "tensor()", "tensor()");
+ checkJoin("tensor()", "tensor(x{})", "tensor(x{})");
+ checkJoin("tensor(x{})", "tensor()", "tensor(x{})");
+ checkJoin("tensor(x{})", "tensor(x{})", "tensor(x{})");
+ checkJoin("tensor(x{})", "tensor(y{})", "tensor(x{},y{})");
+ checkJoin("tensor(x{},y{})", "tensor(y{},z{})", "tensor(x{},y{},z{})");
+ checkJoin("tensor(y{})", "tensor()", "tensor(y{})");
+ checkJoin("tensor(y{})", "tensor(y{})", "tensor(y{})");
+ checkJoin("tensor(a[10])", "tensor(a[10])", "tensor(a[10])");
+ checkJoin("tensor(a[10])", "tensor()", "tensor(a[10])");
+ checkJoin("tensor(a[10])", "tensor(x{},y{},z{})", "tensor(a[10],x{},y{},z{})");
+ // with cell types
+ checkJoin("tensor<bfloat16>(x[5])", "tensor<bfloat16>(x[5])", "tensor<float>(x[5])");
+ checkJoin("tensor<bfloat16>(x[5])", "tensor<float>(x[5])", "tensor<float>(x[5])");
+ checkJoin("tensor<bfloat16>(x[5])", "tensor<int8>(x[5])", "tensor<float>(x[5])");
+ checkJoin("tensor<bfloat16>(x[5])", "tensor()", "tensor<float>(x[5])");
+ checkJoin("tensor<bfloat16>(x[5])", "tensor(x[5])", "tensor(x[5])");
+ checkJoin("tensor<bfloat16>(x{})", "tensor<bfloat16>(y{})", "tensor<float>(x{},y{})");
+ checkJoin("tensor<bfloat16>(x{})", "tensor<int8>(y{})", "tensor<float>(x{},y{})");
+ checkJoin("tensor<bfloat16>(x{})", "tensor()", "tensor<float>(x{})");
+ checkJoin("tensor<float>(x[5])", "tensor<float>(x[5])", "tensor<float>(x[5])");
+ checkJoin("tensor<float>(x[5])", "tensor<int8>(x[5])", "tensor<float>(x[5])");
+ checkJoin("tensor<float>(x[5])", "tensor()", "tensor<float>(x[5])");
+ checkJoin("tensor<float>(x[5])", "tensor(x[5])", "tensor(x[5])");
+ checkJoin("tensor<float>(x{})", "tensor<bfloat16>(y{})", "tensor<float>(x{},y{})");
+ checkJoin("tensor<float>(x{})", "tensor<float>(y{})", "tensor<float>(x{},y{})");
+ checkJoin("tensor<float>(x{})", "tensor<int8>(y{})", "tensor<float>(x{},y{})");
+ checkJoin("tensor<float>(x{})", "tensor()", "tensor<float>(x{})");
+ checkJoin("tensor<int8>(x[5])", "tensor<int8>(x[5])", "tensor<float>(x[5])");
+ checkJoin("tensor<int8>(x{})", "tensor<int8>(y{})", "tensor<float>(x{},y{})");
+ checkJoin("tensor<int8>(x{})", "tensor()", "tensor<float>(x{})");
+ checkJoin("tensor()", "tensor<int8>(x[5])", "tensor<float>(x[5])");
+ checkJoin("tensor(x[5])", "tensor<int8>(x[5])", "tensor(x[5])");
+ checkJoin("tensor(x[5])", "tensor(x[5])", "tensor(x[5])");
+ checkJoin("tensor(x{})", "tensor<bfloat16>(y{})", "tensor(x{},y{})");
+ checkJoin("tensor(x{})", "tensor<float>(y{})", "tensor(x{},y{})");
+ checkJoin("tensor(x{})", "tensor<int8>(y{})", "tensor(x{},y{})");
+ // dimension mismatch should fail:
+ checkJoinFails("tensor(x[3])", "tensor(x[5])");
+ checkJoinFails("tensor(x[5])", "tensor(x[3])");
+ checkJoinFails("tensor(x{})", "tensor(x[5])");
+ }
+
+ @Test
+ public void verifyReduce() {
+ checkFullReduce("tensor()");
+ checkReduce("tensor(x[10],y[20],z[30])", mkl("x"), "tensor(y[20],z[30])");
+ checkReduce("tensor(x[10],y[20],z[30])", mkl("y"), "tensor(x[10],z[30])");
+ checkReduce("tensor<float>(x[10],y[20],z[30])", mkl("z"), "tensor<float>(x[10],y[20])");
+ checkReduce("tensor<bfloat16>(x[10],y[20],z[30])", mkl("z"), "tensor<float>(x[10],y[20])");
+ checkReduce("tensor<int8>(x[10],y[20],z[30])", mkl("z"), "tensor<float>(x[10],y[20])");
+ checkReduce("tensor(x[10],y[20],z[30])", mkl("x", "z"), "tensor(y[20])");
+ checkReduce("tensor<float>(x[10],y[20],z[30])", mkl("z", "x"), "tensor<float>(y[20])");
+ checkReduce("tensor<bfloat16>(x[10],y[20],z[30])", mkl("z", "x"), "tensor<float>(y[20])");
+ checkReduce("tensor<int8>(x[10],y[20],z[30])", mkl("z", "x"), "tensor<float>(y[20])");
+ checkFullReduce("tensor(x[10],y[20],z[30])");
+ checkFullReduce("tensor<float>(x[10],y[20],z[30])");
+ checkFullReduce("tensor<bfloat16>(x[10],y[20],z[30])");
+ checkFullReduce("tensor<int8>(x[10],y[20],z[30])");
+ checkReduce("tensor(x[10],y[20],z[30])", mkl("x", "y", "z"), "tensor()");
+ checkReduce("tensor<float>(x[10],y[20],z[30])", mkl("x", "y", "z"), "tensor()");
+ checkReduce("tensor<bfloat16>(x[10],y[20],z[30])", mkl("x", "y", "z"), "tensor()");
+ checkReduce("tensor<int8>(x[10],y[20],z[30])", mkl("x", "y", "z"), "tensor()");
+ checkReduce("tensor(x[10],y{},z[30])", mkl("x"), "tensor(y{},z[30])");
+ checkReduce("tensor(x[10],y{},z[30])", mkl("y"), "tensor(x[10],z[30])");
+ checkReduce("tensor<float>(x[10],y{},z[30])", mkl("z"), "tensor<float>(x[10],y{})");
+ checkReduce("tensor<bfloat16>(x[10],y{},z[30])", mkl("z"), "tensor<float>(x[10],y{})");
+ checkReduce("tensor<int8>(x[10],y{},z[30])", mkl("z"), "tensor<float>(x[10],y{})");
+ checkReduce("tensor(x[10],y{},z[30])", mkl("x", "z"), "tensor(y{})");
+ checkReduce("tensor<float>(x[10],y{},z[30])", mkl("z", "x"), "tensor<float>(y{})");
+ checkReduce("tensor<bfloat16>(x[10],y{},z[30])", mkl("z", "x"), "tensor<float>(y{})");
+ checkReduce("tensor<int8>(x[10],y{},z[30])", mkl("z", "x"), "tensor<float>(y{})");
+ checkFullReduce("tensor(x[10],y{},z[30])");
+ checkFullReduce("tensor<float>(x[10],y{},z[30])");
+ checkFullReduce("tensor<bfloat16>(x[10],y{},z[30])");
+ checkFullReduce("tensor<int8>(x[10],y{},z[30])");
+ checkReduce("tensor(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()");
+ checkReduce("tensor<float>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()");
+ checkReduce("tensor<bfloat16>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()");
+ checkReduce("tensor<int8>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()");
+ checkReduceFails("tensor(y{})", "x");
+ checkReduceFails("tensor<float>(y[10])", "x");
+ }
+
+ @Test
+ public void verifyMerge() {
+ checkMerge("tensor(a[10])", "tensor(a[10])", "tensor(a[10])");
+ checkMerge("tensor<bfloat16>(x[5])", "tensor<bfloat16>(x[5])", "tensor<float>(x[5])");
+ checkMerge("tensor<bfloat16>(x[5])", "tensor<float>(x[5])", "tensor<float>(x[5])");
+ checkMerge("tensor<bfloat16>(x[5])", "tensor<int8>(x[5])", "tensor<float>(x[5])");
+ checkMerge("tensor<bfloat16>(x[5])", "tensor(x[5])", "tensor(x[5])");
+ checkMerge("tensor<bfloat16>(y{})", "tensor<bfloat16>(y{})", "tensor<float>(y{})");
+ checkMerge("tensor<bfloat16>(y{})", "tensor<int8>(y{})", "tensor<float>(y{})");
+ checkMerge("tensor<float>(x[5])", "tensor<float>(x[5])", "tensor<float>(x[5])");
+ checkMerge("tensor<float>(x[5])", "tensor<int8>(x[5])", "tensor<float>(x[5])");
+ checkMerge("tensor<float>(x[5])", "tensor(x[5])", "tensor(x[5])");
+ checkMerge("tensor<float>(y{})", "tensor<bfloat16>(y{})", "tensor<float>(y{})");
+ checkMerge("tensor<float>(y{})", "tensor<float>(y{})", "tensor<float>(y{})");
+ checkMerge("tensor<float>(y{})", "tensor<int8>(y{})", "tensor<float>(y{})");
+ checkMerge("tensor<int8>(x[5])", "tensor<int8>(x[5])", "tensor<float>(x[5])");
+ checkMerge("tensor<int8>(y{})", "tensor<int8>(y{})", "tensor<float>(y{})");
+ checkMerge("tensor()", "tensor()", "tensor()");
+ checkMerge("tensor(x[5])", "tensor<int8>(x[5])", "tensor(x[5])");
+ checkMerge("tensor(x[5])", "tensor(x[5])", "tensor(x[5])");
+ checkMerge("tensor(x{})", "tensor(x{})", "tensor(x{})");
+ checkMerge("tensor(x{},y{})", "tensor<bfloat16>(x{},y{})", "tensor(x{},y{})");
+ checkMerge("tensor(x{},y{})", "tensor<float>(x{},y{})", "tensor(x{},y{})");
+ checkMerge("tensor(x{},y{})", "tensor<int8>(x{},y{})", "tensor(x{},y{})");
+ checkMerge("tensor(y{})", "tensor(y{})", "tensor(y{})");
+ checkMergeFails("tensor(a[10])", "tensor()");
+ checkMergeFails("tensor(a[10])", "tensor(x{},y{},z{})");
+ checkMergeFails("tensor<bfloat16>(x[5])", "tensor()");
+ checkMergeFails("tensor<bfloat16>(x{})", "tensor()");
+ checkMergeFails("tensor<float>(x[5])", "tensor()");
+ checkMergeFails("tensor<float>(x{})", "tensor()");
+ checkMergeFails("tensor<int8>(x{})", "tensor()");
+ checkMergeFails("tensor()", "tensor<int8>(x[5])");
+ checkMergeFails("tensor()", "tensor(x{})");
+ checkMergeFails("tensor(x[3])", "tensor(x[5])");
+ checkMergeFails("tensor(x[5])", "tensor(x[3])");
+ checkMergeFails("tensor(x{})", "tensor()");
+ checkMergeFails("tensor(x{})", "tensor(x[5])");
+ checkMergeFails("tensor(x{},y{})", "tensor(x{},z{})");
+ checkMergeFails("tensor(y{})", "tensor()");
+ }
+
+ @Test
+ public void verifyRename() {
+ checkRename("tensor(x[10],y[20],z[30])", mkl("y"), mkl("a"), "tensor(a[20],x[10],z[30])");
+ checkRename("tensor(x{})", mkl("x"), mkl("y"), "tensor(y{})");
+ checkRename("tensor(x{},y[5])", mkl("x","y"), mkl("y","x"), "tensor(x[5],y{})");
+ checkRename("tensor(x[10],y[20],z[30])", mkl("x", "y", "z"), mkl("c", "a", "b"), "tensor(a[20],b[30],c[10])");
+ checkRename("tensor(x{})", mkl("x"), mkl("x"), "tensor(x{})");
+ checkRename("tensor(x{})", mkl("x"), mkl("y"), "tensor(y{})");
+ checkRename("tensor<float>(x{})", mkl("x"), mkl("y"), "tensor<float>(y{})");
+ checkRename("tensor<bfloat16>(x{})", mkl("x"), mkl("y"), "tensor<bfloat16>(y{})");
+ checkRename("tensor<int8>(x{})", mkl("x"), mkl("y"), "tensor<int8>(y{})");
+
+ checkRenameFails("tensor(x{})", mkl(), mkl());
+ checkRenameFails("tensor()", mkl(), mkl());
+ checkRenameFails("tensor(x{},y{})", mkl("x"), mkl("y","z"));
+ checkRenameFails("tensor(x{},y{})", mkl("x","y"), mkl("z"));
+ checkRenameFails("tensor()", mkl("a"), mkl("b"));
+ checkRenameFails("tensor(x[10],y[20],z[30])", mkl("y","z"), mkl("a", "x"));
+ }
+
+ @Test
+ public void verifyConcat() {
+ // types can be concatenated
+ checkConcat("tensor(y[7])", "tensor(x{})", "z", "tensor(x{},y[7],z[2])");
+ checkConcat("tensor()", "tensor()", "x", "tensor(x[2])");
+ checkConcat("tensor(x[2])", "tensor()", "x", "tensor(x[3])");
+ checkConcat("tensor(x[3])", "tensor(x[2])", "x", "tensor(x[5])");
+ checkConcat("tensor(x[2])", "tensor()", "y", "tensor(x[2],y[2])");
+ checkConcat("tensor(x[2])", "tensor(x[2])", "y", "tensor(x[2],y[2])");
+ checkConcat("tensor(x[2],y[2])", "tensor(x[3])", "x", "tensor(x[5],y[2])");
+ checkConcat("tensor(x[2],y[2])", "tensor(y[7])", "y", "tensor(x[2],y[9])");
+ checkConcat("tensor(x[5])", "tensor(y[7])", "z", "tensor(x[5],y[7],z[2])");
+ // cell type is handled correctly for concat
+ checkConcat("tensor(x[3])", "tensor(x[2])", "x", "tensor(x[5])");
+ checkConcat("tensor(x[3])", "tensor<float>(x[2])", "x", "tensor(x[5])");
+ checkConcat("tensor(x[3])", "tensor<bfloat16>(x[2])", "x", "tensor(x[5])");
+ checkConcat("tensor(x[3])", "tensor<int8>(x[2])", "x", "tensor(x[5])");
+ checkConcat("tensor<float>(x[3])", "tensor<float>(x[2])", "x", "tensor<float>(x[5])");
+ checkConcat("tensor<float>(x[3])", "tensor<bfloat16>(x[2])", "x", "tensor<float>(x[5])");
+ checkConcat("tensor<float>(x[3])", "tensor<int8>(x[2])", "x", "tensor<float>(x[5])");
+ checkConcat("tensor<bfloat16>(x[3])", "tensor<bfloat16>(x[2])", "x", "tensor<bfloat16>(x[5])");
+ checkConcat("tensor<bfloat16>(x[3])", "tensor<int8>(x[2])", "x", "tensor<float>(x[5])");
+ checkConcat("tensor<int8>(x[3])", "tensor<int8>(x[2])", "x", "tensor<int8>(x[5])");
+ // concat with number preserves cell type
+ checkConcat("tensor(x[3])", "tensor()", "x", "tensor(x[4])");
+ checkConcat("tensor<float>(x[3])", "tensor()", "x", "tensor<float>(x[4])");
+ checkConcat("tensor<bfloat16>(x[3])", "tensor()", "x", "tensor<bfloat16>(x[4])");
+ checkConcat("tensor<int8>(x[3])", "tensor()", "x", "tensor<int8>(x[4])");
+ // invalid combinations must fail
+ checkConcatFails("tensor(x{})", "tensor(x[2])", "x");
+ checkConcatFails("tensor(x{})", "tensor(x{})", "x");
+ checkConcatFails("tensor(x{})", "tensor()", "x");
+ checkConcatFails("tensor(x[3])", "tensor(x[2])", "y");
+ }
+
+ @Test
+ public void verifyPeek() {
+ checkPeek("tensor(x[10],y[20],z[30])", mkl("x"), "tensor(y[20],z[30])");
+ checkPeek("tensor(x[10],y[20],z[30])", mkl("y"), "tensor(x[10],z[30])");
+ checkPeek("tensor<float>(x[10],y[20],z[30])", mkl("z"), "tensor<float>(x[10],y[20])");
+ checkPeek("tensor<bfloat16>(x[10],y[20],z[30])", mkl("z"), "tensor<bfloat16>(x[10],y[20])");
+ checkPeek("tensor<int8>(x[10],y[20],z[30])", mkl("z"), "tensor<int8>(x[10],y[20])");
+ checkPeek("tensor(x[10],y[20],z[30])", mkl("x", "z"), "tensor(y[20])");
+ checkPeek("tensor<float>(x[10],y[20],z[30])", mkl("z", "x"), "tensor<float>(y[20])");
+ checkPeek("tensor<bfloat16>(x[10],y[20],z[30])", mkl("z", "x"), "tensor<bfloat16>(y[20])");
+ checkPeek("tensor<int8>(x[10],y[20],z[30])", mkl("z", "x"), "tensor<int8>(y[20])");
+ checkPeek("tensor(x[10],y[20],z[30])", mkl("x", "y", "z"), "tensor()");
+ checkPeek("tensor<float>(x[10],y[20],z[30])", mkl("x", "y", "z"), "tensor()");
+ checkPeek("tensor<bfloat16>(x[10],y[20],z[30])", mkl("x", "y", "z"), "tensor()");
+ checkPeek("tensor<int8>(x[10],y[20],z[30])", mkl("x", "y", "z"), "tensor()");
+ checkPeek("tensor(x[10],y{},z[30])", mkl("x"), "tensor(y{},z[30])");
+ checkPeek("tensor(x[10],y{},z[30])", mkl("y"), "tensor(x[10],z[30])");
+ checkPeek("tensor<float>(x[10],y{},z[30])", mkl("z"), "tensor<float>(x[10],y{})");
+ checkPeek("tensor<bfloat16>(x[10],y{},z[30])", mkl("z"), "tensor<bfloat16>(x[10],y{})");
+ checkPeek("tensor<int8>(x[10],y{},z[30])", mkl("z"), "tensor<int8>(x[10],y{})");
+ checkPeek("tensor(x[10],y{},z[30])", mkl("x", "z"), "tensor(y{})");
+ checkPeek("tensor<float>(x[10],y{},z[30])", mkl("z", "x"), "tensor<float>(y{})");
+ checkPeek("tensor<bfloat16>(x[10],y{},z[30])", mkl("z", "x"), "tensor<bfloat16>(y{})");
+ checkPeek("tensor<int8>(x[10],y{},z[30])", mkl("z", "x"), "tensor<int8>(y{})");
+ checkPeek("tensor(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()");
+ checkPeek("tensor<float>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()");
+ checkPeek("tensor<bfloat16>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()");
+ checkPeek("tensor<int8>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()");
+ checkFullPeek("tensor(x[10],y[20],z[30])");
+ checkFullPeek("tensor<float>(x[10],y[20],z[30])");
+ checkFullPeek("tensor<bfloat16>(x[10],y[20],z[30])");
+ checkFullPeek("tensor<int8>(x[10],y[20],z[30])");
+ checkFullPeek("tensor(x[10],y{},z[30])");
+ checkFullPeek("tensor<float>(x[10],y{},z[30])");
+ checkFullPeek("tensor<bfloat16>(x[10],y{},z[30])");
+ checkFullPeek("tensor<int8>(x[10],y{},z[30])");
+ checkPeekFails("tensor()", mkl());
+ checkPeekFails("tensor()", mkl("x"));
+ checkPeekFails("tensor(y{})", mkl("x"));
+ checkPeekFails("tensor(y{})", mkl("y", "y"));
+ checkPeekFails("tensor<float>(y[10])", mkl("x"));
+ }
+
+ @Test
+ public void verifyCellCast() {
+ checkCast("tensor(x[10],y[20],z[30])", TensorType.Value.FLOAT, "tensor<float>(x[10],y[20],z[30])");
+ checkCasts("tensor<double>(x[10])");
+ checkCasts("tensor<float>(x[10])");
+ checkCasts("tensor<bfloat16>(x[10])");
+ checkCasts("tensor<int8>(x[10])");
+ checkCasts("tensor<double>(x{})");
+ checkCasts("tensor<float>(x{})");
+ checkCasts("tensor<bfloat16>(x{})");
+ checkCasts("tensor<int8>(x{})");
+ checkCasts("tensor<double>(x{},y[5])");
+ checkCasts("tensor<float>(x{},y[5])");
+ checkCasts("tensor<bfloat16>(x{},y[5])");
+ checkCasts("tensor<int8>(x{},y[5])");
+ checkCast("tensor()", TensorType.Value.DOUBLE, "tensor()");
+ checkCastFails("tensor()", TensorType.Value.FLOAT);
+ checkCastFails("tensor()", TensorType.Value.BFLOAT16);
+ checkCastFails("tensor()", TensorType.Value.INT8);
+ }
+
+ private static void checkMap(String specA, String expected) {
+ var a = TensorType.fromSpec(specA);
+ var result = TypeResolver.map(a);
+ assertEquals(expected, result.toString());
+ }
+
+ private static void checkJoin(String specA, String specB, String expected) {
+ var a = TensorType.fromSpec(specA);
+ var b = TensorType.fromSpec(specB);
+ var result = TypeResolver.join(a, b);
+ assertEquals(expected, result.toString());
+ }
+
+ private static void checkJoinFails(String specA, String specB) {
+ var a = TensorType.fromSpec(specA);
+ var b = TensorType.fromSpec(specB);
+ boolean caught = false;
+ try {
+ var result = TypeResolver.join(a, b);
+ System.err.println("join of "+a+" and "+b+" produces: "+result);
+ } catch (IllegalArgumentException e) {
+ caught = true;
+ }
+ assertTrue(caught);
+ }
+
+ private static void checkReduce(String specA, List<String> dims, String expected) {
+ var a = TensorType.fromSpec(specA);
+ var result = TypeResolver.reduce(a, dims);
+ assertEquals(expected, result.toString());
+ }
+
+ private static void checkFullReduce(String specA) {
+ String expected = "tensor()";
+ List<String> dims = new ArrayList<>();
+ checkReduce(specA, dims, expected);
+ var a = TensorType.fromSpec(specA);
+ for (var dim : a.dimensions()) {
+ dims.add(dim.name());
+ }
+ checkReduce(specA, dims, expected);
+ }
+
+ private static void checkReduceFails(String specA, String dim) {
+ var a = TensorType.fromSpec(specA);
+ boolean caught = false;
+ try {
+ var result = TypeResolver.reduce(a, mkl(dim));
+ System.err.println("Reduce "+specA+" with dim "+dim+" produces: "+result);
+ } catch (IllegalArgumentException e) {
+ caught = true;
+ }
+ assertTrue(caught);
+ }
+
+ private static void checkMerge(String specA, String specB, String expected) {
+ var a = TensorType.fromSpec(specA);
+ var b = TensorType.fromSpec(specB);
+ var result = TypeResolver.merge(a, b);
+ assertEquals(expected, result.toString());
+ }
+
+ private static void checkMergeFails(String specA, String specB) {
+ var a = TensorType.fromSpec(specA);
+ var b = TensorType.fromSpec(specB);
+ boolean caught = false;
+ try {
+ var result = TypeResolver.merge(a, b);
+ System.err.println("merge of "+a+" and "+b+" produces: "+result);
+ } catch (IllegalArgumentException e) {
+ caught = true;
+ }
+ assertTrue(caught);
+ }
+
+ private static void checkRename(String specA, List<String> fromDims, List<String> toDims, String expected) {
+ var a = TensorType.fromSpec(specA);
+ var result = TypeResolver.rename(a, fromDims, toDims);
+ assertEquals(expected, result.toString());
+ }
+
+ private static void checkRenameFails(String specA, List<String> fromDims, List<String> toDims) {
+ var a = TensorType.fromSpec(specA);
+ boolean caught = false;
+ try {
+ var result = TypeResolver.rename(a, fromDims, toDims);
+ System.err.println("rename "+a+" produces: "+result);
+ } catch (IllegalArgumentException e) {
+ caught = true;
+ }
+ assertTrue(caught);
+ }
+
+ private static void checkConcat(String specA, String specB, String dim, String expected) {
+ var a = TensorType.fromSpec(specA);
+ var b = TensorType.fromSpec(specB);
+ var result = TypeResolver.concat(a, b, dim);
+ assertEquals(expected, result.toString());
+ }
+
+ private static void checkConcatFails(String specA, String specB, String dim) {
+ var a = TensorType.fromSpec(specA);
+ var b = TensorType.fromSpec(specB);
+ boolean caught = false;
+ try {
+ var result = TypeResolver.concat(a, b, dim);
+ System.err.println("concat "+a+" and "+b+" along "+dim+" produces: "+result);
+ } catch (IllegalArgumentException e) {
+ caught = true;
+ }
+ assertTrue(caught);
+ }
+
+ private static void checkPeek(String specA, List<String> dims, String expected) {
+ var a = TensorType.fromSpec(specA);
+ var result = TypeResolver.peek(a, dims);
+ assertEquals(expected, result.toString());
+ }
+
+ private static void checkFullPeek(String specA) {
+ String expected = "tensor()";
+ List<String> dims = new ArrayList<>();
+ var a = TensorType.fromSpec(specA);
+ for (var dim : a.dimensions()) {
+ dims.add(dim.name());
+ }
+ checkPeek(specA, dims, expected);
+ }
+
+ private static void checkPeekFails(String specA, List<String> dims) {
+ var a = TensorType.fromSpec(specA);
+ boolean caught = false;
+ try {
+ var result = TypeResolver.peek(a, dims);
+ System.err.println("Peek "+specA+" with dims "+dims+" produces: "+result);
+ } catch (IllegalArgumentException e) {
+ caught = true;
+ }
+ assertTrue(caught);
+ }
+
+ private static void checkCast(String specA, TensorType.Value newValueType, String expected) {
+ var a = TensorType.fromSpec(specA);
+ var result = TypeResolver.cell_cast(a, newValueType);
+ assertEquals(expected, result.toString());
+ }
+
+ private static void checkCasts(String specA) {
+ var a = TensorType.fromSpec(specA);
+ for (var newValueType : TensorType.Value.values()) {
+ var result = TypeResolver.cell_cast(a, newValueType);
+ assertEquals(result.valueType(), newValueType);
+ assertEquals(result.dimensions(), a.dimensions());
+ }
+ }
+
+ private static void checkCastFails(String specA, TensorType.Value newValueType) {
+ var a = TensorType.fromSpec(specA);
+ boolean caught = false;
+ try {
+ var result = TypeResolver.cell_cast(a, newValueType);
+ System.err.println("cast of "+a+" to "+newValueType+" produces: "+result);
+ } catch (IllegalArgumentException e) {
+ caught = true;
+ }
+ assertTrue(caught);
+ }
+
+}