summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHarald Musum <musum@yahoo-inc.com>2018-02-21 23:43:02 +0100
committerGitHub <noreply@github.com>2018-02-21 23:43:02 +0100
commitfdff142dab4a75ace0623c2c8bf513a0a4597aca (patch)
tree0101a275869dd270a1609d1199381e0137e5d1f6 /vespajlib
parent2238f0d8d3b8a07e5e9d0ce1a01fc7f3e149cece (diff)
Revert "Bratseth/typecheck all 3"
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/lang/MutableLong.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java86
16 files changed, 48 insertions, 195 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java
deleted file mode 100644
index e0e4a0828a9..00000000000
--- a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java
+++ /dev/null
@@ -1,33 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.lang;
-
-/**
- * A mutable long
- *
- * @author bratseth
- */
-public class MutableLong {
-
- private long value;
-
- public MutableLong(long value) {
- this.value = value;
- }
-
- public long get() { return value; }
-
- public void set(long value) { this.value = value; }
-
- /** Adds the increment to the current value and returns the resulting value */
- public long add(long increment) {
- value += increment;
- return value;
- }
-
- /** Adds the increment to the current value and returns the resulting value */
- public long subtract(long increment) {
- value -= increment;
- return value;
- }
-
-}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index bf1825446e4..14cd3e70866 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -77,13 +77,6 @@ public class TensorType {
return Optional.empty();
}
- /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */
- public Optional<Long> sizeOfDimension(String dimension) {
- Optional<Dimension> d = dimension(dimension);
- if ( ! d.isPresent()) return Optional.empty();
- return d.get().size();
- }
-
/**
* Returns whether this type can be assigned to the given type,
* i.e if the given type is a generalization of this type.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 8a969180113..3fb94f1251b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -10,7 +10,7 @@ import com.yahoo.tensor.Tensor;
* @author bratseth
*/
@Beta
-public interface EvaluationContext<NAMETYPE extends TypeContext.Name> extends TypeContext<NAMETYPE> {
+public interface EvaluationContext extends TypeContext {
/** Returns the tensor bound to this name, or null if none */
Tensor getTensor(String name);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index b9394da31e3..9fe6b7d053f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -11,20 +11,17 @@ import java.util.HashMap;
* @author bratseth
*/
@Beta
-public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> {
+public class MapEvaluationContext implements EvaluationContext {
private final java.util.Map<String, Tensor> bindings = new HashMap<>();
+ static MapEvaluationContext empty() { return new MapEvaluationContext(); }
+
public void put(String name, Tensor tensor) { bindings.put(name, tensor); }
@Override
public TensorType getType(String name) {
- return getType(new Name(name));
- }
-
- @Override
- public TensorType getType(Name name) {
- Tensor tensor = bindings.get(name.toString());
+ Tensor tensor = bindings.get(name);
if (tensor == null) return null;
return tensor.type();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
index ff2e6318b37..760a225efdf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.TensorType;
*
* @author bratseth
*/
-public interface TypeContext<NAMETYPE extends TypeContext.Name> {
+public interface TypeContext {
/**
* Returns the type of the tensor with this name.
@@ -16,39 +16,6 @@ public interface TypeContext<NAMETYPE extends TypeContext.Name> {
* @return returns the type of the tensor which will be returned by calling getTensor(name)
* or null if getTensor will return null.
*/
- TensorType getType(NAMETYPE name);
-
- /**
- * Returns the type of the tensor with this name by converting from a string name.
- *
- * @return returns the type of the tensor which will be returned by calling getTensor(name)
- * or null if getTensor will return null.
- */
TensorType getType(String name);
- /** A name which is just a string. Names are value objects. */
- class Name {
-
- private final String name;
-
- public Name(String name) {
- this.name = name;
- }
-
- @Override
- public String toString() { return name; }
-
- @Override
- public int hashCode() { return name.hashCode(); }
-
- @Override
- public boolean equals(Object other) {
- if (other == this) return true;
- if ( ! (other instanceof Name)) return false;
- return ((Name)other).name.equals(this.name);
- }
-
- }
-
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index acb2363cba4..34beb465d4c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -44,7 +44,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext context) {
TensorType givenType = context.getType(name);
if (givenType == null) return null;
verifyType(givenType);
@@ -52,7 +52,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext context) {
Tensor tensor = context.getTensor(name);
if (tensor == null) return null;
verifyType(tensor.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index bfc0938abcc..2109b730e1a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -18,14 +18,10 @@ public abstract class CompositeTensorFunction extends TensorFunction {
/** Finds the type this produces by first converting it to a primitive function */
@Override
- public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
- return toPrimitive().type(context);
- }
+ public final TensorType type(TypeContext context) { return toPrimitive().type(context); }
/** Evaluates this by first converting it to a primitive function */
@Override
- public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
- return toPrimitive().evaluate(context);
- }
+ public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 13e7c136feb..c77ed1c0526 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -3,8 +3,6 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.yahoo.lang.MutableInteger;
-import com.yahoo.lang.MutableLong;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
@@ -62,35 +60,21 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext context) {
return type(argumentA.type(context), argumentB.type(context));
}
/** Returns the type resulting from concatenating a and b */
private TensorType type(TensorType a, TensorType b) {
- // TODO: Fail if concat dimension is present but not indexed in a or b
TensorType.Builder builder = new TensorType.Builder(a, b);
- if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) {
- builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) +
- b.sizeOfDimension(dimension).orElse(1L)));
- /*
- MutableLong concatSize = new MutableLong(0);
- a.sizeOfDimension(dimension).ifPresent(concatSize::add);
- b.sizeOfDimension(dimension).ifPresent(concatSize::add);
- builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
- */
- }
+ if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size
+ builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() +
+ b.dimension(dimension).get().size().get()));
return builder.build();
}
- /** Returns true if this dimension is present and unbound */
- private boolean unboundIn(TensorType type, String dimensionName) {
- Optional<TensorType.Dimension> dimension = type.dimension(dimensionName);
- return dimension.isPresent() && ! dimension.get().size().isPresent();
- }
-
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
a = ensureIndexedDimension(dimension, a);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index a43de297b9a..50b479da168 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -42,10 +42,10 @@ public class ConstantTensor extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); }
+ public TensorType type(TypeContext context) { return constant.type(); }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
+ public Tensor evaluate(EvaluationContext context) { return constant; }
@Override
public String toString(ToStringContext context) { return constant.toString(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index edfa8253eb9..e70d1de3db7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -61,10 +61,10 @@ public class Generate extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
+ public TensorType type(TypeContext context) { return type; }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
for (int i = 0; i < indexes.size(); i++) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 17e1c103ea3..7812c985091 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -95,12 +95,12 @@ public class Join extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext context) {
return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 4a338e5501e..53504868ff2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -53,12 +53,12 @@ public class Map extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext context) {
return argument.type(context);
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext context) {
Tensor argument = argument().evaluate(context);
Tensor.Builder builder = Tensor.Builder.of(argument.type());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index e045effbe7e..76a938b9fe2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -101,12 +101,11 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext context) {
return type(argument.type(context));
}
private TensorType type(TensorType argumentType) {
- if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argumentType.dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
@@ -115,7 +114,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index af4492ca1e4..de3d2be265a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -72,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction {
public PrimitiveTensorFunction toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext context) {
return type(argument.type(context));
}
@@ -84,7 +84,7 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext context) {
Tensor tensor = argument.evaluate(context);
TensorType renamedType = type(tensor.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index e805e9d87bb..78ab09c7820 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -43,14 +43,14 @@ public abstract class TensorFunction {
*
* @param context a context which must be passed to all nexted functions when evaluating
*/
- public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context);
+ public abstract Tensor evaluate(EvaluationContext context);
/**
* Returns the type of the tensor this produces given the input types in the context
*
* @param context a context which must be passed to all nexted functions when evaluating
*/
- public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context);
+ public abstract TensorType type(TypeContext context);
/** Evaluate with no context */
public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); }
@@ -58,7 +58,7 @@ public abstract class TensorFunction {
/**
* Return a string representation of this context.
*
- * @param context a context which must be passed to all nested functions when requesting the string value
+ * @param context a context which must be passed to all nexted functions when requesting the string value
*/
public abstract String toString(ToStringContext context);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index eafa5c4addf..7e1f292eb7b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -2,9 +2,6 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.MapEvaluationContext;
-import com.yahoo.tensor.evaluation.TypeContext;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -19,98 +16,51 @@ public class ConcatTestCase {
public void testConcatNumbers() {
Tensor a = Tensor.from("{1}");
Tensor b = Tensor.from("{2}");
- assertConcat("tensor(x[2]):{ {x:0}:1, {x:1}:2 }", a, b, "x");
- assertConcat("tensor(x[2]):{ {x:0}:2, {x:1}:1 }", b, a , "x");
+ assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x"));
}
@Test
public void testConcatEqualShapes() {
- Tensor a = Tensor.from("tensor(x[3]):{ {x:0}:1, {x:1}:2, {x:2}:3 }");
- Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
- assertConcat("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }", a, b, "x");
- assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
- "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }",
- a, b, "y");
+ Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }");
+ Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
+ assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " +
+ "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y"));
}
@Test
public void testConcatNumberAndVector() {
Tensor a = Tensor.from("{1}");
- Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:2, {x:1}:3, {x:2}:4 }");
- assertConcat("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x");
- assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
- "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }",
- a, b, "y");
- }
-
- @Test
- public void testConcatNumberAndVectorUnbound() {
- Tensor a = Tensor.from("{1}");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }");
- assertConcat("tensor(x[])","tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x");
- assertConcat("tensor(x[],y[2])", "tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
- "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }",
- a, b, "y");
+ assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " +
+ "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y"));
}
@Test
public void testUnequalSizesSameDimension() {
- Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }");
- Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
- assertConcat("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x");
- assertConcat("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y");
- }
-
- @Test
- public void testUnequalSizesSameDimensionUnbound() {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }");
- assertConcat("tensor(x[])", "tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x");
- assertConcat("tensor(x[],y[2])", "tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y");
+ assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }"), a.concat(b, "y"));
}
@Test
public void testUnequalEqualSizesDifferentDimension() {
- Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }");
- Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
- assertConcat("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x");
- assertConcat("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
- assertConcat("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z");
- }
-
- @Test
- public void testUnequalEqualSizesDifferentDimensionOneUnbound() {
Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }");
- Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
- assertConcat("tensor(x[],y[3])", "tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x");
- assertConcat("tensor(x[],y[4])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
- assertConcat("tensor(x[],y[3],z[2])", "tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z");
+ Tensor b = Tensor.from("tensor(y[]):{ {y:0}:4, {y:1}:5, {y:2}:6 }");
+ assertEquals(Tensor.from("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y"));
+ assertEquals(Tensor.from("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}"), a.concat(b, "z"));
}
@Test
public void testDimensionsubset() {
Tensor a = Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:3, {x:1,y:1}:4 }");
Tensor b = Tensor.from("tensor(y[2]):{ {y:0}:5, {y:1}:6 }");
- assertConcat("tensor(x[],y[])", "tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}", a, b, "x");
- assertConcat("tensor(x[],y[])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y");
- }
-
- private void assertConcat(String expected, Tensor a, Tensor b, String dimension) {
- assertConcat(null, expected, a, b, dimension);
- }
-
- private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) {
- Tensor expectedAsTensor = Tensor.from(expected);
- TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension)
- .type(new MapEvaluationContext());
- Tensor result = a.concat(b, dimension);
-
- if (expectedType != null)
- assertEquals(TensorType.fromSpec(expectedType), inferredType);
- else
- assertEquals(expectedAsTensor.type(), inferredType);
-
- assertEquals(expectedAsTensor, result);
+ assertEquals(Tensor.from("tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}"), a.concat(b, "x"));
+ assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y"));
}
}