diff options
author | Arne Juul <arnej@yahooinc.com> | 2022-08-31 11:47:39 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2022-09-02 09:14:53 +0000 |
commit | 3ac9c793b6a6e2278be9cc92b527c38640849764 (patch) | |
tree | 3c972eb0766bc0e925427f7d3b69cddce7e8b8e7 /vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java | |
parent | 36bead13fbbd0b3ce5c5a364b6f07ee1d3555b9b (diff) |
allow simple hex format for dense tensors of known type
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 32b36c5c5cb..0c78c2891d6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -5,6 +5,8 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString; + /** * @author bratseth */ @@ -59,6 +61,9 @@ class TensorParser { return tensorFromDenseValueString(valueString, type, dimensionOrder); } else { + var t = maybeFromBinaryValueString(valueString, type, dimensionOrder); + if (t.isPresent()) { return t.get(); } + if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty)) throw new IllegalArgumentException("Got a zero-dimensional tensor value ('" + tensorString + "') where type " + explicitType.get() + " is required"); @@ -118,6 +123,42 @@ class TensorParser { } } + private static Optional<Tensor> maybeFromBinaryValueString( + String valueString, + Optional<TensorType> optType, + List<String> dimensionOrder) + { + if (optType.isEmpty() || dimensionOrder != null) { + return Optional.empty(); + } + var type = optType.get(); + long sz = 1; + for (var d : type.dimensions()) { + sz *= d.size().orElse(0L); + } + if (sz == 0 + || type.dimensions().size() == 0 + || valueString.length() < sz * 2 + || valueString.chars().anyMatch(ch -> (Character.digit(ch, 16) == -1))) + { + return Optional.empty(); + } + try { + double[] values = decodeHexString(valueString, type.valueType()); + if (values.length != sz) { + return Optional.empty(); + } + var builder = IndexedTensor.Builder.of(type); + var dib = (IndexedTensor.DirectIndexBuilder) builder; + for (int i = 0; i < sz; ++i) { + dib.cellByDirectIndex(i, values[i]); + } + return Optional.of(builder.build()); + } catch (IllegalArgumentException e) { + return Optional.empty(); + } + } + private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type, List<String> dimensionOrder) { |