diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-06-26 13:25:29 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-06-26 14:53:17 +0000 |
commit | 89150530a47690fa0df603069789002f79ae7123 (patch) | |
tree | d9cc51199be753125f2e92d0a705b75093289b5f | |
parent | cc517d86dc886058cdc5f95a318945a6a328da28 (diff) |
override type resolving to do some sanity checking
3 files changed, 68 insertions, 4 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 7f70deb0991..76d007dd633 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1706,7 +1706,7 @@ "fields" : [ ] }, "com.yahoo.tensor.functions.CosineSimilarity" : { - "superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction", + "superClass" : "com.yahoo.tensor.functions.TensorFunction", "interfaces" : [ ], "attributes" : [ "public" @@ -1715,6 +1715,8 @@ "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.lang.String)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public int hashCode()" @@ -1757,7 +1759,7 @@ "fields" : [ ] }, "com.yahoo.tensor.functions.EuclideanDistance" : { - "superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction", + "superClass" : "com.yahoo.tensor.functions.TensorFunction", "interfaces" : [ ], "attributes" : [ "public" @@ -1766,6 +1768,8 @@ "public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.lang.String)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public int hashCode()" diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java index ede0355a3a6..ebb8a11fd8a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java @@ -1,7 +1,12 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorType.Dimension; import java.util.Collections; import java.util.List; @@ -12,7 +17,7 @@ import java.util.Objects; * cosine_similarity(a, b, mydim) == sum(a*b, mydim) / sqrt(sum(a*a, mydim) * sum(b*b, mydim)) * @author arnej */ -public class CosineSimilarity<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { +public class CosineSimilarity<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> { private final TensorFunction<NAMETYPE> arg1; private final TensorFunction<NAMETYPE> arg2; @@ -38,6 +43,31 @@ public class CosineSimilarity<NAMETYPE extends Name> extends CompositeTensorFunc } @Override + public TensorType type(TypeContext<NAMETYPE> context) { + TensorType t1 = arg1.toPrimitive().type(context); + TensorType t2 = arg2.toPrimitive().type(context); + var d1 = t1.dimension(dimension); + var d2 = t2.dimension(dimension); + if (d1.isEmpty() || d2.isEmpty() + || d1.get().type() != Dimension.Type.indexedBound + || d2.get().type() != Dimension.Type.indexedBound + || d1.get().size().get() != d2.get().size().get()) + { + throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '" + + dimension + "' dimension with same size, but input types were " + + t1 + " and " + t2); + } + // Finds the type this produces by first converting it to a primitive function + return toPrimitive().type(context); + } + + /** Evaluates this by first converting it to a primitive function */ + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } + + @Override public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { TensorFunction<NAMETYPE> a = arg1.toPrimitive(); TensorFunction<NAMETYPE> b = arg2.toPrimitive(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java index 25399416c29..f9fc8e195d3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java @@ -1,7 +1,12 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorType.Dimension; import java.util.Collections; import java.util.List; @@ -12,7 +17,7 @@ import java.util.Objects; * euclidean_distance(a, b, mydim) == sqrt(sum(pow(a-b, 2), mydim)) * @author arnej */ -public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { +public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> { private final TensorFunction<NAMETYPE> arg1; private final TensorFunction<NAMETYPE> arg2; @@ -38,6 +43,31 @@ public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFun } @Override + public TensorType type(TypeContext<NAMETYPE> context) { + TensorType t1 = arg1.toPrimitive().type(context); + TensorType t2 = arg2.toPrimitive().type(context); + var d1 = t1.dimension(dimension); + var d2 = t2.dimension(dimension); + if (d1.isEmpty() || d2.isEmpty() + || d1.get().type() != Dimension.Type.indexedBound + || d2.get().type() != Dimension.Type.indexedBound + || d1.get().size().get() != d2.get().size().get()) + { + throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '" + + dimension + "' dimension with same size, but input types were " + + t1 + " and " + t2); + } + // Finds the type this produces by first converting it to a primitive function + return toPrimitive().type(context); + } + + /** Evaluates this by first converting it to a primitive function */ + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } + + @Override public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { TensorFunction<NAMETYPE> primitive1 = arg1.toPrimitive(); TensorFunction<NAMETYPE> primitive2 = arg2.toPrimitive(); |