aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-02-21 18:47:20 +0100
committerJon Bratseth <bratseth@oath.com>2018-02-21 18:47:20 +0100
commitacfdb9e6c61b9f8d065645a657c130bd2cb49c87 (patch)
treef98e707acaa41496d4d1277dd3535569f623f5a6 /vespajlib/src
parent31805b7b9640302067713ce05573d9d1e5c92f39 (diff)
Deduce correct concat type
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/lang/MutableLong.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java86
4 files changed, 127 insertions, 21 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java
new file mode 100644
index 00000000000..e0e4a0828a9
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java
@@ -0,0 +1,33 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.lang;
+
+/**
+ * A mutable long
+ *
+ * @author bratseth
+ */
+public class MutableLong {
+
+ private long value;
+
+ public MutableLong(long value) {
+ this.value = value;
+ }
+
+ public long get() { return value; }
+
+ public void set(long value) { this.value = value; }
+
+ /** Adds the increment to the current value and returns the resulting value */
+ public long add(long increment) {
+ value += increment;
+ return value;
+ }
+
+ /** Adds the increment to the current value and returns the resulting value */
+ public long subtract(long increment) {
+ value -= increment;
+ return value;
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 14cd3e70866..bf1825446e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -77,6 +77,13 @@ public class TensorType {
return Optional.empty();
}
+ /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */
+ public Optional<Long> sizeOfDimension(String dimension) {
+ Optional<Dimension> d = dimension(dimension);
+ if ( ! d.isPresent()) return Optional.empty();
+ return d.get().size();
+ }
+
/**
* Returns whether this type can be assigned to the given type,
* i.e if the given type is a generalization of this type.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index a073053bec8..13e7c136feb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -3,6 +3,8 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.lang.MutableInteger;
+import com.yahoo.lang.MutableLong;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -66,13 +68,27 @@ public class Concat extends PrimitiveTensorFunction {
/** Returns the type resulting from concatenating a and b */
private TensorType type(TensorType a, TensorType b) {
+ // TODO: Fail if concat dimension is present but not indexed in a or b
TensorType.Builder builder = new TensorType.Builder(a, b);
- if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
- builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() +
- b.dimension(dimension).get().size().get()));
+ if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) {
+ builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) +
+ b.sizeOfDimension(dimension).orElse(1L)));
+ /*
+ MutableLong concatSize = new MutableLong(0);
+ a.sizeOfDimension(dimension).ifPresent(concatSize::add);
+ b.sizeOfDimension(dimension).ifPresent(concatSize::add);
+ builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
+ */
+ }
return builder.build();
}
+ /** Returns true if this dimension is present and unbound */
+ private boolean unboundIn(TensorType type, String dimensionName) {
+ Optional<TensorType.Dimension> dimension = type.dimension(dimensionName);
+ return dimension.isPresent() && ! dimension.get().size().isPresent();
+ }
+
@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index 7e1f292eb7b..eafa5c4addf 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -2,6 +2,9 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -16,51 +19,98 @@ public class ConcatTestCase {
public void testConcatNumbers() {
Tensor a = Tensor.from("{1}");
Tensor b = Tensor.from("{2}");
- assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x"));
+ assertConcat("tensor(x[2]):{ {x:0}:1, {x:1}:2 }", a, b, "x");
+ assertConcat("tensor(x[2]):{ {x:0}:2, {x:1}:1 }", b, a , "x");
}
@Test
public void testConcatEqualShapes() {
- Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }");
- Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
- assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
- "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y"));
+ Tensor a = Tensor.from("tensor(x[3]):{ {x:0}:1, {x:1}:2, {x:2}:3 }");
+ Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
+ assertConcat("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }", a, b, "x");
+ assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
+ "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }",
+ a, b, "y");
}
@Test
public void testConcatNumberAndVector() {
Tensor a = Tensor.from("{1}");
+ Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:2, {x:1}:3, {x:2}:4 }");
+ assertConcat("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x");
+ assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
+ "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }",
+ a, b, "y");
+ }
+
+ @Test
+ public void testConcatNumberAndVectorUnbound() {
+ Tensor a = Tensor.from("{1}");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }");
- assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
- "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y"));
+ assertConcat("tensor(x[])","tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x");
+ assertConcat("tensor(x[],y[2])", "tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
+ "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }",
+ a, b, "y");
}
@Test
public void testUnequalSizesSameDimension() {
+ Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }");
+ Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
+ assertConcat("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x");
+ assertConcat("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y");
+ }
+
+ @Test
+ public void testUnequalSizesSameDimensionUnbound() {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
- assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }"), a.concat(b, "y"));
+ assertConcat("tensor(x[])", "tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x");
+ assertConcat("tensor(x[],y[2])", "tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y");
}
@Test
public void testUnequalEqualSizesDifferentDimension() {
+ Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }");
+ Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
+ assertConcat("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x");
+ assertConcat("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
+ assertConcat("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z");
+ }
+
+ @Test
+ public void testUnequalEqualSizesDifferentDimensionOneUnbound() {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
- Tensor b = Tensor.from("tensor(y[]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
- assertEquals(Tensor.from("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y"));
- assertEquals(Tensor.from("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}"), a.concat(b, "z"));
+ Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
+ assertConcat("tensor(x[],y[3])", "tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x");
+ assertConcat("tensor(x[],y[4])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
+ assertConcat("tensor(x[],y[3],z[2])", "tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z");
}
@Test
public void testDimensionsubset() {
Tensor a = Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:3, {x:1,y:1}:4 }");
Tensor b = Tensor.from("tensor(y[2]):{ {y:0}:5, {y:1}:6 }");
- assertEquals(Tensor.from("tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}"), a.concat(b, "x"));
- assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y"));
+ assertConcat("tensor(x[],y[])", "tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}", a, b, "x");
+ assertConcat("tensor(x[],y[])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
+ }
+
+ private void assertConcat(String expected, Tensor a, Tensor b, String dimension) {
+ assertConcat(null, expected, a, b, dimension);
+ }
+
+ private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) {
+ Tensor expectedAsTensor = Tensor.from(expected);
+ TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension)
+ .type(new MapEvaluationContext());
+ Tensor result = a.concat(b, dimension);
+
+ if (expectedType != null)
+ assertEquals(TensorType.fromSpec(expectedType), inferredType);
+ else
+ assertEquals(expectedAsTensor.type(), inferredType);
+
+ assertEquals(expectedAsTensor, result);
}
}