aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-08 16:22:01 +0100
committerJon Bratseth <bratseth@oath.com>2018-03-08 16:22:01 +0100
commit8f21c54b669202cdcc1a04934762dceebb929308 (patch)
treecddb1bf2cb106b5eb92594785f7daef69f41e3b4 /vespajlib/src
parentf19b783d4014f799482daa13f8f8c26d5c4c84d9 (diff)
Add TensorFlow variable converter
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/slime/JsonDecoder.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/slime/JsonFormat.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/slime/SlimeFormat.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java40
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java48
9 files changed, 129 insertions, 16 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/slime/JsonDecoder.java b/vespajlib/src/main/java/com/yahoo/slime/JsonDecoder.java
index d4d00180abd..3f73faf289d 100644
--- a/vespajlib/src/main/java/com/yahoo/slime/JsonDecoder.java
+++ b/vespajlib/src/main/java/com/yahoo/slime/JsonDecoder.java
@@ -11,10 +11,10 @@ import java.util.List;
/**
* A port of the C++ json decoder intended to be fast.
*
- * @author lulf
- * @since 5.1.21
+ * @author Ulf Lilleengen
*/
public class JsonDecoder {
+
private BufferedInput in;
private byte c;
@@ -256,7 +256,6 @@ public class JsonDecoder {
return ret;
}
-
private void next() {
if (!in.eof()) {
c = in.getByte();
@@ -302,4 +301,5 @@ public class JsonDecoder {
public final Cursor insertARRAY() { return target.setArray(key); }
public final Cursor insertOBJECT() { return target.setObject(key); }
}
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/slime/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/slime/JsonFormat.java
index 56bcce922bd..d908359df19 100644
--- a/vespajlib/src/main/java/com/yahoo/slime/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/slime/JsonFormat.java
@@ -13,7 +13,7 @@ import java.io.*;
/**
* Encodes json from a slime object.
*
- * @author lulf
+ * @author Ulf Lilleengen
*/
public final class JsonFormat implements SlimeFormat
{
@@ -41,12 +41,30 @@ public final class JsonFormat implements SlimeFormat
}
@Override
- public void decode(InputStream is, Slime slime) throws IOException {
+ public void decode(InputStream is, Slime slime) {
throw new UnsupportedOperationException("Not implemented");
}
- public static final class Encoder implements ArrayTraverser, ObjectTraverser
- {
+ /** Returns the given slime data as UTF-8-encoded JSON */
+ public static byte[] toJsonBytes(Slime slime) {
+ try {
+ ByteArrayOutputStream baos = new ByteArrayOutputStream();
+ new JsonFormat(true).encode(baos, slime);
+ return baos.toByteArray();
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
+ /** Returns the given UTF-8-encoded JSON as a Slime object */
+ public static Slime jsonToSlime(byte[] json) {
+ Slime slime = new Slime();
+ new JsonDecoder().decode(slime, json);
+ return slime;
+ }
+
+ public static final class Encoder implements ArrayTraverser, ObjectTraverser {
private final Inspector top;
private final AbstractByteWriter out;
private boolean head = true;
diff --git a/vespajlib/src/main/java/com/yahoo/slime/SlimeFormat.java b/vespajlib/src/main/java/com/yahoo/slime/SlimeFormat.java
index 7829d772b00..ab4501a1e69 100644
--- a/vespajlib/src/main/java/com/yahoo/slime/SlimeFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/slime/SlimeFormat.java
@@ -6,21 +6,24 @@ import java.io.InputStream;
import java.io.OutputStream;
/**
- * @author lulf
- * @since 5.1
+ * @author Ulf Lilleengen
*/
public interface SlimeFormat {
+
/**
* Encode a slime object into the provided output stream
+ *
* @param os The outputstream to write to.
* @param slime The slime object to encode.
*/
- public void encode(OutputStream os, Slime slime) throws IOException;
+ void encode(OutputStream os, Slime slime) throws IOException;
/**
* Encode a slime object into the provided output stream
+ *
* @param is The input stream to read from.
* @param slime The slime object to decode into.
*/
- public void decode(InputStream is, Slime slime) throws IOException;
+ void decode(InputStream is, Slime slime) throws IOException;
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 5590ccaad0a..9b3a9328f07 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -18,7 +18,7 @@ class TensorParser {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
String valueString = tensorString.substring(colonIndex + 1);
- TensorType typeFromString = new TensorType(TensorTypeParser.fromSpec(typeString));
+ TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
if (type.isPresent() && ! type.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
"passed type " + type);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 2483280817c..0176dac6821 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -32,7 +32,7 @@ public class TensorType {
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
- public TensorType(Collection<Dimension> dimensions) {
+ private TensorType(Collection<Dimension> dimensions) {
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = ImmutableList.copyOf(dimensionList);
@@ -50,7 +50,7 @@ public class TensorType {
* Example: <code>tensor(x[10],y[20])</code> (a matrix)
*/
public static TensorType fromSpec(String specString) {
- return new TensorType(TensorTypeParser.fromSpec(specString));
+ return TensorTypeParser.fromSpec(specString);
}
/** Returns the number of dimensions of this: dimensions().size() */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index 6ed0b8202f1..e3a194a96d7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -23,7 +23,11 @@ public class TensorTypeParser {
private static final Pattern indexedPattern = Pattern.compile("(\\w+)\\[(\\d*)\\]");
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
- public static List<TensorType.Dimension> fromSpec(String specString) {
+ public static TensorType fromSpec(String specString) {
+ return new TensorType.Builder(dimensionsFromSpec(specString)).build();
+ }
+
+ public static List<TensorType.Dimension> dimensionsFromSpec(String specString) {
if ( ! specString.startsWith(START_STRING) || !specString.endsWith(END_STRING)) {
throw new IllegalArgumentException("Tensor type spec must start with '" + START_STRING + "'" +
" and end with '" + END_STRING + "', but was '" + specString + "'");
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
new file mode 100644
index 00000000000..ab68f3a63b2
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -0,0 +1,40 @@
+package com.yahoo.tensor.serialization;
+
+import com.yahoo.slime.Cursor;
+import com.yahoo.slime.Slime;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+
+import java.util.Iterator;
+
+/**
+ * Writes tensors on the JSON format used in Vespa tensor document fields:
+ * A JSON map containing a 'cells' array.
+ * See http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor
+ */
+// TODO: We should probably move reading of this format from the document module to here
+public class JsonFormat {
+
+ /**
+ * Serialize the given tensor into JSON format
+ */
+ public static byte[] encode(Tensor tensor) {
+ Slime slime = new Slime();
+ Cursor root = slime.setObject();
+ Cursor cellsArray = root.setArray("cells");
+ for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ Cursor cellObject = cellsArray.addObject();
+ encodeAddress(tensor.type(), cell.getKey(), cellObject.setObject("address"));
+ cellObject.setDouble("value", cell.getValue());
+ }
+ return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
+ }
+
+ private static void encodeAddress(TensorType type, TensorAddress address, Cursor addressObject) {
+ for (int i = 0; i < address.size(); i++)
+ addressObject.setString(type.dimensions().get(i).name(), address.label(i));
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index eef0b090fd1..f7a0a3cdb7d 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -99,7 +99,7 @@ public class TensorTypeTestCase {
private static void assertIllegalTensorType(String typeSpec, String messageSubstring) {
try {
TensorType.fromSpec(typeSpec);
- fail("Expoected exception to be thrown with message: '" + messageSubstring + "'");
+ fail("Expected exception to be thrown with message: '" + messageSubstring + "'");
} catch (IllegalArgumentException e) {
assertThat(e.getMessage(), containsString(messageSubstring));
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
new file mode 100644
index 00000000000..db343e6b343
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -0,0 +1,48 @@
+package com.yahoo.tensor.serialization;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Test;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class JsonFormatTestCase {
+
+ @Test
+ public void testJsonEncodingOfSparseTensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
+ builder.cell().label("x", "a").label("y", "b").value(2.0);
+ builder.cell().label("x", "c").label("y", "d").value(3.0);
+ Tensor tensor = builder.build();
+ byte[] json = JsonFormat.encode(tensor);
+ assertEquals("{\"cells\":[" +
+ "{\"address\":{\"x\":\"a\",\"y\":\"b\"},\"value\":2.0}," +
+ "{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":3.0}" +
+ "]}",
+ new String(json, StandardCharsets.UTF_8));
+ }
+
+ @Test
+ public void testJsonEncodingOfDenseTensor() {
+ Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})"));
+ builder.cell().label("x", 0).label("y", 0).value(2.0);
+ builder.cell().label("x", 0).label("y", 1).value(3.0);
+ builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 1).value(7.0);
+ Tensor tensor = builder.build();
+ byte[] json = JsonFormat.encode(tensor);
+ assertEquals("{\"cells\":[" +
+ "{\"address\":{\"x\":\"0\",\"y\":\"0\"},\"value\":2.0}," +
+ "{\"address\":{\"x\":\"0\",\"y\":\"1\"},\"value\":3.0}," +
+ "{\"address\":{\"x\":\"1\",\"y\":\"0\"},\"value\":5.0}," +
+ "{\"address\":{\"x\":\"1\",\"y\":\"1\"},\"value\":7.0}" +
+ "]}",
+ new String(json, StandardCharsets.UTF_8));
+ }
+
+}