summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2019-11-08 16:12:24 +0100
committerGitHub <noreply@github.com>2019-11-08 16:12:24 +0100
commit1b3a300f3eec8df65a98229e1ce6c5526d71661e (patch)
tree7f9430fa6e6f6ddc70ae889133a24394558ffaaa /vespajlib
parent983935b160467ad8521b4059d5ffef33c9e75270 (diff)
parentcdcafc6fc8b4417abab8c72bbce5c503533558ea (diff)
Merge pull request #11259 from vespa-engine/bratseth/dynamic-tensors-in-verbose-form
Bratseth/dynamic tensors in verbose form
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java7
6 files changed, 95 insertions, 22 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 6a93a17a8c1..47b066b15a6 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -707,6 +707,7 @@
"final"
],
"methods": [
+ "public static com.yahoo.tensor.DimensionSizes of(com.yahoo.tensor.TensorType)",
"public long size(int)",
"public int dimensions()",
"public long totalSize()",
@@ -820,7 +821,9 @@
"abstract"
],
"methods": [
+ "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)",
"public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)",
+ "public com.yahoo.tensor.TensorAddress toAddress()",
"public long[] indexesCopy()",
"public long[] indexesForReading()",
"public java.util.List toList()",
@@ -1603,6 +1606,7 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.Map)",
"public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.List)"
],
@@ -1832,6 +1836,23 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.ScalarFunction": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "java.util.function.Function"
+ ],
+ "attributes": [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods": [
+ "public abstract java.lang.Double apply(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public bridge synthetic java.lang.Object apply(java.lang.Object)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.ScalarFunctions$Abs": {
"superClass": "java.lang.Object",
"interfaces": [
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index c0d817459d0..d81c02fb75f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -18,6 +18,21 @@ public final class DimensionSizes {
}
/**
+ * Create sizes from a type containing bound indexed dimensions only.
+ *
+ * @throws IllegalStateException if the type contains dimensions which are not bound and indexed
+ */
+ public static DimensionSizes of(TensorType type) {
+ Builder b = new Builder(type.rank());
+ for (int i = 0; i < type.rank(); i++) {
+ if ( type.dimensions().get(i).type() != TensorType.Dimension.Type.indexedBound)
+ throw new IllegalArgumentException(type + " contains dimensions without a size");
+ b.set(i, type.dimensions().get(i).size().get());
+ }
+ return b.build();
+ }
+
+ /**
* Returns the length of this in the nth dimension
*
* @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 15476567fb2..176ddfefc13 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -758,6 +758,15 @@ public abstract class IndexedTensor implements Tensor {
protected final long[] indexes;
+ /**
+ * Create indexes from a type containing bound indexed dimensions only.
+ *
+ * @throws IllegalStateException if the type contains dimensions which are not bound and indexed
+ */
+ public static Indexes of(TensorType type) {
+ return of(DimensionSizes.of(type));
+ }
+
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
}
@@ -824,7 +833,7 @@ public abstract class IndexedTensor implements Tensor {
}
/** Returns the address of the current position of these indexes */
- private TensorAddress toAddress() {
+ public TensorAddress toAddress() {
return TensorAddress.of(indexes);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 9ce2496c65b..b8b644f8b49 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -46,21 +46,28 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
TensorType type() { return type; }
+ @Override
+ public String toString(ToStringContext context) {
+ return type().toString() + ":" + contentToString(context);
+ }
+
+ abstract String contentToString(ToStringContext context);
+
/** Creates a dynamic tensor function. The cell addresses must match the type. */
- public static DynamicTensor from(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) {
+ public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) {
return new MappedDynamicTensor(type, cells);
}
/** Creates a dynamic tensor function for a bound, indexed tensor */
- public static DynamicTensor from(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) {
+ public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) {
return new IndexedDynamicTensor(type, cells);
}
private static class MappedDynamicTensor extends DynamicTensor {
- private final ImmutableMap<TensorAddress, Function<EvaluationContext<?> , Double>> cells;
+ private final ImmutableMap<TensorAddress, ScalarFunction> cells;
- MappedDynamicTensor(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) {
+ MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) {
super(type);
this.cells = ImmutableMap.copyOf(cells);
}
@@ -74,11 +81,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return type().toString() + ":" + contentToString();
- }
-
- private String contentToString() {
+ String contentToString(ToStringContext context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.values().iterator().next() + "}";
@@ -86,7 +89,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
StringBuilder b = new StringBuilder("{");
for (var cell : cells.entrySet()) {
- b.append(cell.getKey().toString(type())).append(":").append(cell.getValue());
+ b.append(cell.getKey().toString(type())).append(":").append(cell.getValue().toString(context));
b.append(",");
}
if (b.length() > 1)
@@ -100,9 +103,9 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
private static class IndexedDynamicTensor extends DynamicTensor {
- private final List<Function<EvaluationContext<?>, Double>> cells;
+ private final List<ScalarFunction> cells;
- IndexedDynamicTensor(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) {
+ IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) {
super(type);
if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))
throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
@@ -119,24 +122,22 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return type().toString() + ":" + contentToString();
- }
-
- private String contentToString() {
+ String contentToString(ToStringContext context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.get(0) + "}";
}
- StringBuilder b = new StringBuilder("[");
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type());
+ StringBuilder b = new StringBuilder("{");
for (var cell : cells) {
- b.append(cell);
+ indexes.next();
+ b.append(indexes.toAddress().toString(type())).append(":").append(cell.toString(context));
b.append(",");
}
if (b.length() > 1)
b.setLength(b.length() - 1);
- b.append("]");
+ b.append("}");
return b.toString();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
new file mode 100644
index 00000000000..c6a244b64df
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -0,0 +1,22 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.evaluation.EvaluationContext;
+
+import java.util.function.Function;
+
+/**
+ * A function which returns a scalar
+ *
+ * @author bratseth
+ */
+public interface ScalarFunction extends Function<EvaluationContext<?>, Double> {
+
+ @Override
+ Double apply(EvaluationContext<?> context);
+
+ default String toString(ToStringContext context) {
+ return toString();
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index 82652fb0e5d..925da9d3c89 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -24,15 +24,17 @@ public class DynamicTensorTestCase {
DynamicTensor t1 = DynamicTensor.from(dense,
List.of(new Constant(1), new Constant(2), new Constant(3)));
assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate());
+ assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString());
TensorType sparse = TensorType.fromSpec("tensor(x{})");
DynamicTensor t2 = DynamicTensor.from(sparse,
Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
new Constant(5)));
assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
+ assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
- private static class Constant implements Function<EvaluationContext<?>, Double> {
+ private static class Constant implements ScalarFunction {
private final double value;
@@ -41,6 +43,9 @@ public class DynamicTensorTestCase {
@Override
public Double apply(EvaluationContext<?> evaluationContext) { return value; }
+ @Override
+ public String toString() { return String.valueOf(value); }
+
}
}