diff options
author | Lester Solbakken <lesters@oath.com> | 2021-03-16 12:11:26 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-03-16 12:11:26 +0100 |
commit | a72c0d645fcf7dcf9225d9e13a16b3bc0434c6ca (patch) | |
tree | f2a58b3268c39ecd4686d2408dc09b0274c67435 /vespajlib | |
parent | f59e36dd56d18f1148a0665823c0dfe7e40dd805 (diff) |
Add Java-side tensor type cell casting
Diffstat (limited to 'vespajlib')
4 files changed, 163 insertions, 0 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index c6727aa372e..6e6791aebc9 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1167,6 +1167,7 @@ "public com.yahoo.tensor.Tensor concat(com.yahoo.tensor.Tensor, java.lang.String)", "public com.yahoo.tensor.Tensor rename(java.util.List, java.util.List)", "public static com.yahoo.tensor.Tensor generate(com.yahoo.tensor.TensorType, java.util.function.Function)", + "public com.yahoo.tensor.Tensor cellCast(com.yahoo.tensor.TensorType$Value)", "public com.yahoo.tensor.Tensor l1Normalize(java.lang.String)", "public com.yahoo.tensor.Tensor l2Normalize(java.lang.String)", "public com.yahoo.tensor.Tensor matmul(com.yahoo.tensor.Tensor, java.lang.String)", @@ -1569,6 +1570,23 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.CellCast": { + "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.TensorType$Value)", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.CompositeTensorFunction": { "superClass": "com.yahoo.tensor.functions.TensorFunction", "interfaces": [], @@ -1651,6 +1669,25 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.Expand": { + "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.lang.String)", + "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", + "public java.util.List dimensions()", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.Generate": { "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", "interfaces": [], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index fbf5bc35129..3378520dc91 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -4,6 +4,7 @@ package com.yahoo.tensor; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Argmax; import com.yahoo.tensor.functions.Argmin; +import com.yahoo.tensor.functions.CellCast; import com.yahoo.tensor.functions.Concat; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Diag; @@ -179,6 +180,10 @@ public interface Tensor { return new Generate<>(type, valueSupplier).evaluate(); } + default Tensor cellCast(TensorType.Value valueType) { + return new CellCast<>(new ConstantTensor<>(this), valueType).evaluate(); + } + // ----------------- Composite tensor functions which have a defined primitive mapping default Tensor l1Normalize(String dimension) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java new file mode 100644 index 00000000000..d052e383c85 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -0,0 +1,83 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Objects; + +/** + * The <i>cell_cast</i> tensor function creates a new tensor with the specified cell value type. + * + * @author lesters + */ +public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argument; + private final TensorType.Value valueType; + + public CellCast(TensorFunction<NAMETYPE> argument, TensorType.Value valueType) { + Objects.requireNonNull(argument, "The argument tensor cannot be null"); + Objects.requireNonNull(valueType, "The value type cannot be null"); + this.argument = argument; + this.valueType = valueType; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("CellCast must have 1 argument, got " + arguments.size()); + return new CellCast<>(arguments.get(0), valueType); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new CellCast<>(argument.toPrimitive(), valueType); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return new TensorType(valueType, argument.type(context).dimensions()); + } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor tensor = argument.evaluate(context); + if (tensor.type().valueType() == valueType) { + return tensor; + } + TensorType type = new TensorType(valueType, tensor.type().dimensions()); + return cast(tensor, type); + } + + private Tensor cast(Tensor tensor, TensorType type) { + Tensor.Builder builder = Tensor.Builder.of(type); + TensorType.Value fromValueType = tensor.type().valueType(); + for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + if (fromValueType == TensorType.Value.FLOAT) { + builder.cell(cell.getKey(), cell.getFloatValue()); + } else if (fromValueType == TensorType.Value.DOUBLE) { + builder.cell(cell.getKey(), cell.getDoubleValue()); + } else { + builder.cell(cell.getKey(), cell.getValue()); + } + } + return builder.build(); + } + + @Override + public String toString(ToStringContext context) { + return "cell_cast(" + argument.toString(context) + ", " + valueType + ")"; + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java new file mode 100644 index 00000000000..bc10ecc3abd --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java @@ -0,0 +1,38 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author lesters + */ +public class CellCastTestCase { + + @Test + public void testCellCasting() { + Tensor tensor; + + tensor = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]"); + assertEquals(TensorType.Value.DOUBLE, tensor.type().valueType()); + assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType()); + assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType()); + assertEquals(tensor, tensor.cellCast(TensorType.Value.FLOAT)); + + tensor = Tensor.from("tensor<double>(x{}):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}"); + assertEquals(TensorType.Value.DOUBLE, tensor.type().valueType()); + assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType()); + assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType()); + assertEquals(tensor, tensor.cellCast(TensorType.Value.FLOAT)); + + tensor = Tensor.from("tensor<float>(x[3],y{}):{a:[1.0, 2.0, 3.0],b:[4.0,5.0,6.0]}"); + assertEquals(TensorType.Value.FLOAT, tensor.type().valueType()); + assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType()); + assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType()); + assertEquals(tensor, tensor.cellCast(TensorType.Value.DOUBLE)); + } + +} |