diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-02-21 18:47:20 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-02-21 18:47:20 +0100 |
commit | acfdb9e6c61b9f8d065645a657c130bd2cb49c87 (patch) | |
tree | f98e707acaa41496d4d1277dd3535569f623f5a6 /vespajlib | |
parent | 31805b7b9640302067713ce05573d9d1e5c92f39 (diff) |
Deduce correct concat type
Diffstat (limited to 'vespajlib')
4 files changed, 127 insertions, 21 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java new file mode 100644 index 00000000000..e0e4a0828a9 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java @@ -0,0 +1,33 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.lang; + +/** + * A mutable long + * + * @author bratseth + */ +public class MutableLong { + + private long value; + + public MutableLong(long value) { + this.value = value; + } + + public long get() { return value; } + + public void set(long value) { this.value = value; } + + /** Adds the increment to the current value and returns the resulting value */ + public long add(long increment) { + value += increment; + return value; + } + + /** Adds the increment to the current value and returns the resulting value */ + public long subtract(long increment) { + value -= increment; + return value; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 14cd3e70866..bf1825446e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -77,6 +77,13 @@ public class TensorType { return Optional.empty(); } + /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ + public Optional<Long> sizeOfDimension(String dimension) { + Optional<Dimension> d = dimension(dimension); + if ( ! d.isPresent()) return Optional.empty(); + return d.get().size(); + } + /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index a073053bec8..13e7c136feb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -3,6 +3,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.lang.MutableInteger; +import com.yahoo.lang.MutableLong; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -66,13 +68,27 @@ public class Concat extends PrimitiveTensorFunction { /** Returns the type resulting from concatenating a and b */ private TensorType type(TensorType a, TensorType b) { + // TODO: Fail if concat dimension is present but not indexed in a or b TensorType.Builder builder = new TensorType.Builder(a, b); - if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size - builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + - b.dimension(dimension).get().size().get())); + if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) { + builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) + + b.sizeOfDimension(dimension).orElse(1L))); + /* + MutableLong concatSize = new MutableLong(0); + a.sizeOfDimension(dimension).ifPresent(concatSize::add); + b.sizeOfDimension(dimension).ifPresent(concatSize::add); + builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); + */ + } return builder.build(); } + /** Returns true if this dimension is present and unbound */ + private boolean unboundIn(TensorType type, String dimensionName) { + Optional<TensorType.Dimension> dimension = type.dimension(dimensionName); + return dimension.isPresent() && ! dimension.get().size().isPresent(); + } + @Override public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index 7e1f292eb7b..eafa5c4addf 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -2,6 +2,9 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -16,51 +19,98 @@ public class ConcatTestCase { public void testConcatNumbers() { Tensor a = Tensor.from("{1}"); Tensor b = Tensor.from("{2}"); - assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x")); + assertConcat("tensor(x[2]):{ {x:0}:1, {x:1}:2 }", a, b, "x"); + assertConcat("tensor(x[2]):{ {x:0}:2, {x:1}:1 }", b, a , "x"); } @Test public void testConcatEqualShapes() { - Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); - Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); - assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + - "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y")); + Tensor a = Tensor.from("tensor(x[3]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); + Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); + assertConcat("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }", a, b, "x"); + assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + + "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }", + a, b, "y"); } @Test public void testConcatNumberAndVector() { Tensor a = Tensor.from("{1}"); + Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); + assertConcat("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x"); + assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + + "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }", + a, b, "y"); + } + + @Test + public void testConcatNumberAndVectorUnbound() { + Tensor a = Tensor.from("{1}"); Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); - assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + - "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y")); + assertConcat("tensor(x[])","tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x"); + assertConcat("tensor(x[],y[2])", "tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + + "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }", + a, b, "y"); } @Test public void testUnequalSizesSameDimension() { + Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"); + Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); + assertConcat("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x"); + assertConcat("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y"); + } + + @Test + public void testUnequalSizesSameDimensionUnbound() { Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); - assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }"), a.concat(b, "y")); + assertConcat("tensor(x[])", "tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x"); + assertConcat("tensor(x[],y[2])", "tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y"); } @Test public void testUnequalEqualSizesDifferentDimension() { + Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"); + Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); + assertConcat("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x"); + assertConcat("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); + assertConcat("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z"); + } + + @Test + public void testUnequalEqualSizesDifferentDimensionOneUnbound() { Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); - Tensor b = Tensor.from("tensor(y[]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); - assertEquals(Tensor.from("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y")); - assertEquals(Tensor.from("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}"), a.concat(b, "z")); + Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); + assertConcat("tensor(x[],y[3])", "tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x"); + assertConcat("tensor(x[],y[4])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); + assertConcat("tensor(x[],y[3],z[2])", "tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z"); } @Test public void testDimensionsubset() { Tensor a = Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:3, {x:1,y:1}:4 }"); Tensor b = Tensor.from("tensor(y[2]):{ {y:0}:5, {y:1}:6 }"); - assertEquals(Tensor.from("tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y")); + assertConcat("tensor(x[],y[])", "tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}", a, b, "x"); + assertConcat("tensor(x[],y[])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); + } + + private void assertConcat(String expected, Tensor a, Tensor b, String dimension) { + assertConcat(null, expected, a, b, dimension); + } + + private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) { + Tensor expectedAsTensor = Tensor.from(expected); + TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension) + .type(new MapEvaluationContext()); + Tensor result = a.concat(b, dimension); + + if (expectedType != null) + assertEquals(TensorType.fromSpec(expectedType), inferredType); + else + assertEquals(expectedAsTensor.type(), inferredType); + + assertEquals(expectedAsTensor, result); } } |