diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-06-26 06:31:44 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-06-26 08:37:46 +0000 |
commit | 643a09268b71ea0ebf128552874f1a3ee15aca2e (patch) | |
tree | 0d224c2130b9f9c213152c4fe0a2a00ecf625b1d /vespajlib/src/test | |
parent | f374e7d08c7e492130956c757ebdbd6cccdda74f (diff) |
add euclidean_distance
Diffstat (limited to 'vespajlib/src/test')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java new file mode 100644 index 00000000000..9d06c313ecc --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -0,0 +1,43 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author arnej + */ +public class EuclideanDistanceTestCase { + + @Test + public void testVectorDistances() { + var a = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]"); + var b = Tensor.from("tensor(x[3]):[4.0, 2.0, 7.0]"); + var c = Tensor.from("tensor(x[3]):[1.0, 6.0, 6.0]"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(b), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(c), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + } + + @Test + public void testDistancesInMixed() { + var a = Tensor.from("tensor(c{},x[3]):{foo:[1.0, 2.0, 3.0],bar:[0.0, 0.0, 0.0]}"); + var b = Tensor.from("tensor(c{},x[3]):{foo:[4.0, 2.0, 7.0],bar:[12.0, 0.0, 5.0]}"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + var expect = Tensor.from("tensor(c{}):{foo:5.0,bar:13.0}"); + assertEquals(expect, result); + } + +} |