aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-03-16 12:11:26 +0100
committerLester Solbakken <lesters@oath.com>2021-03-16 12:11:26 +0100
commita72c0d645fcf7dcf9225d9e13a16b3bc0434c6ca (patch)
treef2a58b3268c39ecd4686d2408dc09b0274c67435
parentf59e36dd56d18f1148a0665823c0dfe7e40dd805 (diff)
Add Java-side tensor type cell casting
-rw-r--r--searchlib/abi-spec.json2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj17
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java17
-rw-r--r--vespajlib/abi-spec.json37
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java83
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java38
7 files changed, 197 insertions, 2 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index d412f408350..9e958dd4d4c 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -897,6 +897,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()",
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
"public final com.yahoo.tensor.TensorType tensorType(java.util.List)",
@@ -1046,6 +1047,7 @@
"public static final int XW_PLUS_B",
"public static final int ARGMAX",
"public static final int ARGMIN",
+ "public static final int CELL_CAST",
"public static final int AVG",
"public static final int COUNT",
"public static final int MAX",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 36b1f9627bb..d33e9ccff7f 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -141,6 +141,7 @@ TOKEN :
<XW_PLUS_B: "xw_plus_b"> |
<ARGMAX: "argmax"> |
<ARGMIN: "argmin"> |
+ <CELL_CAST: "cell_cast"> |
<AVG: "avg" > |
<COUNT: "count"> |
@@ -380,7 +381,8 @@ TensorFunctionNode tensorFunction() :
tensorExpression = tensorSoftmax() |
tensorExpression = tensorXwPlusB() |
tensorExpression = tensorArgmax() |
- tensorExpression = tensorArgmin()
+ tensorExpression = tensorArgmin() |
+ tensorExpression = tensorCellCast()
)
{ return tensorExpression; }
}
@@ -597,6 +599,16 @@ TensorFunctionNode tensorArgmin() :
{ return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrap(tensor), dimensions)); }
}
+TensorFunctionNode tensorCellCast() :
+{
+ ExpressionNode tensor;
+ String valueType;
+}
+{
+ <CELL_CAST> <LBRACE> tensor = expression() <COMMA> valueType = identifier() <RBRACE>
+ { return new TensorFunctionNode(new CellCast(TensorFunctionNode.wrap(tensor), TensorType.Value.fromId(valueType)));}
+}
+
LambdaFunctionNode lambdaFunction() :
{
List<String> variables;
@@ -667,7 +679,7 @@ String tensorFunctionName() :
( <MAP> { return token.image; } ) |
( <REDUCE> { return token.image; } ) |
( <JOIN> { return token.image; } ) |
- ( <MERGE> { return token.image; } ) |
+ ( <MERGE> { return token.image; } ) |
( <RENAME> { return token.image; } ) |
( <CONCAT> { return token.image; } ) |
( <TENSOR> { return token.image; } ) |
@@ -681,6 +693,7 @@ String tensorFunctionName() :
( <XW_PLUS_B> { return token.image; } ) |
( <ARGMAX> { return token.image; } ) |
( <ARGMIN> { return token.image; } ) |
+ ( <CELL_CAST> { return token.image; } )
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 123fa5ac43b..fae5a7a093c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -394,6 +394,23 @@ public class EvaluationTestCase {
}
@Test
+ public void testCellTypeCasting() {
+ EvaluationTester tester = new EvaluationTester();
+
+ tester.assertEvaluates("tensor<float>(x[3]):[1.0, 2.0, 3.0]",
+ "cell_cast(tensor0, float)",
+ "tensor<double>(x[3]):[1, 2, 3]");
+ tester.assertEvaluates("tensor<float>():{1}",
+ "cell_cast(tensor0{x:1}, float)",
+ "tensor<double>(x{}):{1:1, 2:2, 3:3}");
+ tester.assertEvaluates("tensor<float>(x[2]):[3,8]",
+ "cell_cast(tensor0 * tensor1, float)",
+ "tensor<float>(x[2]):[1,2]",
+ "tensor<double>(x[2]):[3,4]");
+ }
+
+
+ @Test
public void testMixedTensorType() throws ParseException {
String expected = "tensor(x[1],y{},z[2]):{{x:0,y:a,z:0}:4.0,{x:0,y:a,z:1}:5.0,{x:0,y:b,z:0}:7.0,{x:0,y:b,z:1}:8.0}";
String a = "tensor(x[1],y{}):{ {x:0,y:a}:1, {x:0,y:b}:2 }";
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index c6727aa372e..6e6791aebc9 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1167,6 +1167,7 @@
"public com.yahoo.tensor.Tensor concat(com.yahoo.tensor.Tensor, java.lang.String)",
"public com.yahoo.tensor.Tensor rename(java.util.List, java.util.List)",
"public static com.yahoo.tensor.Tensor generate(com.yahoo.tensor.TensorType, java.util.function.Function)",
+ "public com.yahoo.tensor.Tensor cellCast(com.yahoo.tensor.TensorType$Value)",
"public com.yahoo.tensor.Tensor l1Normalize(java.lang.String)",
"public com.yahoo.tensor.Tensor l2Normalize(java.lang.String)",
"public com.yahoo.tensor.Tensor matmul(com.yahoo.tensor.Tensor, java.lang.String)",
@@ -1569,6 +1570,23 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.CellCast": {
+ "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.TensorType$Value)",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.CompositeTensorFunction": {
"superClass": "com.yahoo.tensor.functions.TensorFunction",
"interfaces": [],
@@ -1651,6 +1669,25 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.Expand": {
+ "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.lang.String)",
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
+ "public java.util.List dimensions()",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.Generate": {
"superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
"interfaces": [],
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index fbf5bc35129..3378520dc91 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.Argmin;
+import com.yahoo.tensor.functions.CellCast;
import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Diag;
@@ -179,6 +180,10 @@ public interface Tensor {
return new Generate<>(type, valueSupplier).evaluate();
}
+ default Tensor cellCast(TensorType.Value valueType) {
+ return new CellCast<>(new ConstantTensor<>(this), valueType).evaluate();
+ }
+
// ----------------- Composite tensor functions which have a defined primitive mapping
default Tensor l1Normalize(String dimension) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
new file mode 100644
index 00000000000..d052e383c85
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -0,0 +1,83 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+
+/**
+ * The <i>cell_cast</i> tensor function creates a new tensor with the specified cell value type.
+ *
+ * @author lesters
+ */
+public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> argument;
+ private final TensorType.Value valueType;
+
+ public CellCast(TensorFunction<NAMETYPE> argument, TensorType.Value valueType) {
+ Objects.requireNonNull(argument, "The argument tensor cannot be null");
+ Objects.requireNonNull(valueType, "The value type cannot be null");
+ this.argument = argument;
+ this.valueType = valueType;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("CellCast must have 1 argument, got " + arguments.size());
+ return new CellCast<>(arguments.get(0), valueType);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new CellCast<>(argument.toPrimitive(), valueType);
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return new TensorType(valueType, argument.type(context).dimensions());
+ }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor tensor = argument.evaluate(context);
+ if (tensor.type().valueType() == valueType) {
+ return tensor;
+ }
+ TensorType type = new TensorType(valueType, tensor.type().dimensions());
+ return cast(tensor, type);
+ }
+
+ private Tensor cast(Tensor tensor, TensorType type) {
+ Tensor.Builder builder = Tensor.Builder.of(type);
+ TensorType.Value fromValueType = tensor.type().valueType();
+ for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ if (fromValueType == TensorType.Value.FLOAT) {
+ builder.cell(cell.getKey(), cell.getFloatValue());
+ } else if (fromValueType == TensorType.Value.DOUBLE) {
+ builder.cell(cell.getKey(), cell.getDoubleValue());
+ } else {
+ builder.cell(cell.getKey(), cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "cell_cast(" + argument.toString(context) + ", " + valueType + ")";
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java
new file mode 100644
index 00000000000..bc10ecc3abd
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java
@@ -0,0 +1,38 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class CellCastTestCase {
+
+ @Test
+ public void testCellCasting() {
+ Tensor tensor;
+
+ tensor = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]");
+ assertEquals(TensorType.Value.DOUBLE, tensor.type().valueType());
+ assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
+ assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
+ assertEquals(tensor, tensor.cellCast(TensorType.Value.FLOAT));
+
+ tensor = Tensor.from("tensor<double>(x{}):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}");
+ assertEquals(TensorType.Value.DOUBLE, tensor.type().valueType());
+ assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
+ assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
+ assertEquals(tensor, tensor.cellCast(TensorType.Value.FLOAT));
+
+ tensor = Tensor.from("tensor<float>(x[3],y{}):{a:[1.0, 2.0, 3.0],b:[4.0,5.0,6.0]}");
+ assertEquals(TensorType.Value.FLOAT, tensor.type().valueType());
+ assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
+ assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
+ assertEquals(tensor, tensor.cellCast(TensorType.Value.DOUBLE));
+ }
+
+}