diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java | 37 |
1 files changed, 10 insertions, 27 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index abf0d89c2b7..2a93acc19e6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -1,7 +1,6 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; -import com.google.common.collect.ImmutableList; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -46,7 +45,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } @Override - public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argumentA, argumentB); } @Override public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { @@ -362,20 +361,11 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int b = 0; for (var how : plan.combineHow) { switch (how) { - case left: - labels[out++] = leftOnly.label(a++); - break; - case right: - labels[out++] = rightOnly.label(b++); - break; - case both: - labels[out++] = match.label(m++); - break; - case concat: - labels[out++] = String.valueOf(concatDimIdx); - break; - default: - throw new IllegalArgumentException("cannot handle: "+how); + case left -> labels[out++] = leftOnly.label(a++); + case right -> labels[out++] = rightOnly.label(b++); + case both -> labels[out++] = match.label(m++); + case concat -> labels[out++] = String.valueOf(concatDimIdx); + default -> throw new IllegalArgumentException("cannot handle: " + how); } } return TensorAddress.of(labels); @@ -419,17 +409,10 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int separateIdx = 0; for (int i = 0; i < how.handleDims.size(); i++) { switch (how.handleDims.get(i)) { - case common: - commonLabels[commonIdx++] = addr.label(i); - break; - case separate: - separateLabels[separateIdx++] = addr.label(i); - break; - case concat: - ccDimIndex = addr.numericLabel(i); - break; - default: - throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i)); + case common -> commonLabels[commonIdx++] = addr.label(i); + case separate -> separateLabels[separateIdx++] = addr.label(i); + case concat -> ccDimIndex = addr.numericLabel(i); + default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i)); } } TensorAddress commonAddr = TensorAddress.of(commonLabels); |