diff options
Diffstat (limited to 'config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java')
-rw-r--r-- | config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java | 86 |
1 files changed, 43 insertions, 43 deletions
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java index aaf5f381c62..028ad5dea86 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/TensorTransformTestCase.java @@ -20,109 +20,109 @@ import com.yahoo.schema.derived.AttributeFields; import com.yahoo.schema.derived.RawRankProfile; import com.yahoo.schema.parser.ParseException; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; -import org.junit.Test; +import org.junit.jupiter.api.Test; import java.util.List; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; public class TensorTransformTestCase extends AbstractSchemaTestCase { @Test - public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException { + void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException { assertTransformedExpression("max(1.0,2.0)", - "max(1.0,2.0)"); + "max(1.0,2.0)"); assertTransformedExpression("min(attribute(double_field),x)", - "min(attribute(double_field),x)"); + "min(attribute(double_field),x)"); assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))", - "max(attribute(double_field),attribute(double_array_field))"); + "max(attribute(double_field),attribute(double_array_field))"); assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))", - "min(attribute(tensor_field_1),attribute(double_field))"); + "min(attribute(tensor_field_1),attribute(double_field))"); assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)", - "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)"); + "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)"); assertTransformedExpression("min(constant(test_constant_tensor),1.0)", - "min(test_constant_tensor,1.0)"); + "min(test_constant_tensor,1.0)"); assertTransformedExpression("max(constant(base_constant_tensor),1.0)", - "max(base_constant_tensor,1.0)"); + "max(base_constant_tensor,1.0)"); assertTransformedExpression("min(constant(file_constant_tensor),1.0)", - "min(constant(file_constant_tensor),1.0)"); + "min(constant(file_constant_tensor),1.0)"); assertTransformedExpression("max(query(q),1.0)", - "max(query(q),1.0)"); + "max(query(q),1.0)"); assertTransformedExpression("max(query(n),1.0)", - "max(query(n),1.0)"); + "max(query(n),1.0)"); } @Test - public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { + void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)", - "max(attribute(tensor_field_1),x)"); + "max(attribute(tensor_field_1),x)"); assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)", - "1 + max(attribute(tensor_field_1),x)"); + "1 + max(attribute(tensor_field_1),x)"); assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),reduce(attribute(tensor_field_1),sum,x))", - "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),reduce(attribute(tensor_field_1), sum, x))"); + "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),reduce(attribute(tensor_field_1), sum, x))"); assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)", - "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); + "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)", - "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)"); + "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)"); assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)", - "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error. + "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error. assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)", - "max(max(attribute(tensor_field_2),x),y)"); + "max(max(attribute(tensor_field_2),x),y)"); } @Test - public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException { + void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException { assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)", - "max(test_constant_tensor,x)"); + "max(test_constant_tensor,x)"); assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)", - "max(base_constant_tensor,x)"); + "max(base_constant_tensor,x)"); assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)", - "min(constant(file_constant_tensor),x)"); + "min(constant(file_constant_tensor),x)"); } @Test - public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException { + void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException { assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)", - "min(attribute(double_field) + attribute(tensor_field_1),x)"); + "min(attribute(double_field) + attribute(tensor_field_1),x)"); assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)", - "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)"); + "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)"); assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)", - "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)"); + "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)"); assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", - "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) + "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) assertTransformedExpression("reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)", - "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); + "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); } @Test - public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException { + void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException { assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)", - "max(tensorFromLabels(attribute(double_array_field)),double_array_field)"); + "max(tensorFromLabels(attribute(double_array_field)),double_array_field)"); assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)", - "max(tensorFromLabels(attribute(double_array_field),x),x)"); + "max(tensorFromLabels(attribute(double_array_field),x),x)"); assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)", - "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)"); + "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)"); assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)", - "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)"); + "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)"); } @Test - public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException { + void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException { assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)"); assertTransformedExpression("max(query(n),x)", "max(query(n),x)"); } @Test - public void requireThatMaxAndMinWithTensorsReturnedFromFunctionsAreReplaced() throws ParseException { + void requireThatMaxAndMinWithTensorsReturnedFromFunctionsAreReplaced() throws ParseException { assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)", - "max(returns_tensor,x)"); + "max(returns_tensor,x)"); assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)", - "max(wraps_returns_tensor,x)"); + "max(wraps_returns_tensor,x)"); assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)", - "max(tensor_inheriting,x)"); + "max(tensor_inheriting,x)"); assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)", - "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)"); + "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)"); } private void assertTransformedExpression(String expected, String original) throws ParseException { |