diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2019-11-08 16:12:24 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-08 16:12:24 +0100 |
commit | 1b3a300f3eec8df65a98229e1ce6c5526d71661e (patch) | |
tree | 7f9430fa6e6f6ddc70ae889133a24394558ffaaa /vespajlib | |
parent | 983935b160467ad8521b4059d5ffef33c9e75270 (diff) | |
parent | cdcafc6fc8b4417abab8c72bbce5c503533558ea (diff) |
Merge pull request #11259 from vespa-engine/bratseth/dynamic-tensors-in-verbose-form
Bratseth/dynamic tensors in verbose form
Diffstat (limited to 'vespajlib')
6 files changed, 95 insertions, 22 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 6a93a17a8c1..47b066b15a6 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -707,6 +707,7 @@ "final" ], "methods": [ + "public static com.yahoo.tensor.DimensionSizes of(com.yahoo.tensor.TensorType)", "public long size(int)", "public int dimensions()", "public long totalSize()", @@ -820,7 +821,9 @@ "abstract" ], "methods": [ + "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)", "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)", + "public com.yahoo.tensor.TensorAddress toAddress()", "public long[] indexesCopy()", "public long[] indexesForReading()", "public java.util.List toList()", @@ -1603,6 +1606,7 @@ "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.Map)", "public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.List)" ], @@ -1832,6 +1836,23 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.ScalarFunction": { + "superClass": "java.lang.Object", + "interfaces": [ + "java.util.function.Function" + ], + "attributes": [ + "public", + "interface", + "abstract" + ], + "methods": [ + "public abstract java.lang.Double apply(com.yahoo.tensor.evaluation.EvaluationContext)", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public bridge synthetic java.lang.Object apply(java.lang.Object)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.ScalarFunctions$Abs": { "superClass": "java.lang.Object", "interfaces": [ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index c0d817459d0..d81c02fb75f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -18,6 +18,21 @@ public final class DimensionSizes { } /** + * Create sizes from a type containing bound indexed dimensions only. + * + * @throws IllegalStateException if the type contains dimensions which are not bound and indexed + */ + public static DimensionSizes of(TensorType type) { + Builder b = new Builder(type.rank()); + for (int i = 0; i < type.rank(); i++) { + if ( type.dimensions().get(i).type() != TensorType.Dimension.Type.indexedBound) + throw new IllegalArgumentException(type + " contains dimensions without a size"); + b.set(i, type.dimensions().get(i).size().get()); + } + return b.build(); + } + + /** * Returns the length of this in the nth dimension * * @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 15476567fb2..176ddfefc13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -758,6 +758,15 @@ public abstract class IndexedTensor implements Tensor { protected final long[] indexes; + /** + * Create indexes from a type containing bound indexed dimensions only. + * + * @throws IllegalStateException if the type contains dimensions which are not bound and indexed + */ + public static Indexes of(TensorType type) { + return of(DimensionSizes.of(type)); + } + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -824,7 +833,7 @@ public abstract class IndexedTensor implements Tensor { } /** Returns the address of the current position of these indexes */ - private TensorAddress toAddress() { + public TensorAddress toAddress() { return TensorAddress.of(indexes); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 9ce2496c65b..b8b644f8b49 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -46,21 +46,28 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { TensorType type() { return type; } + @Override + public String toString(ToStringContext context) { + return type().toString() + ":" + contentToString(context); + } + + abstract String contentToString(ToStringContext context); + /** Creates a dynamic tensor function. The cell addresses must match the type. */ - public static DynamicTensor from(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) { + public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) { return new MappedDynamicTensor(type, cells); } /** Creates a dynamic tensor function for a bound, indexed tensor */ - public static DynamicTensor from(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) { + public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) { return new IndexedDynamicTensor(type, cells); } private static class MappedDynamicTensor extends DynamicTensor { - private final ImmutableMap<TensorAddress, Function<EvaluationContext<?> , Double>> cells; + private final ImmutableMap<TensorAddress, ScalarFunction> cells; - MappedDynamicTensor(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) { + MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) { super(type); this.cells = ImmutableMap.copyOf(cells); } @@ -74,11 +81,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return type().toString() + ":" + contentToString(); - } - - private String contentToString() { + String contentToString(ToStringContext context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; return "{" + cells.values().iterator().next() + "}"; @@ -86,7 +89,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { StringBuilder b = new StringBuilder("{"); for (var cell : cells.entrySet()) { - b.append(cell.getKey().toString(type())).append(":").append(cell.getValue()); + b.append(cell.getKey().toString(type())).append(":").append(cell.getValue().toString(context)); b.append(","); } if (b.length() > 1) @@ -100,9 +103,9 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { private static class IndexedDynamicTensor extends DynamicTensor { - private final List<Function<EvaluationContext<?>, Double>> cells; + private final List<ScalarFunction> cells; - IndexedDynamicTensor(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) { + IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) { super(type); if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + @@ -119,24 +122,22 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return type().toString() + ":" + contentToString(); - } - - private String contentToString() { + String contentToString(ToStringContext context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; return "{" + cells.get(0) + "}"; } - StringBuilder b = new StringBuilder("["); + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type()); + StringBuilder b = new StringBuilder("{"); for (var cell : cells) { - b.append(cell); + indexes.next(); + b.append(indexes.toAddress().toString(type())).append(":").append(cell.toString(context)); b.append(","); } if (b.length() > 1) b.setLength(b.length() - 1); - b.append("]"); + b.append("}"); return b.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java new file mode 100644 index 00000000000..c6a244b64df --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -0,0 +1,22 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.evaluation.EvaluationContext; + +import java.util.function.Function; + +/** + * A function which returns a scalar + * + * @author bratseth + */ +public interface ScalarFunction extends Function<EvaluationContext<?>, Double> { + + @Override + Double apply(EvaluationContext<?> context); + + default String toString(ToStringContext context) { + return toString(); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index 82652fb0e5d..925da9d3c89 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -24,15 +24,17 @@ public class DynamicTensorTestCase { DynamicTensor t1 = DynamicTensor.from(dense, List.of(new Constant(1), new Constant(2), new Constant(3))); assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); + assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString()); TensorType sparse = TensorType.fromSpec("tensor(x{})"); DynamicTensor t2 = DynamicTensor.from(sparse, Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); + assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } - private static class Constant implements Function<EvaluationContext<?>, Double> { + private static class Constant implements ScalarFunction { private final double value; @@ -41,6 +43,9 @@ public class DynamicTensorTestCase { @Override public Double apply(EvaluationContext<?> evaluationContext) { return value; } + @Override + public String toString() { return String.valueOf(value); } + } } |