diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java | 26 |
1 files changed, 18 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index 32ccdf51336..ad14bc1f1f2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -1,10 +1,12 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; +import java.util.Objects; /** * @author bratseth @@ -12,11 +14,20 @@ import java.util.List; public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorFunction<NAMETYPE> argument; - private final String dimension; + private final List<String> dimensions; + + public Argmin(TensorFunction<NAMETYPE> argument) { + this(argument, Collections.emptyList()); + } public Argmin(TensorFunction<NAMETYPE> argument, String dimension) { + this(argument, Collections.singletonList(dimension)); + } + + public Argmin(TensorFunction<NAMETYPE> argument, List<String> dimensions) { + Objects.requireNonNull(dimensions, "The dimensions cannot be null"); this.argument = argument; - this.dimension = dimension; + this.dimensions = ImmutableList.copyOf(dimensions); } @Override @@ -24,22 +35,21 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if ( arguments.size() != 1) + if (arguments.size() != 1) throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size()); - return new Argmin<>(arguments.get(0), dimension); + return new Argmin<>(arguments.get(0), dimensions); } @Override public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); - return new Join<>(primitiveArgument, - new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimension), - ScalarFunctions.equal()); + TensorFunction<NAMETYPE> reduce = new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimensions); + return new Join<>(primitiveArgument, reduce, ScalarFunctions.equal()); } @Override public String toString(ToStringContext context) { - return "argmin(" + argument.toString(context) + ", " + dimension + ")"; + return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } } |