diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2023-09-13 14:25:17 +0200 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2023-09-13 15:06:24 +0200 |
commit | 7050087f2ba90981ed80118361d229ea5181918f (patch) | |
tree | e14a1090924aade36312523d05ec8e466a89be13 /model-integration | |
parent | cc015e6d4601b9966ec2d092697a146a7fd2c2a3 (diff) |
- Use equals when comparing Optional<Long>
- Minor cleanup
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java | 6 | ||||
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java | 2 |
2 files changed, 4 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 97bfdda385e..5b7a348ba99 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -55,18 +55,18 @@ public class Gemm extends IntermediateOperation { TensorType.Dimension dimC0 = cDimensions.get(0); TensorType.Dimension dimC1 = cDimensions.get(1); - if ( ! (dimA.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) { + if ( ! (dimA.size().equals(dimC0.size()) || dimC0.size().get() == 1) ) { throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + " is not compatible or not broadcastable to " + result.type()); } - if ( ! (dimB.size().get().equals(dimC1.size().get()) || dimC1.size().get() == 1) ) { + if ( ! (dimB.size().equals(dimC1.size()) || dimC1.size().get() == 1) ) { throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + " is not compatible or not broadcastable to " + result.type()); } } else if (cDimensions.size() == 1) { TensorType.Dimension dimC0 = cDimensions.get(0); - if ( ! (dimB.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) { + if ( ! (dimB.size().equals(dimC0.size()) || dimC0.size().get() == 1) ) { throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + " is not compatible or not broadcastable to " + result.type()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index a880bff87be..4934dc9a05c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -47,7 +47,7 @@ public class Tile extends IntermediateOperation { throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor."); OrderedTensorType inputType = inputs.get(0).type().get(); - if (shape.type().dimensions().get(0).size().get() != inputType.rank()) + if (shape.type().dimensions().get(0).size().get().intValue() != inputType.rank()) throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank."); List<Integer> dimSizes = new ArrayList<>(inputType.rank()); |