diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-02-21 18:47:20 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-02-21 18:47:20 +0100 |
commit | acfdb9e6c61b9f8d065645a657c130bd2cb49c87 (patch) | |
tree | f98e707acaa41496d4d1277dd3535569f623f5a6 /vespajlib/src/main/java | |
parent | 31805b7b9640302067713ce05573d9d1e5c92f39 (diff) |
Deduce correct concat type
Diffstat (limited to 'vespajlib/src/main/java')
3 files changed, 59 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java new file mode 100644 index 00000000000..e0e4a0828a9 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java @@ -0,0 +1,33 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.lang; + +/** + * A mutable long + * + * @author bratseth + */ +public class MutableLong { + + private long value; + + public MutableLong(long value) { + this.value = value; + } + + public long get() { return value; } + + public void set(long value) { this.value = value; } + + /** Adds the increment to the current value and returns the resulting value */ + public long add(long increment) { + value += increment; + return value; + } + + /** Adds the increment to the current value and returns the resulting value */ + public long subtract(long increment) { + value -= increment; + return value; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 14cd3e70866..bf1825446e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -77,6 +77,13 @@ public class TensorType { return Optional.empty(); } + /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ + public Optional<Long> sizeOfDimension(String dimension) { + Optional<Dimension> d = dimension(dimension); + if ( ! d.isPresent()) return Optional.empty(); + return d.get().size(); + } + /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index a073053bec8..13e7c136feb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -3,6 +3,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.lang.MutableInteger; +import com.yahoo.lang.MutableLong; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -66,13 +68,27 @@ public class Concat extends PrimitiveTensorFunction { /** Returns the type resulting from concatenating a and b */ private TensorType type(TensorType a, TensorType b) { + // TODO: Fail if concat dimension is present but not indexed in a or b TensorType.Builder builder = new TensorType.Builder(a, b); - if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size - builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + - b.dimension(dimension).get().size().get())); + if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) { + builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) + + b.sizeOfDimension(dimension).orElse(1L))); + /* + MutableLong concatSize = new MutableLong(0); + a.sizeOfDimension(dimension).ifPresent(concatSize::add); + b.sizeOfDimension(dimension).ifPresent(concatSize::add); + builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); + */ + } return builder.build(); } + /** Returns true if this dimension is present and unbound */ + private boolean unboundIn(TensorType type, String dimensionName) { + Optional<TensorType.Dimension> dimension = type.dimension(dimensionName); + return dimension.isPresent() && ! dimension.get().size().isPresent(); + } + @Override public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); |