// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.serialization; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.util.Optional; /** * Class used by clients for serializing a Tensor object into binary format or * de-serializing binary data into a Tensor object. * * The actual binary format used is not a concern for the client and * is hidden in this class and in the binary data. * * @author geirst */ public class TypedBinaryFormat { private static final int SPARSE_BINARY_FORMAT_TYPE = 1; private static final int DENSE_BINARY_FORMAT_TYPE = 2; private static final int MIXED_BINARY_FORMAT_TYPE = 3; private static final int SPARSE_BINARY_FORMAT_WITH_CELLTYPE = 5; private static final int DENSE_BINARY_FORMAT_WITH_CELLTYPE = 6; private static final int MIXED_BINARY_FORMAT_WITH_CELLTYPE = 7; private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing private static final int FLOAT_VALUE_TYPE = 1; public static byte[] encode(Tensor tensor) { GrowableByteBuffer buffer = new GrowableByteBuffer(); BinaryFormat encoder = getFormatEncoder(buffer, tensor); encoder.encode(buffer, tensor); return asByteArray(buffer); } /** * Decode some data to a tensor * * @param type the type to decode and validate to, or empty to use the type given in the data * @param buffer the buffer containing the data, use GrowableByteByffer.wrap(byte[]) if you have a byte array * @return the resulting tensor * @throws IllegalArgumentException if the tensor data was invalid */ public static Tensor decode(Optional type, GrowableByteBuffer buffer) { BinaryFormat decoder = getFormatDecoder(buffer); return decoder.decode(type, buffer); } private static BinaryFormat getFormatEncoder(GrowableByteBuffer buffer, Tensor tensor) { if (tensor instanceof MixedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, MIXED_BINARY_FORMAT_TYPE); return new MixedBinaryFormat(); } if (tensor instanceof MixedTensor) { encodeFormatType(buffer, MIXED_BINARY_FORMAT_WITH_CELLTYPE); encodeValueType(buffer, tensor.type().valueType()); return new MixedBinaryFormat(tensor.type().valueType()); } if (tensor instanceof IndexedTensor && tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, DENSE_BINARY_FORMAT_TYPE); return new DenseBinaryFormat(); } if (tensor instanceof IndexedTensor) { encodeFormatType(buffer, DENSE_BINARY_FORMAT_WITH_CELLTYPE); encodeValueType(buffer, tensor.type().valueType()); return new DenseBinaryFormat(tensor.type().valueType()); } if (tensor.type().valueType() == TensorType.Value.DOUBLE) { encodeFormatType(buffer, SPARSE_BINARY_FORMAT_TYPE); return new SparseBinaryFormat(); } encodeFormatType(buffer, SPARSE_BINARY_FORMAT_WITH_CELLTYPE); encodeValueType(buffer, tensor.type().valueType()); return new SparseBinaryFormat(tensor.type().valueType()); } private static BinaryFormat getFormatDecoder(GrowableByteBuffer buffer) { int formatType = decodeFormatType(buffer); switch (formatType) { case SPARSE_BINARY_FORMAT_TYPE: return new SparseBinaryFormat(); case DENSE_BINARY_FORMAT_TYPE: return new DenseBinaryFormat(); case MIXED_BINARY_FORMAT_TYPE: return new MixedBinaryFormat(); case SPARSE_BINARY_FORMAT_WITH_CELLTYPE: return new SparseBinaryFormat(decodeValueType(buffer)); case DENSE_BINARY_FORMAT_WITH_CELLTYPE: return new DenseBinaryFormat(decodeValueType(buffer)); case MIXED_BINARY_FORMAT_WITH_CELLTYPE: return new MixedBinaryFormat(decodeValueType(buffer)); } throw new IllegalArgumentException("Binary format type " + formatType + " is unknown"); } private static void encodeFormatType(GrowableByteBuffer buffer, int formatType) { buffer.putInt1_4Bytes(formatType); } private static int decodeFormatType(GrowableByteBuffer buffer) { return buffer.getInt1_4Bytes(); } private static void encodeValueType(GrowableByteBuffer buffer, TensorType.Value valueType) { switch (valueType) { case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break; case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break; default: throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType); } } private static TensorType.Value decodeValueType(GrowableByteBuffer buffer) { int valueType = buffer.getInt1_4Bytes(); switch (valueType) { case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE; case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT; } throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal."); } private static byte[] asByteArray(GrowableByteBuffer buffer) { buffer.flip(); byte[] result = new byte[buffer.remaining()]; buffer.get(result); return result; } }