summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2021-03-16 13:30:13 +0100
committerGitHub <noreply@github.com>2021-03-16 13:30:13 +0100
commit73702b1c05deaaf08bcfed78c15494d2e53684a9 (patch)
tree60a9790b7223fce7f9da2c7355bde425968a763b /vespajlib
parent700345986b877638da6ea8d8d7160ed50ea7cd5f (diff)
parentd2c61030d6c62b8c4889d3471d2ee5f17bb14a5f (diff)
Merge pull request #16976 from vespa-engine/revert-16975-lesters/cell-cast-java
Revert "Lesters/cell cast java"
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java83
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java38
4 files changed, 1 insertions, 145 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index e51569da988..c6727aa372e 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1167,7 +1167,6 @@
"public com.yahoo.tensor.Tensor concat(com.yahoo.tensor.Tensor, java.lang.String)",
"public com.yahoo.tensor.Tensor rename(java.util.List, java.util.List)",
"public static com.yahoo.tensor.Tensor generate(com.yahoo.tensor.TensorType, java.util.function.Function)",
- "public com.yahoo.tensor.Tensor cellCast(com.yahoo.tensor.TensorType$Value)",
"public com.yahoo.tensor.Tensor l1Normalize(java.lang.String)",
"public com.yahoo.tensor.Tensor l2Normalize(java.lang.String)",
"public com.yahoo.tensor.Tensor matmul(com.yahoo.tensor.Tensor, java.lang.String)",
@@ -1570,23 +1569,6 @@
],
"fields": []
},
- "com.yahoo.tensor.functions.CellCast": {
- "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
- "interfaces": [],
- "attributes": [
- "public"
- ],
- "methods": [
- "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.TensorType$Value)",
- "public java.util.List arguments()",
- "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
- "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
- ],
- "fields": []
- },
"com.yahoo.tensor.functions.CompositeTensorFunction": {
"superClass": "com.yahoo.tensor.functions.TensorFunction",
"interfaces": [],
@@ -3463,4 +3445,4 @@
],
"fields": []
}
-}
+} \ No newline at end of file
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 3378520dc91..fbf5bc35129 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -4,7 +4,6 @@ package com.yahoo.tensor;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.Argmin;
-import com.yahoo.tensor.functions.CellCast;
import com.yahoo.tensor.functions.Concat;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Diag;
@@ -180,10 +179,6 @@ public interface Tensor {
return new Generate<>(type, valueSupplier).evaluate();
}
- default Tensor cellCast(TensorType.Value valueType) {
- return new CellCast<>(new ConstantTensor<>(this), valueType).evaluate();
- }
-
// ----------------- Composite tensor functions which have a defined primitive mapping
default Tensor l1Normalize(String dimension) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
deleted file mode 100644
index d052e383c85..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ /dev/null
@@ -1,83 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor.functions;
-
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.EvaluationContext;
-import com.yahoo.tensor.evaluation.Name;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.Collections;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Objects;
-
-/**
- * The <i>cell_cast</i> tensor function creates a new tensor with the specified cell value type.
- *
- * @author lesters
- */
-public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
-
- private final TensorFunction<NAMETYPE> argument;
- private final TensorType.Value valueType;
-
- public CellCast(TensorFunction<NAMETYPE> argument, TensorType.Value valueType) {
- Objects.requireNonNull(argument, "The argument tensor cannot be null");
- Objects.requireNonNull(valueType, "The value type cannot be null");
- this.argument = argument;
- this.valueType = valueType;
- }
-
- @Override
- public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
-
- @Override
- public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if ( arguments.size() != 1)
- throw new IllegalArgumentException("CellCast must have 1 argument, got " + arguments.size());
- return new CellCast<>(arguments.get(0), valueType);
- }
-
- @Override
- public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
- return new CellCast<>(argument.toPrimitive(), valueType);
- }
-
- @Override
- public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType(valueType, argument.type(context).dimensions());
- }
-
- @Override
- public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
- Tensor tensor = argument.evaluate(context);
- if (tensor.type().valueType() == valueType) {
- return tensor;
- }
- TensorType type = new TensorType(valueType, tensor.type().dimensions());
- return cast(tensor, type);
- }
-
- private Tensor cast(Tensor tensor, TensorType type) {
- Tensor.Builder builder = Tensor.Builder.of(type);
- TensorType.Value fromValueType = tensor.type().valueType();
- for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
- Tensor.Cell cell = i.next();
- if (fromValueType == TensorType.Value.FLOAT) {
- builder.cell(cell.getKey(), cell.getFloatValue());
- } else if (fromValueType == TensorType.Value.DOUBLE) {
- builder.cell(cell.getKey(), cell.getDoubleValue());
- } else {
- builder.cell(cell.getKey(), cell.getValue());
- }
- }
- return builder.build();
- }
-
- @Override
- public String toString(ToStringContext context) {
- return "cell_cast(" + argument.toString(context) + ", " + valueType + ")";
- }
-
-}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java
deleted file mode 100644
index bc10ecc3abd..00000000000
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CellCastTestCase.java
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor.functions;
-
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * @author lesters
- */
-public class CellCastTestCase {
-
- @Test
- public void testCellCasting() {
- Tensor tensor;
-
- tensor = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]");
- assertEquals(TensorType.Value.DOUBLE, tensor.type().valueType());
- assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
- assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
- assertEquals(tensor, tensor.cellCast(TensorType.Value.FLOAT));
-
- tensor = Tensor.from("tensor<double>(x{}):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}");
- assertEquals(TensorType.Value.DOUBLE, tensor.type().valueType());
- assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
- assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
- assertEquals(tensor, tensor.cellCast(TensorType.Value.FLOAT));
-
- tensor = Tensor.from("tensor<float>(x[3],y{}):{a:[1.0, 2.0, 3.0],b:[4.0,5.0,6.0]}");
- assertEquals(TensorType.Value.FLOAT, tensor.type().valueType());
- assertEquals(TensorType.Value.DOUBLE, tensor.cellCast(TensorType.Value.DOUBLE).type().valueType());
- assertEquals(TensorType.Value.FLOAT, tensor.cellCast(TensorType.Value.FLOAT).type().valueType());
- assertEquals(tensor, tensor.cellCast(TensorType.Value.DOUBLE));
- }
-
-}