summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
authorgjoranv <gjoranv@gmail.com>2017-12-17 21:44:49 +0100
committerGitHub <noreply@github.com>2017-12-17 21:44:49 +0100
commit03bce1fe1a494f2ac9d4268d4c90b08011b3f600 (patch)
tree180f294d2ac97d641f0266216ffdc328db9bfef8 /vespajlib/src/main/java/com/yahoo/tensor/functions
parentb72e55b87eecae006ed92976151137a80d75be0f (diff)
Revert "Bratseth/tensorflow models"
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java54
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java97
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java6
16 files changed, 129 insertions, 167 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 191c7988443..8f4dbf014a7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.evaluation.EvaluationContext;
/**
* A composite tensor function is a tensor function which can be expressed (less tersely)
* as a tree of primitive tensor functions.
- *
+ *
* @author bratseth
*/
@Beta
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index faa0ca36cb6..1dbb94fdb20 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -15,7 +15,7 @@ import java.util.stream.Collectors;
/**
* Concatenation of two tensors along an (indexed) dimension
- *
+ *
* @author bratseth
*/
@Beta
@@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction {
concatenateTo(bIndexed, aIndexed, 0, concatType, bToIndexes, aToIndexes, builder);
return builder.build();
}
-
+
private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType,
int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) {
Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet());
@@ -112,7 +112,7 @@ public class Concat extends PrimitiveTensorFunction {
Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
return tensor.multiply(unitTensor);
}
-
+
}
/** Returns the type resulting from concatenating a and b */
@@ -144,7 +144,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Combine two addresses, adding the offset to the concat dimension
*
- * @return the combined address or null if the addresses are incompatible
+ * @return the combined address or null if the addresses are incompatible
* (in some other dimension than the concat dimension)
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
@@ -161,7 +161,7 @@ public class Concat extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 14ed38718ce..4ac7b21ba90 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -10,18 +10,18 @@ import java.util.List;
/**
* A function which returns a constant tensor.
- *
+ *
* @author bratseth
*/
@Beta
public class ConstantTensor extends PrimitiveTensorFunction {
private final Tensor constant;
-
+
public ConstantTensor(String tensorString) {
this.constant = Tensor.from(tensorString);
}
-
+
public ConstantTensor(Tensor tensor) {
this.constant = tensor;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index c75d8ee4753..bbdbd5c3df1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -11,19 +11,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with 1 in the diagonal and 0 elsewhere.
- *
+ *
* @author bratseth
*/
public class Diag extends CompositeTensorFunction {
private final TensorType type;
private final Function<List<Integer>, Double> diagFunction;
-
+
public Diag(TensorType type) {
this.type = type;
this.diagFunction = ScalarFunctions.equal(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -43,7 +43,7 @@ public class Diag extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::name);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index e42d25197e2..6ea73b7f310 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -15,7 +15,7 @@ import java.util.function.Function;
/**
* An indexed tensor whose values are generated by a function
- *
+ *
* @author bratseth
*/
@Beta
@@ -26,7 +26,7 @@ public class Generate extends PrimitiveTensorFunction {
/**
* Creates a generated tensor
- *
+ *
* @param type the type of the tensor
* @param generator the function generating values from a list of ints specifying the indexes of the
* tensor cell which will receive the value
@@ -39,7 +39,7 @@ public class Generate extends PrimitiveTensorFunction {
this.type = type;
this.generator = generator;
}
-
+
private void validateType(TensorType type) {
for (TensorType.Dimension dimension : type.dimensions())
if (dimension.type() != TensorType.Dimension.Type.indexedBound)
@@ -58,7 +58,7 @@ public class Generate extends PrimitiveTensorFunction {
@Override
public PrimitiveTensorFunction toPrimitive() { return this; }
-
+
@Override
public Tensor evaluate(EvaluationContext context) {
Tensor.Builder builder = Tensor.Builder.of(type);
@@ -69,14 +69,14 @@ public class Generate extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private DimensionSizes dimensionSizes(TensorType type) {
DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size());
for (int i = 0; i < b.dimensions(); i++)
b.set(i, type.dimensions().get(i).size().get());
return b.build();
}
-
+
@Override
public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index ff887e3e9a6..8c4dbfb0acb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -28,12 +28,12 @@ import java.util.function.DoubleBinaryOperator;
* The <i>join</i> tensor operation produces a tensor from the argument tensors containing the set of cells
* given by the cross product of the cells of the given tensors, having as values the value produced by
* applying the given combinator function on the values from the two source cells.
- *
+ *
* @author bratseth
*/
@Beta
public class Join extends PrimitiveTensorFunction {
-
+
private final TensorFunction argumentA, argumentB;
private final DoubleBinaryOperator combinator;
@@ -46,30 +46,6 @@ public class Join extends PrimitiveTensorFunction {
this.combinator = combinator;
}
- /** Returns the type resulting from applying Join to the two given types */
- public static TensorType outputType(TensorType a, TensorType b) {
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (int i = 0; i < a.dimensions().size(); ++i) {
- TensorType.Dimension aDim = a.dimensions().get(i);
- for (int j = 0; j < b.dimensions().size(); ++j) {
- TensorType.Dimension bDim = b.dimensions().get(j);
- if (aDim.name().equals(bDim.name())) { // include
- if (aDim.isIndexed() && bDim.isIndexed()) {
- if (aDim.size().isPresent() || bDim.size().isPresent())
- typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
- bDim.size().orElse(Integer.MAX_VALUE)));
- else
- typeBuilder.indexed(aDim.name());
- }
- else {
- typeBuilder.mapped(aDim.name());
- }
- }
- }
- }
- return typeBuilder.build();
- }
-
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
@@ -112,11 +88,11 @@ public class Join extends PrimitiveTensorFunction {
else
return generalJoin(a, b, joinedType);
}
-
+
private boolean hasSingleIndexedDimension(Tensor tensor) {
return tensor.type().dimensions().size() == 1 && tensor.type().dimensions().get(0).isIndexed();
}
-
+
private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) {
int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0));
Iterator<Double> aIterator = a.valueIterator();
@@ -138,7 +114,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
/** Join a tensor into a superspace */
private Tensor subspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace instanceof IndexedTensor && superspace instanceof IndexedTensor)
@@ -150,7 +126,7 @@ public class Join extends PrimitiveTensorFunction {
private Tensor indexedSubspaceJoin(IndexedTensor subspace, IndexedTensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
if (subspace.size() == 0 || superspace.size() == 0) // special case empty here to avoid doing it when finding sizes
return Tensor.Builder.of(joinedType, new DimensionSizes.Builder(joinedType.dimensions().size()).build()).build();
-
+
DimensionSizes joinedSizes = joinedSize(joinedType, subspace, superspace);
IndexedTensor.Builder builder = (IndexedTensor.Builder)Tensor.Builder.of(joinedType, joinedSizes);
@@ -158,14 +134,14 @@ public class Join extends PrimitiveTensorFunction {
// Find dimensions which are only in the supertype
Set<String> superDimensionNames = new HashSet<>(superspace.type().dimensionNames());
superDimensionNames.removeAll(subspace.type().dimensionNames());
-
+
for (Iterator<IndexedTensor.SubspaceIterator> i = superspace.subspaceIterator(superDimensionNames, joinedSizes); i.hasNext(); ) {
IndexedTensor.SubspaceIterator subspaceInSuper = i.next();
joinSubspaces(subspace.valueIterator(), subspace.size(),
subspaceInSuper, subspaceInSuper.size(),
reversedArgumentOrder, builder);
}
-
+
return builder.build();
}
@@ -224,7 +200,7 @@ public class Join extends PrimitiveTensorFunction {
subspaceIndexes[i] = supertype.indexOfDimension(subtype.dimensions().get(i).name()).get();
return subspaceIndexes;
}
-
+
private TensorAddress mapAddressToSubspace(TensorAddress superAddress, int[] subspaceIndexes) {
String[] subspaceLabels = new String[subspaceIndexes.length];
for (int i = 0; i < subspaceIndexes.length; i++)
@@ -259,7 +235,7 @@ public class Join extends PrimitiveTensorFunction {
DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
// for each combination of dimensions only in a
- for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
IndexedTensor.SubspaceIterator aSubspace = ia.next();
// for each combination of dimensions in a which is also in b
while (aSubspace.hasNext()) {
@@ -276,7 +252,7 @@ public class Join extends PrimitiveTensorFunction {
}
}
}
-
+
private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
for (int i = 0; i < addressType.dimensions().size(); i++)
@@ -284,7 +260,7 @@ public class Join extends PrimitiveTensorFunction {
builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
return builder.build();
}
-
+
/** Returns the sizes from the joined sizes which are present in the type argument */
private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
@@ -295,7 +271,7 @@ public class Join extends PrimitiveTensorFunction {
}
return builder.build();
}
-
+
private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
@@ -364,7 +340,7 @@ public class Join extends PrimitiveTensorFunction {
/**
* Returns the an array having one entry in order for each dimension of fromType
* containing the index at which toType contains the same dimension name.
- * That is, if the returned array contains n at index i then
+ * That is, if the returned array contains n at index i then
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
@@ -384,7 +360,7 @@ public class Join extends PrimitiveTensorFunction {
return TensorAddress.of(joinedLabels);
}
- /**
+ /**
* Maps the content in the given list to the given array, using the given index map.
*
* @return true if the mapping was successful, false if one of the destination positions was
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index a5e1a016a41..a9872bb42d8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
-import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Collections;
@@ -33,8 +32,6 @@ public class Map extends PrimitiveTensorFunction {
this.mapper = mapper;
}
- public static TensorType outputType(TensorType inputType) { return inputType; }
-
public TensorFunction argument() { return argument; }
public DoubleUnaryOperator mapper() { return mapper; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 4071917c2b5..bb27e937699 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -3,7 +3,6 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
-import com.yahoo.tensor.TensorType;
import java.util.List;
@@ -15,17 +14,13 @@ public class Matmul extends CompositeTensorFunction {
private final TensorFunction argument1, argument2;
private final String dimension;
-
+
public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
this.argument1 = argument1;
this.argument2 = argument2;
this.dimension = dimension;
}
- public static TensorType outputType(TensorType a, TensorType b, String dimension) {
- return Join.outputType(a, b);
- }
-
@Override
public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
@@ -44,7 +39,7 @@ public class Matmul extends CompositeTensorFunction {
Reduce.Aggregator.sum,
dimension);
}
-
+
@Override
public String toString(ToStringContext context) {
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index b7c9a5d2342..efb7b9e500c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -8,10 +8,10 @@ import com.yahoo.tensor.Tensor;
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
* Primitive tensor functions are fully inspectable.
- *
+ *
* @author bratseth
*/
@Beta
public abstract class PrimitiveTensorFunction extends TensorFunction {
-
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 958ef85d1dc..457763e97ba 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -22,11 +22,11 @@ import java.util.stream.Stream;
public class Random extends CompositeTensorFunction {
private final TensorType type;
-
+
public Random(TensorType type) {
this.type = type;
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -46,7 +46,7 @@ public class Random extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index a56f82b026a..e2b39a2048d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -12,19 +12,19 @@ import java.util.stream.Stream;
/**
* A tensor generator which returns a tensor of any dimension filled with the sum of the tensor
* indexes of each position.
- *
+ *
* @author bratseth
*/
public class Range extends CompositeTensorFunction {
private final TensorType type;
private final Function<List<Integer>, Double> rangeFunction;
-
+
public Range(TensorType type) {
this.type = type;
this.rangeFunction = ScalarFunctions.sum(dimensionNames().collect(Collectors.toList()));
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.emptyList(); }
@@ -44,7 +44,7 @@ public class Range extends CompositeTensorFunction {
public String toString(ToStringContext context) {
return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
}
-
+
private Stream<String> dimensionNames() {
return type.dimensions().stream().map(TensorType.Dimension::toString);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index de9f90a5804..cfc78be7e0c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -19,7 +19,7 @@ import java.util.Objects;
import java.util.Set;
/**
- * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
+ * The <i>reduce</i> tensor operation returns a tensor produced from the argument tensor where some dimensions
* are collapsed to a single value using an aggregator function.
*
* @author bratseth
@@ -45,7 +45,7 @@ public class Reduce extends PrimitiveTensorFunction {
/**
* Creates a reduce function.
- *
+ *
* @param argument the tensor to reduce
* @param aggregator the aggregator function to use
* @param dimensions the list of dimensions to remove. If an empty list is given, all dimensions are reduced,
@@ -61,15 +61,6 @@ public class Reduce extends PrimitiveTensorFunction {
this.dimensions = ImmutableList.copyOf(dimensions);
}
- public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
- TensorType.Builder b = new TensorType.Builder();
- for (TensorType.Dimension dimension : inputType.dimensions()) {
- if ( ! reduceDimensions.contains(dimension.name()))
- b.dimension(dimension);
- }
- return b.build();
- }
-
public TensorFunction argument() { return argument; }
@Override
@@ -91,7 +82,7 @@ public class Reduce extends PrimitiveTensorFunction {
public String toString(ToStringContext context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
-
+
private String commaSeparated(List<String> list) {
StringBuilder b = new StringBuilder();
for (String element : list)
@@ -103,7 +94,7 @@ public class Reduce extends PrimitiveTensorFunction {
public Tensor evaluate(EvaluationContext context) {
Tensor argument = this.argument.evaluate(context);
if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions))
- throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
+ throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " +
dimensions + ": Not all those dimensions are present in this tensor");
// Special case: Reduce all
@@ -112,14 +103,14 @@ public class Reduce extends PrimitiveTensorFunction {
return reduceIndexedVector((IndexedTensor)argument);
else
return reduceAllGeneral(argument);
-
+
// Reduce type
TensorType.Builder builder = new TensorType.Builder();
for (TensorType.Dimension dimension : argument.type().dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
TensorType reducedType = builder.build();
-
+
// Reduce cells
Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>();
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
@@ -131,10 +122,10 @@ public class Reduce extends PrimitiveTensorFunction {
Tensor.Builder reducedBuilder = Tensor.Builder.of(reducedType);
for (Map.Entry<TensorAddress, ValueAggregator> aggregatingCell : aggregatingCells.entrySet())
reducedBuilder.cell(aggregatingCell.getKey(), aggregatingCell.getValue().aggregatedValue());
-
+
return reducedBuilder.build();
}
-
+
private TensorAddress reduceDimensions(TensorAddress address, TensorType argumentType, TensorType reducedType) {
Set<Integer> indexesToRemove = new HashSet<>();
for (String dimensionToRemove : this.dimensions)
@@ -147,7 +138,7 @@ public class Reduce extends PrimitiveTensorFunction {
reducedLabels[reducedLabelIndex++] = address.label(i);
return TensorAddress.of(reducedLabels);
}
-
+
private Tensor reduceAllGeneral(Tensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
for (Iterator<Double> i = argument.valueIterator(); i.hasNext(); )
@@ -163,7 +154,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static abstract class ValueAggregator {
-
+
private static ValueAggregator ofType(Aggregator aggregator) {
switch (aggregator) {
case avg : return new AvgAggregator();
@@ -174,22 +165,22 @@ public class Reduce extends PrimitiveTensorFunction {
case min : return new MinAggregator();
default: throw new UnsupportedOperationException("Aggregator " + aggregator + " is not implemented");
}
-
+
}
/** Add a new value to those aggregated by this */
public abstract void aggregate(double value);
-
+
/** Returns the value aggregated by this */
public abstract double aggregatedValue();
-
+
}
-
+
private static class AvgAggregator extends ValueAggregator {
private int valueCount = 0;
private double valueSum = 0.0;
-
+
@Override
public void aggregate(double value) {
valueCount++;
@@ -197,7 +188,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public double aggregatedValue() {
+ public double aggregatedValue() {
return valueSum / valueCount;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index ec9b762a41c..6b0daf1b49a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -3,6 +3,8 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -17,7 +19,7 @@ import java.util.Objects;
/**
* The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
- *
+ *
* @author bratseth
*/
@Beta
@@ -27,10 +29,6 @@ public class Rename extends PrimitiveTensorFunction {
private final List<String> fromDimensions;
private final List<String> toDimensions;
- public Rename(TensorFunction argument, String fromDimension, String toDimension) {
- this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
- }
-
public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
@@ -44,7 +42,7 @@ public class Rename extends PrimitiveTensorFunction {
this.fromDimensions = ImmutableList.copyOf(fromDimensions);
this.toDimensions = ImmutableList.copyOf(toDimensions);
}
-
+
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
@@ -64,7 +62,7 @@ public class Rename extends PrimitiveTensorFunction {
Map<String, String> fromToMap = fromToMap();
TensorType renamedType = rename(tensor.type(), fromToMap);
-
+
// an array which lists the index of each label in the renamed type
int[] toIndexes = new int[tensor.type().dimensions().size()];
for (int i = 0; i < tensor.type().dimensions().size(); i++) {
@@ -72,7 +70,7 @@ public class Rename extends PrimitiveTensorFunction {
String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
}
-
+
Tensor.Builder builder = Tensor.Builder.of(renamedType);
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Map.Entry<TensorAddress, Double> cell = i.next();
@@ -88,7 +86,7 @@ public class Rename extends PrimitiveTensorFunction {
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
}
-
+
private TensorAddress rename(TensorAddress address, int[] toIndexes) {
String[] reorderedLabels = new String[toIndexes.length];
for (int i = 0; i < toIndexes.length; i++)
@@ -97,18 +95,18 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return "rename(" + argument.toString(context) + ", " +
+ public String toString(ToStringContext context) {
+ return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
-
+
private Map<String, String> fromToMap() {
Map<String, String> map = new HashMap<>();
for (int i = 0; i < fromDimensions.size(); i++)
map.put(fromDimensions.get(i), toDimensions.get(i));
return map;
}
-
+
private String toVectorString(List<String> elements) {
if (elements.size() == 1)
return elements.get(0);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index fb5029fbfd6..99f79cb735a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -21,87 +21,101 @@ import java.util.stream.Collectors;
@Beta
public class ScalarFunctions {
- public static DoubleBinaryOperator add() { return new Add(); }
- public static DoubleBinaryOperator divide() { return new Divide(); }
+ public static DoubleBinaryOperator add() { return new Addition(); }
+ public static DoubleBinaryOperator multiply() { return new Multiplication(); }
+ public static DoubleBinaryOperator divide() { return new Division(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
- public static DoubleBinaryOperator multiply() { return new Multiply(); }
-
- public static DoubleUnaryOperator acos() { return new Acos(); }
- public static DoubleUnaryOperator exp() { return new Exp(); }
- public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
public static DoubleUnaryOperator square() { return new Square(); }
-
+ public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
+ public static DoubleUnaryOperator exp() { return new Exponent(); }
public static Function<List<Integer>, Double> random() { return new Random(); }
public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
- // Binary operators -----------------------------------------------------------------------------
+ public static class Addition implements DoubleBinaryOperator {
- public static class Add implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left + right; }
+
@Override
public String toString() { return "f(a,b)(a + b)"; }
- }
- public static class Equal implements DoubleBinaryOperator {
- @Override
- public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
- @Override
- public String toString() { return "f(a,b)(a==b)"; }
}
- public static class Exp implements DoubleUnaryOperator {
- @Override
- public double applyAsDouble(double operand) { return Math.exp(operand); }
- @Override
- public String toString() { return "f(a)(exp(a))"; }
- }
+ public static class Multiplication implements DoubleBinaryOperator {
- public static class Multiply implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left * right; }
+ public double applyAsDouble(double left, double right) { return left * right; }
+
@Override
public String toString() { return "f(a,b)(a * b)"; }
+
}
- public static class Divide implements DoubleBinaryOperator {
+ public static class Division implements DoubleBinaryOperator {
+
@Override
public double applyAsDouble(double left, double right) { return left / right; }
+
@Override
public String toString() { return "f(a,b)(a / b)"; }
}
- // Unary operators ------------------------------------------------------------------------------
+ public static class Equal implements DoubleBinaryOperator {
+
+ @Override
+ public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
+
+ @Override
+ public String toString() { return "f(a,b)(a==b)"; }
+ }
+
+ public static class Square implements DoubleUnaryOperator {
- public static class Acos implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return Math.acos(operand); }
+ public double applyAsDouble(double operand) { return operand * operand; }
+
@Override
- public String toString() { return "f(a)(acos(a))"; }
+ public String toString() { return "f(a)(a * a)"; }
+
}
public static class Sqrt implements DoubleUnaryOperator {
+
@Override
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
+
@Override
public String toString() { return "f(a)(sqrt(a))"; }
+
}
- public static class Square implements DoubleUnaryOperator {
+ public static class Exponent implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return operand * operand; }
+ public double applyAsDouble(double operand) { return Math.exp(operand); }
@Override
- public String toString() { return "f(a)(a * a)"; }
+ public String toString() { return "f(a)(exp(a))"; }
}
- // Variable-length operators -----------------------------------------------------------------------------
+ public static class Random implements Function<List<Integer>, Double> {
+
+ @Override
+ public Double apply(List<Integer> values) {
+ return ThreadLocalRandom.current().nextDouble();
+ }
+
+ @Override
+ public String toString() { return "random"; }
- public static class EqualElements implements Function<List<Integer>, Double> {
- private final ImmutableList<String> argumentNames;
+ }
+
+ public static class EqualElements implements Function<List<Integer>, Double> {
+
+ private final ImmutableList<String> argumentNames;
+
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@@ -114,6 +128,7 @@ public class ScalarFunctions {
return 0.0;
return 1.0;
}
+
@Override
public String toString() {
if (argumentNames.size() == 0) return "1";
@@ -128,19 +143,13 @@ public class ScalarFunctions {
}
return b.toString();
}
- }
- public static class Random implements Function<List<Integer>, Double> {
- @Override
- public Double apply(List<Integer> values) {
- return ThreadLocalRandom.current().nextDouble();
- }
- @Override
- public String toString() { return "random"; }
}
public static class SumElements implements Function<List<Integer>, Double> {
+
private final ImmutableList<String> argumentNames;
+
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@@ -152,10 +161,12 @@ public class ScalarFunctions {
sum += value;
return (double)sum;
}
+
@Override
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index c856b548180..bf279eb24d8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -2,8 +2,6 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
-import com.google.common.collect.ImmutableList;
-import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.List;
@@ -21,10 +19,6 @@ public class Softmax extends CompositeTensorFunction {
this.argument = argument;
this.dimension = dimension;
}
-
- public static TensorType outputType(TensorType inputType, String dimension) {
- return Reduce.outputType(inputType, ImmutableList.of(dimension));
- }
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 533a46f87fe..cabcce198d1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -12,7 +12,7 @@ import java.util.List;
* A representation of a tensor function which is able to be translated to a set of primitive
* tensor functions if necessary.
* All tensor functions are immutable.
- *
+ *
* @author bratseth
*/
@Beta
@@ -48,11 +48,11 @@ public abstract class TensorFunction {
/**
* Return a string representation of this context.
- *
+ *
* @param context a context which must be passed to all nexted functions when requesting the string value
*/
public abstract String toString(ToStringContext context);
-
+
@Override
public String toString() { return toString(ToStringContext.empty()); }