aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-09-13 11:55:48 +0000
committerArne Juul <arnej@yahooinc.com>2023-09-13 12:31:26 +0000
commit00de3d2653e08a35f1ddb02f555d364f0741ae35 (patch)
tree39ad5884d558d7ccfbf6fc394ec2ce2c376acc16 /vespajlib
parent6d9d3fb1265a3bf61fdb2582ceb2f148ef9680c1 (diff)
fix dimension size comparison
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java23
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java23
4 files changed, 42 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
index ebb8a11fd8a..0e5b031c2cc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
@@ -51,7 +51,7 @@ public class CosineSimilarity<NAMETYPE extends Name> extends TensorFunction<NAME
if (d1.isEmpty() || d2.isEmpty()
|| d1.get().type() != Dimension.Type.indexedBound
|| d2.get().type() != Dimension.Type.indexedBound
- || d1.get().size().get() != d2.get().size().get())
+ || ! d1.get().size().equals(d2.get().size()))
{
throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '"
+ dimension + "' dimension with same size, but input types were "
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
index f9fc8e195d3..4c771fe8843 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
@@ -51,7 +51,7 @@ public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAM
if (d1.isEmpty() || d2.isEmpty()
|| d1.get().type() != Dimension.Type.indexedBound
|| d2.get().type() != Dimension.Type.indexedBound
- || d1.get().size().get() != d2.get().size().get())
+ || ! d1.get().size().equals(d2.get().size()))
{
throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '"
+ dimension + "' dimension with same size, but input types were "
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java
index b303e2c1739..4697b4edca3 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java
@@ -3,10 +3,15 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import org.junit.Test;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.assertEquals;
@@ -44,11 +49,18 @@ public class CosineSimilarityTestCase {
assertEquals(expect, result);
}
+ static class MyContext implements TypeContext<Name> {
+ Map<String, TensorType> map = new HashMap<>();
+ public TensorType getType(Name name) { return getType(name.name()); }
+ public TensorType getType(String name) { return map.get(name); }
+ }
+
@Test
public void testExpansion() {
- var tType = TensorType.fromSpec("tensor(vecdim[128])");
- var a = new VariableTensor<>("left", tType);
- var b = new VariableTensor<>("right", tType);
+ var tTypeA = TensorType.fromSpec("tensor(foo{},vecdim[128])");
+ var tTypeB = TensorType.fromSpec("tensor(vecdim[128],z[4])");
+ var a = new VariableTensor<>("left", tTypeA);
+ var b = new VariableTensor<>("right", tTypeB);
var op = new CosineSimilarity<>(a, b, "vecdim");
assertEquals("join(" +
( "reduce(join(left, right, f(a,b)(a * b)), sum, vecdim), " +
@@ -61,6 +73,11 @@ public class CosineSimilarityTestCase {
"f(a,b)(a / b)" ) +
")",
op.toPrimitive().toString());
+ var context = new MyContext();
+ context.map.put("left", tTypeA);
+ context.map.put("right", tTypeB);
+ var resType = op.type(context);
+ assertEquals("tensor(foo{},z[4])", resType.toString());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java
index 4fae432b3ca..da9529afa77 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java
@@ -3,10 +3,15 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import org.junit.Test;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.assertEquals;
@@ -41,14 +46,26 @@ public class EuclideanDistanceTestCase {
assertEquals(expect, result);
}
+ static class MyContext implements TypeContext<Name> {
+ Map<String, TensorType> map = new HashMap<>();
+ public TensorType getType(Name name) { return getType(name.name()); }
+ public TensorType getType(String name) { return map.get(name); }
+ }
+
@Test
public void testExpansion() {
- var tType = TensorType.fromSpec("tensor(vecdim[128])");
- var a = new VariableTensor<>("left", tType);
- var b = new VariableTensor<>("right", tType);
+ var tTypeA = TensorType.fromSpec("tensor(vecdim[128])");
+ var tTypeB = TensorType.fromSpec("tensor(vecdim[128])");
+ var a = new VariableTensor<>("left", tTypeA);
+ var b = new VariableTensor<>("right", tTypeB);
var op = new EuclideanDistance<>(a, b, "vecdim");
assertEquals("map(reduce(map(join(left, right, f(a,b)(a - b)), f(a)(a * a)), sum, vecdim), f(a)(sqrt(a)))",
op.toPrimitive().toString());
+ var context = new MyContext();
+ context.map.put("left", tTypeA);
+ context.map.put("right", tTypeB);
+ var resType = op.type(context);
+ assertEquals("tensor()", resType.toString());
}
}