summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-14 10:19:42 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-14 10:19:42 +0100
commit8f066f5697dcc6de1611888943ca9b886eb409ef (patch)
tree7babb1a6ad14e43ab16f9cf14989259bf0ab9b00 /vespajlib
parent1e72a39e1cf1462e398b82e8a28c715807753dd4 (diff)
Apply in correct order and handle zero dimension
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java13
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java11
5 files changed, 33 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index ea16e7ed2f0..35b0e8b0d60 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -190,9 +190,14 @@ public class IndexedTensor implements Tensor {
/** Returns the value at this address, or NaN if there is no value at this address */
@Override
public double get(TensorAddress address) {
+ if (type.dimensions().isEmpty()) // either empty or a sinle value
+ return firstDimension.values().isEmpty() ? Double.NaN : (double)firstDimension.values().get(0);
+
IndexedDimension currentDimension = firstDimension;
for (int i = 0; i < address.labels().size(); i++) {
int index = Integer.parseInt(address.labels().get(i));
+ if (index >= currentDimension.values().size()) return Double.NaN;
+
Object value = currentDimension.values().get(index);
if (value == null) return Double.NaN;
@@ -353,6 +358,11 @@ public class IndexedTensor implements Tensor {
}
@Override
+ public IndexedTensor.Builder.IndexedCellBuilder label(String dimension, int label) {
+ return label(dimension, String.valueOf(label));
+ }
+
+ @Override
public IndexedTensor.Builder value(double cellValue) {
return IndexedTensor.Builder.this.cell(addressBuilder.build(), cellValue);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
index 753b9a32e20..3c609acff45 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java
@@ -134,7 +134,12 @@ public class MappedTensor implements Tensor {
addressBuilder.add(dimension, label);
return this;
}
-
+
+ @Override
+ public MappedCellBuilder label(String dimension, int label) {
+ return label(dimension, String.valueOf(label));
+ }
+
@Override
public Builder value(double cellValue) {
return MappedTensor.Builder.this.cell(addressBuilder.build(), cellValue);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 2aa8bd12bc7..d09e41b0ab6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -314,6 +314,8 @@ public interface Tensor {
CellBuilder label(String dimension, String label);
+ CellBuilder label(String dimension, int label);
+
Builder value(double cellValue);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 5f91ff3034c..200cda694a5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -65,7 +65,6 @@ public class Join extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
-
TensorType joinedType = a.type().combineWith(b.type());
// Choose join algorithm
@@ -74,9 +73,9 @@ public class Join extends PrimitiveTensorFunction {
else if (joinedType.dimensions().size() == a.type().dimensions().size() && joinedType.dimensions().size() == b.type().dimensions().size())
return singleSpaceJoin(a, b, joinedType);
else if (a.type().dimensions().containsAll(b.type().dimensions()))
- return subspaceJoin(b, a, joinedType);
+ return subspaceJoin(b, a, joinedType, true);
else if (b.type().dimensions().containsAll(a.type().dimensions()))
- return subspaceJoin(a, b,joinedType);
+ return subspaceJoin(a, b, joinedType, false);
else
return generalJoin(a, b, joinedType);
}
@@ -101,14 +100,16 @@ public class Join extends PrimitiveTensorFunction {
}
/** Join a tensor into a superspace */
- private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType) {
+ private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
int[] subspaceIndexes = subspaceIndexes(superspace.type(), subspace.type());
Tensor.Builder builder = Tensor.Builder.of(joinedType);
for (Map.Entry<TensorAddress, Double> supercell : superspace.cells().entrySet()) {
TensorAddress subaddress = mapAddressToSubspace(supercell.getKey(), subspaceIndexes);
double subspaceValue = subspace.get(subaddress);
- if (subspaceValue != Double.NaN)
- builder.cell(supercell.getKey(), combinator.applyAsDouble(subspaceValue, supercell.getValue()));
+ if ( ! Double.isNaN(subspaceValue))
+ builder.cell(supercell.getKey(),
+ reversedArgumentOrder ? combinator.applyAsDouble(supercell.getValue(), subspaceValue)
+ : combinator.applyAsDouble(subspaceValue, supercell.getValue()));
}
return builder.build();
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 175136db19a..da472d102ff 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -40,8 +40,6 @@ public class TensorTestCase {
/** Test the same computation made in various ways which are implemented with special-cvase optimizations */
@Test
public void testOptimizedComputation() {
- // All ways of making this computation should return the same value
-
assertEquals("Mapped vector", 42, (int)dotProduct(vector(Type.mapped), vectors(Type.mapped, 2)));
assertEquals("Indexed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.indexedUnbound, 2)));
assertEquals("Mapped matrix", 42, (int)dotProduct(vector(Type.mapped), matrix(Type.mapped, 2)));
@@ -54,6 +52,13 @@ public class TensorTestCase {
assertEquals("Mixed vector", 42, (int)dotProduct(vector(Type.indexedUnbound), vectors(Type.mapped, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2)));
assertEquals("Mixed matrix", 42, (int)dotProduct(vector(Type.indexedUnbound), matrix(Type.mapped, 2)));
+
+ // Test the unoptimized path by joining in another dimension
+ Tensor unitJ = Tensor.Builder.of(new TensorType.Builder().mapped("j").build()).cell().label("j", 0).value(1).build();
+ Tensor unitK = Tensor.Builder.of(new TensorType.Builder().mapped("k").build()).cell().label("k", 0).value(1).build();
+ Tensor vectorInJSpace = vector(Type.mapped).multiply(unitJ);
+ Tensor matrixInKSpace = matrix(Type.mapped, 2).get(0).multiply(unitK);
+ assertEquals("Generic computation implementation", 42, (int)dotProduct(vectorInJSpace, Collections.singletonList(matrixInKSpace)));
}
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
@@ -92,7 +97,7 @@ public class TensorTestCase {
/**
* Create a matrix of vectors (in dimension i) where each vector has the dimension x.
- * Thie matric contains the same vectors as returned by createVectors
+ * This matrix contains the same vectors as returned by createVectors, in a single list element for convenience.
*/
private List<Tensor> matrix(TensorType.Dimension.Type dimensionType, int vectorCount) {
int vectorSize = 3;