diff options
author | Lester Solbakken <lesters@oath.com> | 2020-11-15 12:56:50 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-11-15 12:56:50 +0100 |
commit | 288ef3a1fbc3b1db5c54ab55b5d76d1feedeab0c (patch) | |
tree | 0908f8b5b5df6d2d3fa49930fbc149b183c7fea2 | |
parent | 8626f8690482c85e6e6e7350aa024dde528fccf9 (diff) |
Support mixed tensor type in Java evaluation
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java | 18 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 20 |
2 files changed, 19 insertions, 19 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 1bf4dc5698d..123fa5ac43b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -10,6 +10,7 @@ import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.IfNode; +import com.yahoo.tensor.Tensor; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -393,6 +394,23 @@ public class EvaluationTestCase { } @Test + public void testMixedTensorType() throws ParseException { + String expected = "tensor(x[1],y{},z[2]):{{x:0,y:a,z:0}:4.0,{x:0,y:a,z:1}:5.0,{x:0,y:b,z:0}:7.0,{x:0,y:b,z:1}:8.0}"; + String a = "tensor(x[1],y{}):{ {x:0,y:a}:1, {x:0,y:b}:2 }"; + String b = "tensor(y{},z[2]):{ {y:a,z:0}:3, {y:a,z:1}:4, {y:b,z:0}:5, {y:b,z:1}:6 }"; + String expression = "a + b"; + + MapContext context = new MapContext(); + context.put("a", new TensorValue(Tensor.from(a))); + context.put("b", new TensorValue(Tensor.from(b))); + + Tensor expectedResult = Tensor.from(expected); + Tensor result = new RankingExpression(expression).evaluate(context).asTensor(); + assertEquals(expectedResult, result); + assertEquals(expectedResult.type(), result.type()); + } + + @Test public void testTile() { EvaluationTester tester = new EvaluationTester(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index d8959147ee0..790743c745c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -520,27 +520,9 @@ public class TensorType { } } - private static final boolean supportsMixedTypes = false; - private void addDimensionsOf(TensorType type, boolean allowDifferentSizes) { - if ( ! supportsMixedTypes) { // TODO: Support it - addDimensionsOfAndDisallowMixedDimensions(type, allowDifferentSizes); - } - else { - for (Dimension dimension : type.dimensions) - set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), allowDifferentSizes)); - } - } - - private void addDimensionsOfAndDisallowMixedDimensions(TensorType type, boolean allowDifferentSizes) { - boolean containsMapped = dimensions.values().stream().anyMatch(d -> ! d.isIndexed()); - containsMapped = containsMapped || type.dimensions().stream().anyMatch(d -> ! d.isIndexed()); - for (Dimension dimension : type.dimensions) { - if (containsMapped) - dimension = new MappedDimension(dimension.name()); - Dimension existing = dimensions.get(dimension.name()); - set(dimension.combineWith(Optional.ofNullable(existing), allowDifferentSizes)); + set(dimension.combineWith(Optional.ofNullable(dimensions.get(dimension.name())), allowDifferentSizes)); } } |