diff options
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo')
7 files changed, 37 insertions, 33 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java index 439aac5578a..9c4fa0cf931 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java @@ -2,6 +2,7 @@ package com.yahoo.tensor; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; @@ -42,10 +43,10 @@ public class MatrixDotProductBenchmark { private double dotProduct(Tensor tensor, List<Tensor> tensors) { double largest = Double.MIN_VALUE; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), - Reduce.Aggregator.sum).toPrimitive(); - MapEvaluationContext context = new MapEvaluationContext(); + TensorFunction<TypeContext.Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor), + new VariableTensor<>("argument"), (a, b) -> a * b), + Reduce.Aggregator.sum).toPrimitive(); + MapEvaluationContext<TypeContext.Name> context = new MapEvaluationContext<>(); for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 7b856dde2d5..3c07dd9e6d4 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -2,6 +2,7 @@ package com.yahoo.tensor; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; @@ -49,10 +50,10 @@ public class TensorFunctionBenchmark { private double dotProduct(Tensor tensor, List<Tensor> tensors) { double largest = Double.MIN_VALUE; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), - Reduce.Aggregator.sum).toPrimitive(); - MapEvaluationContext context = new MapEvaluationContext(); + TensorFunction<TypeContext.Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor), + new VariableTensor<>("argument"), (a, b) -> a * b), + Reduce.Aggregator.sum).toPrimitive(); + MapEvaluationContext<TypeContext.Name> context = new MapEvaluationContext<>(); for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index c6fbb9c009d..ae73770b7f7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; @@ -307,10 +308,10 @@ public class TensorTestCase { private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), - Reduce.Aggregator.sum).toPrimitive(); - MapEvaluationContext context = new MapEvaluationContext(); + TensorFunction<TypeContext.Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor), + new VariableTensor<>("argument"), (a, b) -> a * b), + Reduce.Aggregator.sum).toPrimitive(); + MapEvaluationContext<TypeContext.Name> context = new MapEvaluationContext<>(); for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index eafa5c4addf..0476fe1c757 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -101,8 +101,8 @@ public class ConcatTestCase { private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) { Tensor expectedAsTensor = Tensor.from(expected); - TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension) - .type(new MapEvaluationContext()); + TensorType inferredType = new Concat<>(new ConstantTensor<>(a), new ConstantTensor<>(b), dimension) + .type(new MapEvaluationContext<>()); Tensor result = a.concat(b, dimension); if (expectedType != null) diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index e1ae7f13c48..0f8fbade910 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -22,15 +22,15 @@ public class DynamicTensorTestCase { @Test public void testDynamicTensorFunction() { TensorType dense = TensorType.fromSpec("tensor(x[3])"); - DynamicTensor t1 = DynamicTensor.from(dense, - List.of(new Constant(1), new Constant(2), new Constant(3))); + DynamicTensor<TypeContext.Name> t1 = DynamicTensor.from(dense, + List.of(new Constant(1), new Constant(2), new Constant(3))); assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString()); TensorType sparse = TensorType.fromSpec("tensor(x{})"); - DynamicTensor t2 = DynamicTensor.from(sparse, - Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), - new Constant(5))); + DynamicTensor<TypeContext.Name> t2 = DynamicTensor.from(sparse, + Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), + new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index e37bee2d990..ff035f3aed2 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -15,14 +16,14 @@ public class TensorFunctionTestCase { @Test public void testTranslation() { assertTranslated("join(tensor(x{}):{{x:1}:1.0}, reduce(tensor(x{}):{{x:1}:1.0}, sum, x), f(a,b)(a / b))", - new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x")); + new L1Normalize<>(new ConstantTensor<>("{{x:1}:1.0}"), "x")); assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", - new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); + new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); assertTranslated("join(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))", - new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); + new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); } - private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) { + private void assertTranslated(String expectedTranslation, TensorFunction<TypeContext.Name> inputFunction) { assertEquals(expectedTranslation, inputFunction.toPrimitive().toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java index ffb5e1433ca..7127abde016 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java @@ -19,9 +19,9 @@ public class ValueTestCase { @Test public void testValueFunctionGeneralForm() { Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); - Tensor result = new Value(new ConstantTensor(input), - List.of(new Value.DimensionValue("key", "bar"), - new Value.DimensionValue("x", 0))) + Tensor result = new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>("key", "bar"), + new Value.DimensionValue<>("x", 0))) .evaluate(); assertEquals(0, result.type().rank()); assertEquals(2.3, result.asDouble(), delta); @@ -30,8 +30,8 @@ public class ValueTestCase { @Test public void testValueFunctionSingleMappedDimension() { Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }"); - Tensor result = new Value(new ConstantTensor(input), - List.of(new Value.DimensionValue("foo"))) + Tensor result = new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>("foo"))) .evaluate(); assertEquals(0, result.type().rank()); assertEquals(1.4, result.asDouble(), delta); @@ -40,8 +40,8 @@ public class ValueTestCase { @Test public void testValueFunctionSingleIndexedDimension() { Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]"); - Tensor result = new Value(new ConstantTensor(input), - List.of(new Value.DimensionValue(2))) + Tensor result = new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>(2))) .evaluate(); assertEquals(0, result.type().rank()); assertEquals(3.3, result.asDouble(), delta); @@ -51,9 +51,9 @@ public class ValueTestCase { public void testValueFunctionShortFormWithMultipleDimensionsIsNotAllowed() { try { Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); - new Value(new ConstantTensor(input), - List.of(new Value.DimensionValue("bar"), - new Value.DimensionValue(0))) + new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>("bar"), + new Value.DimensionValue<>(0))) .evaluate(); fail("Expected exception"); } |