aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com/yahoo
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java20
7 files changed, 37 insertions, 33 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java
index 439aac5578a..9c4fa0cf931 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
@@ -42,10 +43,10 @@ public class MatrixDotProductBenchmark {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double largest = Double.MIN_VALUE;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
- Reduce.Aggregator.sum).toPrimitive();
- MapEvaluationContext context = new MapEvaluationContext();
+ TensorFunction<TypeContext.Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor),
+ new VariableTensor<>("argument"), (a, b) -> a * b),
+ Reduce.Aggregator.sum).toPrimitive();
+ MapEvaluationContext<TypeContext.Name> context = new MapEvaluationContext<>();
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index 7b856dde2d5..3c07dd9e6d4 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
@@ -49,10 +50,10 @@ public class TensorFunctionBenchmark {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double largest = Double.MIN_VALUE;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
- Reduce.Aggregator.sum).toPrimitive();
- MapEvaluationContext context = new MapEvaluationContext();
+ TensorFunction<TypeContext.Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor),
+ new VariableTensor<>("argument"), (a, b) -> a * b),
+ Reduce.Aggregator.sum).toPrimitive();
+ MapEvaluationContext<TypeContext.Name> context = new MapEvaluationContext<>();
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index c6fbb9c009d..ae73770b7f7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
@@ -307,10 +308,10 @@ public class TensorTestCase {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
- Reduce.Aggregator.sum).toPrimitive();
- MapEvaluationContext context = new MapEvaluationContext();
+ TensorFunction<TypeContext.Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor),
+ new VariableTensor<>("argument"), (a, b) -> a * b),
+ Reduce.Aggregator.sum).toPrimitive();
+ MapEvaluationContext<TypeContext.Name> context = new MapEvaluationContext<>();
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index eafa5c4addf..0476fe1c757 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -101,8 +101,8 @@ public class ConcatTestCase {
private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) {
Tensor expectedAsTensor = Tensor.from(expected);
- TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension)
- .type(new MapEvaluationContext());
+ TensorType inferredType = new Concat<>(new ConstantTensor<>(a), new ConstantTensor<>(b), dimension)
+ .type(new MapEvaluationContext<>());
Tensor result = a.concat(b, dimension);
if (expectedType != null)
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index e1ae7f13c48..0f8fbade910 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -22,15 +22,15 @@ public class DynamicTensorTestCase {
@Test
public void testDynamicTensorFunction() {
TensorType dense = TensorType.fromSpec("tensor(x[3])");
- DynamicTensor t1 = DynamicTensor.from(dense,
- List.of(new Constant(1), new Constant(2), new Constant(3)));
+ DynamicTensor<TypeContext.Name> t1 = DynamicTensor.from(dense,
+ List.of(new Constant(1), new Constant(2), new Constant(3)));
assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate());
assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString());
TensorType sparse = TensorType.fromSpec("tensor(x{})");
- DynamicTensor t2 = DynamicTensor.from(sparse,
- Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
- new Constant(5)));
+ DynamicTensor<TypeContext.Name> t2 = DynamicTensor.from(sparse,
+ Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
+ new Constant(5)));
assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index e37bee2d990..ff035f3aed2 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -15,14 +16,14 @@ public class TensorFunctionTestCase {
@Test
public void testTranslation() {
assertTranslated("join(tensor(x{}):{{x:1}:1.0}, reduce(tensor(x{}):{{x:1}:1.0}, sum, x), f(a,b)(a / b))",
- new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x"));
+ new L1Normalize<>(new ConstantTensor<>("{{x:1}:1.0}"), "x"));
assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
- new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
+ new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
assertTranslated("join(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))",
- new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
+ new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
}
- private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) {
+ private void assertTranslated(String expectedTranslation, TensorFunction<TypeContext.Name> inputFunction) {
assertEquals(expectedTranslation, inputFunction.toPrimitive().toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
index ffb5e1433ca..7127abde016 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
@@ -19,9 +19,9 @@ public class ValueTestCase {
@Test
public void testValueFunctionGeneralForm() {
Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
- Tensor result = new Value(new ConstantTensor(input),
- List.of(new Value.DimensionValue("key", "bar"),
- new Value.DimensionValue("x", 0)))
+ Tensor result = new Value<>(new ConstantTensor<>(input),
+ List.of(new Value.DimensionValue<>("key", "bar"),
+ new Value.DimensionValue<>("x", 0)))
.evaluate();
assertEquals(0, result.type().rank());
assertEquals(2.3, result.asDouble(), delta);
@@ -30,8 +30,8 @@ public class ValueTestCase {
@Test
public void testValueFunctionSingleMappedDimension() {
Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }");
- Tensor result = new Value(new ConstantTensor(input),
- List.of(new Value.DimensionValue("foo")))
+ Tensor result = new Value<>(new ConstantTensor<>(input),
+ List.of(new Value.DimensionValue<>("foo")))
.evaluate();
assertEquals(0, result.type().rank());
assertEquals(1.4, result.asDouble(), delta);
@@ -40,8 +40,8 @@ public class ValueTestCase {
@Test
public void testValueFunctionSingleIndexedDimension() {
Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
- Tensor result = new Value(new ConstantTensor(input),
- List.of(new Value.DimensionValue(2)))
+ Tensor result = new Value<>(new ConstantTensor<>(input),
+ List.of(new Value.DimensionValue<>(2)))
.evaluate();
assertEquals(0, result.type().rank());
assertEquals(3.3, result.asDouble(), delta);
@@ -51,9 +51,9 @@ public class ValueTestCase {
public void testValueFunctionShortFormWithMultipleDimensionsIsNotAllowed() {
try {
Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
- new Value(new ConstantTensor(input),
- List.of(new Value.DimensionValue("bar"),
- new Value.DimensionValue(0)))
+ new Value<>(new ConstantTensor<>(input),
+ List.of(new Value.DimensionValue<>("bar"),
+ new Value.DimensionValue<>(0)))
.evaluate();
fail("Expected exception");
}