aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java41
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java45
2 files changed, 86 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 32b36c5c5cb..0c78c2891d6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -5,6 +5,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString;
+
/**
* @author bratseth
*/
@@ -59,6 +61,9 @@ class TensorParser {
return tensorFromDenseValueString(valueString, type, dimensionOrder);
}
else {
+ var t = maybeFromBinaryValueString(valueString, type, dimensionOrder);
+ if (t.isPresent()) { return t.get(); }
+
if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty))
throw new IllegalArgumentException("Got a zero-dimensional tensor value ('" + tensorString +
"') where type " + explicitType.get() + " is required");
@@ -118,6 +123,42 @@ class TensorParser {
}
}
+ private static Optional<Tensor> maybeFromBinaryValueString(
+ String valueString,
+ Optional<TensorType> optType,
+ List<String> dimensionOrder)
+ {
+ if (optType.isEmpty() || dimensionOrder != null) {
+ return Optional.empty();
+ }
+ var type = optType.get();
+ long sz = 1;
+ for (var d : type.dimensions()) {
+ sz *= d.size().orElse(0L);
+ }
+ if (sz == 0
+ || type.dimensions().size() == 0
+ || valueString.length() < sz * 2
+ || valueString.chars().anyMatch(ch -> (Character.digit(ch, 16) == -1)))
+ {
+ return Optional.empty();
+ }
+ try {
+ double[] values = decodeHexString(valueString, type.valueType());
+ if (values.length != sz) {
+ return Optional.empty();
+ }
+ var builder = IndexedTensor.Builder.of(type);
+ var dib = (IndexedTensor.DirectIndexBuilder) builder;
+ for (int i = 0; i < sz; ++i) {
+ dib.cellByDirectIndex(i, values[i]);
+ }
+ return Optional.of(builder.build());
+ } catch (IllegalArgumentException e) {
+ return Optional.empty();
+ }
+ }
+
private static Tensor tensorFromDenseValueString(String valueString,
Optional<TensorType> type,
List<String> dimensionOrder) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
index 53df95ec2e8..6ce9dc4ce65 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -11,6 +11,11 @@ public class TensorParserTestCase {
@Test
public void testEmpty() {
assertEquals(Tensor.Builder.of(TensorType.empty).cell(1).build(), Tensor.from("tensor():{{}:1}"));
+ assertEquals(Tensor.Builder.of(TensorType.empty).cell(10).build(), Tensor.from("10.0"));
+ assertEquals(Tensor.Builder.of(TensorType.empty).cell(10).build(), Tensor.from(TensorType.empty, "10.0"));
+ // looks like a hex string, but should not be interpreted as such:
+ assertEquals(Tensor.Builder.of(TensorType.empty).cell(10).build(), Tensor.from("0000000000000010"));
+ assertEquals(Tensor.Builder.of(TensorType.empty).cell(10).build(), Tensor.from(TensorType.empty, "0000000000000010"));
}
@Test
@@ -81,6 +86,46 @@ public class TensorParserTestCase {
.cell( 5.0, 2, 0, 0)
.cell(-6.0, 2, 1, 0).build(),
Tensor.from("tensor( x[3],y[2],z[1]) : [1.0, 2.0, 3.0 , 4.0, 5, -6.0]"));
+
+ var int8TT = TensorType.fromSpec("tensor<int8>(x[2],y[3])");
+
+ assertEquals("binary tensor A",
+ Tensor.Builder.of(int8TT)
+ .cell(1, 0, 0)
+ .cell(20, 0, 1)
+ .cell(127, 0, 2)
+ .cell(-1, 1, 0)
+ .cell(50, 1, 1)
+ .cell(-128, 1, 2).build(),
+ Tensor.from(int8TT, "01147FFF3280"));
+
+ assertEquals("binary tensor B",
+ Tensor.Builder.of(int8TT)
+ .cell(26.0, 0, 0)
+ .cell(0.0, 0, 1)
+ .cell(31.0, 0, 2)
+ .cell(-68.0, 1, 0)
+ .cell(-98.0, 1, 1)
+ .cell(-34.0, 1, 2).build(),
+ Tensor.from(int8TT, "1a001fbc9ede"));
+
+ assertEquals("binary tensor C",
+ Tensor.Builder.of(int8TT)
+ .cell(16, 0, 0)
+ .cell(32, 0, 1)
+ .cell(48, 0, 2)
+ .cell(-16, 1, 0)
+ .cell(-32, 1, 1)
+ .cell(-64, 1, 2).build(),
+ Tensor.from(int8TT, "102030F0E0C0"));
+
+ var floatTT = TensorType.fromSpec("tensor<float>(x[3])");
+ assertEquals("float tensor hexdump",
+ Tensor.Builder.of(floatTT)
+ .cell(0, 0)
+ .cell(1.25, 1)
+ .cell(-19.125, 2).build(),
+ Tensor.from(floatTT, "000000003FA00000c1990000"));
}
@Test