summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java32
1 files changed, 31 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
index 25399416c29..f9fc8e195d3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
@@ -1,7 +1,12 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorType.Dimension;
import java.util.Collections;
import java.util.List;
@@ -12,7 +17,7 @@ import java.util.Objects;
* euclidean_distance(a, b, mydim) == sqrt(sum(pow(a-b, 2), mydim))
* @author arnej
*/
-public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
+public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> {
private final TensorFunction<NAMETYPE> arg1;
private final TensorFunction<NAMETYPE> arg2;
@@ -38,6 +43,31 @@ public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFun
}
@Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ TensorType t1 = arg1.toPrimitive().type(context);
+ TensorType t2 = arg2.toPrimitive().type(context);
+ var d1 = t1.dimension(dimension);
+ var d2 = t2.dimension(dimension);
+ if (d1.isEmpty() || d2.isEmpty()
+ || d1.get().type() != Dimension.Type.indexedBound
+ || d2.get().type() != Dimension.Type.indexedBound
+ || d1.get().size().get() != d2.get().size().get())
+ {
+ throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '"
+ + dimension + "' dimension with same size, but input types were "
+ + t1 + " and " + t2);
+ }
+ // Finds the type this produces by first converting it to a primitive function
+ return toPrimitive().type(context);
+ }
+
+ /** Evaluates this by first converting it to a primitive function */
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ return toPrimitive().evaluate(context);
+ }
+
+ @Override
public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
TensorFunction<NAMETYPE> primitive1 = arg1.toPrimitive();
TensorFunction<NAMETYPE> primitive2 = arg2.toPrimitive();