diff options
author | Jon Bratseth <bratseth@oath.com> | 2019-06-02 15:20:01 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-06-02 15:20:01 +0200 |
commit | b33f016430c8f659e11d32cee616442cd029bcc2 (patch) | |
tree | de389ef84956ddaf0c46a62e3df79e876ae53660 /vespajlib | |
parent | a255c7c12e62862e4fb57a8ac2adb8af194223d0 (diff) | |
parent | c25c8a52e2328bcff2f5a35496e7568ee5a7c752 (diff) |
Merge pull request #9641 from vespa-engine/bratseth/ranking-expression-models
Bratseth/ranking expression models
Diffstat (limited to 'vespajlib')
4 files changed, 33 insertions, 4 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 9264b0a8255..04e68e60178 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1180,7 +1180,8 @@ "public static boolean approxEquals(double, double)", "public static com.yahoo.tensor.Tensor from(com.yahoo.tensor.TensorType, java.lang.String)", "public static com.yahoo.tensor.Tensor from(java.lang.String, java.lang.String)", - "public static com.yahoo.tensor.Tensor from(java.lang.String)" + "public static com.yahoo.tensor.Tensor from(java.lang.String)", + "public static com.yahoo.tensor.Tensor from(double)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index ebb341147cf..22ff793e6fa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -367,6 +367,13 @@ public interface Tensor { return TensorParser.tensorFrom(tensorString, Optional.empty()); } + /** + * Returns a double as a tensor: A dimensionless tensor containing the value as its cell + */ + static Tensor from(double value) { + return Tensor.Builder.of(TensorType.empty).cell(value).build(); + } + class Cell implements Map.Entry<TensorAddress, Double> { private final TensorAddress address; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 3213982355b..6382361f187 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -1,7 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; +import com.yahoo.slime.ArrayTraverser; import com.yahoo.slime.Cursor; +import com.yahoo.slime.Inspector; +import com.yahoo.slime.JsonDecoder; +import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -17,9 +21,7 @@ import java.util.Iterator; // TODO: We should probably move reading of this format from the document module to here public class JsonFormat { - /** - * Serialize the given tensor into JSON format - */ + /** Serializes the given tensor into JSON format */ public static byte[] encode(Tensor tensor) { Slime slime = new Slime(); Cursor root = slime.setObject(); @@ -38,4 +40,19 @@ public class JsonFormat { addressObject.setString(type.dimensions().get(i).name(), address.label(i)); } + /** Deserializes the given tensor from JSON format */ + // TODO: Add explicit validation (valid() checks) below + public static Tensor decode(TensorType type, byte[] jsonTensorValue) { + Tensor.Builder tensorBuilder = Tensor.Builder.of(type); + Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get(); + Inspector cells = root.field("cells"); + cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell())); + return tensorBuilder.build(); + } + + private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) { + cell.field("address").traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString())); + cellBuilder.value(cell.field("value").asDouble()); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 16af413f2f0..5a025b6eb96 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -26,6 +26,8 @@ public class JsonFormatTestCase { "{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":3.0}" + "]}", new String(json, StandardCharsets.UTF_8)); + Tensor decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); } @Test @@ -44,6 +46,8 @@ public class JsonFormatTestCase { "{\"address\":{\"x\":\"1\",\"y\":\"1\"},\"value\":7.0}" + "]}", new String(json, StandardCharsets.UTF_8)); + Tensor decoded = JsonFormat.decode(tensor.type(), json); + assertEquals(tensor, decoded); } } |