diff options
author | Lester Solbakken <lesters@oath.com> | 2020-04-15 13:12:58 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-04-15 13:12:58 +0200 |
commit | 600d27808b11f0d339f12891a3e899a1fe61af82 (patch) | |
tree | 70989564206cf5973874daab669ee226874be2bf /vespajlib | |
parent | 6aac938f0d89f644bebcb629cae4efa4536911b5 (diff) |
Properly handle dimensions argument to argmax/argmin
Diffstat (limited to 'vespajlib')
4 files changed, 42 insertions, 16 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 04f859e2802..5cddc82d05a 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1538,7 +1538,9 @@ "public" ], "methods": [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction)", "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.lang.String)", + "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", @@ -1553,7 +1555,9 @@ "public" ], "methods": [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction)", "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.lang.String)", + "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index a365f0f4bdc..a4b68a662da 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -1,10 +1,12 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * @author bratseth @@ -12,11 +14,20 @@ import java.util.List; public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorFunction<NAMETYPE> argument; - private final String dimension; + private final List<String> dimensions; + + public Argmax(TensorFunction<NAMETYPE> argument) { + this(argument, Collections.emptyList()); + } public Argmax(TensorFunction<NAMETYPE> argument, String dimension) { + this(argument, Collections.singletonList(dimension)); + } + + public Argmax(TensorFunction<NAMETYPE> argument, List<String> dimensions) { + Objects.requireNonNull(dimensions, "The dimensions cannot be null"); this.argument = argument; - this.dimension = dimension; + this.dimensions = ImmutableList.copyOf(dimensions); } @Override @@ -24,22 +35,21 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if ( arguments.size() != 1) + if (arguments.size() != 1) throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size()); - return new Argmax<>(arguments.get(0), dimension); + return new Argmax<>(arguments.get(0), dimensions); } @Override public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); - return new Join<>(primitiveArgument, - new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimension), - ScalarFunctions.equal()); + TensorFunction<NAMETYPE> reduce = new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimensions); + return new Join<>(primitiveArgument, reduce, ScalarFunctions.equal()); } @Override public String toString(ToStringContext context) { - return "argmax(" + argument.toString(context) + ", " + dimension + ")"; + return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index 32ccdf51336..ad14bc1f1f2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -1,10 +1,12 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * @author bratseth @@ -12,11 +14,20 @@ import java.util.List; public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorFunction<NAMETYPE> argument; - private final String dimension; + private final List<String> dimensions; + + public Argmin(TensorFunction<NAMETYPE> argument) { + this(argument, Collections.emptyList()); + } public Argmin(TensorFunction<NAMETYPE> argument, String dimension) { + this(argument, Collections.singletonList(dimension)); + } + + public Argmin(TensorFunction<NAMETYPE> argument, List<String> dimensions) { + Objects.requireNonNull(dimensions, "The dimensions cannot be null"); this.argument = argument; - this.dimension = dimension; + this.dimensions = ImmutableList.copyOf(dimensions); } @Override @@ -24,22 +35,21 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if ( arguments.size() != 1) + if (arguments.size() != 1) throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size()); - return new Argmin<>(arguments.get(0), dimension); + return new Argmin<>(arguments.get(0), dimensions); } @Override public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); - return new Join<>(primitiveArgument, - new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimension), - ScalarFunctions.equal()); + TensorFunction<NAMETYPE> reduce = new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimensions); + return new Join<>(primitiveArgument, reduce, ScalarFunctions.equal()); } @Override public String toString(ToStringContext context) { - return "argmin(" + argument.toString(context) + ", " + dimension + ")"; + return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index 625d5d44b19..05f7d27907c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -21,6 +21,8 @@ public class TensorFunctionTestCase { new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max, x), f(a,b)(a==b))", new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); + assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max), f(a,b)(a==b))", + new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"))); } private void assertTranslated(String expectedTranslation, TensorFunction<Name> inputFunction) { |