diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java | 7 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java | 11 |
2 files changed, 15 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java index 64c808b58a4..37a4bf375d0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java @@ -9,6 +9,8 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; import static com.yahoo.tensor.TensorType.Dimension; import static com.yahoo.tensor.TensorType.Value; @@ -19,6 +21,8 @@ import static com.yahoo.tensor.TensorType.Value; */ public class TypeResolver { + private static final Logger logger = Logger.getLogger(TypeResolver.class.getName()); + static private TensorType scalar() { return TensorType.empty; } @@ -44,7 +48,8 @@ public class TypeResolver { if (map.containsKey(name)) { map.remove(name); } else { - throw new IllegalArgumentException("reducing non-existing dimension "+name+" in type "+inputType); + logger.log(Level.WARNING, "reducing non-existing dimension "+name+" in type "+inputType); + // throw new IllegalArgumentException("reducing non-existing dimension "+name+" in type "+inputType); } } if (map.isEmpty()) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java index 17bad8e6902..8e4205c8c27 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java @@ -120,8 +120,15 @@ public class TypeResolverTestCase { checkReduce("tensor<float>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()"); checkReduce("tensor<bfloat16>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()"); checkReduce("tensor<int8>(x[10],y{},z[30])", mkl("x", "y", "z"), "tensor()"); - checkReduceFails("tensor(y{})", "x"); - checkReduceFails("tensor<float>(y[10])", "x"); + // for now, these will just log a warning + //checkReduceFails("tensor()", "x"); + //checkReduceFails("tensor(y{})", "x"); + //checkReduceFails("tensor<float>(y[10])", "x"); + //checkReduceFails("tensor<int8>(y[10])", "x"); + checkReduce("tensor()", mkl("x"), "tensor()"); + checkReduce("tensor(y{})", mkl("x"), "tensor(y{})"); + checkReduce("tensor<float>(y[10])", mkl("x"), "tensor<float>(y[10])"); + checkReduce("tensor<int8>(y[10])", mkl("x"), "tensor<float>(y[10])"); } @Test |