From 643a09268b71ea0ebf128552874f1a3ee15aca2e Mon Sep 17 00:00:00 2001 From: Arne Juul Date: Mon, 26 Jun 2023 06:31:44 +0000 Subject: add euclidean_distance --- .../functions/EuclideanDistanceTestCase.java | 43 ++++++++++++++++++++++ 1 file changed, 43 insertions(+) create mode 100644 vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java (limited to 'vespajlib/src/test') diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java new file mode 100644 index 00000000000..9d06c313ecc --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -0,0 +1,43 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author arnej + */ +public class EuclideanDistanceTestCase { + + @Test + public void testVectorDistances() { + var a = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]"); + var b = Tensor.from("tensor(x[3]):[4.0, 2.0, 7.0]"); + var c = Tensor.from("tensor(x[3]):[1.0, 6.0, 6.0]"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(b), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(c), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + } + + @Test + public void testDistancesInMixed() { + var a = Tensor.from("tensor(c{},x[3]):{foo:[1.0, 2.0, 3.0],bar:[0.0, 0.0, 0.0]}"); + var b = Tensor.from("tensor(c{},x[3]):{foo:[4.0, 2.0, 7.0],bar:[12.0, 0.0, 5.0]}"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + var expect = Tensor.from("tensor(c{}):{foo:5.0,bar:13.0}"); + assertEquals(expect, result); + } + +} -- cgit v1.2.3