diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java | 32 |
1 files changed, 31 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java index 25399416c29..f9fc8e195d3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java @@ -1,7 +1,12 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorType.Dimension; import java.util.Collections; import java.util.List; @@ -12,7 +17,7 @@ import java.util.Objects; * euclidean_distance(a, b, mydim) == sqrt(sum(pow(a-b, 2), mydim)) * @author arnej */ -public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { +public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> { private final TensorFunction<NAMETYPE> arg1; private final TensorFunction<NAMETYPE> arg2; @@ -38,6 +43,31 @@ public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFun } @Override + public TensorType type(TypeContext<NAMETYPE> context) { + TensorType t1 = arg1.toPrimitive().type(context); + TensorType t2 = arg2.toPrimitive().type(context); + var d1 = t1.dimension(dimension); + var d2 = t2.dimension(dimension); + if (d1.isEmpty() || d2.isEmpty() + || d1.get().type() != Dimension.Type.indexedBound + || d2.get().type() != Dimension.Type.indexedBound + || d1.get().size().get() != d2.get().size().get()) + { + throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '" + + dimension + "' dimension with same size, but input types were " + + t1 + " and " + t2); + } + // Finds the type this produces by first converting it to a primitive function + return toPrimitive().type(context); + } + + /** Evaluates this by first converting it to a primitive function */ + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } + + @Override public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { TensorFunction<NAMETYPE> primitive1 = arg1.toPrimitive(); TensorFunction<NAMETYPE> primitive2 = arg2.toPrimitive(); |