diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2016-11-26 22:45:20 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2016-11-26 22:45:20 +0100 |
commit | 2f55986b4de9420e5728c5abbaafb69fb2f10a34 (patch) | |
tree | 9a6a77f76d25620771dfe7ab5de49910c4321fc5 /searchlib/src/test | |
parent | 2bc82ba9d9698214e703f19039387609d82b12f8 (diff) |
Revert "Revert "Bratseth/tensor functions 3""
Diffstat (limited to 'searchlib/src/test')
4 files changed, 299 insertions, 270 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 24d7c82235c..f28ff739b4c 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -6,7 +6,10 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; -import junit.framework.TestCase; +import org.junit.Test; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.assertFalse; import java.io.BufferedReader; import java.io.File; @@ -14,15 +17,18 @@ import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.concurrent.*; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen + * @author bratseth */ -public class RankingExpressionTestCase extends TestCase { +public class RankingExpressionTestCase { + @Test public void testParamInFeature() throws ParseException { assertParse("if (1 > 2, dotProduct(allparentid,query(cate1_parentid)), 2)", "if ( 1 > 2,\n" + @@ -31,6 +37,7 @@ public class RankingExpressionTestCase extends TestCase { ")"); } + @Test public void testDollarShorthand() throws ParseException { assertParse("query(var1)", " $var1"); assertParse("query(var1)", " $var1 "); @@ -44,6 +51,7 @@ public class RankingExpressionTestCase extends TestCase { assertParse("if (if (f1.out < query(p1), 0, 1) < if (f2.out < query(p2), 0, 1), f3.out, query(p3))", "if(if(f1.out<$p1,0,1)<if(f2.out<$p2,0,1),f3.out,$p3)"); } + @Test public void testLookaheadIndefinitely() throws Exception { ExecutorService exec = Executors.newSingleThreadExecutor(); Future<Boolean> future = exec.submit(new Callable<Boolean>() { @@ -60,7 +68,8 @@ public class RankingExpressionTestCase extends TestCase { assertTrue(future.get(60, TimeUnit.SECONDS)); } - public void testSelfRecursionScript() throws ParseException { + @Test + public void testSelfRecursionSerialization() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo"))); @@ -72,7 +81,8 @@ public class RankingExpressionTestCase extends TestCase { } } - public void testMacroCycleScript() throws ParseException { + @Test + public void testMacroCycleSerialization() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar"))); macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo"))); @@ -85,42 +95,48 @@ public class RankingExpressionTestCase extends TestCase { } } - public void testScript() throws ParseException { + @Test + public void testSerialization() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))"))); macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2"))); macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)"))); macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977"))); - assertScript("foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros, - Arrays.asList( - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)", - "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))", - "min(6,pow(7,2))", - "min(1,pow(2,2))", - "min(3,pow(4,2))", - "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))" - )); - assertScript("foo(1, 2) + bar(3, 4)", macros, - Arrays.asList( - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)", - "min(1,pow(2,2))", - "3 * 3 + 2 * 3 * 4 + 4 * 4" - )); - assertScript("baz(1, 2)", macros, - Arrays.asList( - "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)", - "min(1,pow(2,2))", - "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)", - "1 * 1 + 2 * 1 * 2 + 2 * 2" - )); - assertScript("cox", macros, - Arrays.asList( - "rankingExpression(cox)", - "10 + 08 * 1977" - )); + assertSerialization(Arrays.asList( + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)", + "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))", + "min(6,pow(7,2))", + "min(1,pow(2,2))", + "min(3,pow(4,2))", + "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros); + assertSerialization(Arrays.asList( + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)", + "min(1,pow(2,2))", + "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros); + assertSerialization(Arrays.asList( + "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)", + "min(1,pow(2,2))", + "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)", + "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros); + assertSerialization(Arrays.asList( + "rankingExpression(cox)", + "10 + 08 * 1977"), "cox", macros + ); + } + + @Test + public void testTensorSerialization() { + assertSerialization("map(constant(tensor0), f(a)(cos(a)))", + "map(constant(tensor0), f(a)(cos(a)))"); + assertSerialization("map(constant(tensor0), f(a)(cos(a))) + join(attribute(tensor1), map(reduce(map(attribute(tensor1), f(a)(a * a)), sum, x), f(a)(sqrt(a))), f(a,b)(a / b))", + "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)"); + assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))", + "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)"); + } + @Test public void testBug3464208() throws ParseException { List<ExpressionFunction> macros = new ArrayList<>(); macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69"))); @@ -135,18 +151,11 @@ public class RankingExpressionTestCase extends TestCase { String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " + "rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)"; - assertScript(lhs + " + " + rhs, macros, - Arrays.asList( - expLhs + " + " + expRhs, - "69" - )); - assertScript(lhs + " - " + rhs, macros, - Arrays.asList( - expLhs + " - " + expRhs, - "69" - )); + assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros); + assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros); } + @Test public void testParse() throws ParseException, IOException { BufferedReader reader = new BufferedReader(new FileReader("src/tests/rankingexpression/rankingexpressionlist")); String line; @@ -181,36 +190,43 @@ public class RankingExpressionTestCase extends TestCase { } } + @Test public void testIssue() throws ParseException { assertEquals("feature.0", new RankingExpression("feature.0").toString()); assertEquals("if (1 > 2, 3, 4) + feature(arg1).out.out", new RankingExpression("if ( 1 > 2 , 3 , 4 ) + feature ( arg1 ) . out.out").toString()); } + @Test public void testNegativeConstantArgument() throws ParseException { assertEquals("foo(-1.2)", new RankingExpression("foo(-1.2)").toString()); } + @Test public void testNaming() throws ParseException { RankingExpression test = new RankingExpression("a+b"); test.setName("test"); assertEquals("test: a + b", test.toString()); } + @Test public void testCondition() throws ParseException { RankingExpression expression = new RankingExpression("if(1<2,3,4)"); assertTrue(expression.getRoot() instanceof IfNode); } + @Test public void testFileImporting() throws ParseException { RankingExpression expression = new RankingExpression(new File("src/test/files/simple.expression")); assertEquals("simple: a + b", expression.toString()); } + @Test public void testNonCanonicalLegalStrings() throws ParseException { assertParse("a * b + c * d", "a* (b) + \nc*d"); } + @Test public void testEquality() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo)==\"BAR\",log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"), new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); @@ -219,6 +235,7 @@ public class RankingExpressionTestCase extends TestCase { new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).earliness) * fieldMatch(title).completeness)"))); } + @Test public void testSetMembershipConditions() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"), new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); @@ -231,6 +248,7 @@ public class RankingExpressionTestCase extends TestCase { assertEquals(new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"), new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)")); } + @Test public void testComments() throws ParseException { assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],\n" + "# a comment\n" + @@ -241,6 +259,7 @@ public class RankingExpressionTestCase extends TestCase { new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)")); } + @Test public void testIsNan() throws ParseException { String strExpr = "if (isNan(attribute(foo)) == 1.0, 1.0, attribute(foo))"; RankingExpression expr = new RankingExpression(strExpr); @@ -255,27 +274,59 @@ public class RankingExpressionTestCase extends TestCase { assertEquals(expected, new RankingExpression(expression).toString()); } - private void assertScript(String expression, List<ExpressionFunction> macros, List<String> expectedScripts) - throws ParseException { - boolean print = false; - if (print) - System.out.println("Parsing expression '" + expression + "'."); - - RankingExpression exp = new RankingExpression(expression); - Map<String, String> scripts = exp.getRankProperties(macros); - if (print) { - for (String key : scripts.keySet()) { - System.out.println("Script '" + key + "': " + scripts.get(key)); - } + /** Test serialization with no macros */ + private void assertSerialization(String expectedSerialization, String expressionString) { + String serializedExpression; + try { + RankingExpression expression = new RankingExpression(expressionString); + // No macros -> expect one rank property + serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next(); + assertEquals(expectedSerialization, serializedExpression); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); } - for (Map.Entry<String, String> m : scripts.entrySet()) - System.out.println(m); - for (int i = 0; i < expectedScripts.size();) { - String val = expectedScripts.get(i++); - assertTrue("Script contains " + val, scripts.containsValue(val)); + try { + // No macros -> output should be parseable to a ranking expression + // (but not the same one due to primitivization) + RankingExpression reparsedExpression = new RankingExpression(serializedExpression); + // Serializing the primitivized expression should yield the same expression again + String reserializedExpression = + reparsedExpression.getRankProperties(Collections.emptyList()).values().iterator().next(); + assertEquals(expectedSerialization, reserializedExpression); + } + catch (ParseException e) { + throw new IllegalArgumentException("Could not parse the serialized expression", e); } - if (print) - System.out.println(""); } + + private void assertSerialization(List<String> expectedSerialization, String expressionString, + List<ExpressionFunction> macros) { + assertSerialization(expectedSerialization, expressionString, macros, false); + } + private void assertSerialization(List<String> expectedSerialization, String expressionString, + List<ExpressionFunction> macros, boolean print) { + try { + if (print) + System.out.println("Parsing expression '" + expressionString + "'."); + + RankingExpression expression = new RankingExpression(expressionString); + Map<String, String> rankProperties = expression.getRankProperties(macros); + if (print) { + for (String key : rankProperties.keySet()) + System.out.println("Property '" + key + "': " + rankProperties.get(key)); + } + for (int i = 0; i < expectedSerialization.size();) { + String val = expectedSerialization.get(i++); + assertTrue("Properties contains " + val, rankProperties.containsValue(val)); + } + if (print) + System.out.println(""); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index b67a423181d..93800e2c246 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -20,7 +20,7 @@ import java.util.Set; */ public class EvaluationTestCase extends junit.framework.TestCase { - private Context defaultContext; + private MapContext defaultContext; @Override protected void setUp() { @@ -100,201 +100,180 @@ public class EvaluationTestCase extends junit.framework.TestCase { @Test public void testTensorEvaluation() { - assertEvaluates("{}", "{}"); // empty - assertEvaluates("( {{x:-}:1} * {} )", "( {{x:-}:1} * {} )"); // empty with dimensions + assertEvaluates("{}", "tensor0", "{}"); - // sum(tensor) - assertEvaluates(5.0, "sum({{}:5.0})"); - assertEvaluates(-5.0, "sum({{}:-5.0})"); - assertEvaluates(12.5, "sum({ {d1:l1}:5.5, {d2:l2}:7.0 })"); - assertEvaluates(0.0, "sum({ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0})"); - - // scalar functions on tensors + // tensor map assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", - "log10({ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 })"); - assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }", - "5 * { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); - assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }", - "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } + 3"); - assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }", - "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } / 10"); + "map(tensor0, f(x) (log10(x)))", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:4, {d1:l1}:9, {d1:l1,d2:l1 }:16 }", + "map(tensor0, f(x) (x * x))", "{ {}:2, {d1:l1}:3, {d1:l1,d2:l1}:4 }"); + // -- tensor map composites + assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }", + "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }", - "- { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }", - "min({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)"); + "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }", - "max({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)"); - assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + {{h:1}:1.0,{h:2}:1.0}"); - - // sum(tensor, dimension) - assertEvaluates("{ {y:1}:4.0, {y:2}:12.0 }", - "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, x)"); - assertEvaluates("{ {x:1}:6.0, {x:2}:10.0 }", - "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, y)"); - - // tensor sum - assertEvaluates("{ }", "{} + {}"); - assertEvaluates("{ {x:1}:3, {x:2}:5 }", - "{ {x:1}:3 } + { {x:2}:5 }"); - assertEvaluates("{ {x:1}:8 }", - "{ {x:1}:3 } + { {x:1}:5 }"); - assertEvaluates("{ {x:1}:3, {y:1}:5 }", - "{ {x:1}:3 } + { {y:1}:5 }"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", - "{ {x:1}:3, {x:2}:7 } + { {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", - "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } + { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "{ {x:1}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", - "{ {x:1}:5, {x:1,y:1}:1 } + { {z:1}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "{ {}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }"); - assertEvaluates("{ {}:16, {x:1,y:1}:1, {y:1,z:1}:7 }", - "{ {}:5, {x:1,y:1}:1 } + { {}:11, {y:1,z:1}:7 }"); - - // tensor difference - assertEvaluates("{ }", "{} - {}"); - assertEvaluates("{ {x:1}:3, {x:2}:-5 }", - "{ {x:1}:3 } - { {x:2}:5 }"); - assertEvaluates("{ {x:1}:-2 }", - "{ {x:1}:3 } - { {x:1}:5 }"); - assertEvaluates("{ {x:1}:3, {y:1}:-5 }", - "{ {x:1}:3 } - { {y:1}:5 }"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:-5 }", - "{ {x:1}:3, {x:2}:7 } - { {y:1}:5 }"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:-7, {y:2,z:1}:-11, {y:1,z:2}:-13 }", - "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } - { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }", - "{ {x:1}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:-11, {y:1,z:1}:-7 }", - "{ {x:1}:5, {x:1,y:1}:1 } - { {z:1}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }", - "{ {}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }"); - assertEvaluates("{ {}:-6, {x:1,y:1}:1, {y:1,z:1}:-7 }", - "{ {}:5, {x:1,y:1}:1 } - { {}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1}:0 }", - "{ {x:1}:3 } - { {x:1}:3 }"); - assertEvaluates("{ {x:1}:0, {x:2}:1 }", - "{ {x:1}:3, {x:2}:1 } - { {x:1}:3 }"); - - // tensor product - assertEvaluates("{ }", "{} * {}"); - assertEvaluates("( {{x:-,y:-,z:-}:1}*{} )", "( {{x:-}:1} * {} ) * ( {{y:-,z:-}:1} * {} )"); // empty dimensions are preserved - assertEvaluates("( {{x:-}:1} * {} )", - "{ {x:1}:3 } * { {x:2}:5 }"); + "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }"); + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "abs(tensor0)", "{ {x:1}:1, {x:2}:-2 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "acos(tensor0)", "{ {x:1}:1, {x:2}:1 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "asin(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "atan(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "ceil(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cos(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cosh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "elu(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:1 }", "exp(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "fabs(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "floor(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "isNan(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "log(tensor0)", "{ {x:1}:1, {x:2}:1 }"); + assertEvaluates("{ {x:1}:0, {x:2}:1 }", "log10(tensor0)", "{ {x:1}:1, {x:2}:10 }"); + assertEvaluates("{ {x:1}:0, {x:2}:2 }", "mod(tensor0, 3)", "{ {x:1}:3, {x:2}:8 }"); + assertEvaluates("{ {x:1}:1, {x:2}:8 }", "pow(tensor0, 3)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "relu(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:2 }", "round(tensor0)", "{ {x:1}:1, {x:2}:1.8 }"); + assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "sigmoid(tensor0)","{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:-1 }", "sign(tensor0)", "{ {x:1}:3, {x:2}:-5 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sin(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sinh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:1, {x:2}:4 }", "square(tensor0)", "{ {x:1}:1, {x:2}:2 }"); + assertEvaluates("{ {x:1}:1, {x:2}:3 }", "sqrt(tensor0)", "{ {x:1}:1, {x:2}:9 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tan(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tanh(tensor0)", "{ {x:1}:0, {x:2}:0 }"); + + // tensor reduce + // -- reduce 2 dimensions + assertEvaluates("{ {}:4 }", + "reduce(tensor0, avg, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:4 }", + "reduce(tensor0, count, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:105 }", + "reduce(tensor0, prod, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:16 }", + "reduce(tensor0, sum, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:7 }", + "reduce(tensor0, max, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:1 }", + "reduce(tensor0, min, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + // -- reduce 2 by specifying no arguments + assertEvaluates("{ {}:4 }", + "reduce(tensor0, avg)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + // -- reduce 1 dimension + assertEvaluates("{ {y:1}:2, {y:2}:6 }", + "reduce(tensor0, avg, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:2, {y:2}:2 }", + "reduce(tensor0, count, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:3, {y:2}:35 }", + "reduce(tensor0, prod, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:4, {y:2}:12 }", + "reduce(tensor0, sum, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:3, {y:2}:7 }", + "reduce(tensor0, max, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {y:1}:1, {y:2}:5 }", + "reduce(tensor0, min, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + // -- reduce composites + assertEvaluates("{ {}: 5 }", "sum(tensor0)", "5.0"); + assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0"); + assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:l1}:5.5, {d2:l2}:7.0 }"); + assertEvaluates("{ {}: 0 }", "sum(tensor0)", "{ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0}"); + assertEvaluates("{ {y:1}:4, {y:2}:12.0 }", + "sum(tensor0, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {x:1}:6, {x:2}:10.0 }", + "sum(tensor0, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + assertEvaluates("{ {}:16 }", + "sum(tensor0, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }"); + + // tensor join + assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", "join(tensor0, tensor1, f(x,y) (x*y))", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + // -- join composites + assertEvaluates("{ }", "tensor0 * tensor0", "{}"); + assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )", + "{{x:-}:1}", "{}", "{{y:-,z:-}:1}"); // empty dimensions are preserved + assertEvaluates("tensor(x{}):{}", + "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:2}:5 }"); assertEvaluates("{ {x:1}:15 }", - "{ {x:1}:3 } * { {x:1}:5 }"); + "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:1}:5 }"); assertEvaluates("{ {x:1,y:1}:15 }", - "{ {x:1}:3 } * { {y:1}:5 }"); + "tensor0 * tensor1", "{ {x:1}:3 }", "{ {y:1}:5 }"); assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", - "{ {x:1}:3, {x:2}:7 } * { {y:1}:5 }"); + "tensor0 * tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:8, {x:2,y:1}:12 }", + "tensor0 + tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:-2, {x:2,y:1}:2 }", + "tensor0 - tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:4 }", + "tensor0 / tensor1", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }"); + assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:7 }", + "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:3, {x:2,y:1}:5 }", + "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }", - "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } * { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); - assertEvaluates("{ {x:1,y:1,z:1}:7 }", - "{ {x:1}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,z:1}:55 }", - "{ {x:1}:5, {x:1,y:1}:1 } * { {z:1}:11, {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1,y:1,z:1}:7 }", - "{ {}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }"); - assertEvaluates("{ {x:1,y:1,z:1}:7, {}:55 }", - "{ {}:5, {x:1,y:1}:1 } * { {}:11, {y:1,z:1}:7 }"); - - // match product - assertEvaluates("{ }", "match({}, {})"); - assertEvaluates("( {{x:-}:1} * {} )", - "match({ {x:1}:3 }, { {x:2}:5 })"); - assertEvaluates("{ {x:1}:15 }", - "match({ {x:1}:3 }, { {x:1}:5 })"); - assertEvaluates("( {{x:-,y:-}:1} * {} )", - "match({ {x:1}:3 }, { {y:1}:5 })"); - assertEvaluates("( {{x:-,y:-}:1} * {} )", - "match({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", - "match({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", - "match({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", - "match({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )", - "match({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("( {{x:-,y:-,z:-}:1} * { {}:55 } )", - "match({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); - assertEvaluates("( {{z:-}:1} * { {x:1}:15, {x:1,y:1}:7 } )", - "match({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); - - // min - assertEvaluates("{ {x:1}:3, {x:2}:5 }", - "min({ {x:1}:3 }, { {x:2}:5 })"); - assertEvaluates("{ {x:1}:3 }", - "min({ {x:1}:3 }, { {x:1}:5 })"); - assertEvaluates("{ {x:1}:3, {y:1}:5 }", - "min({ {x:1}:3 }, { {y:1}:5 })"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", - "min({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", - "min({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", - "min({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "min({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }", - "min({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); - - // max - assertEvaluates("{ {x:1}:3, {x:2}:5 }", - "max({ {x:1}:3 }, { {x:2}:5 })"); - assertEvaluates("{ {x:1}:5 }", - "max({ {x:1}:3 }, { {x:1}:5 })"); - assertEvaluates("{ {x:1}:3, {y:1}:5 }", - "max({ {x:1}:3 }, { {y:1}:5 })"); - assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }", - "max({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })"); - assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }", - "max({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }", - "max({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })"); - assertEvaluates("{ {}:11, {x:1,y:1}:1, {y:1,z:1}:7 }", - "max({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })"); - assertEvaluates("{ {}:5, {x:1}:5, {x:2}:4, {x:1,y:1}:7, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }", - "max({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })"); - - // Combined - assertEvaluates(7.5 + 45 + 1.7, - "sum( " + // model computation - " match( " + // model weight application - " { {x:1}:1, {x:2}:2 } * { {y:1}:3, {y:2}:4 } * { {z:1}:5 }, " + // feature combinations - " { {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }" + // model weights - "))+1.7"); - - // undefined is not the same as 0 - assertEvaluates(1.0, "sum({ {x:1}:0, {x:2}:0 } * { {x:1}:1, {x:2}:1 } + 0.5)"); - assertEvaluates(0.0, "sum({ } * { {x:1}:1, {x:2}:1 } + 0.5)"); + "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }"); + assertEvaluates("{ {x:1,y:2,z:1}:35, {x:1,y:2,z:2}:65 }", + "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:2,z:1}:7, {y:3,z:1}:11, {y:2,z:2}:13 }"); + assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }"); + assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }", + "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }", + "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }", + "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }"); + assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:0 }", + "atan2(tensor0, tensor1)", "{ {x:1}:0, {x:2}:0 }", "{ {y:1}:1 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 > tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 < tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 >= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 <= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }"); + assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }", + "tensor0 == tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); + // TODO + // argmax + // argmin + assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }", + "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }"); + + // tensor rename + assertEvaluates("{ {newX:1,y:2}:3 }", "rename(tensor0, x, newX)", "{ {x:1,y:2}:3.0 }"); + assertEvaluates("{ {x:2,y:1}:3 }", "rename(tensor0, (x, y), (y, x))", "{ {x:1,y:2}:3.0 }"); + + // tensor generate - TODO + // assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0, {x:2,y:2}:1, {x:1,y:2}:0 }", "tensor(x[2],y[2])(x==y)"); + // range + // diag + // fill + // random + + // composite functions + assertEvaluates("{ {x:1}:0.25, {x:2}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }"); + assertEvaluates("{ {x:1}:0.31622776601683794, {x:2}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }"); + assertEvaluates("{ {y:1}:81.0 }", "matmul(tensor0, tensor1, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }"); + assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "softmax(tensor0, x)", "{ {x:1}:1, {x:2}:1 }", "{ {y:1}:1 }"); + assertEvaluates("{ {x:1,y:1}:88.0 }", "xw_plus_b(tensor0, tensor1, tensor2, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }", "{ {x:1}:7 }"); + + // expressions combining functions + assertEvaluates(String.valueOf(7.5 + 45 + 1.7), + "sum( " + // model computation: + " tensor0 * tensor1 * tensor2 " + // - feature combinations + " * tensor3" + // - model weights application + ") + 1.7", + "{ {x:1}:1, {x:2}:2 }", "{ {y:1}:3, {y:2}:4 }", "{ {z:1}:5 }", + "{ {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }"); + assertEvaluates("1.0", "sum(tensor0 * tensor1 + 0.5)", "{ {x:1}:0, {x:2}:0 }", "{ {x:1}:1, {x:2}:1 }"); + assertEvaluates("0.0", "sum(tensor0 * tensor1 + 0.5)", "{}", "{ {x:1}:1, {x:2}:1 }"); // tensor result dimensions are given from argument dimensions, not the resulting values - assertEvaluates("x", "( {{x:-}:1.0} * {} )", "{ {x:1}:1 } * { {x:2}:1 }"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"); - - // demonstration of where this produces different results: { {x:1}:1 } with 2 dimensions ... - assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )","{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 } * { {x:1,y:1}:1 }"); - // ... vs { {x:1}:1 } with only one dimension - assertEvaluates("x, y", "{{x:1,y:1}:1.0}", "{ {x:1}:1 } * { {x:1,y:1}:1 }"); - - // check that dimensions are preserved through other operations - String d2 = "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"; // creates a 2d tensor with only an 1d value - assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )", "match(" + d2 + ", {})"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " - {}"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " + {}"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "min(1.5, " + d2 +")"); - assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "max({{x:1}:0}, " + d2 +")"); + assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2}:1 }"); + assertEvaluates("tensor(x{},y{}):{{x:1}:1.0}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1}:1 }"); } public void testProgrammaticBuildingAndPrecedence() { @@ -316,12 +295,16 @@ public class EvaluationTestCase extends junit.framework.TestCase { assertEvaluates(77, "average(\"2*3\",\"pow(2,3)\")+average(\"2*3\",\"pow(2,3)\").timesten", context); } - private RankingExpression assertEvaluates(String tensorValue, String expressionString) { - return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, defaultContext); + private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) { + MapContext context = defaultContext.thawedCopy(); + int argumentIndex = 0; + for (String tensorArgument : tensorArguments) + context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument))); + return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context); } /** Validate also that the dimension of the resulting tensors are as expected */ - private RankingExpression assertEvaluates(String tensorDimensions, String resultTensor, String expressionString) { + private RankingExpression assertEvaluates_old(String tensorDimensions, String resultTensor, String expressionString) { RankingExpression expression = assertEvaluates(new TensorValue(MapTensor.from(resultTensor)), expressionString, defaultContext); TensorValue value = (TensorValue)expression.evaluate(defaultContext); assertEquals(toSet(tensorDimensions), value.asTensor().dimensions()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java index 95c4402a612..08fdc9917a4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java @@ -17,22 +17,25 @@ public class NeuralNetEvaluationTestCase { /** "XOR" neural network, separate expression per layer */ @Test public void testPerLayerExpression() { - String input = "{ {x:1}:0, {x:2}:1 }"; - - String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; - String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; - String firstLayerInput = "sum(" + input + "*" + firstLayerWeights + ", x) + " + firstLayerBias; + String input = "{ {x:1}:0, {x:2}:1 }"; // tensor0 + String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; // tensor1 + String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; // tensor2 + String firstLayerInput = "sum(tensor0 * tensor1, x) + tensor2"; String firstLayerOutput = "min(1.0, max(0.0, 0.5 + " + firstLayerInput + "))"; // non-linearity, "poor man's sigmoid" - assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput); - String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; - String secondLayerBias = "{ {y:1}:-0.5 }"; - String secondLayerInput = "sum(" + firstLayerOutput + "*" + secondLayerWeights + ", h) + " + secondLayerBias; + assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput, input, firstLayerWeights, firstLayerBias); + String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; // tensor3 + String secondLayerBias = "{ {y:1}:-0.5 }"; // tensor4 + String secondLayerInput = "sum(" + firstLayerOutput + "* tensor3, h) + tensor4"; String secondLayerOutput = "min(1.0, max(0.0, 0.5 + " + secondLayerInput + "))"; // non-linearity, "poor man's sigmoid" - assertEvaluates("{ {y:1}:1 }", secondLayerOutput); + assertEvaluates("{ {y:1}:1 }", secondLayerOutput, input, firstLayerWeights, firstLayerBias, secondLayerWeights, secondLayerBias); } - private RankingExpression assertEvaluates(String tensorValue, String expressionString) { - return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, new MapContext()); + private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) { + MapContext context = new MapContext(); + int argumentIndex = 0; + for (String tensorArgument : tensorArguments) + context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument))); + return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context); } private RankingExpression assertEvaluates(Value value, String expressionString, Context context) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java index 9d94ec0bc99..61b230ab390 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java @@ -69,12 +69,4 @@ public class SimplifierTestCase { assertEquals("a + (b + c) / 100000000.0", transformed.toString()); } - @Test - public void testSimplificationWithTensorConstants() throws ParseException { - new Simplifier().transform(new RankingExpression( - "sum(sum((tensorFromWeightedSet(query(wset_query),x)+" + - " tensorFromWeightedSet(attribute(wset),x)) * " + - " {{x:0,y:0}:54, {x:0,y:1} :69, {x:1,y:0} :72, {x:1,y:1} :93},x))")); - } - } |