diff options
author | Arne Juul <arnej@vespa.ai> | 2023-11-02 08:47:23 +0000 |
---|---|---|
committer | Arne Juul <arnej@vespa.ai> | 2023-11-02 19:54:24 +0000 |
commit | bd9d7a9f74d41f2e88694aa2f1629ced0bca6428 (patch) | |
tree | af40320eae453618b6c00b854f2cf5d72d17e26e /vespajlib | |
parent | 96f6abe9caa338074ee39cb2fd566d3efff464c9 (diff) |
add reference implementation of MapSubspaces
Diffstat (limited to 'vespajlib')
7 files changed, 442 insertions, 3 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 3e588e24d47..5dbdb7d157b 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1886,6 +1886,25 @@ ], "fields" : [ ] }, + "com.yahoo.tensor.functions.MapSubspaces" : { + "superClass" : "com.yahoo.tensor.functions.PrimitiveTensorFunction", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.lang.String, com.yahoo.tensor.functions.TensorFunction)", + "public com.yahoo.tensor.functions.TensorFunction argument()", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public int hashCode()" + ], + "fields" : [ ] + }, "com.yahoo.tensor.functions.Matmul" : { "superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction", "interfaces" : [ ], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 084eaf2bf98..7f890a9ec51 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -80,7 +80,7 @@ public class TensorType { }; /** The empty tensor type - which is the same as a double */ - public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList()); + public static final TensorType empty = new TensorType(); private final Value valueType; @@ -90,6 +90,13 @@ public class TensorType { private final TensorType mappedSubtype; private final TensorType indexedSubtype; + private TensorType() { + this.valueType = Value.DOUBLE; + this.dimensions = List.of(); + this.mappedSubtype = this; + this.indexedSubtype = this; + } + public TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index d1407bf3d9b..d875f1ef4eb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -26,5 +26,4 @@ public interface TypeContext<NAMETYPE extends Name> { */ TensorType getType(String name); - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java new file mode 100644 index 00000000000..b6655a15361 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java @@ -0,0 +1,55 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Optional; +import java.util.function.Function; + +/** + * A function suitable for use in MapSubspaces + * + * @author arnej + */ +class DenseSubspaceFunction<NAMETYPE extends Name> { + + private final String argName; + private final TensorFunction<NAMETYPE> function; + + public DenseSubspaceFunction(String argName, TensorFunction<NAMETYPE> function) { + this.argName = argName; + this.function = function; + } + + Tensor map(Tensor subspace) { + var context = new MapEvaluationContext<NAMETYPE>(); + context.put(argName, subspace); + return function.evaluate(context); + } + + class MyTypeContext implements TypeContext<NAMETYPE> { + private final TensorType subspaceType; + MyTypeContext(TensorType subspaceType) { this.subspaceType = subspaceType; } + public TensorType getType(NAMETYPE name) { return getType(name.name()); } + public TensorType getType(String name) { return argName.equals(name) ? subspaceType : null; } + } + + TensorType outputType(TensorType subspaceType) { + var context = new MyTypeContext(subspaceType); + var result = function.type(context); + if (result.mappedSubtype().rank() > 0) { + throw new IllegalArgumentException("function used in map_subspaces type had mapped dimensions: " + result); + } + return result; + } + + public String toString() { + return "f(" + argName + ")(" + function + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java new file mode 100644 index 00000000000..c87ef42976d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java @@ -0,0 +1,146 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * The <i>map_subspaces</i> tensor function transforms each dense subspace in a (mixed) tensor + * + * @author arnej + */ +public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argument; + private final DenseSubspaceFunction<NAMETYPE> function; + + private MapSubspaces(TensorFunction<NAMETYPE> argument, DenseSubspaceFunction<NAMETYPE> function) { + this.argument = argument; + this.function = function; + } + public MapSubspaces(TensorFunction<NAMETYPE> argument, String functionArg, TensorFunction<NAMETYPE> function) { + this(argument, new DenseSubspaceFunction<>(functionArg, function)); + Objects.requireNonNull(argument, "The argument cannot be null"); + Objects.requireNonNull(functionArg, "The functionArg cannot be null"); + Objects.requireNonNull(function, "The function cannot be null"); + } + + private TensorType outputType(TensorType inputType) { + var m = inputType.mappedSubtype(); + var d = function.outputType(inputType.indexedSubtype()); + if (m.rank() == 0) { + return d; + } + if (d.rank() == 0) { + return TypeResolver.map(m); // decay cell type + } + TensorType.Value cellType = d.valueType(); + Map<String, TensorType.Dimension> dims = new HashMap<>(); + for (var dim : m.dimensions()) { + dims.put(dim.name(), dim); + } + for (var dim : d.dimensions()) { + var old = dims.put(dim.name(), dim); + if (old != null) { + throw new IllegalArgumentException("dimension name collision in map_subspaces: " + m + " vs " + d); + } + } + return new TensorType(cellType, dims.values()); + } + + public TensorFunction<NAMETYPE> argument() { return argument; } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("MapSubspaces must have 1 argument, got " + arguments.size()); + return new MapSubspaces<NAMETYPE>(arguments.get(0), function); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new MapSubspaces<>(argument.toPrimitive(), function); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return outputType(argument.type(context)); + } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor input = argument().evaluate(context); + TensorType inputType = input.type(); + TensorType inputTypeMapped = inputType.mappedSubtype(); + TensorType inputTypeDense = inputType.indexedSubtype(); + Map<TensorAddress, Tensor.Builder> builders = new HashMap<>(); + for (Iterator<Tensor.Cell> iter = input.cellIterator(); iter.hasNext(); ) { + var cell = iter.next(); + var fullAddr = cell.getKey(); + var mapAddrBuilder = new TensorAddress.Builder(inputTypeMapped); + var idxAddrBuilder = new TensorAddress.Builder(inputTypeDense); + for (int i = 0; i < inputType.dimensions().size(); i++) { + var dim = inputType.dimensions().get(i); + if (dim.isMapped()) { + mapAddrBuilder.add(dim.name(), fullAddr.label(i)); + } else { + idxAddrBuilder.add(dim.name(), fullAddr.label(i)); + } + } + var mapAddr = mapAddrBuilder.build(); + var builder = builders.computeIfAbsent(mapAddr, k -> Tensor.Builder.of(inputTypeDense)); + var idxAddr = idxAddrBuilder.build(); + builder.cell(idxAddr, cell.getValue()); + } + TensorType outputType = outputType(input.type()); + TensorType denseOutputType = outputType.indexedSubtype(); + var denseOutputDims = denseOutputType.dimensions(); + Tensor.Builder builder = Tensor.Builder.of(outputType); + for (var entry : builders.entrySet()) { + TensorAddress mappedAddr = entry.getKey(); + Tensor denseInput = entry.getValue().build(); + Tensor denseOutput = function.map(denseInput); + // XXX check denseOutput.type().dimensions() + for (Iterator<Tensor.Cell> iter = denseOutput.cellIterator(); iter.hasNext(); ) { + var cell = iter.next(); + var denseAddr = cell.getKey(); + var addrBuilder = new TensorAddress.Builder(outputType); + for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) { + var dim = inputTypeMapped.dimensions().get(i); + addrBuilder.add(dim.name(), mappedAddr.label(i)); + } + for (int i = 0; i < denseOutputDims.size(); i++) { + var dim = denseOutputDims.get(i); + addrBuilder.add(dim.name(), denseAddr.label(i)); + } + builder.cell(addrBuilder.build(), cell.getValue()); + } + } + return builder.build(); + } + + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "map_subspaces(" + argument.toString(context) + ", " + function + ")"; + } + + @Override + public int hashCode() { return Objects.hash("map_subspaces", argument, function); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 05c5f412c39..790c956f9c8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -44,7 +44,7 @@ public abstract class TensorFunction<NAMETYPE extends Name> { * * @param context a context which must be passed to all nested functions when evaluating */ - public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context); + public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context); /** * Returns the type of the tensor this produces given the input types in the context diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MapSubspacesTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MapSubspacesTestCase.java new file mode 100644 index 00000000000..cf26b630b88 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MapSubspacesTestCase.java @@ -0,0 +1,213 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.evaluation.VariableTensor; + +import java.util.List; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author arnej + */ +public class MapSubspacesTestCase { + + static class MyCellGenSumNext implements ScalarFunction<Name> { + @Override public Double apply(EvaluationContext<Name> context) { + Tensor input = context.getTensor("denseInput"); + long dimIdx = (long) context.getTensor("x").asDouble(); + var addrA = TensorAddress.of(dimIdx); + var addrB = TensorAddress.of(dimIdx + 1); + double value = input.get(addrA) + input.get(addrB); + return value; + } + } + + private static Tensor map3to2(Tensor input, String cellType) { + TensorType tt = TensorType.fromSpec("tensor<" + cellType + ">(x[2])"); + var tfun = Generate.<Name>bound(tt, new MyCellGenSumNext()); + var constInput = new ConstantTensor<Name>(input); + var mapper = new MapSubspaces<Name>(constInput, "denseInput", tfun); + Tensor mapped = mapper.evaluate(); + System.err.println("Mapped 3->2: " + mapped); + return mapped; + } + + private static void checkResult(Tensor expect, Tensor result, TensorType.Value cellType) { + Tensor withType = expect.cellCast(cellType); + assertEquals(withType, result); + assertEquals(cellType, result.type().valueType()); + } + + @Test + public void testBasicMap() { + Tensor t1, t2; + t1 = Tensor.from("tensor(a{},x[3]):{foo:[1,2,3],bar:[4,5,6]}"); + t2 = Tensor.from("tensor(a{},x[2]):{foo:[3,5],bar:[9,11]}"); + checkResult(t2, map3to2(t1, "double"), TensorType.Value.DOUBLE); + checkResult(t2, map3to2(t1, "float"), TensorType.Value.FLOAT); + checkResult(t2, map3to2(t1, "bfloat16"), TensorType.Value.BFLOAT16); + checkResult(t2, map3to2(t1, "int8"), TensorType.Value.INT8); + t1 = Tensor.from("tensor(x[3]):[3,4,6]"); + t2 = Tensor.from("tensor(x[2]):[7,10]"); + checkResult(t2, map3to2(t1, "double"), TensorType.Value.DOUBLE); + checkResult(t2, map3to2(t1, "float"), TensorType.Value.FLOAT); + checkResult(t2, map3to2(t1, "bfloat16"), TensorType.Value.BFLOAT16); + checkResult(t2, map3to2(t1, "int8"), TensorType.Value.INT8); + t1 = Tensor.from("tensor(x[4],z{}):{foo:[1,2,3,99],bar:[4,5,6,99]}"); + t2 = Tensor.from("tensor(x[2],z{}):{foo:[3,5],bar:[9,11]}"); + checkResult(t2, map3to2(t1, "double"), TensorType.Value.DOUBLE); + checkResult(t2, map3to2(t1, "float"), TensorType.Value.FLOAT); + checkResult(t2, map3to2(t1, "bfloat16"), TensorType.Value.BFLOAT16); + checkResult(t2, map3to2(t1, "int8"), TensorType.Value.INT8); + t1 = Tensor.from("tensor(a{},x[3],z{}):{" + + "{a:aa,x:0,z:kz}:1," + + "{a:aa,x:1,z:kz}:2," + + "{a:aa,x:2,z:kz}:3," + + "{a:ba,x:0,z:kz}:4," + + "{a:ba,x:1,z:kz}:5," + + "{a:ba,x:2,z:kz}:6," + + "{a:ba,x:0,z:nz}:7," + + "{a:ba,x:1,z:nz}:8," + + "{a:ba,x:2,z:nz}:9" + "}"); + t2 = Tensor.from("tensor(a{},x[2],z{}):{" + + "{a:aa,x:0,z:kz}:3," + + "{a:aa,x:1,z:kz}:5," + + "{a:ba,x:0,z:kz}:9," + + "{a:ba,x:1,z:kz}:11," + + "{a:ba,x:0,z:nz}:15," + + "{a:ba,x:1,z:nz}:17" + "}"); + checkResult(t2, map3to2(t1, "double"), TensorType.Value.DOUBLE); + checkResult(t2, map3to2(t1, "float"), TensorType.Value.FLOAT); + checkResult(t2, map3to2(t1, "bfloat16"), TensorType.Value.BFLOAT16); + checkResult(t2, map3to2(t1, "int8"), TensorType.Value.INT8); + } + + static class MyCellGenFromScalar implements ScalarFunction<Name> { + @Override public Double apply(EvaluationContext<Name> context) { + double input = context.getTensor("denseInput").asDouble(); + double dimIdx = context.getTensor("x").asDouble(); + double value = input + dimIdx * 2; + return value; + } + } + + private static Tensor map1to3(Tensor input, String cellType) { + TensorType tt = TensorType.fromSpec("tensor<" + cellType + ">(x[3])"); + var tfun = Generate.<Name>bound(tt, new MyCellGenFromScalar()); + var constInput = new ConstantTensor<Name>(input); + var mapper = new MapSubspaces<Name>(constInput, "denseInput", tfun); + Tensor mapped = mapper.evaluate(); + System.err.println("Mapped 1->3: " + mapped); + return mapped; + } + + @Test + public void testFromSparse() { + Tensor t1, t2; + t1 = Tensor.from("tensor(a{}):{foo:2,bar:17}"); + t2 = Tensor.from("tensor(a{},x[3]):{foo:[2,4,6],bar:[17,19,21]}"); + checkResult(t2, map1to3(t1, "double"), TensorType.Value.DOUBLE); + checkResult(t2, map1to3(t1, "float"), TensorType.Value.FLOAT); + checkResult(t2, map1to3(t1, "bfloat16"), TensorType.Value.BFLOAT16); + checkResult(t2, map1to3(t1, "int8"), TensorType.Value.INT8); + t1 = Tensor.from("tensor():{5}"); + t2 = Tensor.from("tensor(x[3]):[5,7,9]"); + checkResult(t2, map1to3(t1, "double"), TensorType.Value.DOUBLE); + checkResult(t2, map1to3(t1, "float"), TensorType.Value.FLOAT); + checkResult(t2, map1to3(t1, "bfloat16"), TensorType.Value.BFLOAT16); + checkResult(t2, map1to3(t1, "int8"), TensorType.Value.INT8); + t1 = Tensor.from("tensor<float>(a{}):{foo:2,bar:17}"); + t2 = Tensor.from("tensor(a{},x[3]):{foo:[2,4,6],bar:[17,19,21]}"); + checkResult(t2, map1to3(t1, "double"), TensorType.Value.DOUBLE); + checkResult(t2, map1to3(t1, "float"), TensorType.Value.FLOAT); + checkResult(t2, map1to3(t1, "bfloat16"), TensorType.Value.BFLOAT16); + checkResult(t2, map1to3(t1, "int8"), TensorType.Value.INT8); + } + + static class MyWeightedSum extends TensorFunction<Name> { + public List<TensorFunction<Name>> arguments() { return List.of(); } + public TensorFunction<Name> withArguments(List<TensorFunction<Name>> arguments) { return this; } + public PrimitiveTensorFunction<Name> toPrimitive() { return null; } + public Tensor evaluate(EvaluationContext<Name> context) { + Tensor input = context.getTensor("denseInput"); + double value = 0.0; + double w = 8.0; + long sz = input.type().dimensions().get(0).size().get(); + for (long i = 0; i < sz; i++) { + var addr = TensorAddress.of(i); + value += w * input.get(addr); + w = w * 0.5; + } + return Tensor.from(value); + } + public TensorType type(TypeContext<Name> context) { return TensorType.empty; } + public String toString(ToStringContext<Name> context) { return "MyWeightedSum(denseInput)"; } + public int hashCode() { return 0; } + } + + private static Tensor mapNto1(Tensor input) { + var tfun = new MyWeightedSum(); + var constInput = new ConstantTensor<Name>(input); + var mapper = new MapSubspaces<Name>(constInput, "denseInput", tfun); + Tensor mapped = mapper.evaluate(); + System.err.println("Mapped N->1: " + mapped); + return mapped; + } + + @Test + public void testToSparse() { + Tensor t1, t2; + t1 = Tensor.from("tensor(a{},x[3]):{foo:[2,4,6],bar:[17,19,21]}"); + t2 = Tensor.from("tensor(a{}):{foo:44,bar:254}"); + checkResult(t2, mapNto1(t1), TensorType.Value.DOUBLE); + checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT); + checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.FLOAT); + checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.FLOAT); + t1 = Tensor.from("tensor(a{},x[4]):{foo:[2,4,6,8],bar:[1,1,1,1]}"); + t2 = Tensor.from("tensor(a{}):{foo:52,bar:15}"); + checkResult(t2, mapNto1(t1), TensorType.Value.DOUBLE); + checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT); + checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.FLOAT); + checkResult(t2, mapNto1(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.FLOAT); + } + + private static Tensor mapIdentity(Tensor input) { + var tfun = new VariableTensor<Name>("denseInput"); + var constInput = new ConstantTensor<Name>(input); + var mapper = new MapSubspaces<Name>(constInput, "denseInput", tfun); + Tensor mapped = mapper.evaluate(); + System.err.println("Identity mapped: " + mapped); + return mapped; + } + + @Test + public void testIdentityMapping() { + Tensor t1; + t1 = Tensor.from("tensor(a{},x[3]):{foo:[2,4,6],bar:[17,19,21]}"); + checkResult(t1, mapIdentity(t1), TensorType.Value.DOUBLE); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.BFLOAT16); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.INT8); + t1 = Tensor.from("tensor(a{}):{foo:17,bar:42}"); + checkResult(t1, mapIdentity(t1), TensorType.Value.DOUBLE); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.FLOAT); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.FLOAT); + t1 = Tensor.from("tensor(y[4]):[2,3,4,5]"); + checkResult(t1, mapIdentity(t1), TensorType.Value.DOUBLE); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.FLOAT)), TensorType.Value.FLOAT); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.BFLOAT16)), TensorType.Value.BFLOAT16); + checkResult(t1, mapIdentity(t1.cellCast(TensorType.Value.INT8)), TensorType.Value.INT8); + t1 = Tensor.from(42); + assertEquals(t1, mapIdentity(t1)); + } + +} |