diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-08 16:22:01 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-03-08 16:22:01 +0100 |
commit | 8f21c54b669202cdcc1a04934762dceebb929308 (patch) | |
tree | cddb1bf2cb106b5eb92594785f7daef69f41e3b4 /vespajlib/src/test/java/com/yahoo/tensor | |
parent | f19b783d4014f799482daa13f8f8c26d5c4c84d9 (diff) |
Add TensorFlow variable converter
Diffstat (limited to 'vespajlib/src/test/java/com/yahoo/tensor')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java | 2 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 48 |
2 files changed, 49 insertions, 1 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index eef0b090fd1..f7a0a3cdb7d 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -99,7 +99,7 @@ public class TensorTypeTestCase { private static void assertIllegalTensorType(String typeSpec, String messageSubstring) { try { TensorType.fromSpec(typeSpec); - fail("Expoected exception to be thrown with message: '" + messageSubstring + "'"); + fail("Expected exception to be thrown with message: '" + messageSubstring + "'"); } catch (IllegalArgumentException e) { assertThat(e.getMessage(), containsString(messageSubstring)); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java new file mode 100644 index 00000000000..db343e6b343 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -0,0 +1,48 @@ +package com.yahoo.tensor.serialization; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import java.nio.charset.StandardCharsets; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class JsonFormatTestCase { + + @Test + public void testJsonEncodingOfSparseTensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); + builder.cell().label("x", "a").label("y", "b").value(2.0); + builder.cell().label("x", "c").label("y", "d").value(3.0); + Tensor tensor = builder.build(); + byte[] json = JsonFormat.encode(tensor); + assertEquals("{\"cells\":[" + + "{\"address\":{\"x\":\"a\",\"y\":\"b\"},\"value\":2.0}," + + "{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":3.0}" + + "]}", + new String(json, StandardCharsets.UTF_8)); + } + + @Test + public void testJsonEncodingOfDenseTensor() { + Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); + builder.cell().label("x", 0).label("y", 0).value(2.0); + builder.cell().label("x", 0).label("y", 1).value(3.0); + builder.cell().label("x", 1).label("y", 0).value(5.0); + builder.cell().label("x", 1).label("y", 1).value(7.0); + Tensor tensor = builder.build(); + byte[] json = JsonFormat.encode(tensor); + assertEquals("{\"cells\":[" + + "{\"address\":{\"x\":\"0\",\"y\":\"0\"},\"value\":2.0}," + + "{\"address\":{\"x\":\"0\",\"y\":\"1\"},\"value\":3.0}," + + "{\"address\":{\"x\":\"1\",\"y\":\"0\"},\"value\":5.0}," + + "{\"address\":{\"x\":\"1\",\"y\":\"1\"},\"value\":7.0}" + + "]}", + new String(json, StandardCharsets.UTF_8)); + } + +} |