summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-04-15 13:12:58 +0200
committerLester Solbakken <lesters@oath.com>2020-04-15 13:12:58 +0200
commit600d27808b11f0d339f12891a3e899a1fe61af82 (patch)
tree70989564206cf5973874daab669ee226874be2bf /vespajlib
parent6aac938f0d89f644bebcb629cae4efa4536911b5 (diff)
Properly handle dimensions argument to argmax/argmin
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java26
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java2
4 files changed, 42 insertions, 16 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 04f859e2802..5cddc82d05a 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1538,7 +1538,9 @@
"public"
],
"methods": [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction)",
"public void <init>(com.yahoo.tensor.functions.TensorFunction, java.lang.String)",
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
@@ -1553,7 +1555,9 @@
"public"
],
"methods": [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction)",
"public void <init>(com.yahoo.tensor.functions.TensorFunction, java.lang.String)",
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index a365f0f4bdc..a4b68a662da 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -1,10 +1,12 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -12,11 +14,20 @@ import java.util.List;
public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
private final TensorFunction<NAMETYPE> argument;
- private final String dimension;
+ private final List<String> dimensions;
+
+ public Argmax(TensorFunction<NAMETYPE> argument) {
+ this(argument, Collections.emptyList());
+ }
public Argmax(TensorFunction<NAMETYPE> argument, String dimension) {
+ this(argument, Collections.singletonList(dimension));
+ }
+
+ public Argmax(TensorFunction<NAMETYPE> argument, List<String> dimensions) {
+ Objects.requireNonNull(dimensions, "The dimensions cannot be null");
this.argument = argument;
- this.dimension = dimension;
+ this.dimensions = ImmutableList.copyOf(dimensions);
}
@Override
@@ -24,22 +35,21 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
@Override
public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if ( arguments.size() != 1)
+ if (arguments.size() != 1)
throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size());
- return new Argmax<>(arguments.get(0), dimension);
+ return new Argmax<>(arguments.get(0), dimensions);
}
@Override
public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
- return new Join<>(primitiveArgument,
- new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimension),
- ScalarFunctions.equal());
+ TensorFunction<NAMETYPE> reduce = new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimensions);
+ return new Join<>(primitiveArgument, reduce, ScalarFunctions.equal());
}
@Override
public String toString(ToStringContext context) {
- return "argmax(" + argument.toString(context) + ", " + dimension + ")";
+ return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index 32ccdf51336..ad14bc1f1f2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -1,10 +1,12 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
+import java.util.Objects;
/**
* @author bratseth
@@ -12,11 +14,20 @@ import java.util.List;
public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
private final TensorFunction<NAMETYPE> argument;
- private final String dimension;
+ private final List<String> dimensions;
+
+ public Argmin(TensorFunction<NAMETYPE> argument) {
+ this(argument, Collections.emptyList());
+ }
public Argmin(TensorFunction<NAMETYPE> argument, String dimension) {
+ this(argument, Collections.singletonList(dimension));
+ }
+
+ public Argmin(TensorFunction<NAMETYPE> argument, List<String> dimensions) {
+ Objects.requireNonNull(dimensions, "The dimensions cannot be null");
this.argument = argument;
- this.dimension = dimension;
+ this.dimensions = ImmutableList.copyOf(dimensions);
}
@Override
@@ -24,22 +35,21 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
@Override
public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if ( arguments.size() != 1)
+ if (arguments.size() != 1)
throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size());
- return new Argmin<>(arguments.get(0), dimension);
+ return new Argmin<>(arguments.get(0), dimensions);
}
@Override
public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
- return new Join<>(primitiveArgument,
- new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimension),
- ScalarFunctions.equal());
+ TensorFunction<NAMETYPE> reduce = new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimensions);
+ return new Join<>(primitiveArgument, reduce, ScalarFunctions.equal());
}
@Override
public String toString(ToStringContext context) {
- return "argmin(" + argument.toString(context) + ", " + dimension + ")";
+ return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index 625d5d44b19..05f7d27907c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -21,6 +21,8 @@ public class TensorFunctionTestCase {
new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max, x), f(a,b)(a==b))",
new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
+ assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max), f(a,b)(a==b))",
+ new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }")));
}
private void assertTranslated(String expectedTranslation, TensorFunction<Name> inputFunction) {