summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-09-13 11:55:48 +0000
committerArne Juul <arnej@yahooinc.com>2023-09-13 12:31:26 +0000
commit00de3d2653e08a35f1ddb02f555d364f0741ae35 (patch)
tree39ad5884d558d7ccfbf6fc394ec2ce2c376acc16 /vespajlib/src/test/java/com
parent6d9d3fb1265a3bf61fdb2582ceb2f148ef9680c1 (diff)
fix dimension size comparison
Diffstat (limited to 'vespajlib/src/test/java/com')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java23
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java23
2 files changed, 40 insertions, 6 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java
index b303e2c1739..4697b4edca3 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/CosineSimilarityTestCase.java
@@ -3,10 +3,15 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import org.junit.Test;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.assertEquals;
@@ -44,11 +49,18 @@ public class CosineSimilarityTestCase {
assertEquals(expect, result);
}
+ static class MyContext implements TypeContext<Name> {
+ Map<String, TensorType> map = new HashMap<>();
+ public TensorType getType(Name name) { return getType(name.name()); }
+ public TensorType getType(String name) { return map.get(name); }
+ }
+
@Test
public void testExpansion() {
- var tType = TensorType.fromSpec("tensor(vecdim[128])");
- var a = new VariableTensor<>("left", tType);
- var b = new VariableTensor<>("right", tType);
+ var tTypeA = TensorType.fromSpec("tensor(foo{},vecdim[128])");
+ var tTypeB = TensorType.fromSpec("tensor(vecdim[128],z[4])");
+ var a = new VariableTensor<>("left", tTypeA);
+ var b = new VariableTensor<>("right", tTypeB);
var op = new CosineSimilarity<>(a, b, "vecdim");
assertEquals("join(" +
( "reduce(join(left, right, f(a,b)(a * b)), sum, vecdim), " +
@@ -61,6 +73,11 @@ public class CosineSimilarityTestCase {
"f(a,b)(a / b)" ) +
")",
op.toPrimitive().toString());
+ var context = new MyContext();
+ context.map.put("left", tTypeA);
+ context.map.put("right", tTypeB);
+ var resType = op.type(context);
+ assertEquals("tensor(foo{},z[4])", resType.toString());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java
index 4fae432b3ca..da9529afa77 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java
@@ -3,10 +3,15 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.evaluation.VariableTensor;
import org.junit.Test;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import static org.junit.Assert.assertEquals;
@@ -41,14 +46,26 @@ public class EuclideanDistanceTestCase {
assertEquals(expect, result);
}
+ static class MyContext implements TypeContext<Name> {
+ Map<String, TensorType> map = new HashMap<>();
+ public TensorType getType(Name name) { return getType(name.name()); }
+ public TensorType getType(String name) { return map.get(name); }
+ }
+
@Test
public void testExpansion() {
- var tType = TensorType.fromSpec("tensor(vecdim[128])");
- var a = new VariableTensor<>("left", tType);
- var b = new VariableTensor<>("right", tType);
+ var tTypeA = TensorType.fromSpec("tensor(vecdim[128])");
+ var tTypeB = TensorType.fromSpec("tensor(vecdim[128])");
+ var a = new VariableTensor<>("left", tTypeA);
+ var b = new VariableTensor<>("right", tTypeB);
var op = new EuclideanDistance<>(a, b, "vecdim");
assertEquals("map(reduce(map(join(left, right, f(a,b)(a - b)), f(a)(a * a)), sum, vecdim), f(a)(sqrt(a)))",
op.toPrimitive().toString());
+ var context = new MyContext();
+ context.map.put("left", tTypeA);
+ context.map.put("right", tTypeB);
+ var resType = op.type(context);
+ assertEquals("tensor()", resType.toString());
}
}