diff options
author | Jon Bratseth <bratseth@oath.com> | 2019-12-09 14:31:43 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-09 14:31:43 -0800 |
commit | 4192f7e82258b6fc6165230cdf1910c384a138c1 (patch) | |
tree | ad80d958954597289c8d838112d6cf092fbaffc3 | |
parent | 676dd43a12db59f96536aa6d8a45369d24d17404 (diff) | |
parent | 7ef64a61b4f04a400428fe58ed2475aa37c43d39 (diff) |
Merge pull request #11528 from vespa-engine/bratseth/tensor-slice
Generalized Slice tensor function.
9 files changed, 574 insertions, 339 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 98b975546e7..debcd11fdbd 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -917,7 +917,7 @@ "public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)", "public final void labelAndDimensionValues(java.util.List)", "public final java.util.List valueAddress()", - "public final com.yahoo.tensor.functions.Value$DimensionValue dimensionValue(java.util.Optional)", + "public final com.yahoo.tensor.functions.Slice$DimensionValue dimensionValue(java.util.Optional)", "public void <init>(java.io.InputStream)", "public void <init>(java.io.InputStream, java.lang.String)", "public void ReInit(java.io.InputStream)", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 83496c7c5f1..de3ad6b5d8c 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -18,6 +18,7 @@ PARSER_BEGIN(RankingExpressionParser) package com.yahoo.searchlib.rankingexpression.parser; import com.yahoo.searchlib.rankingexpression.rule.*; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.*; @@ -247,7 +248,7 @@ ExpressionNode value() : ( <LBRACE> value = expression() <RBRACE> { value = new EmbracedNode(value); } ) ) ) - [ LOOKAHEAD(2) valueAddress = valueAddress() { value = new TensorFunctionNode(new Value(TensorFunctionNode.wrap(value), valueAddress)); } ] + [ LOOKAHEAD(2) valueAddress = valueAddress() { value = new TensorFunctionNode(new Slice(TensorFunctionNode.wrap(value), valueAddress)); } ] { value = not ? new NotNode(value) : value; value = neg ? new NegativeNode(value) : value; @@ -818,17 +819,17 @@ ConstantNode constantPrimitive() : ( <INTEGER> { value = token.image; } | <FLOAT> { value = token.image; } | <STRING> { value = token.image; } ) - { return new ConstantNode(com.yahoo.searchlib.rankingexpression.evaluation.Value.parse(sign + value),sign + value); } + { return new ConstantNode(Value.parse(sign + value),sign + value); } } -com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue() : +Value primitiveValue() : { String sign = ""; } { ( <SUB> { sign = "-";} ) ? ( <INTEGER> | <FLOAT> | <STRING> ) - { return com.yahoo.searchlib.rankingexpression.evaluation.Value.parse(sign + token.image); } + { return Value.parse(sign + token.image); } } TensorFunctionNode tensorValueBody(TensorType type) : @@ -894,7 +895,7 @@ void labelAndDimension(TensorAddress.Builder addressBuilder) : void labelAndDimensionValues(List addressValues) : { String dimension; - Value.DimensionValue dimensionValue; + Slice.DimensionValue dimensionValue; } { dimension = identifier() <COLON> dimensionValue = dimensionValue(Optional.of(dimension)) @@ -906,11 +907,11 @@ List valueAddress() : { List dimensionValues = new ArrayList(); ExpressionNode valueExpression; - Value.DimensionValue dimensionValue; + Slice.DimensionValue dimensionValue; } { ( - ( <LSQUARE> ( valueExpression = expression() { dimensionValues.add(new Value.DimensionValue(TensorFunctionNode.wrapScalar(valueExpression))); } ) <RSQUARE> ) + ( <LSQUARE> ( valueExpression = expression() { dimensionValues.add(new Slice.DimensionValue(TensorFunctionNode.wrapScalar(valueExpression))); } ) <RSQUARE> ) | LOOKAHEAD(3) ( <LCURLY> ( labelAndDimensionValues(dimensionValues))+ @@ -923,16 +924,16 @@ List valueAddress() : { return dimensionValues;} } -Value.DimensionValue dimensionValue(Optional dimensionName) : +Slice.DimensionValue dimensionValue(Optional dimensionName) : { ExpressionNode value; } { value = expression() { - if (value instanceof ReferenceNode && ((ReferenceNode)value).reference().isIdentifier()) - return new Value.DimensionValue(dimensionName, ((ReferenceNode)value).reference().name()); - else - return new Value.DimensionValue(dimensionName, TensorFunctionNode.wrapScalar(value)); - } + if (value instanceof ReferenceNode && ((ReferenceNode)value).reference().isIdentifier()) + return new Slice.DimensionValue(dimensionName, ((ReferenceNode)value).reference().name()); + else + return new Slice.DimensionValue(dimensionName, TensorFunctionNode.wrapScalar(value)); +} }
\ No newline at end of file diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index cea58d565c2..e991173805f 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1036,6 +1036,7 @@ "methods": [ "public void <init>(int)", "public void add(java.lang.String, long)", + "public void add(java.lang.String, java.lang.String)", "public com.yahoo.tensor.PartialAddress build()" ], "fields": [] @@ -1046,7 +1047,15 @@ "attributes": [ "public" ], - "methods": [], + "methods": [ + "public java.lang.String dimension(int)", + "public long numericLabel(java.lang.String)", + "public java.lang.String label(java.lang.String)", + "public java.lang.String label(int)", + "public int size()", + "public com.yahoo.tensor.TensorAddress asAddress(com.yahoo.tensor.TensorType)", + "public java.lang.String toString()" + ], "fields": [] }, "com.yahoo.tensor.Tensor$Builder$CellBuilder": { @@ -2465,6 +2474,47 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.Slice$DimensionValue": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(java.lang.String, java.lang.String)", + "public void <init>(java.lang.String, int)", + "public void <init>(int)", + "public void <init>(java.lang.String)", + "public void <init>(com.yahoo.tensor.functions.ScalarFunction)", + "public void <init>(java.util.Optional, java.lang.String)", + "public void <init>(java.util.Optional, com.yahoo.tensor.functions.ScalarFunction)", + "public void <init>(java.lang.String, com.yahoo.tensor.functions.ScalarFunction)", + "public java.util.Optional dimension()", + "public java.util.Optional label()", + "public java.util.Optional index()", + "public java.lang.String toString()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)" + ], + "fields": [] + }, + "com.yahoo.tensor.functions.Slice": { + "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.Slice withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.Softmax": { "superClass": "com.yahoo.tensor.functions.CompositeTensorFunction", "interfaces": [], @@ -2531,47 +2581,6 @@ ], "fields": [] }, - "com.yahoo.tensor.functions.Value$DimensionValue": { - "superClass": "java.lang.Object", - "interfaces": [], - "attributes": [ - "public" - ], - "methods": [ - "public void <init>(java.lang.String, java.lang.String)", - "public void <init>(java.lang.String, int)", - "public void <init>(int)", - "public void <init>(java.lang.String)", - "public void <init>(com.yahoo.tensor.functions.ScalarFunction)", - "public void <init>(java.util.Optional, java.lang.String)", - "public void <init>(java.util.Optional, com.yahoo.tensor.functions.ScalarFunction)", - "public void <init>(java.lang.String, com.yahoo.tensor.functions.ScalarFunction)", - "public java.util.Optional dimension()", - "public java.util.Optional label()", - "public java.util.Optional index()", - "public java.lang.String toString()", - "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)" - ], - "fields": [] - }, - "com.yahoo.tensor.functions.Value": { - "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", - "interfaces": [], - "attributes": [ - "public" - ], - "methods": [ - "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", - "public java.util.List arguments()", - "public com.yahoo.tensor.functions.Value withArguments(java.util.List)", - "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", - "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", - "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", - "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", - "public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)" - ], - "fields": [] - }, "com.yahoo.tensor.functions.XwPlusB": { "superClass": "com.yahoo.tensor.functions.CompositeTensorFunction", "interfaces": [], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1f3c373c1e8..1cde1fcdbb7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -54,9 +54,8 @@ public class MixedTensor implements Tensor { public double get(TensorAddress address) { long cellIndex = index.indexOf(address); Cell cell = cells.get((int)cellIndex); - if (!address.equals(cell.getKey())) { - throw new IllegalStateException("Unable to find correct cell by direct index."); - } + if ( ! address.equals(cell.getKey())) + throw new IllegalStateException("Unable to find correct cell in " + this + " by direct index " + address); return cell.getValue(); } @@ -375,9 +374,8 @@ public class MixedTensor implements Tensor { public long indexOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); - if ( ! sparseMap.containsKey(sparsePart)) { - throw new IllegalArgumentException("Address not found"); - } + if ( ! sparseMap.containsKey(sparsePart)) + throw new IllegalArgumentException("Address subspace " + sparsePart + " not found in " + this); long base = sparseMap.get(sparsePart); long offset = denseOffset(address); return base + offset; @@ -414,7 +412,7 @@ public class MixedTensor implements Tensor { TensorType.Dimension dimension = type.dimensions().get(i); if (dimension.isIndexed()) { denseSubspaceSize *= dimension.size().orElseThrow(() -> - new IllegalArgumentException("Unknown size of indexed dimension.")); + new IllegalArgumentException("Unknown size of indexed dimension")); } } } @@ -422,15 +420,13 @@ public class MixedTensor implements Tensor { } private TensorAddress sparsePartialAddress(TensorAddress address) { - if (type.dimensions().size() != address.size()) { - throw new IllegalArgumentException("Tensor type and address are not of same size."); - } + if (type.dimensions().size() != address.size()) + throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address); TensorAddress.Builder builder = new TensorAddress.Builder(sparseType); for (int i = 0; i < type.dimensions().size(); ++i) { TensorType.Dimension dimension = type.dimensions().get(i); - if (!dimension.isIndexed()) { + if ( ! dimension.isIndexed()) builder.add(dimension.name(), address.label(i)); - } } return builder.build(); } @@ -488,6 +484,11 @@ public class MixedTensor implements Tensor { return TensorAddress.of(labels); } + @Override + public String toString() { + return "indexes into " + type; + } + } public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index 9c41d5aad68..4eca9c47402 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.Arrays; + /** * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors * dimensions. @@ -13,10 +15,10 @@ package com.yahoo.tensor; // - We can add support for string labels later without breaking the API public class PartialAddress { - // Two arrays which contains corresponding dimension=label pairs. + // Two arrays which contains corresponding dimension:label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final long[] labels; + private final Object[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -25,23 +27,99 @@ public class PartialAddress { builder.labels = null; } - /** Returns the int label of this dimension, or -1 if no label is specified for it */ - long numericLabel(String dimensionName) { + public String dimension(int i) { + return dimensionNames[i]; + } + + /** Returns the numeric label of this dimension, or -1 if no label is specified for it */ + public long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) - return labels[i]; + return asLong(labels[i]); return -1; } + /** Returns the label of this dimension, or null if no label is specified for it */ + public String label(String dimensionName) { + for (int i = 0; i < dimensionNames.length; i++) + if (dimensionNames[i].equals(dimensionName)) + return labels[i].toString(); + return null; + } + + /** + * Returns the label at position i + * + * @throws IllegalArgumentException if i is out of bounds + */ + public String label(int i) { + if (i >= size()) + throw new IllegalArgumentException("No label at position " + i + " in " + this); + return labels[i].toString(); + } + + public int size() { return dimensionNames.length; } + + /** Returns this as an address in the given tensor type */ + // We need the type here not just for validation but because this must map to the dimension order given by the type + public TensorAddress asAddress(TensorType type) { + if (type.rank() != size()) + throw new IllegalArgumentException(type + " has a different rank than " + this); + if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) { + long[] numericLabels = new long[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + long label = numericLabel(type.dimensions().get(i).name()); + if (label < 0) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + numericLabels[i] = label; + } + return TensorAddress.of(numericLabels); + } + else { + String[] stringLabels = new String[labels.length]; + for (int i = 0; i < type.dimensions().size(); i++) { + String label = label(type.dimensions().get(i).name()); + if (label == null) + throw new IllegalArgumentException(type + " dimension names does not match " + this); + stringLabels[i] = label; + } + return TensorAddress.of(stringLabels); + } + } + + private long asLong(Object label) { + if (label instanceof Long) { + return (Long) label; + } + else { + try { + return Long.parseLong(label.toString()); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Label '" + label + "' is not numeric"); + } + } + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder("Partial address {"); + for (int i = 0; i < dimensionNames.length; i++) + b.append(dimensionNames[i]).append(":").append(label(i)).append(", "); + if (size() > 0) + b.setLength(b.length() - 2); + return b.toString(); + } + public static class Builder { private String[] dimensionNames; - private long[] labels; + private Object[] labels; private int index = 0; public Builder(int size) { dimensionNames = new String[size]; - labels = new long[size]; + labels = new Object[size]; } public void add(String dimensionName, long label) { @@ -50,6 +128,12 @@ public class PartialAddress { index++; } + public void add(String dimensionName, String label) { + dimensionNames[index] = dimensionName; + labels[index] = label; + index++; + } + public PartialAddress build() { return new PartialAddress(this); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java new file mode 100644 index 00000000000..4d3989b8782 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -0,0 +1,266 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.annotations.Beta; +import com.yahoo.tensor.PartialAddress; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Iterator; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Returns a subspace of a tensor + * + * @author bratseth + */ +@Beta +public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argument; + private final List<DimensionValue<NAMETYPE>> subspaceAddress; + + /** + * Creates a value function + * + * @param argument the tensor to return a cell value from + * @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress + * because those require a type, but a type is not resolved until this is evaluated + */ + public Slice(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> subspaceAddress) { + this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); + if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty())) + throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " + + "Specify dimension names explicitly instead"); + this.subspaceAddress = subspaceAddress; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } + + @Override + public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); + return new Slice<>(arguments.get(0), subspaceAddress); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor tensor = argument.evaluate(context); + TensorType resultType = resultType(tensor.type()); + + PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context); + if (resultType.rank() == 0) // shortcut common case + return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type()))); + + Tensor.Builder b = Tensor.Builder.of(resultType); + for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) { + Tensor.Cell cell = i.next(); + if (matches(subspaceAddress, cell.getKey(), tensor.type())) + b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue()); + } + return b.build(); + } + + private PartialAddress subspaceToAddress(TensorType type, EvaluationContext<NAMETYPE> context) { + PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size()); + for (int i = 0; i < subspaceAddress.size(); i++) { + if (subspaceAddress.get(i).label().isPresent()) + b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), + subspaceAddress.get(i).label().get()); + else + b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()), + subspaceAddress.get(i).index().get().apply(context).intValue()); + } + return b.build(); + } + + private boolean matches(PartialAddress subspaceAddress, + TensorAddress address, TensorType type) { + for (int i = 0; i < subspaceAddress.size(); i++) { + String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get()); + if ( ! label.equals(subspaceAddress.label(i))) + return false; + } + return true; + } + + /** Returns the subset of the given address which is present in the subspace type */ + private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) { + TensorAddress.Builder b = new TensorAddress.Builder(subspaceType); + for (int i = 0; i < address.size(); i++) { + String dimension = type.dimensions().get(i).name(); + if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent()) + b.add(dimension, address.label(i)); + } + return b.build(); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return resultType(argument.type(context)); + } + + private TensorType resultType(TensorType argumentType) { + TensorType.Builder b = new TensorType.Builder(); + + // Special case where a single indexed or mapped dimension is sliced + if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + if (subspaceAddress.get(0).index().isPresent()) { + if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1) + throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " + + " to " + argumentType + ", which have multiple"); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if ( ! dimension.isIndexed()) + b.dimension(dimension); + } + } + else { + if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1) + throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " + + " to " + argumentType + ", which have multiple"); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if (dimension.isIndexed()) + b.dimension(dimension); + } + + } + } + else { // general slicing + Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet()); + for (TensorType.Dimension dimension : argumentType.dimensions()) { + if (slicedDimensions.contains(dimension.name())) + slicedDimensions.remove(dimension.name()); + else + b.dimension(dimension); + } + if ( ! slicedDimensions.isEmpty()) + throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " + + argumentType); + } + return b.build(); + } + + @Override + public String toString(ToStringContext context) { + StringBuilder b = new StringBuilder(argument.toString(context)); + if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + if (subspaceAddress.get(0).index().isPresent()) + b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]"); + else + b.append("{").append(subspaceAddress.get(0).label().get()).append("}"); + } + else { + b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); + } + return b.toString(); + } + + public static class DimensionValue<NAMETYPE extends Name> { + + private final Optional<String> dimension; + + /** The label of this, or null if index is set */ + private final String label; + + /** The function returning the index of this, or null if label is set */ + private final ScalarFunction<NAMETYPE> index; + + public DimensionValue(String dimension, String label) { + this(Optional.of(dimension), label, null); + } + + public DimensionValue(String dimension, int index) { + this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index)); + } + + public DimensionValue(int index) { + this(Optional.empty(), null, new ConstantIntegerFunction<>(index)); + } + + public DimensionValue(String label) { + this(Optional.empty(), label, null); + } + + public DimensionValue(ScalarFunction<NAMETYPE> index) { + this(Optional.empty(), null, index); + } + + public DimensionValue(Optional<String> dimension, String label) { + this(dimension, label, null); + } + + public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { + this(dimension, null, index); + } + + public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { + this(Optional.of(dimension), null, index); + } + + private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { + this.dimension = dimension; + this.label = label; + this.index = index; + } + + /** + * Returns the given name of the dimension, or null if dense form is used, such that name + * must be inferred from order + */ + public Optional<String> dimension() { return dimension; } + + /** Returns the label for this dimension or empty if it is provided by an index function */ + public Optional<String> label() { return Optional.ofNullable(label); } + + /** Returns the index expression for this dimension, or empty if it is not a number */ + public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } + + @Override + public String toString() { + return toString(null); + } + + public String toString(ToStringContext context) { + StringBuilder b = new StringBuilder(); + dimension.ifPresent(d -> b.append(d).append(":")); + if (label != null) + b.append(label); + else + b.append(index.toString(context)); + return b.toString(); + } + + } + + private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { + + private final int value; + + public ConstantIntegerFunction(int value) { + this.value = value; + } + + @Override + public Double apply(EvaluationContext<NAMETYPE> context) { + return (double)value; + } + + @Override + public String toString() { return String.valueOf(value); } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java deleted file mode 100644 index 37a54807673..00000000000 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java +++ /dev/null @@ -1,189 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor.functions; - -import com.google.common.annotations.Beta; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorAddress; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.EvaluationContext; -import com.yahoo.tensor.evaluation.Name; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * Returns the value of a cell of a tensor (as a rank 0 tensor). - * - * @author bratseth - */ -@Beta -public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - - private final TensorFunction<NAMETYPE> argument; - private final List<DimensionValue<NAMETYPE>> cellAddress; - - /** - * Creates a value function - * - * @param argument the tensor to return a cell value from - * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress - * because those require a type, but a type is not resolved until this is evaluated - */ - public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) { - this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); - if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty())) - throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " + - "Specify dimension names explicitly"); - this.cellAddress = cellAddress; - } - - @Override - public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } - - @Override - public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { - if (arguments.size() != 1) - throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); - return new Value<>(arguments.get(0), cellAddress); - } - - @Override - public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } - - @Override - public Tensor evaluate(EvaluationContext<NAMETYPE> context) { - Tensor tensor = argument.evaluate(context); - if (tensor.type().rank() != cellAddress.size()) - throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() + - " to a tensor of type " + tensor.type()); - TensorAddress.Builder b = new TensorAddress.Builder(tensor.type()); - for (int i = 0; i < cellAddress.size(); i++) { - if (cellAddress.get(i).label().isPresent()) - b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), - cellAddress.get(i).label().get()); - else - b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), - String.valueOf(cellAddress.get(i).index().get().apply(context).intValue())); - } - return Tensor.from(tensor.get(b.build())); - } - - @Override - public TensorType type(TypeContext<NAMETYPE> context) { - return new TensorType.Builder(argument.type(context).valueType()).build(); - } - - @Override - public String toString(ToStringContext context) { - StringBuilder b = new StringBuilder(argument.toString(context)); - if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) { - if (cellAddress.get(0).index().isPresent()) - b.append("[").append(cellAddress.get(0).index().get().toString(context)).append("]"); - else - b.append("{").append(cellAddress.get(0).label().get()).append("}"); - } - else { - b.append("{").append(cellAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); - } - return b.toString(); - } - - public static class DimensionValue<NAMETYPE extends Name> { - - private final Optional<String> dimension; - - /** The label of this, or null if index is set */ - private final String label; - - /** The function returning the index of this, or null if label is set */ - private final ScalarFunction<NAMETYPE> index; - - public DimensionValue(String dimension, String label) { - this(Optional.of(dimension), label, null); - } - - public DimensionValue(String dimension, int index) { - this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index)); - } - - public DimensionValue(int index) { - this(Optional.empty(), null, new ConstantIntegerFunction<>(index)); - } - - public DimensionValue(String label) { - this(Optional.empty(), label, null); - } - - public DimensionValue(ScalarFunction<NAMETYPE> index) { - this(Optional.empty(), null, index); - } - - public DimensionValue(Optional<String> dimension, String label) { - this(dimension, label, null); - } - - public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { - this(dimension, null, index); - } - - public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { - this(Optional.of(dimension), null, index); - } - - private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { - this.dimension = dimension; - this.label = label; - this.index = index; - } - - /** - * Returns the given name of the dimension, or null if dense form is used, such that name - * must be inferred from order - */ - public Optional<String> dimension() { return dimension; } - - /** Returns the label for this dimension or empty if it is provided by an index function */ - public Optional<String> label() { return Optional.ofNullable(label); } - - /** Returns the index expression for this dimension, or empty if it is not a number */ - public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } - - @Override - public String toString() { - return toString(null); - } - - public String toString(ToStringContext context) { - StringBuilder b = new StringBuilder(); - dimension.ifPresent(d -> b.append(d).append(":")); - if (label != null) - b.append(label); - else - b.append(index.toString(context)); - return b.toString(); - } - - } - - private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { - - private final int value; - - public ConstantIntegerFunction(int value) { - this.value = value; - } - - @Override - public Double apply(EvaluationContext<NAMETYPE> context) { - return (double)value; - } - - @Override - public String toString() { return String.valueOf(value); } - - } - -} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java new file mode 100644 index 00000000000..55e6151f7e9 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java @@ -0,0 +1,138 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class SliceTestCase { + + private static final double delta = 0.000001; + + @Test + public void testSliceFunctionGeneralFormToRank0() { + Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("key", "bar"), + new Slice.DimensionValue<>("x", 0))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(2.3, result.asDouble(), delta); + } + + @Test + public void testSliceFunctionGeneralFormToRank0ReverseDimensionOrder() { + Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("x", 0), + new Slice.DimensionValue<>("key", "bar"))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(2.3, result.asDouble(), delta); + } + + @Test + public void testSliceFunctionGeneralFormToIndexedRank2to1() { + Tensor input = Tensor.from("tensor(key{},x[2]):{ {key:foo,x:0}:1.3, {key:foo,x:1}:1.4, {key:bar,x:0}:2.3, {key:bar,x:1}:2.4 }"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("key", "bar"))) + .evaluate(); + assertEquals(1, result.type().rank()); + assertEquals(Tensor.from("tensor(x[2]):[2.3, 2.4]]"), result); + } + + @Test + public void testSliceFunctionGeneralFormToMappedRank2to1() { + Tensor input = Tensor.from("tensor(key{},x[2]):{ {key:foo,x:0}:1.3, {key:foo,x:1}:1.4, {key:bar,x:0}:2.3, {key:bar,x:1}:2.4 }"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("x", 0))) + .evaluate(); + assertEquals(1, result.type().rank()); + assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.3, {key:bar}:2.3}"), result); + } + + @Test + public void testSliceFunctionGeneralFormToMappedRank3to1() { + Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("x", 1), + new Slice.DimensionValue<>("y", 0))) + .evaluate(); + assertEquals(1, result.type().rank()); + assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.4, {key:bar}:2.4}"), result); + } + + @Test + public void testSliceFunctionGeneralFormToMappedRank3to1ReverseDimensionOrder() { + Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("y", 0), + new Slice.DimensionValue<>("x", 1))) + .evaluate(); + assertEquals(1, result.type().rank()); + assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.4, {key:bar}:2.4}"), result); + } + + @Test + public void testSliceFunctionGeneralFormToMappedRank3to2() { + Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("x", 1))) + .evaluate(); + assertEquals(2, result.type().rank()); + assertEquals(Tensor.from("tensor(key{},y[1]):{{key:foo,y:0}:1.4, {key:bar,y:0}:2.4}"), result); + } + + @Test + public void testSliceFunctionSingleMappedDimensionToRank0() { + Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("foo"))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(1.4, result.asDouble(), delta); + } + + @Test + public void testSliceFunctionSingleIndexedDimensionToRank0() { + Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]"); + Tensor result = new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>(2))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(3.3, result.asDouble(), delta); + } + + @Test + public void testSliceFunctionShortFormWithMultipleDimensionsIsNotAllowed() { + try { + Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); + new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>("bar"), + new Slice.DimensionValue<>(0))) + .evaluate(); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Short form of subspace addresses is only supported with a single dimension: Specify dimension names explicitly instead", + e.getMessage()); + } + } + + @Test + public void testToString() { + Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]"); + assertEquals("tensor(key[3]):[1.1, 2.2, 3.3][2]", + new Slice<>(new ConstantTensor<>(input), + List.of(new Slice.DimensionValue<>(2))) + .toString()); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java deleted file mode 100644 index 227fbffbaa8..00000000000 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.tensor.functions; - -import com.yahoo.tensor.Tensor; -import org.junit.Test; - -import java.util.List; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.fail; - -/** - * @author bratseth - */ -public class ValueTestCase { - - private static final double delta = 0.000001; - - @Test - public void testValueFunctionGeneralForm() { - Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); - Tensor result = new Value<>(new ConstantTensor<>(input), - List.of(new Value.DimensionValue<>("key", "bar"), - new Value.DimensionValue<>("x", 0))) - .evaluate(); - assertEquals(0, result.type().rank()); - assertEquals(2.3, result.asDouble(), delta); - } - - @Test - public void testValueFunctionSingleMappedDimension() { - Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }"); - Tensor result = new Value<>(new ConstantTensor<>(input), - List.of(new Value.DimensionValue<>("foo"))) - .evaluate(); - assertEquals(0, result.type().rank()); - assertEquals(1.4, result.asDouble(), delta); - } - - @Test - public void testValueFunctionSingleIndexedDimension() { - Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]"); - Tensor result = new Value<>(new ConstantTensor<>(input), - List.of(new Value.DimensionValue<>(2))) - .evaluate(); - assertEquals(0, result.type().rank()); - assertEquals(3.3, result.asDouble(), delta); - } - - @Test - public void testValueFunctionShortFormWithMultipleDimensionsIsNotAllowed() { - try { - Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); - new Value<>(new ConstantTensor<>(input), - List.of(new Value.DimensionValue<>("bar"), - new Value.DimensionValue<>(0))) - .evaluate(); - fail("Expected exception"); - } - catch (IllegalArgumentException e) { - assertEquals("Short form of cell addresses is only supported with a single dimension: Specify dimension names explicitly", - e.getMessage()); - } - } - - @Test - public void testToString() { - Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]"); - assertEquals("tensor(key[3]):[1.1, 2.2, 3.3][2]", - new Value<>(new ConstantTensor<>(input), - List.of(new Value.DimensionValue<>(2))) - .toString()); - } - -} |