summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2023-09-13 16:33:14 +0200
committerGitHub <noreply@github.com>2023-09-13 16:33:14 +0200
commitef1b98e1095a9a16638ae7a5acc392be8e4f4938 (patch)
treea0061cc89ecdc4c39e9289f8884b87b7300c1230
parent84fae7748bf666c64c63d5143947121d84fe1732 (diff)
parent7050087f2ba90981ed80118361d229ea5181918f (diff)
Merge pull request #28510 from vespa-engine/balder/compare-values-not-object-references
- Use equals when comparing Optional<Long>
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java71
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java56
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java1
8 files changed, 77 insertions, 90 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
index 92c9eccf2d3..aa71cc35f99 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
@@ -10,14 +10,11 @@ import com.yahoo.tensor.TensorType;
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
-import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import java.util.function.Function;
-import java.util.stream.Collectors;
/**
* ConstantTensorJsonValidator strictly validates a constant tensor in JSON format read from a Reader object
@@ -135,21 +132,11 @@ public class ConstantTensorJsonValidator {
assertCurrentTokenIs(JsonToken.FIELD_NAME);
String fieldName = parser.getCurrentName();
switch (fieldName) {
- case FIELD_TYPE:
- consumeTypeField();
- break;
- case FIELD_VALUES:
- consumeValuesField();
- break;
- case FIELD_CELLS:
- consumeCellsField();
- break;
- case FIELD_BLOCKS:
- consumeBlocksField();
- break;
- default:
- consumeAnyField(fieldName, parser.nextToken());
- break;
+ case FIELD_TYPE -> consumeTypeField();
+ case FIELD_VALUES -> consumeValuesField();
+ case FIELD_CELLS -> consumeCellsField();
+ case FIELD_BLOCKS -> consumeBlocksField();
+ default -> consumeAnyField(fieldName, parser.nextToken());
}
}
if (seenSimpleMapValue) {
@@ -212,17 +199,17 @@ public class ConstantTensorJsonValidator {
assertNextTokenIs(JsonToken.FIELD_NAME);
String fieldName = parser.getCurrentName();
switch (fieldName) {
- case FIELD_ADDRESS:
- validateTensorAddress(new HashSet<>(tensorDimensions.keySet()));
- seenAddress = true;
- break;
- case FIELD_VALUE:
- validateNumeric(FIELD_VALUE, parser.nextToken());
- seenValue = true;
- break;
- default:
- throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a cell object",
- FIELD_ADDRESS, FIELD_VALUE));
+ case FIELD_ADDRESS -> {
+ validateTensorAddress(new HashSet<>(tensorDimensions.keySet()));
+ seenAddress = true;
+ }
+ case FIELD_VALUE -> {
+ validateNumeric(FIELD_VALUE, parser.nextToken());
+ seenValue = true;
+ }
+ default ->
+ throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a cell object",
+ FIELD_ADDRESS, FIELD_VALUE));
}
}
if (! seenAddress) {
@@ -275,7 +262,7 @@ public class ConstantTensorJsonValidator {
private void validateBoundIndex(TensorType.IndexedBoundDimension dimension) throws IOException {
try {
int value = Integer.parseInt(parser.getValueAsString());
- if (value >= dimension.size().get())
+ if (value >= dimension.size().get().intValue())
throw new InvalidConstantTensorException(parser, String.format("Index %s not within limits of bound dimension '%s'", value, dimension.name()));
} catch (NumberFormatException e) {
throwCoordinateIsNotInteger(parser.getValueAsString(), dimension.name());
@@ -415,18 +402,18 @@ public class ConstantTensorJsonValidator {
assertNextTokenIs(JsonToken.FIELD_NAME);
String fieldName = parser.getCurrentName();
switch (fieldName) {
- case FIELD_ADDRESS:
- validateTensorAddress(new HashSet<>(mappedDims));
- seenAddress = true;
- break;
- case FIELD_VALUES:
- assertNextTokenIs(JsonToken.START_ARRAY);
- consumeValuesArray();
- seenValues = true;
- break;
- default:
- throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a block object",
- FIELD_ADDRESS, FIELD_VALUES));
+ case FIELD_ADDRESS -> {
+ validateTensorAddress(new HashSet<>(mappedDims));
+ seenAddress = true;
+ }
+ case FIELD_VALUES -> {
+ assertNextTokenIs(JsonToken.START_ARRAY);
+ consumeValuesArray();
+ seenValues = true;
+ }
+ default ->
+ throw new InvalidConstantTensorException(parser, String.format("Only '%s' or '%s' fields are permitted within a block object",
+ FIELD_ADDRESS, FIELD_VALUES));
}
}
if (! seenAddress) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
index 97bfdda385e..5b7a348ba99 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Gemm.java
@@ -55,18 +55,18 @@ public class Gemm extends IntermediateOperation {
TensorType.Dimension dimC0 = cDimensions.get(0);
TensorType.Dimension dimC1 = cDimensions.get(1);
- if ( ! (dimA.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) {
+ if ( ! (dimA.size().equals(dimC0.size()) || dimC0.size().get() == 1) ) {
throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
" is not compatible or not broadcastable to " + result.type());
}
- if ( ! (dimB.size().get().equals(dimC1.size().get()) || dimC1.size().get() == 1) ) {
+ if ( ! (dimB.size().equals(dimC1.size()) || dimC1.size().get() == 1) ) {
throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
" is not compatible or not broadcastable to " + result.type());
}
} else if (cDimensions.size() == 1) {
TensorType.Dimension dimC0 = cDimensions.get(0);
- if ( ! (dimB.size().get().equals(dimC0.size().get()) || dimC0.size().get() == 1) ) {
+ if ( ! (dimB.size().equals(dimC0.size()) || dimC0.size().get() == 1) ) {
throw new IllegalArgumentException("GEMM: type of optional input C " + inputs.get(2).type().get() +
" is not compatible or not broadcastable to " + result.type());
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
index a880bff87be..4934dc9a05c 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -47,7 +47,7 @@ public class Tile extends IntermediateOperation {
throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor.");
OrderedTensorType inputType = inputs.get(0).type().get();
- if (shape.type().dimensions().get(0).size().get() != inputType.rank())
+ if (shape.type().dimensions().get(0).size().get().intValue() != inputType.rank())
throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank.");
List<Integer> dimSizes = new ArrayList<>(inputType.rank());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 0c78c2891d6..4e7aa4bd482 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -137,7 +137,7 @@ class TensorParser {
sz *= d.size().orElse(0L);
}
if (sz == 0
- || type.dimensions().size() == 0
+ || type.dimensions().isEmpty()
|| valueString.length() < sz * 2
|| valueString.chars().anyMatch(ch -> (Character.digit(ch, 16) == -1)))
{
@@ -253,14 +253,13 @@ class TensorParser {
try {
String cellValueString = string.substring(position, nextNumberEnd);
try {
- switch (cellValueType) {
- case DOUBLE: return Double.parseDouble(cellValueString);
- case FLOAT: return Float.parseFloat(cellValueString);
- case BFLOAT16: return Float.parseFloat(cellValueString);
- case INT8: return Float.parseFloat(cellValueString);
- default:
- throw new IllegalArgumentException(cellValueType + " is not supported");
- }
+ return switch (cellValueType) {
+ case DOUBLE -> Double.parseDouble(cellValueString);
+ case FLOAT -> Float.parseFloat(cellValueString);
+ case BFLOAT16 -> Float.parseFloat(cellValueString);
+ case INT8 -> Float.parseFloat(cellValueString);
+ default -> throw new IllegalArgumentException(cellValueType + " is not supported");
+ };
} catch (NumberFormatException e) {
throw new IllegalArgumentException("At value position " + position + ": '" +
cellValueString + "' is not a valid " + cellValueType);
@@ -346,10 +345,10 @@ class TensorParser {
protected void consumeNumber() {
Number number = consumeNumber(builder.type().valueType());
switch (builder.type().valueType()) {
- case DOUBLE: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); break;
- case FLOAT: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
- case BFLOAT16: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
- case INT8: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
+ case DOUBLE -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double) number);
+ case FLOAT -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number);
+ case BFLOAT16 -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number);
+ case INT8 -> builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float) number);
}
}
}
@@ -390,10 +389,10 @@ class TensorParser {
private void consumeNumber() {
Number number = consumeNumber(builder.type().valueType());
switch (builder.type().valueType()) {
- case DOUBLE: builder.cell((Double)number, indexes); break;
- case FLOAT: builder.cell((Float)number, indexes); break;
- case BFLOAT16: builder.cell((Float)number, indexes); break;
- case INT8: builder.cell((Float)number, indexes); break;
+ case DOUBLE -> builder.cell((Double) number, indexes);
+ case FLOAT -> builder.cell((Float) number, indexes);
+ case BFLOAT16 -> builder.cell((Float) number, indexes);
+ case INT8 -> builder.cell((Float) number, indexes);
}
}
@@ -418,7 +417,7 @@ class TensorParser {
private static class MixedValueParser extends ValueParser {
private final Tensor.Builder builder;
- private List<String> dimensionOrder;
+ private final List<String> dimensionOrder;
public MixedValueParser(String string, List<String> dimensionOrder, Tensor.Builder builder) {
super(string);
@@ -450,7 +449,7 @@ class TensorParser {
}
private TensorType.Dimension findMappedDimension() {
- Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(d -> d.isMapped()).findAny();
+ Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(TensorType.Dimension::isMapped).findAny();
if (mappedDimension.isPresent()) return mappedDimension.get();
if (builder.type().rank() == 1 && builder.type().dimensions().get(0).size().isEmpty())
return builder.type().dimensions().get(0);
@@ -469,10 +468,10 @@ class TensorParser {
private void consumeNumber(TensorAddress address) {
Number number = consumeNumber(builder.type().valueType());
switch (builder.type().valueType()) {
- case DOUBLE: builder.cell(address, (Double)number); break;
- case FLOAT: builder.cell(address, (Float)number); break;
- case BFLOAT16: builder.cell(address, (Float)number); break;
- case INT8: builder.cell(address, (Float)number); break;
+ case DOUBLE -> builder.cell(address, (Double) number);
+ case FLOAT -> builder.cell(address, (Float) number);
+ case BFLOAT16 -> builder.cell(address, (Float) number);
+ case INT8 -> builder.cell(address, (Float) number);
}
}
}
@@ -507,12 +506,11 @@ class TensorParser {
String cellValueString = string.substring(position, valueEnd).trim();
try {
switch (cellValueType) {
- case DOUBLE: builder.cell(address, Double.parseDouble(cellValueString)); break;
- case FLOAT: builder.cell(address, Float.parseFloat(cellValueString)); break;
- case BFLOAT16: builder.cell(address, Float.parseFloat(cellValueString)); break;
- case INT8: builder.cell(address, Float.parseFloat(cellValueString)); break;
- default:
- throw new IllegalArgumentException(cellValueType + " is not supported");
+ case DOUBLE -> builder.cell(address, Double.parseDouble(cellValueString));
+ case FLOAT -> builder.cell(address, Float.parseFloat(cellValueString));
+ case BFLOAT16 -> builder.cell(address, Float.parseFloat(cellValueString));
+ case INT8 -> builder.cell(address, Float.parseFloat(cellValueString));
+ default -> throw new IllegalArgumentException(cellValueType + " is not supported");
}
}
catch (NumberFormatException e) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 6b010529046..f702dba6739 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -96,11 +96,11 @@ public class TensorType {
Collections.sort(dimensionList);
this.dimensions = List.copyOf(dimensionList);
- if (dimensionList.stream().allMatch(d -> d.isIndexed())) {
+ if (dimensionList.stream().allMatch(Dimension::isIndexed)) {
mappedSubtype = empty;
indexedSubtype = this;
}
- else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) {
+ else if (dimensionList.stream().noneMatch(Dimension::isIndexed)) {
mappedSubtype = this;
indexedSubtype = empty;
}
@@ -158,7 +158,7 @@ public class TensorType {
/** Returns the dimension with this name, or empty if not present */
public Optional<Dimension> dimension(String name) {
- return indexOfDimension(name).map(i -> dimensions.get(i));
+ return indexOfDimension(name).map(dimensions::get);
}
/** Returns the 0-base index of this dimension, or empty if it is not present */
@@ -172,7 +172,7 @@ public class TensorType {
/* Returns the bound of this dimension if it is present and bound in this, empty otherwise */
public Optional<Long> sizeOfDimension(String dimension) {
Optional<Dimension> d = dimension(dimension);
- if ( ! d.isPresent()) return Optional.empty();
+ if (d.isEmpty()) return Optional.empty();
return d.get().size();
}
@@ -213,12 +213,12 @@ public class TensorType {
if (thisDimension.isIndexed() != generalizationDimension.isIndexed()) return false;
if (considerName && ! thisDimension.name().equals(generalizationDimension.name())) return false;
if (generalizationDimension.size().isPresent()) {
- if ( ! thisDimension.size().isPresent()) return false;
+ if (thisDimension.size().isEmpty()) return false;
if (convertible) {
if (thisDimension.size().get() > generalizationDimension.size().get()) return false;
}
else { // assignable
- if (!thisDimension.size().get().equals(generalizationDimension.size().get())) return false;
+ if (!thisDimension.size().equals(generalizationDimension.size())) return false;
}
}
}
@@ -269,7 +269,7 @@ public class TensorType {
if ( ! thisDim.name().equals(otherDim.name())) return Optional.empty();
if (thisDim.isIndexed() && otherDim.isIndexed()) {
if (thisDim.size().isPresent() && otherDim.size().isPresent()) {
- if ( ! thisDim.size().get().equals(otherDim.size().get()))
+ if ( ! thisDim.size().equals(otherDim.size()))
return Optional.empty();
b.dimension(thisDim); // both are equal and bound
}
@@ -314,7 +314,12 @@ public class TensorType {
public final String name() { return name; }
- /** Returns the size of this dimension if it is bound, empty otherwise */
+ /**
+ * Returns the size of this dimension if it is bound, empty otherwise
+ * Beware not use == != when comparing size. Use equals
+ */
+ // TODO Optional<Long> => OptionalLong to avoid mistakes when comparing values
+ // Deprecate if we find an alternative good name for size()
public abstract Optional<Long> size();
public abstract Type type();
@@ -337,7 +342,7 @@ public class TensorType {
* [] + {} = {}
*/
Dimension combineWith(Optional<Dimension> other, boolean allowDifferentSizes) {
- if ( ! other.isPresent()) return this;
+ if (other.isEmpty()) return this;
if (this instanceof MappedDimension) return this;
if (other.get() instanceof MappedDimension) return other.get();
// both are indexed
@@ -600,9 +605,9 @@ public class TensorType {
public Builder dimension(String name, Dimension.Type type) {
switch (type) {
- case mapped : mapped(name); break;
- case indexedUnbound : indexed(name); break;
- default : throw new IllegalArgumentException("This can not create a dimension of type " + type);
+ case mapped -> mapped(name);
+ case indexedUnbound -> indexed(name);
+ default -> throw new IllegalArgumentException("This can not create a dimension of type " + type);
}
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
index 0e5b031c2cc..9bc3c80a230 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CosineSimilarity.java
@@ -8,7 +8,6 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorType.Dimension;
-import java.util.Collections;
import java.util.List;
import java.util.Objects;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
index 4c771fe8843..07b572b3f93 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/EuclideanDistance.java
@@ -8,7 +8,6 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorType.Dimension;
-import java.util.Collections;
import java.util.List;
import java.util.Objects;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index de1c30e6414..323056c7204 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -12,7 +12,6 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.function.DoubleBinaryOperator;
-import java.util.stream.Collectors;
/**
* An optimization for tensor expressions where a join immediately follows a