diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-09-13 11:55:48 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-09-13 12:31:26 +0000 |
commit | 00de3d2653e08a35f1ddb02f555d364f0741ae35 (patch) | |
tree | 39ad5884d558d7ccfbf6fc394ec2ce2c376acc16 /vespajlib/src/test/java/com | |
parent | 6d9d3fb1265a3bf61fdb2582ceb2f148ef9680c1 (diff) |
fix dimension size comparison
Diffstat (limited to 'vespajlib/src/test/java/com')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java | 23 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java | 23 |
2 files changed, 40 insertions, 6 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java index b303e2c1739..4697b4edca3 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java @@ -3,10 +3,15 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.evaluation.VariableTensor; import org.junit.Test; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; @@ -44,11 +49,18 @@ public class CosineSimilarityTestCase { assertEquals(expect, result); } + static class MyContext implements TypeContext<Name> { + Map<String, TensorType> map = new HashMap<>(); + public TensorType getType(Name name) { return getType(name.name()); } + public TensorType getType(String name) { return map.get(name); } + } + @Test public void testExpansion() { - var tType = TensorType.fromSpec("tensor(vecdim[128])"); - var a = new VariableTensor<>("left", tType); - var b = new VariableTensor<>("right", tType); + var tTypeA = TensorType.fromSpec("tensor(foo{},vecdim[128])"); + var tTypeB = TensorType.fromSpec("tensor(vecdim[128],z[4])"); + var a = new VariableTensor<>("left", tTypeA); + var b = new VariableTensor<>("right", tTypeB); var op = new CosineSimilarity<>(a, b, "vecdim"); assertEquals("join(" + ( "reduce(join(left, right, f(a,b)(a * b)), sum, vecdim), " + @@ -61,6 +73,11 @@ public class CosineSimilarityTestCase { "f(a,b)(a / b)" ) + ")", op.toPrimitive().toString()); + var context = new MyContext(); + context.map.put("left", tTypeA); + context.map.put("right", tTypeB); + var resType = op.type(context); + assertEquals("tensor(foo{},z[4])", resType.toString()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java index 4fae432b3ca..da9529afa77 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -3,10 +3,15 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.evaluation.VariableTensor; import org.junit.Test; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; @@ -41,14 +46,26 @@ public class EuclideanDistanceTestCase { assertEquals(expect, result); } + static class MyContext implements TypeContext<Name> { + Map<String, TensorType> map = new HashMap<>(); + public TensorType getType(Name name) { return getType(name.name()); } + public TensorType getType(String name) { return map.get(name); } + } + @Test public void testExpansion() { - var tType = TensorType.fromSpec("tensor(vecdim[128])"); - var a = new VariableTensor<>("left", tType); - var b = new VariableTensor<>("right", tType); + var tTypeA = TensorType.fromSpec("tensor(vecdim[128])"); + var tTypeB = TensorType.fromSpec("tensor(vecdim[128])"); + var a = new VariableTensor<>("left", tTypeA); + var b = new VariableTensor<>("right", tTypeB); var op = new EuclideanDistance<>(a, b, "vecdim"); assertEquals("map(reduce(map(join(left, right, f(a,b)(a - b)), f(a)(a * a)), sum, vecdim), f(a)(sqrt(a)))", op.toPrimitive().toString()); + var context = new MyContext(); + context.map.put("left", tTypeA); + context.map.put("right", tTypeB); + var resType = op.type(context); + assertEquals("tensor()", resType.toString()); } } |