aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorArne Juul <arnej@vespa.ai>2023-11-02 08:47:23 +0000
committerArne Juul <arnej@vespa.ai>2023-11-02 19:54:24 +0000
commitbd9d7a9f74d41f2e88694aa2f1629ced0bca6428 (patch)
treeaf40320eae453618b6c00b854f2cf5d72d17e26e /vespajlib/src
parent96f6abe9caa338074ee39cb2fd566d3efff464c9 (diff)
add reference implementation of MapSubspaces
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java55
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java146
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/MapSubspacesTestCase.java213
6 files changed, 423 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 084eaf2bf98..7f890a9ec51 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -80,7 +80,7 @@ public class TensorType {
};
/** The empty tensor type - which is the same as a double */
- public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList());
+ public static final TensorType empty = new TensorType();
private final Value valueType;
@@ -90,6 +90,13 @@ public class TensorType {
private final TensorType mappedSubtype;
private final TensorType indexedSubtype;
+ private TensorType() {
+ this.valueType = Value.DOUBLE;
+ this.dimensions = List.of();
+ this.mappedSubtype = this;
+ this.indexedSubtype = this;
+ }
+
public TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
index d1407bf3d9b..d875f1ef4eb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
@@ -26,5 +26,4 @@ public interface TypeContext<NAMETYPE extends Name> {
*/
TensorType getType(String name);
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java
new file mode 100644
index 00000000000..b6655a15361
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java
@@ -0,0 +1,55 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * A function suitable for use in MapSubspaces
+ *
+ * @author arnej
+ */
+class DenseSubspaceFunction<NAMETYPE extends Name> {
+
+ private final String argName;
+ private final TensorFunction<NAMETYPE> function;
+
+ public DenseSubspaceFunction(String argName, TensorFunction<NAMETYPE> function) {
+ this.argName = argName;
+ this.function = function;
+ }
+
+ Tensor map(Tensor subspace) {
+ var context = new MapEvaluationContext<NAMETYPE>();
+ context.put(argName, subspace);
+ return function.evaluate(context);
+ }
+
+ class MyTypeContext implements TypeContext<NAMETYPE> {
+ private final TensorType subspaceType;
+ MyTypeContext(TensorType subspaceType) { this.subspaceType = subspaceType; }
+ public TensorType getType(NAMETYPE name) { return getType(name.name()); }
+ public TensorType getType(String name) { return argName.equals(name) ? subspaceType : null; }
+ }
+
+ TensorType outputType(TensorType subspaceType) {
+ var context = new MyTypeContext(subspaceType);
+ var result = function.type(context);
+ if (result.mappedSubtype().rank() > 0) {
+ throw new IllegalArgumentException("function used in map_subspaces type had mapped dimensions: " + result);
+ }
+ return result;
+ }
+
+ public String toString() {
+ return "f(" + argName + ")(" + function + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
new file mode 100644
index 00000000000..c87ef42976d
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
@@ -0,0 +1,146 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * The <i>map_subspaces</i> tensor function transforms each dense subspace in a (mixed) tensor
+ *
+ * @author arnej
+ */
+public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> argument;
+ private final DenseSubspaceFunction<NAMETYPE> function;
+
+ private MapSubspaces(TensorFunction<NAMETYPE> argument, DenseSubspaceFunction<NAMETYPE> function) {
+ this.argument = argument;
+ this.function = function;
+ }
+ public MapSubspaces(TensorFunction<NAMETYPE> argument, String functionArg, TensorFunction<NAMETYPE> function) {
+ this(argument, new DenseSubspaceFunction<>(functionArg, function));
+ Objects.requireNonNull(argument, "The argument cannot be null");
+ Objects.requireNonNull(functionArg, "The functionArg cannot be null");
+ Objects.requireNonNull(function, "The function cannot be null");
+ }
+
+ private TensorType outputType(TensorType inputType) {
+ var m = inputType.mappedSubtype();
+ var d = function.outputType(inputType.indexedSubtype());
+ if (m.rank() == 0) {
+ return d;
+ }
+ if (d.rank() == 0) {
+ return TypeResolver.map(m); // decay cell type
+ }
+ TensorType.Value cellType = d.valueType();
+ Map<String, TensorType.Dimension> dims = new HashMap<>();
+ for (var dim : m.dimensions()) {
+ dims.put(dim.name(), dim);
+ }
+ for (var dim : d.dimensions()) {
+ var old = dims.put(dim.name(), dim);
+ if (old != null) {
+ throw new IllegalArgumentException("dimension name collision in map_subspaces: " + m + " vs " + d);
+ }
+ }
+ return new TensorType(cellType, dims.values());
+ }
+
+ public TensorFunction<NAMETYPE> argument() { return argument; }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("MapSubspaces must have 1 argument, got " + arguments.size());
+ return new MapSubspaces<NAMETYPE>(arguments.get(0), function);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new MapSubspaces<>(argument.toPrimitive(), function);
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return outputType(argument.type(context));
+ }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor input = argument().evaluate(context);
+ TensorType inputType = input.type();
+ TensorType inputTypeMapped = inputType.mappedSubtype();
+ TensorType inputTypeDense = inputType.indexedSubtype();
+ Map<TensorAddress, Tensor.Builder> builders = new HashMap<>();
+ for (Iterator<Tensor.Cell> iter = input.cellIterator(); iter.hasNext(); ) {
+ var cell = iter.next();
+ var fullAddr = cell.getKey();
+ var mapAddrBuilder = new TensorAddress.Builder(inputTypeMapped);
+ var idxAddrBuilder = new TensorAddress.Builder(inputTypeDense);
+ for (int i = 0; i < inputType.dimensions().size(); i++) {
+ var dim = inputType.dimensions().get(i);
+ if (dim.isMapped()) {
+ mapAddrBuilder.add(dim.name(), fullAddr.label(i));
+ } else {
+ idxAddrBuilder.add(dim.name(), fullAddr.label(i));
+ }
+ }
+ var mapAddr = mapAddrBuilder.build();
+ var builder = builders.computeIfAbsent(mapAddr, k -> Tensor.Builder.of(inputTypeDense));
+ var idxAddr = idxAddrBuilder.build();
+ builder.cell(idxAddr, cell.getValue());
+ }
+ TensorType outputType = outputType(input.type());
+ TensorType denseOutputType = outputType.indexedSubtype();
+ var denseOutputDims = denseOutputType.dimensions();
+ Tensor.Builder builder = Tensor.Builder.of(outputType);
+ for (var entry : builders.entrySet()) {
+ TensorAddress mappedAddr = entry.getKey();
+ Tensor denseInput = entry.getValue().build();
+ Tensor denseOutput = function.map(denseInput);
+ // XXX check denseOutput.type().dimensions()
+ for (Iterator<Tensor.Cell> iter = denseOutput.cellIterator(); iter.hasNext(); ) {
+ var cell = iter.next();
+ var denseAddr = cell.getKey();
+ var addrBuilder = new TensorAddress.Builder(outputType);
+ for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) {
+ var dim = inputTypeMapped.dimensions().get(i);
+ addrBuilder.add(dim.name(), mappedAddr.label(i));
+ }
+ for (int i = 0; i < denseOutputDims.size(); i++) {
+ var dim = denseOutputDims.get(i);
+ addrBuilder.add(dim.name(), denseAddr.label(i));
+ }
+ builder.cell(addrBuilder.build(), cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "map_subspaces(" + argument.toString(context) + ", " + function + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("map_subspaces", argument, function); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 05c5f412c39..790c956f9c8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -44,7 +44,7 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
*
* @param context a context which must be passed to all nested functions when evaluating
*/
- public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context);
+ public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context);
/**
* Returns the type of the tensor this produces given the input types in the context
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MapSubspacesTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MapSubspacesTestCase.java
new file mode 100644
index 00000000000..cf26b630b88
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MapSubspacesTestCase.java
@@ -0,0 +1,213 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.evaluation.VariableTensor;
+
+import java.util.List;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author arnej
+ */
+public class MapSubspacesTestCase {
+
+ static class MyCellGenSumNext implements ScalarFunction<Name> {
+ @Override public Double apply(EvaluationContext<Name> context) {
+ Tensor input = context.getTensor("denseInput");
+ long dimIdx = (long) context.getTensor("x").asDouble();
+ var addrA = TensorAddress.of(dimIdx);
+ var addrB = TensorAddress.of(dimIdx + 1);
+ double value = input.get(addrA) + input.get(addrB);
+ return value;
+ }
+ }
+
+ private static Tensor map3to2(Tensor input, String cellType) {
+ TensorType tt = TensorType.fromSpec("tensor<" + cellType + ">(x[2])");
+ var tfun = Generate.<Name>bound(tt, new MyCellGenSumNext());
+ var constInput = new ConstantTensor<Name>(input);
+ var mapper = new MapSubspaces<Name>(constInput, "denseInput", tfun);
+ Tensor mapped = mapper.evaluate();
+ System.err.println("Mapped 3->2: " + mapped);
+ return mapped;
+ }
+
+ private static void checkResult(Tensor expect, Tensor result, TensorType.Value cellType) {
+ Tensor withType = expect.cellCast(cellType);
+ assertEquals(withType, result);
+ assertEquals(cellType, result.type().valueType());
+ }
+
+ @Test
+ public void testBasicMap() {
+ Tensor t1, t2;
+ t1 = Tensor.from("tensor(a{},x[3]):{foo:[1,2,3],bar:[4,5,6]}");
+ t2 = Tensor.from("tensor(a{},x[2]):{foo:[3,5],bar:[9,11]}");
+ checkResult(t2, map3to2(t1, "double"), TensorType.Value.DOUBLE);
+ checkResult(t2, map3to2(t1, "float"), TensorType.Value.FLOAT);
+ checkResult(t2, map3to2(t1, "bfloat16"), TensorType.Value.BFLOAT16);
+ checkResult(t2, map3to2(t1, "int8"), TensorType.Value.INT8);
+ t1 = Tensor.from("tensor(x[3]):[3,4,6]");
+ t2 = Tensor.from("tensor(x[2]):[7,10]");
+ checkResult(t2, map3to2(t1, "double"), TensorType.Value.DOUBLE);
+ checkResult(t2, map3to2(t1, "float"), TensorType.Value.FLOAT);
+ checkResult(t2, map3to2(t1, "bfloat16"), TensorType.Value.BFLOAT16);
+ checkResult(t2, map3to2(t1, "int8"), TensorType.Value.INT8);
+ t1 = Tensor.from("tensor(x[4],z{}):{foo:[1,2,3,99],bar:[4,5,6,99]}");
+ t2 = Tensor.from("tensor(x[2],z{}):{foo:[3,5],bar:[9,11]}");
+ checkResult(t2, map3to2(t1, "double"), TensorType.Value.DOUBLE);
+ checkResult(t2, map3to2(t1, "float"), TensorType.Value.FLOAT);
+ checkResult(t2, map3to2(t1, "bfloat16"), TensorType.Value.BFLOAT16);
+ checkResult(t2, map3to2(t1, "int8"), TensorType.Value.INT8);
+ t1 = Tensor.from("tensor(a{},x[3],z{}):{" +
+ "{a:aa,x:0,z:kz}:1," +
+ "{a:aa,x:1,z:kz}:2," +
+ "{a:aa,x:2,z:kz}:3," +
+ "{a:ba,x:0,z:kz}:4," +
+ "{a:ba,x:1,z:kz}:5," +
+ "{a:ba,x:2,z:kz}:6," +
+ "{a:ba,x:0,z:nz}:7," +
+ "{a:ba,x:1,z:nz}:8," +
+ "{a:ba,x:2,z:nz}:9" + "}");
+ t2 = Tensor.from("tensor(a{},x[2],z{}):{" +
+ "{a:aa,x:0,z:kz}:3," +
+ "{a:aa,x:1,z:kz}:5," +
+ "{a:ba,x:0,z:kz}:9," +
+ "{a:ba,x:1,z:kz}:11," +
+ "{a:ba,x:0,z:nz}:15," +
+ "{a:ba,x:1,z:nz}:17" + "}");
+ checkResult(t2, map3to2(t1, "double"), TensorType.Value.DOUBLE);
+ checkResult(t2, map3to2(t1, "float"), TensorType.Value.FLOAT);
+ checkResult(t2, map3to2(t1, "bfloat16"), TensorType.Value.BFLOAT16);
+ checkResult(t2, map3to2(t1, "int8"), TensorType.Value.INT8);
+ }
+
+ static class MyCellGenFromScalar implements ScalarFunction<Name> {
+ @Override public Double apply(EvaluationContext<Name> context) {
+ double input = context.getTensor("denseInput").asDouble();
+ double dimIdx = context.getTensor("x").asDouble();
+ double value = input + dimIdx * 2;
+ return value;
+ }
+ }
+
+ private static Tensor map1to3(Tensor input, String cellType) {
+ TensorType tt = TensorType.fromSpec("tensor<" + cellType + ">(x[3])");
+ var tfun = Generate.<Name>bound(tt, new MyCellGenFromScalar());
+ var constInput = new ConstantTensor<Name>(input);
+ var mapper = new MapSubspaces<Name>(constInput, "denseInput", tfun);
+ Tensor mapped = mapper.evaluate();
+ System.err.println("Mapped 1->3: " + mapped);
+ return mapped;
+ }
+
+ @Test
+ public void testFromSparse() {
+ Tensor t1, t2;
+ t1 = Tensor.from("tensor(a{}):{foo:2,bar:17}");
+ t2 = Tensor.from("tensor(a{},x[3]):{foo:[2,4,6],bar:[17,19,21]}");
+ checkResult(t2, map1to3(t1, "double"), TensorType.Value.DOUBLE);
+ checkResult(t2, map1to3(t1, "float"), TensorType.Value.FLOAT);
+ checkResult(t2, map1to3(t1, "bfloat16"), TensorType.Value.BFLOAT16);
+ checkResult(t2, map1to3(t1, "int8"), TensorType.Value.INT8);
+ t1 = Tensor.from("tensor():{5}");
+ t2 = Tensor.from("tensor(x[3]):[5,7,9]");
+ checkResult(t2, map1to3(t1, "double"), TensorType.Value.DOUBLE);
+ checkResult(t2, map1to3(t1, "float"), TensorType.Value.FLOAT);
+ checkResult(t2, map1to3(t1, "bfloat16"), TensorType.Value.BFLOAT16);
+ checkResult(t2, map1to3(t1, "int8"), TensorType.Value.INT8);
+ t1 = Tensor.from("tensor<float>(a{}):{foo:2,bar:17}");
+ t2 = Tensor.from("tensor(a{},x[3]):{foo:[2,4,6],bar:[17,19,21]}");
+ checkResult(t2, map1to3(t1, "double"), TensorType.Value.DOUBLE);
+ checkResult(t2, map1to3(t1, "float"), TensorType.Value.FLOAT);
+ checkResult(t2, map1to3(t1, "bfloat16"), TensorType.Value.BFLOAT16);
+ checkResult(t2, map1to3(t1, "int8"), TensorType.Value.INT8);
+ }
+
+ static class MyWeightedSum extends TensorFunction<Name> {
+ public List<TensorFunction<Name>> arguments() { return List.of(); }
+ public TensorFunction<Name> withArguments(List<TensorFunction<Name>> arguments) { return this; }
+ public PrimitiveTensorFunction<Name> toPrimitive() { return null; }
+ public Tensor evaluate(EvaluationContext<Name> context) {
+ Tensor input = context.getTensor("denseInput");
+ double value = 0.0;
+ double w = 8.0;
+ long sz = input.type().dimensions().get(0).size().get();
+ for (long i = 0; i < sz; i++) {
+ var addr = TensorAddress.of(i);
+ value += w * input.get(addr);
+ w = w * 0.5;
+ }
+ return Tensor.from(value);
+ }
+ public TensorType type(TypeContext<Name> context) { return TensorType.empty; }
+ public String toString(ToStringContext<Name> context) { return "MyWeightedSum(denseInput)"; }
+ public int hashCode() { return 0; }
+ }
+
+ private static Tensor mapNto1(Tensor input) {
+ var tfun = new MyWeightedSum();
+ var constInput = new ConstantTensor<Name>(input);
+ var mapper = new MapSubspaces<Name>(constInput, "denseInput", tfun);
+ Tensor mapped = mapper.evaluate();
+ System.err.println("Mapped N->1: " + mapped);
+ return mapped;
+ }
+
+ @Test
+ public void testToSparse() {
+ Tensor t1, t2;
+ t1 = Tensor.from("tensor(a{},x[3]):{foo:[2,4,6],bar:[17,19,21]}");
+ t2 = Tensor.from("tensor(a{}):{foo:44,bar:254}");
+ checkResult(t2, mapNto1(t1), TensorType.Value.DOUBLE);
+ checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT);
+ checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.FLOAT);
+ checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.FLOAT);
+ t1 = Tensor.from("tensor(a{},x[4]):{foo:[2,4,6,8],bar:[1,1,1,1]}");
+ t2 = Tensor.from("tensor(a{}):{foo:52,bar:15}");
+ checkResult(t2, mapNto1(t1), TensorType.Value.DOUBLE);
+ checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT);
+ checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.FLOAT);
+ checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.FLOAT);
+ }
+
+ private static Tensor mapIdentity(Tensor input) {
+ var tfun = new VariableTensor<Name>("denseInput");
+ var constInput = new ConstantTensor<Name>(input);
+ var mapper = new MapSubspaces<Name>(constInput, "denseInput", tfun);
+ Tensor mapped = mapper.evaluate();
+ System.err.println("Identity mapped: " + mapped);
+ return mapped;
+ }
+
+ @Test
+ public void testIdentityMapping() {
+ Tensor t1;
+ t1 = Tensor.from("tensor(a{},x[3]):{foo:[2,4,6],bar:[17,19,21]}");
+ checkResult(t1, mapIdentity(t1), TensorType.Value.DOUBLE);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.BFLOAT16);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.INT8);
+ t1 = Tensor.from("tensor(a{}):{foo:17,bar:42}");
+ checkResult(t1, mapIdentity(t1), TensorType.Value.DOUBLE);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.FLOAT);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.FLOAT);
+ t1 = Tensor.from("tensor(y[4]):[2,3,4,5]");
+ checkResult(t1, mapIdentity(t1), TensorType.Value.DOUBLE);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.BFLOAT16);
+ checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.INT8);
+ t1 = Tensor.from(42);
+ assertEquals(t1, mapIdentity(t1));
+ }
+
+}