diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo')
5 files changed, 243 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index c1a24abd878..e99e7da7415 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -443,7 +443,7 @@ public class IndexedTensor implements Tensor { public boolean equals(Object o) { if (o == this) return true; if ( ! ( o instanceof Map.Entry)) return false; - Map.Entry other = (Map.Entry)o; + Map.Entry<?,?> other = (Map.Entry)o; if ( ! this.getValue().equals(other.getValue())) return false; if ( ! this.getKey().equals(other.getKey())) return false; return true; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java new file mode 100644 index 00000000000..0bb92bc2a6f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -0,0 +1,56 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.CompositeTensorFunction; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere. + * + * @author bratseth + */ +public class Diag extends CompositeTensorFunction { + + private final TensorType type; + private final Function<List<Integer>, Double> diagFunction; + + public Diag(TensorType type) { + this.type = type; + this.diagFunction = ScalarFunctions.equalArguments(dimensionNames().collect(Collectors.toList())); + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Generate(type, diagFunction); + } + + @Override + public String toString(ToStringContext context) { + return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; + } + + private Stream<String> dimensionNames() { + return type.dimensions().stream().map(TensorType.Dimension::toString); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java new file mode 100644 index 00000000000..ba34c0d9748 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -0,0 +1,53 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.CompositeTensorFunction; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.tensor.functions.ToStringContext; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A tensor generator which returns a tensor of any dimension filled with random numbers between 0 and 1. + * + * @author bratseth + */ +public class Random extends CompositeTensorFunction { + + private final TensorType type; + + public Random(TensorType type) { + this.type = type; + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Generate(type, ScalarFunctions.random()); + } + + @Override + public String toString(ToStringContext context) { + return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; + } + + private Stream<String> dimensionNames() { + return type.dimensions().stream().map(TensorType.Dimension::toString); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java new file mode 100644 index 00000000000..e18edd48127 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -0,0 +1,51 @@ +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; + +import java.util.Collections; +import java.util.List; +import java.util.function.Function; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * A tensor generator which returns a tensor of any dimension filled with the sum of the tensor + * indexes of each position. + * + * @author bratseth + */ +public class Range extends CompositeTensorFunction { + + private final TensorType type; + private final Function<List<Integer>, Double> rangeFunction; + + public Range(TensorType type) { + this.type = type; + this.rangeFunction = ScalarFunctions.sumArguments(dimensionNames().collect(Collectors.toList())); + } + + @Override + public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + + @Override + public TensorFunction replaceArguments(List<TensorFunction> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + return new Generate(type, rangeFunction); + } + + @Override + public String toString(ToStringContext context) { + return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; + } + + private Stream<String> dimensionNames() { + return type.dimensions().stream().map(TensorType.Dimension::toString); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index c1b1cb2243d..a0b60f53df3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -1,9 +1,15 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import java.util.List; +import java.util.concurrent.ThreadLocalRandom; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleSupplier; import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; +import java.util.stream.Collectors; /** * Factory of scalar Java functions. @@ -21,6 +27,13 @@ public class ScalarFunctions { public static DoubleUnaryOperator square() { return new Square(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator exp() { return new Exponent(); } + public static Function<List<Integer>, Double> random() { return new Random(); } + public static Function<List<Integer>, Double> equalArguments(List<String> argumentNames) { + return new EqualArguments(argumentNames); + } + public static Function<List<Integer>, Double> sumArguments(List<String> argumentNames) { + return new SumArguments(argumentNames); + } public static class Addition implements DoubleBinaryOperator { @@ -81,4 +94,73 @@ public class ScalarFunctions { } + public static class Random implements Function<List<Integer>, Double> { + + @Override + public Double apply(List<Integer> values) { + return ThreadLocalRandom.current().nextDouble(); + } + + @Override + public String toString() { return "random()"; } + + } + + public static class EqualArguments implements Function<List<Integer>, Double> { + + private final ImmutableList<String> argumentNames; + + private EqualArguments(List<String> argumentNames) { + this.argumentNames = ImmutableList.copyOf(argumentNames); + } + + @Override + public Double apply(List<Integer> values) { + if (values.isEmpty()) return 1.0; + for (Integer value : values) + if ( ! value.equals(values.get(0))) + return 0.0; + return 1.0; + } + + @Override + public String toString() { + if (argumentNames.size() == 0) return "(1)"; + if (argumentNames.size() == 1) return "(1)"; + if (argumentNames.size() == 2) return "(" + argumentNames.get(0) + "==" + argumentNames.get(1) + ")"; + + StringBuilder b = new StringBuilder("("); + for (int i = 0; i < argumentNames.size() -1; i++) { + b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")"); + if ( i < argumentNames.size() -2) + b.append("*"); + } + return b.toString(); + } + + } + + public static class SumArguments implements Function<List<Integer>, Double> { + + private final ImmutableList<String> argumentNames; + + private SumArguments(List<String> argumentNames) { + this.argumentNames = ImmutableList.copyOf(argumentNames); + } + + @Override + public Double apply(List<Integer> values) { + int sum = 0; + for (Integer value : values) + sum += value; + return (double)sum; + } + + @Override + public String toString() { + return "(" + argumentNames.stream().collect(Collectors.joining("+")) + ")"; + } + + } + } |