diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-09-13 11:55:48 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-09-13 12:31:26 +0000 |
commit | 00de3d2653e08a35f1ddb02f555d364f0741ae35 (patch) | |
tree | 39ad5884d558d7ccfbf6fc394ec2ce2c376acc16 /vespajlib | |
parent | 6d9d3fb1265a3bf61fdb2582ceb2f148ef9680c1 (diff) |
fix dimension size comparison
Diffstat (limited to 'vespajlib')
4 files changed, 42 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java index ebb8a11fd8a..0e5b031c2cc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java @@ -51,7 +51,7 @@ public class CosineSimilarity<NAMETYPE extends Name> extends TensorFunction<NAME if (d1.isEmpty() || d2.isEmpty() || d1.get().type() != Dimension.Type.indexedBound || d2.get().type() != Dimension.Type.indexedBound - || d1.get().size().get() != d2.get().size().get()) + || ! d1.get().size().equals(d2.get().size())) { throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '" + dimension + "' dimension with same size, but input types were " diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java index f9fc8e195d3..4c771fe8843 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java @@ -51,7 +51,7 @@ public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAM if (d1.isEmpty() || d2.isEmpty() || d1.get().type() != Dimension.Type.indexedBound || d2.get().type() != Dimension.Type.indexedBound - || d1.get().size().get() != d2.get().size().get()) + || ! d1.get().size().equals(d2.get().size())) { throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '" + dimension + "' dimension with same size, but input types were " diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java index b303e2c1739..4697b4edca3 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java @@ -3,10 +3,15 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.evaluation.VariableTensor; import org.junit.Test; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; @@ -44,11 +49,18 @@ public class CosineSimilarityTestCase { assertEquals(expect, result); } + static class MyContext implements TypeContext<Name> { + Map<String, TensorType> map = new HashMap<>(); + public TensorType getType(Name name) { return getType(name.name()); } + public TensorType getType(String name) { return map.get(name); } + } + @Test public void testExpansion() { - var tType = TensorType.fromSpec("tensor(vecdim[128])"); - var a = new VariableTensor<>("left", tType); - var b = new VariableTensor<>("right", tType); + var tTypeA = TensorType.fromSpec("tensor(foo{},vecdim[128])"); + var tTypeB = TensorType.fromSpec("tensor(vecdim[128],z[4])"); + var a = new VariableTensor<>("left", tTypeA); + var b = new VariableTensor<>("right", tTypeB); var op = new CosineSimilarity<>(a, b, "vecdim"); assertEquals("join(" + ( "reduce(join(left, right, f(a,b)(a * b)), sum, vecdim), " + @@ -61,6 +73,11 @@ public class CosineSimilarityTestCase { "f(a,b)(a / b)" ) + ")", op.toPrimitive().toString()); + var context = new MyContext(); + context.map.put("left", tTypeA); + context.map.put("right", tTypeB); + var resType = op.type(context); + assertEquals("tensor(foo{},z[4])", resType.toString()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java index 4fae432b3ca..da9529afa77 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -3,10 +3,15 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.evaluation.VariableTensor; import org.junit.Test; +import java.util.HashMap; import java.util.List; +import java.util.Map; import static org.junit.Assert.assertEquals; @@ -41,14 +46,26 @@ public class EuclideanDistanceTestCase { assertEquals(expect, result); } + static class MyContext implements TypeContext<Name> { + Map<String, TensorType> map = new HashMap<>(); + public TensorType getType(Name name) { return getType(name.name()); } + public TensorType getType(String name) { return map.get(name); } + } + @Test public void testExpansion() { - var tType = TensorType.fromSpec("tensor(vecdim[128])"); - var a = new VariableTensor<>("left", tType); - var b = new VariableTensor<>("right", tType); + var tTypeA = TensorType.fromSpec("tensor(vecdim[128])"); + var tTypeB = TensorType.fromSpec("tensor(vecdim[128])"); + var a = new VariableTensor<>("left", tTypeA); + var b = new VariableTensor<>("right", tTypeB); var op = new EuclideanDistance<>(a, b, "vecdim"); assertEquals("map(reduce(map(join(left, right, f(a,b)(a - b)), f(a)(a * a)), sum, vecdim), f(a)(sqrt(a)))", op.toPrimitive().toString()); + var context = new MyContext(); + context.map.put("left", tTypeA); + context.map.put("right", tTypeB); + var resType = op.type(context); + assertEquals("tensor()", resType.toString()); } } |