summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java37
1 files changed, 10 insertions, 27 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index abf0d89c2b7..2a93acc19e6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -1,7 +1,6 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
-import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -46,7 +45,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
@Override
- public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argumentA, argumentB); }
@Override
public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
@@ -362,20 +361,11 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int b = 0;
for (var how : plan.combineHow) {
switch (how) {
- case left:
- labels[out++] = leftOnly.label(a++);
- break;
- case right:
- labels[out++] = rightOnly.label(b++);
- break;
- case both:
- labels[out++] = match.label(m++);
- break;
- case concat:
- labels[out++] = String.valueOf(concatDimIdx);
- break;
- default:
- throw new IllegalArgumentException("cannot handle: "+how);
+ case left -> labels[out++] = leftOnly.label(a++);
+ case right -> labels[out++] = rightOnly.label(b++);
+ case both -> labels[out++] = match.label(m++);
+ case concat -> labels[out++] = String.valueOf(concatDimIdx);
+ default -> throw new IllegalArgumentException("cannot handle: " + how);
}
}
return TensorAddress.of(labels);
@@ -419,17 +409,10 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int separateIdx = 0;
for (int i = 0; i < how.handleDims.size(); i++) {
switch (how.handleDims.get(i)) {
- case common:
- commonLabels[commonIdx++] = addr.label(i);
- break;
- case separate:
- separateLabels[separateIdx++] = addr.label(i);
- break;
- case concat:
- ccDimIndex = addr.numericLabel(i);
- break;
- default:
- throw new IllegalArgumentException("cannot handle: "+how.handleDims.get(i));
+ case common -> commonLabels[commonIdx++] = addr.label(i);
+ case separate -> separateLabels[separateIdx++] = addr.label(i);
+ case concat -> ccDimIndex = addr.numericLabel(i);
+ default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i));
}
}
TensorAddress commonAddr = TensorAddress.of(commonLabels);