diff options
4 files changed, 258 insertions, 1 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 4ec1a5a234b..307d5df3d6c 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2506,6 +2506,42 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.Value$DimensionValue": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(java.lang.String, java.lang.String)", + "public void <init>(java.lang.String, int)", + "public void <init>(int)", + "public void <init>(java.lang.String)", + "public java.util.Optional dimension()", + "public java.lang.String label()", + "public java.lang.String toString()" + ], + "fields": [] + }, + "com.yahoo.tensor.functions.Value": { + "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.Value withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public java.lang.String toString()", + "public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.XwPlusB": { "superClass": "com.yahoo.tensor.functions.CompositeTensorFunction", "interfaces": [], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index afd82751137..0523624ea9f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -37,7 +37,7 @@ import java.util.function.Function; * Each cell is is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines * the location of that cell. Both dimensions and labels are string on the form of an identifier or integer. * <p> - * The size of the set of dimensions of a tensor is called its <i>order</i>. + * The size of the set of dimensions of a tensor is called its <i>rank</i>. * <p> * In contrast to regular mathematical formulations of tensors, this definition of a tensor allows <i>sparseness</i> * as there is no built-in notion of a contiguous space, and even in cases where a space is implied (such as when diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java new file mode 100644 index 00000000000..3113d48335a --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java @@ -0,0 +1,155 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.annotations.Beta; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Returns the value of a cell of a tensor (as a rank 0 tensor). + * + * @author bratseth + */ +@Beta +public class Value extends PrimitiveTensorFunction { + + private final TensorFunction argument; + private final List<DimensionValue> cellAddress; + + /** + * Creates a value function + * + * @param argument the tensor to return a cell value from + * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress + * because those require a type, but a type is not resolved until this is evaluated + */ + public Value(TensorFunction argument, List<DimensionValue> cellAddress) { + this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); + if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty())) + throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " + + "Specify dimension names explicitly"); + this.cellAddress = cellAddress; + } + + @Override + public List<TensorFunction> arguments() { return List.of(argument); } + + @Override + public Value withArguments(List<TensorFunction> arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); + return new Value(arguments.get(0), cellAddress); + } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + @Override + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor tensor = argument.evaluate(context); + if (tensor.type().rank() != cellAddress.size()) + throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() + + " to a tensor of type " + tensor.type()); + TensorAddress.Builder b = new TensorAddress.Builder(tensor.type()); + for (int i = 0; i < cellAddress.size(); i++) { + b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), + cellAddress.get(i).label()); + } + return Tensor.from(tensor.get(b.build())); + } + + @Override + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + return new TensorType.Builder(argument.type(context).valueType()).build(); + } + + @Override + public String toString(ToStringContext context) { + return toString(); + } + + @Override + public String toString() { + if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) { + if (cellAddress.get(0).index().isPresent()) + return "[" + cellAddress.get(0).index().get() + "]"; + else + return "{" + cellAddress.get(0).index().get() + "}"; + } + else { + return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}"; + } + } + + public static class DimensionValue { + + private final Optional<String> dimension; + + /** The label of this. Always available, whether or not index is */ + private final String label; + + /** The index of this, or empty if this is a non-integer label */ + private final Optional<Integer> index; + + public DimensionValue(String dimension, String label) { + this(Optional.of(dimension), label, indexOrEmpty(label)); + } + + public DimensionValue(String dimension, int index) { + this(Optional.of(dimension), String.valueOf(index), Optional.of(index)); + } + + public DimensionValue(int index) { + this(Optional.empty(), String.valueOf(index), Optional.of(index)); + } + + public DimensionValue(String label) { + this(Optional.empty(), label, indexOrEmpty(label)); + } + + private DimensionValue(Optional<String> dimension, String label, Optional<Integer> index) { + this.dimension = dimension; + this.label = label; + this.index = index; + } + + /** + * Returns the given name of the dimension, or null if dense form is used, such that name + * must be inferred from order + */ + public Optional<String> dimension() { return dimension; } + + /** Returns the label or index for this dimension as a string */ + public String label() { return label; } + + /** Returns the index for this dimension, or empty if it is not a number */ + Optional<Integer> index() { return index; } + + @Override + public String toString() { + if (dimension.isPresent()) + return dimension.get() + ":" + label; + else + return label; + } + + private static Optional<Integer> indexOrEmpty(String label) { + try { + return Optional.of(Integer.parseInt(label)); + } + catch (IllegalArgumentException e) { + return Optional.empty(); + } + } + + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java new file mode 100644 index 00000000000..ffb5e1433ca --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java @@ -0,0 +1,66 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class ValueTestCase { + + private static final double delta = 0.000001; + + @Test + public void testValueFunctionGeneralForm() { + Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); + Tensor result = new Value(new ConstantTensor(input), + List.of(new Value.DimensionValue("key", "bar"), + new Value.DimensionValue("x", 0))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(2.3, result.asDouble(), delta); + } + + @Test + public void testValueFunctionSingleMappedDimension() { + Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }"); + Tensor result = new Value(new ConstantTensor(input), + List.of(new Value.DimensionValue("foo"))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(1.4, result.asDouble(), delta); + } + + @Test + public void testValueFunctionSingleIndexedDimension() { + Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]"); + Tensor result = new Value(new ConstantTensor(input), + List.of(new Value.DimensionValue(2))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(3.3, result.asDouble(), delta); + } + + @Test + public void testValueFunctionShortFormWithMultipleDimensionsIsNotAllowed() { + try { + Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); + new Value(new ConstantTensor(input), + List.of(new Value.DimensionValue("bar"), + new Value.DimensionValue(0))) + .evaluate(); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Short form of cell addresses is only supported with a single dimension: Specify dimension names explicitly", + e.getMessage()); + } + } + +} |