aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2016-11-26 22:45:20 +0100
committerGitHub <noreply@github.com>2016-11-26 22:45:20 +0100
commit2f55986b4de9420e5728c5abbaafb69fb2f10a34 (patch)
tree9a6a77f76d25620771dfe7ab5de49910c4321fc5 /searchlib/src/test
parent2bc82ba9d9698214e703f19039387609d82b12f8 (diff)
Revert "Revert "Bratseth/tensor functions 3""
Diffstat (limited to 'searchlib/src/test')
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java175
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java359
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java27
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java8
4 files changed, 299 insertions, 270 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 24d7c82235c..f28ff739b4c 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -6,7 +6,10 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
-import junit.framework.TestCase;
+import org.junit.Test;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.assertFalse;
import java.io.BufferedReader;
import java.io.File;
@@ -14,15 +17,18 @@ import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.*;
/**
- * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a>
+ * @author Simon Thoresen
+ * @author bratseth
*/
-public class RankingExpressionTestCase extends TestCase {
+public class RankingExpressionTestCase {
+ @Test
public void testParamInFeature() throws ParseException {
assertParse("if (1 > 2, dotProduct(allparentid,query(cate1_parentid)), 2)",
"if ( 1 > 2,\n" +
@@ -31,6 +37,7 @@ public class RankingExpressionTestCase extends TestCase {
")");
}
+ @Test
public void testDollarShorthand() throws ParseException {
assertParse("query(var1)", " $var1");
assertParse("query(var1)", " $var1 ");
@@ -44,6 +51,7 @@ public class RankingExpressionTestCase extends TestCase {
assertParse("if (if (f1.out < query(p1), 0, 1) < if (f2.out < query(p2), 0, 1), f3.out, query(p3))", "if(if(f1.out<$p1,0,1)<if(f2.out<$p2,0,1),f3.out,$p3)");
}
+ @Test
public void testLookaheadIndefinitely() throws Exception {
ExecutorService exec = Executors.newSingleThreadExecutor();
Future<Boolean> future = exec.submit(new Callable<Boolean>() {
@@ -60,7 +68,8 @@ public class RankingExpressionTestCase extends TestCase {
assertTrue(future.get(60, TimeUnit.SECONDS));
}
- public void testSelfRecursionScript() throws ParseException {
+ @Test
+ public void testSelfRecursionSerialization() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", null, new RankingExpression("foo")));
@@ -72,7 +81,8 @@ public class RankingExpressionTestCase extends TestCase {
}
}
- public void testMacroCycleScript() throws ParseException {
+ @Test
+ public void testMacroCycleSerialization() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", null, new RankingExpression("bar")));
macros.add(new ExpressionFunction("bar", null, new RankingExpression("foo")));
@@ -85,42 +95,48 @@ public class RankingExpressionTestCase extends TestCase {
}
}
- public void testScript() throws ParseException {
+ @Test
+ public void testSerialization() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("foo", Arrays.asList("arg1", "arg2"), new RankingExpression("min(arg1, pow(arg2, 2))")));
macros.add(new ExpressionFunction("bar", Arrays.asList("arg1", "arg2"), new RankingExpression("arg1 * arg1 + 2 * arg1 * arg2 + arg2 * arg2")));
macros.add(new ExpressionFunction("baz", Arrays.asList("arg1", "arg2"), new RankingExpression("foo(1, 2) / bar(arg1, arg2)")));
macros.add(new ExpressionFunction("cox", null, new RankingExpression("10 + 08 * 1977")));
- assertScript("foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros,
- Arrays.asList(
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)",
- "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))",
- "min(6,pow(7,2))",
- "min(1,pow(2,2))",
- "min(3,pow(4,2))",
- "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"
- ));
- assertScript("foo(1, 2) + bar(3, 4)", macros,
- Arrays.asList(
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)",
- "min(1,pow(2,2))",
- "3 * 3 + 2 * 3 * 4 + 4 * 4"
- ));
- assertScript("baz(1, 2)", macros,
- Arrays.asList(
- "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)",
- "min(1,pow(2,2))",
- "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)",
- "1 * 1 + 2 * 1 * 2 + 2 * 2"
- ));
- assertScript("cox", macros,
- Arrays.asList(
- "rankingExpression(cox)",
- "10 + 08 * 1977"
- ));
+ assertSerialization(Arrays.asList(
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(foo@af74e3fd9070bd18.a368ed0a5ba3a5d0) * rankingExpression(foo@dbab346efdad5362.e5c39e42ebd91c30)",
+ "min(5,pow(rankingExpression(foo@d1d1417259cdc651.573bbcd4be18f379),2))",
+ "min(6,pow(7,2))",
+ "min(1,pow(2,2))",
+ "min(3,pow(4,2))",
+ "min(rankingExpression(foo@84951be88255b0ec.d0303e061b36fab8),pow(8,2))"), "foo(1,2) + foo(3,4) * foo(5, foo(foo(6, 7), 8))", macros);
+ assertSerialization(Arrays.asList(
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) + rankingExpression(bar@af74e3fd9070bd18.a368ed0a5ba3a5d0)",
+ "min(1,pow(2,2))",
+ "3 * 3 + 2 * 3 * 4 + 4 * 4"), "foo(1, 2) + bar(3, 4)", macros);
+ assertSerialization(Arrays.asList(
+ "rankingExpression(baz@e2dc17a89864aed0.12232eb692c6c502)",
+ "min(1,pow(2,2))",
+ "rankingExpression(foo@e2dc17a89864aed0.12232eb692c6c502) / rankingExpression(bar@e2dc17a89864aed0.12232eb692c6c502)",
+ "1 * 1 + 2 * 1 * 2 + 2 * 2"), "baz(1, 2)", macros);
+ assertSerialization(Arrays.asList(
+ "rankingExpression(cox)",
+ "10 + 08 * 1977"), "cox", macros
+ );
+ }
+
+ @Test
+ public void testTensorSerialization() {
+ assertSerialization("map(constant(tensor0), f(a)(cos(a)))",
+ "map(constant(tensor0), f(a)(cos(a)))");
+ assertSerialization("map(constant(tensor0), f(a)(cos(a))) + join(attribute(tensor1), map(reduce(map(attribute(tensor1), f(a)(a * a)), sum, x), f(a)(sqrt(a))), f(a,b)(a / b))",
+ "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)");
+ assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))",
+ "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)");
+
}
+ @Test
public void testBug3464208() throws ParseException {
List<ExpressionFunction> macros = new ArrayList<>();
macros.add(new ExpressionFunction("log10tweetage", null, new RankingExpression("69")));
@@ -135,18 +151,11 @@ public class RankingExpressionTestCase extends TestCase {
String expRhs = "(rankingExpression(log10tweetage) * rankingExpression(log10tweetage) * " +
"rankingExpression(log10tweetage)) + 5.0 * attribute(ythl)";
- assertScript(lhs + " + " + rhs, macros,
- Arrays.asList(
- expLhs + " + " + expRhs,
- "69"
- ));
- assertScript(lhs + " - " + rhs, macros,
- Arrays.asList(
- expLhs + " - " + expRhs,
- "69"
- ));
+ assertSerialization(Arrays.asList(expLhs + " + " + expRhs, "69"), lhs + " + " + rhs, macros);
+ assertSerialization(Arrays.asList(expLhs + " - " + expRhs, "69"), lhs + " - " + rhs, macros);
}
+ @Test
public void testParse() throws ParseException, IOException {
BufferedReader reader = new BufferedReader(new FileReader("src/tests/rankingexpression/rankingexpressionlist"));
String line;
@@ -181,36 +190,43 @@ public class RankingExpressionTestCase extends TestCase {
}
}
+ @Test
public void testIssue() throws ParseException {
assertEquals("feature.0", new RankingExpression("feature.0").toString());
assertEquals("if (1 > 2, 3, 4) + feature(arg1).out.out",
new RankingExpression("if ( 1 > 2 , 3 , 4 ) + feature ( arg1 ) . out.out").toString());
}
+ @Test
public void testNegativeConstantArgument() throws ParseException {
assertEquals("foo(-1.2)", new RankingExpression("foo(-1.2)").toString());
}
+ @Test
public void testNaming() throws ParseException {
RankingExpression test = new RankingExpression("a+b");
test.setName("test");
assertEquals("test: a + b", test.toString());
}
+ @Test
public void testCondition() throws ParseException {
RankingExpression expression = new RankingExpression("if(1<2,3,4)");
assertTrue(expression.getRoot() instanceof IfNode);
}
+ @Test
public void testFileImporting() throws ParseException {
RankingExpression expression = new RankingExpression(new File("src/test/files/simple.expression"));
assertEquals("simple: a + b", expression.toString());
}
+ @Test
public void testNonCanonicalLegalStrings() throws ParseException {
assertParse("a * b + c * d", "a* (b) + \nc*d");
}
+ @Test
public void testEquality() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo)==\"BAR\",log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"),
new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
@@ -219,6 +235,7 @@ public class RankingExpressionTestCase extends TestCase {
new RankingExpression("if(attribute(foo)==\"BAR\", log(attribute(popularity)+5),log(fieldMatch(title).earliness) * fieldMatch(title).completeness)")));
}
+ @Test
public void testSetMembershipConditions() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],log(attribute(popularity)+5),log(fieldMatch(title).proximity)*fieldMatch(title).completeness)"),
new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
@@ -231,6 +248,7 @@ public class RankingExpressionTestCase extends TestCase {
assertEquals(new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"), new RankingExpression("if (GENDER$ in [-1.0, 1.0], 1, 0)"));
}
+ @Test
public void testComments() throws ParseException {
assertEquals(new RankingExpression("if ( attribute(foo) in [\"FOO\", \"BAR\"],\n" +
"# a comment\n" +
@@ -241,6 +259,7 @@ public class RankingExpressionTestCase extends TestCase {
new RankingExpression("if(attribute(foo) in [\"FOO\",\"BAR\"], log(attribute(popularity)+5),log(fieldMatch(title).proximity) * fieldMatch(title).completeness)"));
}
+ @Test
public void testIsNan() throws ParseException {
String strExpr = "if (isNan(attribute(foo)) == 1.0, 1.0, attribute(foo))";
RankingExpression expr = new RankingExpression(strExpr);
@@ -255,27 +274,59 @@ public class RankingExpressionTestCase extends TestCase {
assertEquals(expected, new RankingExpression(expression).toString());
}
- private void assertScript(String expression, List<ExpressionFunction> macros, List<String> expectedScripts)
- throws ParseException {
- boolean print = false;
- if (print)
- System.out.println("Parsing expression '" + expression + "'.");
-
- RankingExpression exp = new RankingExpression(expression);
- Map<String, String> scripts = exp.getRankProperties(macros);
- if (print) {
- for (String key : scripts.keySet()) {
- System.out.println("Script '" + key + "': " + scripts.get(key));
- }
+ /** Test serialization with no macros */
+ private void assertSerialization(String expectedSerialization, String expressionString) {
+ String serializedExpression;
+ try {
+ RankingExpression expression = new RankingExpression(expressionString);
+ // No macros -> expect one rank property
+ serializedExpression = expression.getRankProperties(Collections.emptyList()).values().iterator().next();
+ assertEquals(expectedSerialization, serializedExpression);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
}
- for (Map.Entry<String, String> m : scripts.entrySet())
- System.out.println(m);
- for (int i = 0; i < expectedScripts.size();) {
- String val = expectedScripts.get(i++);
- assertTrue("Script contains " + val, scripts.containsValue(val));
+ try {
+ // No macros -> output should be parseable to a ranking expression
+ // (but not the same one due to primitivization)
+ RankingExpression reparsedExpression = new RankingExpression(serializedExpression);
+ // Serializing the primitivized expression should yield the same expression again
+ String reserializedExpression =
+ reparsedExpression.getRankProperties(Collections.emptyList()).values().iterator().next();
+ assertEquals(expectedSerialization, reserializedExpression);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException("Could not parse the serialized expression", e);
}
- if (print)
- System.out.println("");
}
+
+ private void assertSerialization(List<String> expectedSerialization, String expressionString,
+ List<ExpressionFunction> macros) {
+ assertSerialization(expectedSerialization, expressionString, macros, false);
+ }
+ private void assertSerialization(List<String> expectedSerialization, String expressionString,
+ List<ExpressionFunction> macros, boolean print) {
+ try {
+ if (print)
+ System.out.println("Parsing expression '" + expressionString + "'.");
+
+ RankingExpression expression = new RankingExpression(expressionString);
+ Map<String, String> rankProperties = expression.getRankProperties(macros);
+ if (print) {
+ for (String key : rankProperties.keySet())
+ System.out.println("Property '" + key + "': " + rankProperties.get(key));
+ }
+ for (int i = 0; i < expectedSerialization.size();) {
+ String val = expectedSerialization.get(i++);
+ assertTrue("Properties contains " + val, rankProperties.containsValue(val));
+ }
+ if (print)
+ System.out.println("");
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index b67a423181d..93800e2c246 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -20,7 +20,7 @@ import java.util.Set;
*/
public class EvaluationTestCase extends junit.framework.TestCase {
- private Context defaultContext;
+ private MapContext defaultContext;
@Override
protected void setUp() {
@@ -100,201 +100,180 @@ public class EvaluationTestCase extends junit.framework.TestCase {
@Test
public void testTensorEvaluation() {
- assertEvaluates("{}", "{}"); // empty
- assertEvaluates("( {{x:-}:1} * {} )", "( {{x:-}:1} * {} )"); // empty with dimensions
+ assertEvaluates("{}", "tensor0", "{}");
- // sum(tensor)
- assertEvaluates(5.0, "sum({{}:5.0})");
- assertEvaluates(-5.0, "sum({{}:-5.0})");
- assertEvaluates(12.5, "sum({ {d1:l1}:5.5, {d2:l2}:7.0 })");
- assertEvaluates(0.0, "sum({ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0})");
-
- // scalar functions on tensors
+ // tensor map
assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
- "log10({ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 })");
- assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }",
- "5 * { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
- assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }",
- "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } + 3");
- assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }",
- "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 } / 10");
+ "map(tensor0, f(x) (log10(x)))", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:4, {d1:l1}:9, {d1:l1,d2:l1 }:16 }",
+ "map(tensor0, f(x) (x * x))", "{ {}:2, {d1:l1}:3, {d1:l1,d2:l1}:4 }");
+ // -- tensor map composites
+ assertEvaluates("{ {}:1, {d1:l1}:2, {d1:l1,d2:l1 }:3 }",
+ "log10(tensor0)", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
assertEvaluates("{ {}:-10, {d1:l1}:-100, {d1:l1,d2:l1 }:-1000 }",
- "- { {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ "- tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
assertEvaluates("{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1 }:0 }",
- "min({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)");
+ "min(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
assertEvaluates("{ {}:0, {d1:l1}:0, {d1:l1,d2:l1 }:10 }",
- "max({ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }, 0)");
- assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + {{h:1}:1.0,{h:2}:1.0}");
-
- // sum(tensor, dimension)
- assertEvaluates("{ {y:1}:4.0, {y:2}:12.0 }",
- "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, x)");
- assertEvaluates("{ {x:1}:6.0, {x:2}:10.0 }",
- "sum({ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }, y)");
-
- // tensor sum
- assertEvaluates("{ }", "{} + {}");
- assertEvaluates("{ {x:1}:3, {x:2}:5 }",
- "{ {x:1}:3 } + { {x:2}:5 }");
- assertEvaluates("{ {x:1}:8 }",
- "{ {x:1}:3 } + { {x:1}:5 }");
- assertEvaluates("{ {x:1}:3, {y:1}:5 }",
- "{ {x:1}:3 } + { {y:1}:5 }");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
- "{ {x:1}:3, {x:2}:7 } + { {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
- "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } + { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } + { {z:1}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "{ {}:5, {x:1,y:1}:1 } + { {y:1,z:1}:7 }");
- assertEvaluates("{ {}:16, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "{ {}:5, {x:1,y:1}:1 } + { {}:11, {y:1,z:1}:7 }");
-
- // tensor difference
- assertEvaluates("{ }", "{} - {}");
- assertEvaluates("{ {x:1}:3, {x:2}:-5 }",
- "{ {x:1}:3 } - { {x:2}:5 }");
- assertEvaluates("{ {x:1}:-2 }",
- "{ {x:1}:3 } - { {x:1}:5 }");
- assertEvaluates("{ {x:1}:3, {y:1}:-5 }",
- "{ {x:1}:3 } - { {y:1}:5 }");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:-5 }",
- "{ {x:1}:3, {x:2}:7 } - { {y:1}:5 }");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:-7, {y:2,z:1}:-11, {y:1,z:2}:-13 }",
- "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } - { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:-11, {y:1,z:1}:-7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } - { {z:1}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:-7 }",
- "{ {}:5, {x:1,y:1}:1 } - { {y:1,z:1}:7 }");
- assertEvaluates("{ {}:-6, {x:1,y:1}:1, {y:1,z:1}:-7 }",
- "{ {}:5, {x:1,y:1}:1 } - { {}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1}:0 }",
- "{ {x:1}:3 } - { {x:1}:3 }");
- assertEvaluates("{ {x:1}:0, {x:2}:1 }",
- "{ {x:1}:3, {x:2}:1 } - { {x:1}:3 }");
-
- // tensor product
- assertEvaluates("{ }", "{} * {}");
- assertEvaluates("( {{x:-,y:-,z:-}:1}*{} )", "( {{x:-}:1} * {} ) * ( {{y:-,z:-}:1} * {} )"); // empty dimensions are preserved
- assertEvaluates("( {{x:-}:1} * {} )",
- "{ {x:1}:3 } * { {x:2}:5 }");
+ "max(tensor0, 0)", "{ {}:-10, {d1:l1}:0, {d1:l1,d2:l1}:10 }");
+ // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence)
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "abs(tensor0)", "{ {x:1}:1, {x:2}:-2 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "acos(tensor0)", "{ {x:1}:1, {x:2}:1 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "asin(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "atan(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "ceil(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cos(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:1 }", "cosh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "elu(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:1 }", "exp(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "fabs(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "floor(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "isNan(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "log(tensor0)", "{ {x:1}:1, {x:2}:1 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:1 }", "log10(tensor0)", "{ {x:1}:1, {x:2}:10 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:2 }", "mod(tensor0, 3)", "{ {x:1}:3, {x:2}:8 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:8 }", "pow(tensor0, 3)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "relu(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:2 }", "round(tensor0)", "{ {x:1}:1, {x:2}:1.8 }");
+ assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "sigmoid(tensor0)","{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:-1 }", "sign(tensor0)", "{ {x:1}:3, {x:2}:-5 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sin(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "sinh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:4 }", "square(tensor0)", "{ {x:1}:1, {x:2}:2 }");
+ assertEvaluates("{ {x:1}:1, {x:2}:3 }", "sqrt(tensor0)", "{ {x:1}:1, {x:2}:9 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tan(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+ assertEvaluates("{ {x:1}:0, {x:2}:0 }", "tanh(tensor0)", "{ {x:1}:0, {x:2}:0 }");
+
+ // tensor reduce
+ // -- reduce 2 dimensions
+ assertEvaluates("{ {}:4 }",
+ "reduce(tensor0, avg, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:4 }",
+ "reduce(tensor0, count, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:105 }",
+ "reduce(tensor0, prod, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:16 }",
+ "reduce(tensor0, sum, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:7 }",
+ "reduce(tensor0, max, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:1 }",
+ "reduce(tensor0, min, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ // -- reduce 2 by specifying no arguments
+ assertEvaluates("{ {}:4 }",
+ "reduce(tensor0, avg)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ // -- reduce 1 dimension
+ assertEvaluates("{ {y:1}:2, {y:2}:6 }",
+ "reduce(tensor0, avg, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:2, {y:2}:2 }",
+ "reduce(tensor0, count, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:3, {y:2}:35 }",
+ "reduce(tensor0, prod, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:4, {y:2}:12 }",
+ "reduce(tensor0, sum, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:3, {y:2}:7 }",
+ "reduce(tensor0, max, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {y:1}:1, {y:2}:5 }",
+ "reduce(tensor0, min, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ // -- reduce composites
+ assertEvaluates("{ {}: 5 }", "sum(tensor0)", "5.0");
+ assertEvaluates("{ {}:-5 }", "sum(tensor0)", "-5.0");
+ assertEvaluates("{ {}:12.5 }", "sum(tensor0)", "{ {d1:l1}:5.5, {d2:l2}:7.0 }");
+ assertEvaluates("{ {}: 0 }", "sum(tensor0)", "{ {d1:l1}:5.0, {d2:l2}:7.0, {}:-12.0}");
+ assertEvaluates("{ {y:1}:4, {y:2}:12.0 }",
+ "sum(tensor0, x)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {x:1}:6, {x:2}:10.0 }",
+ "sum(tensor0, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+ assertEvaluates("{ {}:16 }",
+ "sum(tensor0, x, y)", "{ {x:1,y:1}:1.0, {x:2,y:1}:3.0, {x:1,y:2}:5.0, {x:2,y:2}:7.0 }");
+
+ // tensor join
+ assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }", "join(tensor0, tensor1, f(x,y) (x*y))", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ // -- join composites
+ assertEvaluates("{ }", "tensor0 * tensor0", "{}");
+ assertEvaluates("tensor(x{},y{},z{}):{}", "( tensor0 * tensor1 ) * ( tensor2 * tensor1 )",
+ "{{x:-}:1}", "{}", "{{y:-,z:-}:1}"); // empty dimensions are preserved
+ assertEvaluates("tensor(x{}):{}",
+ "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:2}:5 }");
assertEvaluates("{ {x:1}:15 }",
- "{ {x:1}:3 } * { {x:1}:5 }");
+ "tensor0 * tensor1", "{ {x:1}:3 }", "{ {x:1}:5 }");
assertEvaluates("{ {x:1,y:1}:15 }",
- "{ {x:1}:3 } * { {y:1}:5 }");
+ "tensor0 * tensor1", "{ {x:1}:3 }", "{ {y:1}:5 }");
assertEvaluates("{ {x:1,y:1}:15, {x:2,y:1}:35 }",
- "{ {x:1}:3, {x:2}:7 } * { {y:1}:5 }");
+ "tensor0 * tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:8, {x:2,y:1}:12 }",
+ "tensor0 + tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:-2, {x:2,y:1}:2 }",
+ "tensor0 - tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:4 }",
+ "tensor0 / tensor1", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }");
+ assertEvaluates("{ {x:1,y:1}:5, {x:2,y:1}:7 }",
+ "max(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:3, {x:2,y:1}:5 }",
+ "min(tensor0, tensor1)", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,y:1,z:2}:13, {x:2,y:1,z:1}:21, {x:2,y:1,z:2}:39, {x:1,y:2,z:1}:55 }",
- "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 } * { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
- assertEvaluates("{ {x:1,y:1,z:1}:7 }",
- "{ {x:1}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1,y:1,z:1}:7, {x:1,z:1}:55 }",
- "{ {x:1}:5, {x:1,y:1}:1 } * { {z:1}:11, {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1,y:1,z:1}:7 }",
- "{ {}:5, {x:1,y:1}:1 } * { {y:1,z:1}:7 }");
- assertEvaluates("{ {x:1,y:1,z:1}:7, {}:55 }",
- "{ {}:5, {x:1,y:1}:1 } * { {}:11, {y:1,z:1}:7 }");
-
- // match product
- assertEvaluates("{ }", "match({}, {})");
- assertEvaluates("( {{x:-}:1} * {} )",
- "match({ {x:1}:3 }, { {x:2}:5 })");
- assertEvaluates("{ {x:1}:15 }",
- "match({ {x:1}:3 }, { {x:1}:5 })");
- assertEvaluates("( {{x:-,y:-}:1} * {} )",
- "match({ {x:1}:3 }, { {y:1}:5 })");
- assertEvaluates("( {{x:-,y:-}:1} * {} )",
- "match({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
- "match({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
- "match({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
- "match({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * {} )",
- "match({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("( {{x:-,y:-,z:-}:1} * { {}:55 } )",
- "match({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
- assertEvaluates("( {{z:-}:1} * { {x:1}:15, {x:1,y:1}:7 } )",
- "match({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
-
- // min
- assertEvaluates("{ {x:1}:3, {x:2}:5 }",
- "min({ {x:1}:3 }, { {x:2}:5 })");
- assertEvaluates("{ {x:1}:3 }",
- "min({ {x:1}:3 }, { {x:1}:5 })");
- assertEvaluates("{ {x:1}:3, {y:1}:5 }",
- "min({ {x:1}:3 }, { {y:1}:5 })");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
- "min({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
- "min({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
- "min({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "min({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }",
- "min({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
-
- // max
- assertEvaluates("{ {x:1}:3, {x:2}:5 }",
- "max({ {x:1}:3 }, { {x:2}:5 })");
- assertEvaluates("{ {x:1}:5 }",
- "max({ {x:1}:3 }, { {x:1}:5 })");
- assertEvaluates("{ {x:1}:3, {y:1}:5 }",
- "max({ {x:1}:3 }, { {y:1}:5 })");
- assertEvaluates("{ {x:1}:3, {x:2}:7, {y:1}:5 }",
- "max({ {x:1}:3, {x:2}:7 }, { {y:1}:5 })");
- assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5, {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }",
- "max({ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }, { {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 })");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max({ {x:1}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("{ {x:1}:5, {x:1,y:1}:1, {z:1}:11, {y:1,z:1}:7 }",
- "max({ {x:1}:5, {x:1,y:1}:1 }, { {z:1}:11, {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max({ {}:5, {x:1,y:1}:1 }, { {y:1,z:1}:7 })");
- assertEvaluates("{ {}:11, {x:1,y:1}:1, {y:1,z:1}:7 }",
- "max({ {}:5, {x:1,y:1}:1 }, { {}:11, {y:1,z:1}:7 })");
- assertEvaluates("{ {}:5, {x:1}:5, {x:2}:4, {x:1,y:1}:7, {x:1,y:2}:6, {z:1,y:1,x:1}:10 }",
- "max({ {}:5, {x:1}:3, {x:2}:4, {x:1,y:1}:1, {x:1,y:2}:6 }, { {x:1}:5, {y:1,x:1}:7, {z:1,y:1,x:1}:10 })");
-
- // Combined
- assertEvaluates(7.5 + 45 + 1.7,
- "sum( " + // model computation
- " match( " + // model weight application
- " { {x:1}:1, {x:2}:2 } * { {y:1}:3, {y:2}:4 } * { {z:1}:5 }, " + // feature combinations
- " { {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }" + // model weights
- "))+1.7");
-
- // undefined is not the same as 0
- assertEvaluates(1.0, "sum({ {x:1}:0, {x:2}:0 } * { {x:1}:1, {x:2}:1 } + 0.5)");
- assertEvaluates(0.0, "sum({ } * { {x:1}:1, {x:2}:1 } + 0.5)");
+ "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:1,z:1}:7, {y:2,z:1}:11, {y:1,z:2}:13 }");
+ assertEvaluates("{ {x:1,y:2,z:1}:35, {x:1,y:2,z:2}:65 }",
+ "tensor0 * tensor1", "{ {x:1,y:1}:1, {x:2,y:1}:3, {x:1,y:2}:5 }", "{ {y:2,z:1}:7, {y:3,z:1}:11, {y:2,z:2}:13 }");
+ assertEvaluates("{{x:1,y:1}:0.0}","tensor1 * tensor2 * tensor3", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1,y:1}:1 }", "{ {x:1,y:1}:1 }");
+ assertEvaluates("{ {}:50, {d1:l1}:500, {d1:l1,d2:l1}:5000 }",
+ "5 * tensor0", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:13, {d1:l1}:103, {d1:l1,d2:l1}:1003 }",
+ "tensor0 + 3","{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {}:1, {d1:l1}:10, {d1:l1,d2:l1 }:100 }",
+ "tensor0 / 10", "{ {}:10, {d1:l1}:100, {d1:l1,d2:l1}:1000 }");
+ assertEvaluates("{ {h:1}:1.5, {h:2}:1.5 }", "0.5 + tensor0", "{ {h:1}:1.0,{h:2}:1.0 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:0 }",
+ "atan2(tensor0, tensor1)", "{ {x:1}:0, {x:2}:0 }", "{ {y:1}:1 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 > tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 < tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 >= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 <= tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:5 }");
+ assertEvaluates("{ {x:1,y:1}:0, {x:2,y:1}:1 }",
+ "tensor0 == tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
+ // TODO
+ // argmax
+ // argmin
+ assertEvaluates("{ {x:1,y:1}:1, {x:2,y:1}:0 }",
+ "tensor0 != tensor1", "{ {x:1}:3, {x:2}:7 }", "{ {y:1}:7 }");
+
+ // tensor rename
+ assertEvaluates("{ {newX:1,y:2}:3 }", "rename(tensor0, x, newX)", "{ {x:1,y:2}:3.0 }");
+ assertEvaluates("{ {x:2,y:1}:3 }", "rename(tensor0, (x, y), (y, x))", "{ {x:1,y:2}:3.0 }");
+
+ // tensor generate - TODO
+ // assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0, {x:2,y:2}:1, {x:1,y:2}:0 }", "tensor(x[2],y[2])(x==y)");
+ // range
+ // diag
+ // fill
+ // random
+
+ // composite functions
+ assertEvaluates("{ {x:1}:0.25, {x:2}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }");
+ assertEvaluates("{ {x:1}:0.31622776601683794, {x:2}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:1}:1, {x:2}:3 }");
+ assertEvaluates("{ {y:1}:81.0 }", "matmul(tensor0, tensor1, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }");
+ assertEvaluates("{ {x:1}:0.5, {x:2}:0.5 }", "softmax(tensor0, x)", "{ {x:1}:1, {x:2}:1 }", "{ {y:1}:1 }");
+ assertEvaluates("{ {x:1,y:1}:88.0 }", "xw_plus_b(tensor0, tensor1, tensor2, x)", "{ {x:1}:15, {x:2}:12 }", "{ {y:1}:3 }", "{ {x:1}:7 }");
+
+ // expressions combining functions
+ assertEvaluates(String.valueOf(7.5 + 45 + 1.7),
+ "sum( " + // model computation:
+ " tensor0 * tensor1 * tensor2 " + // - feature combinations
+ " * tensor3" + // - model weights application
+ ") + 1.7",
+ "{ {x:1}:1, {x:2}:2 }", "{ {y:1}:3, {y:2}:4 }", "{ {z:1}:5 }",
+ "{ {x:1,y:1,z:1}:0.5, {x:2,y:1,z:1}:1.5, {x:1,y:1,z:2}:4.5 }");
+ assertEvaluates("1.0", "sum(tensor0 * tensor1 + 0.5)", "{ {x:1}:0, {x:2}:0 }", "{ {x:1}:1, {x:2}:1 }");
+ assertEvaluates("0.0", "sum(tensor0 * tensor1 + 0.5)", "{}", "{ {x:1}:1, {x:2}:1 }");
// tensor result dimensions are given from argument dimensions, not the resulting values
- assertEvaluates("x", "( {{x:-}:1.0} * {} )", "{ {x:1}:1 } * { {x:2}:1 }");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }");
-
- // demonstration of where this produces different results: { {x:1}:1 } with 2 dimensions ...
- assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )","{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 } * { {x:1,y:1}:1 }");
- // ... vs { {x:1}:1 } with only one dimension
- assertEvaluates("x, y", "{{x:1,y:1}:1.0}", "{ {x:1}:1 } * { {x:1,y:1}:1 }");
-
- // check that dimensions are preserved through other operations
- String d2 = "{ {x:1}:1 } * { {x:2,y:1}:1, {x:1}:1 }"; // creates a 2d tensor with only an 1d value
- assertEvaluates("x, y", "( {{x:-,y:-}:1.0} * {} )", "match(" + d2 + ", {})");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " - {}");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", d2 + " + {}");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "min(1.5, " + d2 +")");
- assertEvaluates("x, y", "( {{y:-}:1.0} * {{x:1}:1.0} )", "max({{x:1}:0}, " + d2 +")");
+ assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2}:1 }");
+ assertEvaluates("tensor(x{},y{}):{{x:1}:1.0}", "tensor0 * tensor1", "{ {x:1}:1 }", "{ {x:2,y:1}:1, {x:1}:1 }");
}
public void testProgrammaticBuildingAndPrecedence() {
@@ -316,12 +295,16 @@ public class EvaluationTestCase extends junit.framework.TestCase {
assertEvaluates(77, "average(\"2*3\",\"pow(2,3)\")+average(\"2*3\",\"pow(2,3)\").timesten", context);
}
- private RankingExpression assertEvaluates(String tensorValue, String expressionString) {
- return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, defaultContext);
+ private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
+ MapContext context = defaultContext.thawedCopy();
+ int argumentIndex = 0;
+ for (String tensorArgument : tensorArguments)
+ context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument)));
+ return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context);
}
/** Validate also that the dimension of the resulting tensors are as expected */
- private RankingExpression assertEvaluates(String tensorDimensions, String resultTensor, String expressionString) {
+ private RankingExpression assertEvaluates_old(String tensorDimensions, String resultTensor, String expressionString) {
RankingExpression expression = assertEvaluates(new TensorValue(MapTensor.from(resultTensor)), expressionString, defaultContext);
TensorValue value = (TensorValue)expression.evaluate(defaultContext);
assertEquals(toSet(tensorDimensions), value.asTensor().dimensions());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
index 95c4402a612..08fdc9917a4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/NeuralNetEvaluationTestCase.java
@@ -17,22 +17,25 @@ public class NeuralNetEvaluationTestCase {
/** "XOR" neural network, separate expression per layer */
@Test
public void testPerLayerExpression() {
- String input = "{ {x:1}:0, {x:2}:1 }";
-
- String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }";
- String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }";
- String firstLayerInput = "sum(" + input + "*" + firstLayerWeights + ", x) + " + firstLayerBias;
+ String input = "{ {x:1}:0, {x:2}:1 }"; // tensor0
+ String firstLayerWeights = "{ {x:1,h:1}:1, {x:1,h:2}:1, {x:2,h:1}:1, {x:2,h:2}:1 }"; // tensor1
+ String firstLayerBias = "{ {h:1}:-0.5, {h:2}:-1.5 }"; // tensor2
+ String firstLayerInput = "sum(tensor0 * tensor1, x) + tensor2";
String firstLayerOutput = "min(1.0, max(0.0, 0.5 + " + firstLayerInput + "))"; // non-linearity, "poor man's sigmoid"
- assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput);
- String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }";
- String secondLayerBias = "{ {y:1}:-0.5 }";
- String secondLayerInput = "sum(" + firstLayerOutput + "*" + secondLayerWeights + ", h) + " + secondLayerBias;
+ assertEvaluates("{ {h:1}:1.0, {h:2}:0.0} }", firstLayerOutput, input, firstLayerWeights, firstLayerBias);
+ String secondLayerWeights = "{ {h:1,y:1}:1, {h:2,y:1}:-1 }"; // tensor3
+ String secondLayerBias = "{ {y:1}:-0.5 }"; // tensor4
+ String secondLayerInput = "sum(" + firstLayerOutput + "* tensor3, h) + tensor4";
String secondLayerOutput = "min(1.0, max(0.0, 0.5 + " + secondLayerInput + "))"; // non-linearity, "poor man's sigmoid"
- assertEvaluates("{ {y:1}:1 }", secondLayerOutput);
+ assertEvaluates("{ {y:1}:1 }", secondLayerOutput, input, firstLayerWeights, firstLayerBias, secondLayerWeights, secondLayerBias);
}
- private RankingExpression assertEvaluates(String tensorValue, String expressionString) {
- return assertEvaluates(new TensorValue(MapTensor.from(tensorValue)), expressionString, new MapContext());
+ private RankingExpression assertEvaluates(String expectedTensor, String expressionString, String ... tensorArguments) {
+ MapContext context = new MapContext();
+ int argumentIndex = 0;
+ for (String tensorArgument : tensorArguments)
+ context.put("tensor" + (argumentIndex++), new TensorValue(MapTensor.from(tensorArgument)));
+ return assertEvaluates(new TensorValue(MapTensor.from(expectedTensor)), expressionString, context);
}
private RankingExpression assertEvaluates(Value value, String expressionString, Context context) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
index 9d94ec0bc99..61b230ab390 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/transform/SimplifierTestCase.java
@@ -69,12 +69,4 @@ public class SimplifierTestCase {
assertEquals("a + (b + c) / 100000000.0", transformed.toString());
}
- @Test
- public void testSimplificationWithTensorConstants() throws ParseException {
- new Simplifier().transform(new RankingExpression(
- "sum(sum((tensorFromWeightedSet(query(wset_query),x)+" +
- " tensorFromWeightedSet(attribute(wset),x)) * " +
- " {{x:0,y:0}:54, {x:0,y:1} :69, {x:1,y:0} :72, {x:1,y:1} :93},x))"));
- }
-
}