diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-06-26 06:31:44 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-06-26 08:37:46 +0000 |
commit | 643a09268b71ea0ebf128552874f1a3ee15aca2e (patch) | |
tree | 0d224c2130b9f9c213152c4fe0a2a00ecf625b1d /vespajlib/src | |
parent | f374e7d08c7e492130956c757ebdbd6cccdda74f (diff) |
add euclidean_distance
Diffstat (limited to 'vespajlib/src')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java | 57 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java | 43 |
2 files changed, 100 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java new file mode 100644 index 00000000000..4feddf9f808 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java @@ -0,0 +1,57 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.evaluation.Name; + +import java.util.Collections; +import java.util.List; +import java.util.Objects; + +/** + * @author arnej + */ +public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> arg1; + private final TensorFunction<NAMETYPE> arg2; + private final String dimension; + + public EuclideanDistance(TensorFunction<NAMETYPE> argument1, + TensorFunction<NAMETYPE> argument2, + String dimension) + { + this.arg1 = argument1; + this.arg2 = argument2; + this.dimension = dimension; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(arg1, arg2); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if ( arguments.size() != 2) + throw new IllegalArgumentException("EuclideanDistance must have 2 arguments, got " + arguments.size()); + return new EuclideanDistance<>(arguments.get(0), arguments.get(1), dimension); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitive1 = arg1.toPrimitive(); + TensorFunction<NAMETYPE> primitive2 = arg2.toPrimitive(); + // this should match the C++ optimized "l2_distance" + var diffs = new Join<>(primitive1, primitive2, ScalarFunctions.subtract()); + var squaredDiffs = new Map<>(diffs, ScalarFunctions.square()); + var sumOfSquares = new Reduce<>(squaredDiffs, Reduce.Aggregator.sum, dimension); + return new Map<>(sumOfSquares, ScalarFunctions.sqrt()); + } + + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "euclidean_distance(" + arg1.toString(context) + ", " + arg2.toString(context) + ", " + dimension + ")"; + } + + @Override + public int hashCode() { return Objects.hash("euclidean_distance", arg1, arg2, dimension); } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java new file mode 100644 index 00000000000..9d06c313ecc --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/EuclideanDistanceTestCase.java @@ -0,0 +1,43 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author arnej + */ +public class EuclideanDistanceTestCase { + + @Test + public void testVectorDistances() { + var a = Tensor.from("tensor(x[3]):[1.0, 2.0, 3.0]"); + var b = Tensor.from("tensor(x[3]):[4.0, 2.0, 7.0]"); + var c = Tensor.from("tensor(x[3]):[1.0, 6.0, 6.0]"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(b), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + op = new EuclideanDistance<>(new ConstantTensor<>(c), new ConstantTensor<>(a), "x"); + result = op.evaluate(); + assertEquals(5.0, result.asDouble(), 0.000001); + } + + @Test + public void testDistancesInMixed() { + var a = Tensor.from("tensor(c{},x[3]):{foo:[1.0, 2.0, 3.0],bar:[0.0, 0.0, 0.0]}"); + var b = Tensor.from("tensor(c{},x[3]):{foo:[4.0, 2.0, 7.0],bar:[12.0, 0.0, 5.0]}"); + var op = new EuclideanDistance<>(new ConstantTensor<>(a), new ConstantTensor<>(b), "x"); + Tensor result = op.evaluate(); + var expect = Tensor.from("tensor(c{}):{foo:5.0,bar:13.0}"); + assertEquals(expect, result); + } + +} |