summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2023-09-13 14:25:17 +0200
committerHenning Baldersheim <balder@yahoo-inc.com>2023-09-13 15:06:24 +0200
commit7050087f2ba90981ed80118361d229ea5181918f (patch)
treee14a1090924aade36312523d05ec8e466a89be13 /model-integration
parentcc015e6d4601b9966ec2d092697a146a7fd2c2a3 (diff)
- Use equals when comparing Optional<Long>
- Minor cleanup
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java2
2 files changed, 4 insertions, 4 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 97bfdda385e..5b7a348ba99 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -55,18 +55,18 @@ public class Gemm extends IntermediateOperation {
TensorType.Dimension dimC0 = cDimensions.get(0);
TensorType.Dimension dimC1 = cDimensions.get(1);
- if ( ! (dimA.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) {
+ if ( ! (dimA.size().equals(dimC0.size()) || dimC0.size().get() == 1) ) {
throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
" is not compatible or not broadcastable to " + result.type());
}
- if ( ! (dimB.size().get().equals(dimC1.size().get()) || dimC1.size().get() == 1) ) {
+ if ( ! (dimB.size().equals(dimC1.size()) || dimC1.size().get() == 1) ) {
throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
" is not compatible or not broadcastable to " + result.type());
}
} else if (cDimensions.size() == 1) {
TensorType.Dimension dimC0 = cDimensions.get(0);
- if ( ! (dimB.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) {
+ if ( ! (dimB.size().equals(dimC0.size()) || dimC0.size().get() == 1) ) {
throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
" is not compatible or not broadcastable to " + result.type());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
index a880bff87be..4934dc9a05c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -47,7 +47,7 @@ public class Tile extends IntermediateOperation {
throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor.");
OrderedTensorType inputType = inputs.get(0).type().get();
- if (shape.type().dimensions().get(0).size().get() != inputType.rank())
+ if (shape.type().dimensions().get(0).size().get().intValue() != inputType.rank())
throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank.");
List<Integer> dimSizes = new ArrayList<>(inputType.rank());