summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-06-26 13:25:29 +0000
committerArne Juul <arnej@yahooinc.com>2023-06-26 14:53:17 +0000
commit89150530a47690fa0df603069789002f79ae7123 (patch)
treed9cc51199be753125f2e92d0a705b75093289b5f
parentcc517d86dc886058cdc5f95a318945a6a328da28 (diff)
override type resolving to do some sanity checking
-rw-r--r--vespajlib/abi-spec.json8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java32
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java32
3 files changed, 68 insertions, 4 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 7f70deb0991..76d007dd633 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1706,7 +1706,7 @@
"fields" : [ ]
},
"com.yahoo.tensor.functions.CosineSimilarity" : {
- "superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction",
+ "superClass" : "com.yahoo.tensor.functions.TensorFunction",
"interfaces" : [ ],
"attributes" : [
"public"
@@ -1715,6 +1715,8 @@
"public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.lang.String)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public int hashCode()"
@@ -1757,7 +1759,7 @@
"fields" : [ ]
},
"com.yahoo.tensor.functions.EuclideanDistance" : {
- "superClass" : "com.yahoo.tensor.functions.CompositeTensorFunction",
+ "superClass" : "com.yahoo.tensor.functions.TensorFunction",
"interfaces" : [ ],
"attributes" : [
"public"
@@ -1766,6 +1768,8 @@
"public void <init>(com.yahoo.tensor.functions.TensorFunction, com.yahoo.tensor.functions.TensorFunction, java.lang.String)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public int hashCode()"
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
index ede0355a3a6..ebb8a11fd8a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
@@ -1,7 +1,12 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorType.Dimension;
import java.util.Collections;
import java.util.List;
@@ -12,7 +17,7 @@ import java.util.Objects;
* cosine_similarity(a, b, mydim) == sum(a*b, mydim) / sqrt(sum(a*a, mydim) * sum(b*b, mydim))
* @author arnej
*/
-public class CosineSimilarity<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
+public class CosineSimilarity<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> {
private final TensorFunction<NAMETYPE> arg1;
private final TensorFunction<NAMETYPE> arg2;
@@ -38,6 +43,31 @@ public class CosineSimilarity<NAMETYPE extends Name> extends CompositeTensorFunc
}
@Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ TensorType t1 = arg1.toPrimitive().type(context);
+ TensorType t2 = arg2.toPrimitive().type(context);
+ var d1 = t1.dimension(dimension);
+ var d2 = t2.dimension(dimension);
+ if (d1.isEmpty() || d2.isEmpty()
+ || d1.get().type() != Dimension.Type.indexedBound
+ || d2.get().type() != Dimension.Type.indexedBound
+ || d1.get().size().get() != d2.get().size().get())
+ {
+ throw new IllegalArgumentException("cosine_similarity expects both arguments to have the '"
+ + dimension + "' dimension with same size, but input types were "
+ + t1 + " and " + t2);
+ }
+ // Finds the type this produces by first converting it to a primitive function
+ return toPrimitive().type(context);
+ }
+
+ /** Evaluates this by first converting it to a primitive function */
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ return toPrimitive().evaluate(context);
+ }
+
+ @Override
public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
TensorFunction<NAMETYPE> a = arg1.toPrimitive();
TensorFunction<NAMETYPE> b = arg2.toPrimitive();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
index 25399416c29..f9fc8e195d3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
@@ -1,7 +1,12 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorType.Dimension;
import java.util.Collections;
import java.util.List;
@@ -12,7 +17,7 @@ import java.util.Objects;
* euclidean_distance(a, b, mydim) == sqrt(sum(pow(a-b, 2), mydim))
* @author arnej
*/
-public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
+public class EuclideanDistance<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> {
private final TensorFunction<NAMETYPE> arg1;
private final TensorFunction<NAMETYPE> arg2;
@@ -38,6 +43,31 @@ public class EuclideanDistance<NAMETYPE extends Name> extends CompositeTensorFun
}
@Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ TensorType t1 = arg1.toPrimitive().type(context);
+ TensorType t2 = arg2.toPrimitive().type(context);
+ var d1 = t1.dimension(dimension);
+ var d2 = t2.dimension(dimension);
+ if (d1.isEmpty() || d2.isEmpty()
+ || d1.get().type() != Dimension.Type.indexedBound
+ || d2.get().type() != Dimension.Type.indexedBound
+ || d1.get().size().get() != d2.get().size().get())
+ {
+ throw new IllegalArgumentException("euclidean_distance expects both arguments to have the '"
+ + dimension + "' dimension with same size, but input types were "
+ + t1 + " and " + t2);
+ }
+ // Finds the type this produces by first converting it to a primitive function
+ return toPrimitive().type(context);
+ }
+
+ /** Evaluates this by first converting it to a primitive function */
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ return toPrimitive().evaluate(context);
+ }
+
+ @Override
public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
TensorFunction<NAMETYPE> primitive1 = arg1.toPrimitive();
TensorFunction<NAMETYPE> primitive2 = arg2.toPrimitive();