diff options
author | Harald Musum <musum@yahoo-inc.com> | 2018-02-21 23:43:02 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-21 23:43:02 +0100 |
commit | fdff142dab4a75ace0623c2c8bf513a0a4597aca (patch) | |
tree | 0101a275869dd270a1609d1199381e0137e5d1f6 /vespajlib | |
parent | 2238f0d8d3b8a07e5e9d0ce1a01fc7f3e149cece (diff) |
Revert "Bratseth/typecheck all 3"
Diffstat (limited to 'vespajlib')
16 files changed, 48 insertions, 195 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java deleted file mode 100644 index e0e4a0828a9..00000000000 --- a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.lang; - -/** - * A mutable long - * - * @author bratseth - */ -public class MutableLong { - - private long value; - - public MutableLong(long value) { - this.value = value; - } - - public long get() { return value; } - - public void set(long value) { this.value = value; } - - /** Adds the increment to the current value and returns the resulting value */ - public long add(long increment) { - value += increment; - return value; - } - - /** Adds the increment to the current value and returns the resulting value */ - public long subtract(long increment) { - value -= increment; - return value; - } - -} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index bf1825446e4..14cd3e70866 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -77,13 +77,6 @@ public class TensorType { return Optional.empty(); } - /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ - public Optional<Long> sizeOfDimension(String dimension) { - Optional<Dimension> d = dimension(dimension); - if ( ! d.isPresent()) return Optional.empty(); - return d.get().size(); - } - /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 8a969180113..3fb94f1251b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -10,7 +10,7 @@ import com.yahoo.tensor.Tensor; * @author bratseth */ @Beta -public interface EvaluationContext<NAMETYPE extends TypeContext.Name> extends TypeContext<NAMETYPE> { +public interface EvaluationContext extends TypeContext { /** Returns the tensor bound to this name, or null if none */ Tensor getTensor(String name); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index b9394da31e3..9fe6b7d053f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -11,20 +11,17 @@ import java.util.HashMap; * @author bratseth */ @Beta -public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> { +public class MapEvaluationContext implements EvaluationContext { private final java.util.Map<String, Tensor> bindings = new HashMap<>(); + static MapEvaluationContext empty() { return new MapEvaluationContext(); } + public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override public TensorType getType(String name) { - return getType(new Name(name)); - } - - @Override - public TensorType getType(Name name) { - Tensor tensor = bindings.get(name.toString()); + Tensor tensor = bindings.get(name); if (tensor == null) return null; return tensor.type(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index ff2e6318b37..760a225efdf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.TensorType; * * @author bratseth */ -public interface TypeContext<NAMETYPE extends TypeContext.Name> { +public interface TypeContext { /** * Returns the type of the tensor with this name. @@ -16,39 +16,6 @@ public interface TypeContext<NAMETYPE extends TypeContext.Name> { * @return returns the type of the tensor which will be returned by calling getTensor(name) * or null if getTensor will return null. */ - TensorType getType(NAMETYPE name); - - /** - * Returns the type of the tensor with this name by converting from a string name. - * - * @return returns the type of the tensor which will be returned by calling getTensor(name) - * or null if getTensor will return null. - */ TensorType getType(String name); - /** A name which is just a string. Names are value objects. */ - class Name { - - private final String name; - - public Name(String name) { - this.name = name; - } - - @Override - public String toString() { return name; } - - @Override - public int hashCode() { return name.hashCode(); } - - @Override - public boolean equals(Object other) { - if (other == this) return true; - if ( ! (other instanceof Name)) return false; - return ((Name)other).name.equals(this.name); - } - - } - - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index acb2363cba4..34beb465d4c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -44,7 +44,7 @@ public class VariableTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext context) { TensorType givenType = context.getType(name); if (givenType == null) return null; verifyType(givenType); @@ -52,7 +52,7 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext context) { Tensor tensor = context.getTensor(name); if (tensor == null) return null; verifyType(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index bfc0938abcc..2109b730e1a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -18,14 +18,10 @@ public abstract class CompositeTensorFunction extends TensorFunction { /** Finds the type this produces by first converting it to a primitive function */ @Override - public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { - return toPrimitive().type(context); - } + public final TensorType type(TypeContext context) { return toPrimitive().type(context); } /** Evaluates this by first converting it to a primitive function */ @Override - public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { - return toPrimitive().evaluate(context); - } + public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 13e7c136feb..c77ed1c0526 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -3,8 +3,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; -import com.yahoo.lang.MutableInteger; -import com.yahoo.lang.MutableLong; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -62,35 +60,21 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext context) { return type(argumentA.type(context), argumentB.type(context)); } /** Returns the type resulting from concatenating a and b */ private TensorType type(TensorType a, TensorType b) { - // TODO: Fail if concat dimension is present but not indexed in a or b TensorType.Builder builder = new TensorType.Builder(a, b); - if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) { - builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) + - b.sizeOfDimension(dimension).orElse(1L))); - /* - MutableLong concatSize = new MutableLong(0); - a.sizeOfDimension(dimension).ifPresent(concatSize::add); - b.sizeOfDimension(dimension).ifPresent(concatSize::add); - builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); - */ - } + if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size + builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + + b.dimension(dimension).get().size().get())); return builder.build(); } - /** Returns true if this dimension is present and unbound */ - private boolean unboundIn(TensorType type, String dimensionName) { - Optional<TensorType.Dimension> dimension = type.dimension(dimensionName); - return dimension.isPresent() && ! dimension.get().size().isPresent(); - } - @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); a = ensureIndexedDimension(dimension, a); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index a43de297b9a..50b479da168 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -42,10 +42,10 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } + public TensorType type(TypeContext context) { return constant.type(); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } + public Tensor evaluate(EvaluationContext context) { return constant; } @Override public String toString(ToStringContext context) { return constant.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index edfa8253eb9..e70d1de3db7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -61,10 +61,10 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + public TensorType type(TypeContext context) { return type; } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); for (int i = 0; i < indexes.size(); i++) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 17e1c103ea3..7812c985091 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -95,12 +95,12 @@ public class Join extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 4a338e5501e..53504868ff2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -53,12 +53,12 @@ public class Map extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext context) { return argument.type(context); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index e045effbe7e..76a938b9fe2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -101,12 +101,11 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext context) { return type(argument.type(context)); } private TensorType type(TensorType argumentType) { - if (dimensions.isEmpty()) return TensorType.empty; // means reduce all TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argumentType.dimensions()) if ( ! dimensions.contains(dimension.name())) // keep @@ -115,7 +114,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index af4492ca1e4..de3d2be265a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -72,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext context) { return type(argument.type(context)); } @@ -84,7 +84,7 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext context) { Tensor tensor = argument.evaluate(context); TensorType renamedType = type(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index e805e9d87bb..78ab09c7820 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -43,14 +43,14 @@ public abstract class TensorFunction { * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context); + public abstract Tensor evaluate(EvaluationContext context); /** * Returns the type of the tensor this produces given the input types in the context * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context); + public abstract TensorType type(TypeContext context); /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } @@ -58,7 +58,7 @@ public abstract class TensorFunction { /** * Return a string representation of this context. * - * @param context a context which must be passed to all nested functions when requesting the string value + * @param context a context which must be passed to all nexted functions when requesting the string value */ public abstract String toString(ToStringContext context); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index eafa5c4addf..7e1f292eb7b 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -2,9 +2,6 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.MapEvaluationContext; -import com.yahoo.tensor.evaluation.TypeContext; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -19,98 +16,51 @@ public class ConcatTestCase { public void testConcatNumbers() { Tensor a = Tensor.from("{1}"); Tensor b = Tensor.from("{2}"); - assertConcat("tensor(x[2]):{ {x:0}:1, {x:1}:2 }", a, b, "x"); - assertConcat("tensor(x[2]):{ {x:0}:2, {x:1}:1 }", b, a , "x"); + assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x")); } @Test public void testConcatEqualShapes() { - Tensor a = Tensor.from("tensor(x[3]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); - Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); - assertConcat("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }", a, b, "x"); - assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + - "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }", - a, b, "y"); + Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); + Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); + assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + + "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y")); } @Test public void testConcatNumberAndVector() { Tensor a = Tensor.from("{1}"); - Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); - assertConcat("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x"); - assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + - "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }", - a, b, "y"); - } - - @Test - public void testConcatNumberAndVectorUnbound() { - Tensor a = Tensor.from("{1}"); Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); - assertConcat("tensor(x[])","tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x"); - assertConcat("tensor(x[],y[2])", "tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + - "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }", - a, b, "y"); + assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + + "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y")); } @Test public void testUnequalSizesSameDimension() { - Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"); - Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); - assertConcat("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x"); - assertConcat("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y"); - } - - @Test - public void testUnequalSizesSameDimensionUnbound() { Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); - assertConcat("tensor(x[])", "tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x"); - assertConcat("tensor(x[],y[2])", "tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y"); + assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }"), a.concat(b, "y")); } @Test public void testUnequalEqualSizesDifferentDimension() { - Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"); - Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); - assertConcat("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x"); - assertConcat("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); - assertConcat("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z"); - } - - @Test - public void testUnequalEqualSizesDifferentDimensionOneUnbound() { Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); - Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); - assertConcat("tensor(x[],y[3])", "tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x"); - assertConcat("tensor(x[],y[4])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); - assertConcat("tensor(x[],y[3],z[2])", "tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z"); + Tensor b = Tensor.from("tensor(y[]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); + assertEquals(Tensor.from("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y")); + assertEquals(Tensor.from("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}"), a.concat(b, "z")); } @Test public void testDimensionsubset() { Tensor a = Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:3, {x:1,y:1}:4 }"); Tensor b = Tensor.from("tensor(y[2]):{ {y:0}:5, {y:1}:6 }"); - assertConcat("tensor(x[],y[])", "tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}", a, b, "x"); - assertConcat("tensor(x[],y[])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); - } - - private void assertConcat(String expected, Tensor a, Tensor b, String dimension) { - assertConcat(null, expected, a, b, dimension); - } - - private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) { - Tensor expectedAsTensor = Tensor.from(expected); - TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension) - .type(new MapEvaluationContext()); - Tensor result = a.concat(b, dimension); - - if (expectedType != null) - assertEquals(TensorType.fromSpec(expectedType), inferredType); - else - assertEquals(expectedAsTensor.type(), inferredType); - - assertEquals(expectedAsTensor, result); + assertEquals(Tensor.from("tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}"), a.concat(b, "x")); + assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y")); } } |