aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-11-28 08:54:41 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-11-28 08:54:41 +0100
commitb9161cb0f3eec983af285e01fae9b28756f038a0 (patch)
tree082316e679ccd2a81d8fff12fffe67c798811a2a
parent2f55986b4de9420e5728c5abbaafb69fb2f10a34 (diff)
Propagate set/getChildren
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java7
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java37
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java7
15 files changed, 116 insertions, 36 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
index cade223d51a..4ece81860e2 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java
@@ -4,6 +4,7 @@ import com.yahoo.collections.Pair;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
import com.yahoo.searchdefinition.parser.ParseException;
+import org.junit.Ignore;
import org.junit.Test;
import java.util.ArrayList;
@@ -165,11 +166,11 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase
for (Object o : testRankProperties)
System.out.println(o);
assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString());
- assertEquals("(rankingExpression(relu@).rankingScript,max(1.0,sum(query(q) * constant(W_hidden), input) + constant(b_input)))", censorBindingHash(testRankProperties.get(1).toString()));
+ assertEquals("(rankingExpression(relu@).rankingScript,max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))", censorBindingHash(testRankProperties.get(1).toString()));
assertEquals("(rankingExpression(hidden_layer).rankingScript,rankingExpression(relu@))", censorBindingHash(testRankProperties.get(2).toString()));
- assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(sum(rankingExpression(hidden_layer) * constant(W_final), hidden) + constant(b_final)))", testRankProperties.get(3).toString());
+ assertEquals("(rankingExpression(final_layer).rankingScript,sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", testRankProperties.get(3).toString());
assertEquals("(vespa.rank.secondphase,rankingExpression(secondphase))", testRankProperties.get(4).toString());
- assertEquals("(rankingExpression(secondphase).rankingScript,sum(rankingExpression(final_layer)))", testRankProperties.get(5).toString());
+ assertEquals("(rankingExpression(secondphase).rankingScript,reduce(rankingExpression(final_layer), sum))", testRankProperties.get(5).toString());
}
private String censorBindingHash(String s) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
index 19547c5039b..f6a7ddd4983 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorTestCase.java
@@ -54,31 +54,6 @@ public class RankingExpressionWithTensorTestCase {
}
@Test
- public void requireThatExpressionWithSingleLineTensorCanBeParsed() throws ParseException {
- SearchFixture f = new SearchFixture(
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression: sum({ {x:1}:1, {x:2,y:1}:2 })\n" +
- " }\n" +
- " }");
- f.assertFirstPhaseExpression("sum({{x:1}:1.0,{x:2,y:1}:2.0})", "my_profile");
- }
-
- @Test
- public void requireThatExpressionWithMultiLineTensorCanBeParsed() throws ParseException {
- SearchFixture f = new SearchFixture(
- " rank-profile my_profile {\n" +
- " first-phase {\n" +
- " expression {\n" +
- " sum({ {x:1}:1,\n" +
- " {x:2,y:1}:2 })\n" +
- " }\n" +
- " }\n" +
- " }");
- f.assertFirstPhaseExpression("sum({{x:1}:1.0,{x:2,y:1}:2.0})", "my_profile");
- }
-
- @Test
public void requireThatSingleLineConstantTensorAndTypeCanBeParsed() throws ParseException {
SearchFixture f = new SearchFixture(
" rank-profile my_profile {\n" +
@@ -92,7 +67,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }\n" +
" }");
- f.assertFirstPhaseExpression("sum(constant(my_tensor))", "my_profile");
+ f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile");
f.assertRankProperty("{{x:1}:1.0,{x:2,y:1}:2.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor(x{},y{})", "constant(my_tensor).type", "my_profile");
}
@@ -114,7 +89,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }\n" +
" }");
- f.assertFirstPhaseExpression("sum(constant(my_tensor))", "my_profile");
+ f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile");
f.assertRankProperty("{{x:1}:1.0,{x:2,y:1}:2.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor(x{},y{})", "constant(my_tensor).type", "my_profile");
}
@@ -132,7 +107,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }\n" +
" }");
- f.assertSecondPhaseExpression("sum(constant(my_tensor))", "my_profile");
+ f.assertSecondPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile");
f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor", "constant(my_tensor).type", "my_profile");
}
@@ -152,7 +127,7 @@ public class RankingExpressionWithTensorTestCase {
" expression: sum(my_tensor)\n" +
" }\n" +
" }");
- f.assertFirstPhaseExpression("sum(constant(my_tensor))", "my_profile");
+ f.assertFirstPhaseExpression("reduce(constant(my_tensor), sum)", "my_profile");
f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor", "constant(my_tensor).type", "my_profile");
}
@@ -174,7 +149,7 @@ public class RankingExpressionWithTensorTestCase {
" }\n" +
" }");
f.assertFirstPhaseExpression("5.0 + my_macro", "my_profile");
- f.assertMacro("sum(constant(my_tensor))", "my_macro", "my_profile");
+ f.assertMacro("reduce(constant(my_tensor), sum)", "my_macro", "my_profile");
f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor", "constant(my_tensor).type", "my_profile");
}
@@ -194,7 +169,7 @@ public class RankingExpressionWithTensorTestCase {
" my_number_2: 5.0\n" +
" }\n" +
" }");
- f.assertFirstPhaseExpression("3.0 + sum(constant(my_tensor)) + 5.0", "my_profile");
+ f.assertFirstPhaseExpression("3.0 + reduce(constant(my_tensor), sum) + 5.0", "my_profile");
f.assertRankProperty("{{x:1}:1.0}", "constant(my_tensor).value", "my_profile");
f.assertRankProperty("tensor", "constant(my_tensor).type", "my_profile");
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 26d3f1dcc0e..93d551ebfd7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.EvaluationContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
@@ -39,7 +40,10 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
- throw new UnsupportedOperationException("Not implemented");
+ List<TensorFunction> wrappedChildren = children.stream()
+ .map(TensorFunctionExpressionNode::new)
+ .collect(Collectors.toList());
+ return new TensorFunctionNode(function.replaceArguments(wrappedChildren));
}
@Override
@@ -71,7 +75,23 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
+ public List<TensorFunction> functionArguments() {
+ if (expression instanceof CompositeNode)
+ return ((CompositeNode)expression).children().stream()
+ .map(TensorFunctionExpressionNode::new)
+ .collect(Collectors.toList());
+ else
+ return Collections.emptyList();
+ }
+
+ @Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if (arguments.size() == 0) return this;
+ List<ExpressionNode> unwrappedChildren = arguments.stream()
+ .map(arg -> ((TensorFunctionExpressionNode)arg).expression)
+ .collect(Collectors.toList());
+ return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren));
+ }
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 0727579a331..153a3f896de 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -27,6 +27,13 @@ public class ConstantTensor extends PrimitiveTensorFunction {
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index c0e5776bf48..013a95fe51f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -44,6 +44,13 @@ public class Generate extends PrimitiveTensorFunction {
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 323da5906c3..ce1f123a216 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -44,6 +44,13 @@ public class Join extends PrimitiveTensorFunction {
public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 2)
+ throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
+ return new Join(arguments.get(0), arguments.get(1), combinator);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index 4467b378b3f..2e61792aa90 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -20,6 +20,13 @@ public class L1Normalize extends CompositeTensorFunction {
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size());
+ return new L1Normalize(arguments.get(0), dimension);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveArgument = argument.toPrimitive();
// join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y))
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index 0e96b43bd22..40d1b2a95c1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -20,6 +20,13 @@ public class L2Normalize extends CompositeTensorFunction {
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size());
+ return new L2Normalize(arguments.get(0), dimension);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveArgument = argument.toPrimitive();
return new Join(primitiveArgument,
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 5db88953c64..c1b148ff82f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -34,6 +34,13 @@ public class Map extends PrimitiveTensorFunction {
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size());
+ return new Map(arguments.get(0), mapper);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
return new Map(argument.toPrimitive(), mapper);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 4492ab083d4..8a6622213e5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -22,6 +22,13 @@ public class Matmul extends CompositeTensorFunction {
public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 2)
+ throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size());
+ return new Matmul(arguments.get(0), arguments.get(1), dimension);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveArgument1 = argument1.toPrimitive();
TensorFunction primitiveArgument2 = argument2.toPrimitive();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index ef18cb61b17..e6f9874c0bd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -63,6 +63,13 @@ public class Reduce extends PrimitiveTensorFunction {
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
+ return new Reduce(arguments.get(0), aggregator, dimensions);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
return new Reduce(argument.toPrimitive(), aggregator, dimensions);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 05af86c33e8..0995e56eb9a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -44,6 +44,13 @@ public class Rename extends PrimitiveTensorFunction {
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
+ return new Rename(arguments.get(0), fromDimensions, toDimensions);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index b05b8172b42..713452d55d2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -20,6 +20,13 @@ public class Softmax extends CompositeTensorFunction {
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size());
+ return new Softmax(arguments.get(0), dimension);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveArgument = argument.toPrimitive();
return new Join(new Map(primitiveArgument, ScalarFunctions.exp()),
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index a717292632e..34ccf0704ca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -17,6 +17,13 @@ public abstract class TensorFunction {
public abstract List<TensorFunction> functionArguments();
/**
+ * Returns a copy of this tensor function with the arguments replaced by the given list of arguments.
+ *
+ * @throws IllegalArgumentException if the argument list has the wrong size for this function
+ */
+ public abstract TensorFunction replaceArguments(List<TensorFunction> arguments);
+
+ /**
* Translate this function - and all of its arguments recursively -
* to a tree of primitive functions only.
*
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 1988c1d2390..e83a514bd13 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -23,6 +23,13 @@ public class XwPlusB extends CompositeTensorFunction {
public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); }
@Override
+ public TensorFunction replaceArguments(List<TensorFunction> arguments) {
+ if ( arguments.size() != 3)
+ throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size());
+ return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension);
+ }
+
+ @Override
public PrimitiveTensorFunction toPrimitive() {
TensorFunction primitiveX = x.toPrimitive();
TensorFunction primitiveW = w.toPrimitive();