summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-02-21 18:47:20 +0100
committerJon Bratseth <bratseth@oath.com>2018-02-21 18:47:20 +0100
commitacfdb9e6c61b9f8d065645a657c130bd2cb49c87 (patch)
treef98e707acaa41496d4d1277dd3535569f623f5a6 /vespajlib/src/main/java
parent31805b7b9640302067713ce05573d9d1e5c92f39 (diff)
Deduce correct concat type
Diffstat (limited to 'vespajlib/src/main/java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/lang/MutableLong.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java22
3 files changed, 59 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java
new file mode 100644
index 00000000000..e0e4a0828a9
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java
@@ -0,0 +1,33 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.lang;
+
+/**
+ * A mutable long
+ *
+ * @author bratseth
+ */
+public class MutableLong {
+
+ private long value;
+
+ public MutableLong(long value) {
+ this.value = value;
+ }
+
+ public long get() { return value; }
+
+ public void set(long value) { this.value = value; }
+
+ /** Adds the increment to the current value and returns the resulting value */
+ public long add(long increment) {
+ value += increment;
+ return value;
+ }
+
+ /** Adds the increment to the current value and returns the resulting value */
+ public long subtract(long increment) {
+ value -= increment;
+ return value;
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 14cd3e70866..bf1825446e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -77,6 +77,13 @@ public class TensorType {
return Optional.empty();
}
+ /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */
+ public Optional<Long> sizeOfDimension(String dimension) {
+ Optional<Dimension> d = dimension(dimension);
+ if ( ! d.isPresent()) return Optional.empty();
+ return d.get().size();
+ }
+
/**
* Returns whether this type can be assigned to the given type,
* i.e if the given type is a generalization of this type.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index a073053bec8..13e7c136feb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -3,6 +3,8 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.lang.MutableInteger;
+import com.yahoo.lang.MutableLong;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -66,13 +68,27 @@ public class Concat extends PrimitiveTensorFunction {
/** Returns the type resulting from concatenating a and b */
private TensorType type(TensorType a, TensorType b) {
+ // TODO: Fail if concat dimension is present but not indexed in a or b
TensorType.Builder builder = new TensorType.Builder(a, b);
- if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
- builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() +
- b.dimension(dimension).get().size().get()));
+ if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) {
+ builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) +
+ b.sizeOfDimension(dimension).orElse(1L)));
+ /*
+ MutableLong concatSize = new MutableLong(0);
+ a.sizeOfDimension(dimension).ifPresent(concatSize::add);
+ b.sizeOfDimension(dimension).ifPresent(concatSize::add);
+ builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
+ */
+ }
return builder.build();
}
+ /** Returns true if this dimension is present and unbound */
+ private boolean unboundIn(TensorType type, String dimensionName) {
+ Optional<TensorType.Dimension> dimension = type.dimension(dimensionName);
+ return dimension.isPresent() && ! dimension.get().size().isPresent();
+ }
+
@Override
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);