diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-06-26 14:30:53 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-26 14:30:53 +0200 |
commit | 626bcc6c265229d8c97f4e0a1c996013650b335e (patch) | |
tree | 8042b0897155d1049d2bbc2ea20dc68ff3bda03b /vespajlib/src/test/java | |
parent | 0c341f8ed39b3edcd1938d964cbdf9ce7c179411 (diff) | |
parent | 9faebe628164657eaad3de625b9b799a385aea6e (diff) |
Merge pull request #27544 from vespa-engine/arnej/add-euclidean-distance
add euclidean_distance
Diffstat (limited to 'vespajlib/src/test/java')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java new file mode 100644 index 00000000000..4fae432b3ca --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -0,0 +1,54 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author arnej + */ +public class EuclideanDistanceTestCase { + + @Test + public void testVectorDistances() { + var a = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]"); + var b = Tensor.from("tensor(x[3]):[4.0, 2.0, 7.0]"); + var c = Tensor.from("tensor(x[3]):[1.0, 6.0, 6.0]"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(b), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(c), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + } + + @Test + public void testDistancesInMixed() { + var a = Tensor.from("tensor(c{},x[3]):{foo:[1.0, 2.0, 3.0],bar:[0.0, 0.0, 0.0]}"); + var b = Tensor.from("tensor(c{},x[3]):{foo:[4.0, 2.0, 7.0],bar:[12.0, 0.0, 5.0]}"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + var expect = Tensor.from("tensor(c{}):{foo:5.0,bar:13.0}"); + assertEquals(expect, result); + } + + @Test + public void testExpansion() { + var tType = TensorType.fromSpec("tensor(vecdim[128])"); + var a = new VariableTensor<>("left", tType); + var b = new VariableTensor<>("right", tType); + var op = new EuclideanDistance<>(a, b, "vecdim"); + assertEquals("map(reduce(map(join(left, right, f(a,b)(a - b)), f(a)(a * a)), sum, vecdim), f(a)(sqrt(a)))", + op.toPrimitive().toString()); + } + +} |