summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2019-06-02 15:20:01 +0200
committerGitHub <noreply@github.com>2019-06-02 15:20:01 +0200
commitb33f016430c8f659e11d32cee616442cd029bcc2 (patch)
treede389ef84956ddaf0c46a62e3df79e876ae53660 /vespajlib
parenta255c7c12e62862e4fb57a8ac2adb8af194223d0 (diff)
parentc25c8a52e2328bcff2f5a35496e7568ee5a7c752 (diff)
Merge pull request #9641 from vespa-engine/bratseth/ranking-expression-models
Bratseth/ranking expression models
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java23
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java4
4 files changed, 33 insertions, 4 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 9264b0a8255..04e68e60178 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1180,7 +1180,8 @@
"public static boolean approxEquals(double, double)",
"public static com.yahoo.tensor.Tensor from(com.yahoo.tensor.TensorType, java.lang.String)",
"public static com.yahoo.tensor.Tensor from(java.lang.String, java.lang.String)",
- "public static com.yahoo.tensor.Tensor from(java.lang.String)"
+ "public static com.yahoo.tensor.Tensor from(java.lang.String)",
+ "public static com.yahoo.tensor.Tensor from(double)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index ebb341147cf..22ff793e6fa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -367,6 +367,13 @@ public interface Tensor {
return TensorParser.tensorFrom(tensorString, Optional.empty());
}
+ /**
+ * Returns a double as a tensor: A dimensionless tensor containing the value as its cell
+ */
+ static Tensor from(double value) {
+ return Tensor.Builder.of(TensorType.empty).cell(value).build();
+ }
+
class Cell implements Map.Entry<TensorAddress, Double> {
private final TensorAddress address;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 3213982355b..6382361f187 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -1,7 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.serialization;
+import com.yahoo.slime.ArrayTraverser;
import com.yahoo.slime.Cursor;
+import com.yahoo.slime.Inspector;
+import com.yahoo.slime.JsonDecoder;
+import com.yahoo.slime.ObjectTraverser;
import com.yahoo.slime.Slime;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -17,9 +21,7 @@ import java.util.Iterator;
// TODO: We should probably move reading of this format from the document module to here
public class JsonFormat {
- /**
- * Serialize the given tensor into JSON format
- */
+ /** Serializes the given tensor into JSON format */
public static byte[] encode(Tensor tensor) {
Slime slime = new Slime();
Cursor root = slime.setObject();
@@ -38,4 +40,19 @@ public class JsonFormat {
addressObject.setString(type.dimensions().get(i).name(), address.label(i));
}
+ /** Deserializes the given tensor from JSON format */
+ // TODO: Add explicit validation (valid() checks) below
+ public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
+ Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
+ Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get();
+ Inspector cells = root.field("cells");
+ cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell()));
+ return tensorBuilder.build();
+ }
+
+ private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) {
+ cell.field("address").traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString()));
+ cellBuilder.value(cell.field("value").asDouble());
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 16af413f2f0..5a025b6eb96 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -26,6 +26,8 @@ public class JsonFormatTestCase {
"{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":3.0}" +
"]}",
new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
}
@Test
@@ -44,6 +46,8 @@ public class JsonFormatTestCase {
"{\"address\":{\"x\":\"1\",\"y\":\"1\"},\"value\":7.0}" +
"]}",
new String(json, StandardCharsets.UTF_8));
+ Tensor decoded = JsonFormat.decode(tensor.type(), json);
+ assertEquals(tensor, decoded);
}
}