From 7d48aa76c6c89851bf5d99109e41d2b485bc87ab Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 10 Jan 2017 14:47:44 +0100 Subject: Maintain TensorType in documents --- .../src/main/java/com/yahoo/tensor/TensorType.java | 19 +++++++++++++++++++ .../yahoo/tensor/serialization/TypedBinaryFormat.java | 1 + .../java/com/yahoo/vespa/objects/Identifiable.java | 4 ++-- .../serialization/SparseBinaryFormatTestCase.java | 2 +- 4 files changed, 23 insertions(+), 3 deletions(-) (limited to 'vespajlib') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 82f36972a47..e58a08c8d31 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -77,6 +77,25 @@ public class TensorType { return Optional.empty(); } + /** + * Returns whether a tensor of the given type can be assigned to this type, + * i.e of this type is a generalization of the given type. + */ + public boolean isAssignableTo(TensorType other) { + if (other.dimensions().size() != this.dimensions().size()) return false; + for (int i = 0; i < other.dimensions().size(); i++) { + Dimension thisDimension = this.dimensions().get(i); + Dimension otherDimension = other.dimensions().get(i); + if (thisDimension.isIndexed() != other.isIndexed()) return false; + if ( ! thisDimension.name().equals(otherDimension.name())) return false; + if (thisDimension.size().isPresent()) { + if ( ! otherDimension.size().isPresent()) return false; + if (otherDimension.size().get() > thisDimension.size().get() ) return false; + } + } + return true; + } + @Override public String toString() { return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java index 5a45f20b6d8..5bb93b9da83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization; import com.google.common.annotations.Beta; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * Class used by clients for serializing a Tensor object into binary format or diff --git a/vespajlib/src/main/java/com/yahoo/vespa/objects/Identifiable.java b/vespajlib/src/main/java/com/yahoo/vespa/objects/Identifiable.java index e0edc6f4e64..d303a69a68d 100644 --- a/vespajlib/src/main/java/com/yahoo/vespa/objects/Identifiable.java +++ b/vespajlib/src/main/java/com/yahoo/vespa/objects/Identifiable.java @@ -16,7 +16,7 @@ import java.util.HashMap; * methods. * * @author baldersheim - * @author Simon Thoresen + * @author Simon Thoresen */ public class Identifiable extends Selectable implements Cloneable { @@ -177,7 +177,7 @@ public class Identifiable extends Selectable implements Cloneable { * * @param id The class identifier to register with. * @param spec The class to register. - * @return The identifier argument. + * @return the identifier argument. */ protected static int registerClass(int id, Class spec) { if (registry == null) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java index ad908101329..8f96edf7dd8 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java @@ -32,7 +32,7 @@ public class SparseBinaryFormatTestCase { private static void assertSerialization(Tensor tensor) { byte[] encodedTensor = TypedBinaryFormat.encode(tensor); - Tensor decodedTensor = TypedBinaryFormat.decode(encodedTensor); + Tensor decodedTensor = TypedBinaryFormat.decode(null, encodedTensor); assertEquals(tensor, decodedTensor); } -- cgit v1.2.3