diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2023-09-13 16:33:14 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-09-13 16:33:14 +0200 |
commit | ef1b98e1095a9a16638ae7a5acc392be8e4f4938 (patch) | |
tree | a0061cc89ecdc4c39e9289f8884b87b7300c1230 | |
parent | 84fae7748bf666c64c63d5143947121d84fe1732 (diff) | |
parent | 7050087f2ba90981ed80118361d229ea5181918f (diff) |
Merge pull request #28510 from vespa-engine/balder/compare-values-not-object-references
- Use equals when comparing Optional<Long>
8 files changed, 77 insertions, 90 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index 92c9eccf2d3..aa71cc35f99 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -10,14 +10,11 @@ import com.yahoo.tensor.TensorType; import java.io.IOException; import java.io.Reader; import java.util.ArrayList; -import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.Function; -import java.util.stream.Collectors; /** * ConstantTensorJsonValidator strictly validates a constant tensor in JSON format read from a Reader object @@ -135,21 +132,11 @@ public class ConstantTensorJsonValidator { assertCurrentTokenIs(JsonToken.FIELD_NAME); String fieldName = parser.getCurrentName(); switch (fieldName) { - case FIELD_TYPE: - consumeTypeField(); - break; - case FIELD_VALUES: - consumeValuesField(); - break; - case FIELD_CELLS: - consumeCellsField(); - break; - case FIELD_BLOCKS: - consumeBlocksField(); - break; - default: - consumeAnyField(fieldName, parser.nextToken()); - break; + case FIELD_TYPE -> consumeTypeField(); + case FIELD_VALUES -> consumeValuesField(); + case FIELD_CELLS -> consumeCellsField(); + case FIELD_BLOCKS -> consumeBlocksField(); + default -> consumeAnyField(fieldName, parser.nextToken()); } } if (seenSimpleMapValue) { @@ -212,17 +199,17 @@ public class ConstantTensorJsonValidator { assertNextTokenIs(JsonToken.FIELD_NAME); String fieldName = parser.getCurrentName(); switch (fieldName) { - case FIELD_ADDRESS: - validateTensorAddress(new HashSet<>(tensorDimensions.keySet())); - seenAddress = true; - break; - case FIELD_VALUE: - validateNumeric(FIELD_VALUE, parser.nextToken()); - seenValue = true; - break; - default: - throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a cell object", - FIELD_ADDRESS, FIELD_VALUE)); + case FIELD_ADDRESS -> { + validateTensorAddress(new HashSet<>(tensorDimensions.keySet())); + seenAddress = true; + } + case FIELD_VALUE -> { + validateNumeric(FIELD_VALUE, parser.nextToken()); + seenValue = true; + } + default -> + throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a cell object", + FIELD_ADDRESS, FIELD_VALUE)); } } if (! seenAddress) { @@ -275,7 +262,7 @@ public class ConstantTensorJsonValidator { private void validateBoundIndex(TensorType.IndexedBoundDimension dimension) throws IOException { try { int value = Integer.parseInt(parser.getValueAsString()); - if (value >= dimension.size().get()) + if (value >= dimension.size().get().intValue()) throw new InvalidConstantTensorException(parser, String.format("Index %s not within limits of bound dimension '%s'", value, dimension.name())); } catch (NumberFormatException e) { throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name()); @@ -415,18 +402,18 @@ public class ConstantTensorJsonValidator { assertNextTokenIs(JsonToken.FIELD_NAME); String fieldName = parser.getCurrentName(); switch (fieldName) { - case FIELD_ADDRESS: - validateTensorAddress(new HashSet<>(mappedDims)); - seenAddress = true; - break; - case FIELD_VALUES: - assertNextTokenIs(JsonToken.START_ARRAY); - consumeValuesArray(); - seenValues = true; - break; - default: - throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a block object", - FIELD_ADDRESS, FIELD_VALUES)); + case FIELD_ADDRESS -> { + validateTensorAddress(new HashSet<>(mappedDims)); + seenAddress = true; + } + case FIELD_VALUES -> { + assertNextTokenIs(JsonToken.START_ARRAY); + consumeValuesArray(); + seenValues = true; + } + default -> + throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a block object", + FIELD_ADDRESS, FIELD_VALUES)); } } if (! seenAddress) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java index 97bfdda385e..5b7a348ba99 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java @@ -55,18 +55,18 @@ public class Gemm extends IntermediateOperation { TensorType.Dimension dimC0 = cDimensions.get(0); TensorType.Dimension dimC1 = cDimensions.get(1); - if ( ! (dimA.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) { + if ( ! (dimA.size().equals(dimC0.size()) || dimC0.size().get() == 1) ) { throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + " is not compatible or not broadcastable to " + result.type()); } - if ( ! (dimB.size().get().equals(dimC1.size().get()) || dimC1.size().get() == 1) ) { + if ( ! (dimB.size().equals(dimC1.size()) || dimC1.size().get() == 1) ) { throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + " is not compatible or not broadcastable to " + result.type()); } } else if (cDimensions.size() == 1) { TensorType.Dimension dimC0 = cDimensions.get(0); - if ( ! (dimB.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) { + if ( ! (dimB.size().equals(dimC0.size()) || dimC0.size().get() == 1) ) { throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() + " is not compatible or not broadcastable to " + result.type()); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java index a880bff87be..4934dc9a05c 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -47,7 +47,7 @@ public class Tile extends IntermediateOperation { throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor."); OrderedTensorType inputType = inputs.get(0).type().get(); - if (shape.type().dimensions().get(0).size().get() != inputType.rank()) + if (shape.type().dimensions().get(0).size().get().intValue() != inputType.rank()) throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank."); List<Integer> dimSizes = new ArrayList<>(inputType.rank()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 0c78c2891d6..4e7aa4bd482 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -137,7 +137,7 @@ class TensorParser { sz *= d.size().orElse(0L); } if (sz == 0 - || type.dimensions().size() == 0 + || type.dimensions().isEmpty() || valueString.length() < sz * 2 || valueString.chars().anyMatch(ch -> (Character.digit(ch, 16) == -1))) { @@ -253,14 +253,13 @@ class TensorParser { try { String cellValueString = string.substring(position, nextNumberEnd); try { - switch (cellValueType) { - case DOUBLE: return Double.parseDouble(cellValueString); - case FLOAT: return Float.parseFloat(cellValueString); - case BFLOAT16: return Float.parseFloat(cellValueString); - case INT8: return Float.parseFloat(cellValueString); - default: - throw new IllegalArgumentException(cellValueType + " is not supported"); - } + return switch (cellValueType) { + case DOUBLE -> Double.parseDouble(cellValueString); + case FLOAT -> Float.parseFloat(cellValueString); + case BFLOAT16 -> Float.parseFloat(cellValueString); + case INT8 -> Float.parseFloat(cellValueString); + default -> throw new IllegalArgumentException(cellValueType + " is not supported"); + }; } catch (NumberFormatException e) { throw new IllegalArgumentException("At value position " + position + ": '" + cellValueString + "' is not a valid " + cellValueType); @@ -346,10 +345,10 @@ class TensorParser { protected void consumeNumber() { Number number = consumeNumber(builder.type().valueType()); switch (builder.type().valueType()) { - case DOUBLE: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); break; - case FLOAT: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; - case BFLOAT16: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; - case INT8: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break; + case DOUBLE -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double) number); + case FLOAT -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number); + case BFLOAT16 -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number); + case INT8 -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number); } } } @@ -390,10 +389,10 @@ class TensorParser { private void consumeNumber() { Number number = consumeNumber(builder.type().valueType()); switch (builder.type().valueType()) { - case DOUBLE: builder.cell((Double)number, indexes); break; - case FLOAT: builder.cell((Float)number, indexes); break; - case BFLOAT16: builder.cell((Float)number, indexes); break; - case INT8: builder.cell((Float)number, indexes); break; + case DOUBLE -> builder.cell((Double) number, indexes); + case FLOAT -> builder.cell((Float) number, indexes); + case BFLOAT16 -> builder.cell((Float) number, indexes); + case INT8 -> builder.cell((Float) number, indexes); } } @@ -418,7 +417,7 @@ class TensorParser { private static class MixedValueParser extends ValueParser { private final Tensor.Builder builder; - private List<String> dimensionOrder; + private final List<String> dimensionOrder; public MixedValueParser(String string, List<String> dimensionOrder, Tensor.Builder builder) { super(string); @@ -450,7 +449,7 @@ class TensorParser { } private TensorType.Dimension findMappedDimension() { - Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(d -> d.isMapped()).findAny(); + Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(TensorType.Dimension::isMapped).findAny(); if (mappedDimension.isPresent()) return mappedDimension.get(); if (builder.type().rank() == 1 && builder.type().dimensions().get(0).size().isEmpty()) return builder.type().dimensions().get(0); @@ -469,10 +468,10 @@ class TensorParser { private void consumeNumber(TensorAddress address) { Number number = consumeNumber(builder.type().valueType()); switch (builder.type().valueType()) { - case DOUBLE: builder.cell(address, (Double)number); break; - case FLOAT: builder.cell(address, (Float)number); break; - case BFLOAT16: builder.cell(address, (Float)number); break; - case INT8: builder.cell(address, (Float)number); break; + case DOUBLE -> builder.cell(address, (Double) number); + case FLOAT -> builder.cell(address, (Float) number); + case BFLOAT16 -> builder.cell(address, (Float) number); + case INT8 -> builder.cell(address, (Float) number); } } } @@ -507,12 +506,11 @@ class TensorParser { String cellValueString = string.substring(position, valueEnd).trim(); try { switch (cellValueType) { - case DOUBLE: builder.cell(address, Double.parseDouble(cellValueString)); break; - case FLOAT: builder.cell(address, Float.parseFloat(cellValueString)); break; - case BFLOAT16: builder.cell(address, Float.parseFloat(cellValueString)); break; - case INT8: builder.cell(address, Float.parseFloat(cellValueString)); break; - default: - throw new IllegalArgumentException(cellValueType + " is not supported"); + case DOUBLE -> builder.cell(address, Double.parseDouble(cellValueString)); + case FLOAT -> builder.cell(address, Float.parseFloat(cellValueString)); + case BFLOAT16 -> builder.cell(address, Float.parseFloat(cellValueString)); + case INT8 -> builder.cell(address, Float.parseFloat(cellValueString)); + default -> throw new IllegalArgumentException(cellValueType + " is not supported"); } } catch (NumberFormatException e) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 6b010529046..f702dba6739 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -96,11 +96,11 @@ public class TensorType { Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); - if (dimensionList.stream().allMatch(d -> d.isIndexed())) { + if (dimensionList.stream().allMatch(Dimension::isIndexed)) { mappedSubtype = empty; indexedSubtype = this; } - else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) { + else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) { mappedSubtype = this; indexedSubtype = empty; } @@ -158,7 +158,7 @@ public class TensorType { /** Returns the dimension with this name, or empty if not present */ public Optional<Dimension> dimension(String name) { - return indexOfDimension(name).map(i -> dimensions.get(i)); + return indexOfDimension(name).map(dimensions::get); } /** Returns the 0-base index of this dimension, or empty if it is not present */ @@ -172,7 +172,7 @@ public class TensorType { /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ public Optional<Long> sizeOfDimension(String dimension) { Optional<Dimension> d = dimension(dimension); - if ( ! d.isPresent()) return Optional.empty(); + if (d.isEmpty()) return Optional.empty(); return d.get().size(); } @@ -213,12 +213,12 @@ public class TensorType { if (thisDimension.isIndexed() != generalizationDimension.isIndexed()) return false; if (considerName && ! thisDimension.name().equals(generalizationDimension.name())) return false; if (generalizationDimension.size().isPresent()) { - if ( ! thisDimension.size().isPresent()) return false; + if (thisDimension.size().isEmpty()) return false; if (convertible) { if (thisDimension.size().get() > generalizationDimension.size().get()) return false; } else { // assignable - if (!thisDimension.size().get().equals(generalizationDimension.size().get())) return false; + if (!thisDimension.size().equals(generalizationDimension.size())) return false; } } } @@ -269,7 +269,7 @@ public class TensorType { if ( ! thisDim.name().equals(otherDim.name())) return Optional.empty(); if (thisDim.isIndexed() && otherDim.isIndexed()) { if (thisDim.size().isPresent() && otherDim.size().isPresent()) { - if ( ! thisDim.size().get().equals(otherDim.size().get())) + if ( ! thisDim.size().equals(otherDim.size())) return Optional.empty(); b.dimension(thisDim); // both are equal and bound } @@ -314,7 +314,12 @@ public class TensorType { public final String name() { return name; } - /** Returns the size of this dimension if it is bound, empty otherwise */ + /** + * Returns the size of this dimension if it is bound, empty otherwise + * Beware not use == != when comparing size. Use equals + */ + // TODO Optional<Long> => OptionalLong to avoid mistakes when comparing values + // Deprecate if we find an alternative good name for size() public abstract Optional<Long> size(); public abstract Type type(); @@ -337,7 +342,7 @@ public class TensorType { * [] + {} = {} */ Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) { - if ( ! other.isPresent()) return this; + if (other.isEmpty()) return this; if (this instanceof MappedDimension) return this; if (other.get() instanceof MappedDimension) return other.get(); // both are indexed @@ -600,9 +605,9 @@ public class TensorType { public Builder dimension(String name, Dimension.Type type) { switch (type) { - case mapped : mapped(name); break; - case indexedUnbound : indexed(name); break; - default : throw new IllegalArgumentException("This can not create a dimension of type " + type); + case mapped -> mapped(name); + case indexedUnbound -> indexed(name); + default -> throw new IllegalArgumentException("This can not create a dimension of type " + type); } return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java index 0e5b031c2cc..9bc3c80a230 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java @@ -8,7 +8,6 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TensorType.Dimension; -import java.util.Collections; import java.util.List; import java.util.Objects; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java index 4c771fe8843..07b572b3f93 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java @@ -8,7 +8,6 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TensorType.Dimension; -import java.util.Collections; import java.util.List; import java.util.Objects; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index de1c30e6414..323056c7204 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -12,7 +12,6 @@ import java.util.Arrays; import java.util.List; import java.util.Objects; import java.util.function.DoubleBinaryOperator; -import java.util.stream.Collectors; /** * An optimization for tensor expressions where a join immediately follows a |