diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-05 22:49:08 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-05 22:49:08 +0100 |
commit | ed8c274dc76794efa692efba6cf509b058b13648 (patch) | |
tree | c1dcb9fbc70b851be5cfdb8c335089283715f698 /vespajlib/src | |
parent | 64c5daa351557869e64786188afa75ed3b59991b (diff) |
Literal tensors with value expressions
Diffstat (limited to 'vespajlib/src')
4 files changed, 192 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java b/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java index ca0baf95ee2..7db43a7442a 100644 --- a/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java +++ b/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java @@ -19,7 +19,6 @@ import java.util.Set; * * @author bratseth */ -@Beta public class CopyOnWriteHashMap<K,V> extends AbstractMap<K,V> implements Cloneable { private Map<K,V> map; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index f5ef88016ac..15476567fb2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -745,8 +745,6 @@ public abstract class IndexedTensor implements Tensor { } - // TODO: Make dimensionSizes a class - /** * An array of indexes into this tensor which are able to find the next index in the value order. * next() can be called once per element in the dimensions we iterate over. It must be called once diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java new file mode 100644 index 00000000000..9ce2496c65b --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -0,0 +1,146 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** + * A function which is a tensor whose values are computed by individual lambda functions on evaluation. + * + * @author bratseth + */ +public abstract class DynamicTensor extends PrimitiveTensorFunction { + + private final TensorType type; + + DynamicTensor(TensorType type) { + this.type = type; + } + + @Override + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + + @Override + public List<TensorFunction> arguments() { return Collections.emptyList(); } + + @Override + public TensorFunction withArguments(List<TensorFunction> arguments) { + if (arguments.size() != 0) + throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + TensorType type() { return type; } + + /** Creates a dynamic tensor function. The cell addresses must match the type. */ + public static DynamicTensor from(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) { + return new MappedDynamicTensor(type, cells); + } + + /** Creates a dynamic tensor function for a bound, indexed tensor */ + public static DynamicTensor from(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) { + return new IndexedDynamicTensor(type, cells); + } + + private static class MappedDynamicTensor extends DynamicTensor { + + private final ImmutableMap<TensorAddress, Function<EvaluationContext<?> , Double>> cells; + + MappedDynamicTensor(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) { + super(type); + this.cells = ImmutableMap.copyOf(cells); + } + + @Override + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (var cell : cells.entrySet()) + builder.cell(cell.getKey(), cell.getValue().apply(context)); + return builder.build(); + } + + @Override + public String toString(ToStringContext context) { + return type().toString() + ":" + contentToString(); + } + + private String contentToString() { + if (type().dimensions().isEmpty()) { + if (cells.isEmpty()) return "{}"; + return "{" + cells.values().iterator().next() + "}"; + } + + StringBuilder b = new StringBuilder("{"); + for (var cell : cells.entrySet()) { + b.append(cell.getKey().toString(type())).append(":").append(cell.getValue()); + b.append(","); + } + if (b.length() > 1) + b.setLength(b.length() - 1); + b.append("}"); + + return b.toString(); + } + + } + + private static class IndexedDynamicTensor extends DynamicTensor { + + private final List<Function<EvaluationContext<?>, Double>> cells; + + IndexedDynamicTensor(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) { + super(type); + if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) + throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + + "only indexed, bound dimensions, but this has " + type); + this.cells = List.copyOf(cells); + } + + @Override + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); + for (int i = 0; i < cells.size(); i++) + builder.cellByDirectIndex(i, cells.get(i).apply(context)); + return builder.build(); + } + + @Override + public String toString(ToStringContext context) { + return type().toString() + ":" + contentToString(); + } + + private String contentToString() { + if (type().dimensions().isEmpty()) { + if (cells.isEmpty()) return "{}"; + return "{" + cells.get(0) + "}"; + } + + StringBuilder b = new StringBuilder("["); + for (var cell : cells) { + b.append(cell); + b.append(","); + } + if (b.length() > 1) + b.setLength(b.length() - 1); + b.append("]"); + + return b.toString(); + } + + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java new file mode 100644 index 00000000000..82652fb0e5d --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -0,0 +1,46 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class DynamicTensorTestCase { + + @Test + public void testDynamicTensorFunction() { + TensorType dense = TensorType.fromSpec("tensor(x[3])"); + DynamicTensor t1 = DynamicTensor.from(dense, + List.of(new Constant(1), new Constant(2), new Constant(3))); + assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); + + TensorType sparse = TensorType.fromSpec("tensor(x{})"); + DynamicTensor t2 = DynamicTensor.from(sparse, + Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), + new Constant(5))); + assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); + } + + private static class Constant implements Function<EvaluationContext<?>, Double> { + + private final double value; + + public Constant(double value) { this.value = value; } + + @Override + public Double apply(EvaluationContext<?> evaluationContext) { return value; } + + } + +} |