// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; import java.util.Objects; /** * @author bratseth */ public class Argmax extends CompositeTensorFunction { private final TensorFunction argument; private final List dimensions; public Argmax(TensorFunction argument) { this(argument, Collections.emptyList()); } public Argmax(TensorFunction argument, String dimension) { this(argument, Collections.singletonList(dimension)); } public Argmax(TensorFunction argument, List dimensions) { Objects.requireNonNull(dimensions, "The dimensions cannot be null"); this.argument = argument; this.dimensions = ImmutableList.copyOf(dimensions); } @Override public List> arguments() { return Collections.singletonList(argument); } @Override public TensorFunction withArguments(List> arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size()); return new Argmax<>(arguments.get(0), dimensions); } @Override public PrimitiveTensorFunction toPrimitive() { TensorFunction primitiveArgument = argument.toPrimitive(); TensorFunction reduce = new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimensions); return new Join<>(primitiveArgument, reduce, ScalarFunctions.equal()); } @Override public String toString(ToStringContext context) { return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } }