aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2022-08-31 11:47:39 +0000
committerArne Juul <arnej@yahooinc.com>2022-08-31 12:27:20 +0000
commit45eda4e9feae8d045ea05f72e49021de62842e0f (patch)
tree0ead6d47b7e8b82b4bd49304b4a7964859800d43 /vespajlib/src/main
parent6dd09abd549f28ab65bfa2ffe38f69228c3d9b12 (diff)
allow simple hex format for dense tensors of known type
Diffstat (limited to 'vespajlib/src/main')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java37
1 files changed, 37 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 32b36c5c5cb..67c81921803 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -5,6 +5,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString;
+
/**
* @author bratseth
*/
@@ -59,6 +61,9 @@ class TensorParser {
return tensorFromDenseValueString(valueString, type, dimensionOrder);
}
else {
+ var t = maybeFromBinaryValueString(valueString, type, dimensionOrder);
+ if (t.isPresent()) { return t.get(); }
+
if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty))
throw new IllegalArgumentException("Got a zero-dimensional tensor value ('" + tensorString +
"') where type " + explicitType.get() + " is required");
@@ -118,6 +123,38 @@ class TensorParser {
}
}
+ private static Optional<Tensor> maybeFromBinaryValueString(
+ String valueString,
+ Optional<TensorType> type,
+ List<String> dimensionOrder)
+ {
+ if (type.isEmpty() || dimensionOrder != null) {
+ return Optional.empty();
+ }
+ long sz = 1;
+ for (var d : type.get().dimensions()) {
+ sz *= d.size().orElse(0L);
+ }
+ if (sz == 0 || valueString.length() < sz * 2) {
+ return Optional.empty();
+ }
+ try {
+ double[] values = decodeHexString(valueString, type.get().valueType());
+ if (values.length != sz) {
+ return Optional.empty();
+ }
+ var builder = IndexedTensor.Builder.of(type.get());
+ var dib = (IndexedTensor.DirectIndexBuilder) builder;
+ for (int i = 0; i < sz; ++i) {
+ System.out.println("idx "+i+" -> "+values[i]);
+ dib.cellByDirectIndex(i, values[i]);
+ }
+ return Optional.of(builder.build());
+ } catch (NumberFormatException e) {
+ return Optional.empty();
+ }
+ }
+
private static Tensor tensorFromDenseValueString(String valueString,
Optional<TensorType> type,
List<String> dimensionOrder) {