diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-14 10:19:42 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-14 10:19:42 +0100 |
commit | 8f066f5697dcc6de1611888943ca9b886eb409ef (patch) | |
tree | 7babb1a6ad14e43ab16f9cf14989259bf0ab9b00 /vespajlib | |
parent | 1e72a39e1cf1462e398b82e8a28c715807753dd4 (diff) |
Apply in correct order and handle zero dimension
Diffstat (limited to 'vespajlib')
5 files changed, 33 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index ea16e7ed2f0..35b0e8b0d60 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -190,9 +190,14 @@ public class IndexedTensor implements Tensor { /** Returns the value at this address, or NaN if there is no value at this address */ @Override public double get(TensorAddress address) { + if (type.dimensions().isEmpty()) // either empty or a sinle value + return firstDimension.values().isEmpty() ? Double.NaN : (double)firstDimension.values().get(0); + IndexedDimension currentDimension = firstDimension; for (int i = 0; i < address.labels().size(); i++) { int index = Integer.parseInt(address.labels().get(i)); + if (index >= currentDimension.values().size()) return Double.NaN; + Object value = currentDimension.values().get(index); if (value == null) return Double.NaN; @@ -353,6 +358,11 @@ public class IndexedTensor implements Tensor { } @Override + public IndexedTensor.Builder.IndexedCellBuilder label(String dimension, int label) { + return label(dimension, String.valueOf(label)); + } + + @Override public IndexedTensor.Builder value(double cellValue) { return IndexedTensor.Builder.this.cell(addressBuilder.build(), cellValue); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 753b9a32e20..3c609acff45 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -134,7 +134,12 @@ public class MappedTensor implements Tensor { addressBuilder.add(dimension, label); return this; } - + + @Override + public MappedCellBuilder label(String dimension, int label) { + return label(dimension, String.valueOf(label)); + } + @Override public Builder value(double cellValue) { return MappedTensor.Builder.this.cell(addressBuilder.build(), cellValue); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 2aa8bd12bc7..d09e41b0ab6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -314,6 +314,8 @@ public interface Tensor { CellBuilder label(String dimension, String label); + CellBuilder label(String dimension, int label); + Builder value(double cellValue); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 5f91ff3034c..200cda694a5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -65,7 +65,6 @@ public class Join extends PrimitiveTensorFunction { public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - TensorType joinedType = a.type().combineWith(b.type()); // Choose join algorithm @@ -74,9 +73,9 @@ public class Join extends PrimitiveTensorFunction { else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size()) return singleSpaceJoin(a, b, joinedType); else if (a.type().dimensions().containsAll(b.type().dimensions())) - return subspaceJoin(b, a, joinedType); + return subspaceJoin(b, a, joinedType, true); else if (b.type().dimensions().containsAll(a.type().dimensions())) - return subspaceJoin(a, b,joinedType); + return subspaceJoin(a, b, joinedType, false); else return generalJoin(a, b, joinedType); } @@ -101,14 +100,16 @@ public class Join extends PrimitiveTensorFunction { } /** Join a tensor into a superspace */ - private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType) { + private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type()); Tensor.Builder builder = Tensor.Builder.of(joinedType); for (Map.Entry<TensorAddress, Double> supercell : superspace.cells().entrySet()) { TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes); double subspaceValue = subspace.get(subaddress); - if (subspaceValue != Double.NaN) - builder.cell(supercell.getKey(), combinator.applyAsDouble(subspaceValue, supercell.getValue())); + if ( ! Double.isNaN(subspaceValue)) + builder.cell(supercell.getKey(), + reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue) + : combinator.applyAsDouble(subspaceValue, supercell.getValue())); } return builder.build(); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 175136db19a..da472d102ff 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -40,8 +40,6 @@ public class TensorTestCase { /** Test the same computation made in various ways which are implemented with special-cvase optimizations */ @Test public void testOptimizedComputation() { - // All ways of making this computation should return the same value - assertEquals("Mapped vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.mapped, 2))); assertEquals("Indexed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.indexedUnbound, 2))); assertEquals("Mapped matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.mapped, 2))); @@ -54,6 +52,13 @@ public class TensorTestCase { assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.mapped, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2))); assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2))); + + // Test the unoptimized path by joining in another dimension + Tensor unitJ = Tensor.Builder.of(new TensorType.Builder().mapped("j").build()).cell().label("j", 0).value(1).build(); + Tensor unitK = Tensor.Builder.of(new TensorType.Builder().mapped("k").build()).cell().label("k", 0).value(1).build(); + Tensor vectorInJSpace = vector(Type.mapped).multiply(unitJ); + Tensor matrixInKSpace = matrix(Type.mapped, 2).get(0).multiply(unitK); + assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace))); } private double dotProduct(Tensor tensor, List<Tensor> tensors) { @@ -92,7 +97,7 @@ public class TensorTestCase { /** * Create a matrix of vectors (in dimension i) where each vector has the dimension x. - * Thie matric contains the same vectors as returned by createVectors + * This matrix contains the same vectors as returned by createVectors, in a single list element for convenience. */ private List<Tensor> matrix(TensorType.Dimension.Type dimensionType, int vectorCount) { int vectorSize = 3; |