diff options
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java | 2 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java | 11 |
2 files changed, 12 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index d43e9ee74a3..19edfc0269e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -255,7 +255,7 @@ public abstract class IndexedTensor implements Tensor { if (type.valueType() == TensorType.Value.FLOAT) return new IndexedFloatTensor.BoundFloatBuilder(type, sizes); - else if (type.valueType() == TensorType.Value.FLOAT) + else if (type.valueType() == TensorType.Value.DOUBLE) return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); else return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 02d16e6f3e4..b01d171792c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -37,6 +37,17 @@ public class TensorTestCase { } @Test + public void testValueTypes() { + assertEquals(Tensor.from("tensor<double>(x[1]):{{x:0}:5}").getClass(), IndexedDoubleTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<double>(x[1])")).cell(5.0, 0).build().getClass(), + IndexedDoubleTensor.class); + + assertEquals(Tensor.from("tensor<float>(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class); + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[1])")).cell(5.0, 0).build().getClass(), + IndexedFloatTensor.class); + } + + @Test public void testParseError() { try { Tensor.from("--"); |