diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-02-01 11:03:17 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-02-01 11:03:17 +0100 |
commit | d2579b9fce1e2d5dd5f509ff767b986129e0973a (patch) | |
tree | 1206c16a24478ed1cc4d5eecb17966acf6f56267 | |
parent | c4dcb35fb2979ea07b4ac576a95089a5c53e68dc (diff) |
- Use numericLabel over label for address manipulation.
- Only use label when actual string representation is needed.
10 files changed, 41 insertions, 59 deletions
diff --git a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java index ef2b40c962d..7e15a729684 100644 --- a/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java +++ b/document/src/main/java/com/yahoo/document/update/TensorModifyUpdate.java @@ -119,7 +119,7 @@ public class TensorModifyUpdate extends ValueUpdate<TensorFieldValue> { for (int i = 0; i < type.dimensions().size(); ++i) { var dim = type.dimensions().get(i); if (dim.isMapped()) { - builder.add(dim.name(), address.label(i)); + builder.add(dim.name(), (int) address.numericLabel(i)); } } return builder.build(); diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 1a9caaa5ca1..7c5e8912e49 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -97,7 +97,7 @@ public class EmbedExpression extends Expression { Tensor.Cell cell = cells.next(); builder.cell() .label(targetType.mappedSubtype().dimensions().get(0).name(), i) - .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().label(0)) + .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().numericLabel(0)) .value(cell.getValue()); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java index 467a7860053..ed672c2dcd7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java @@ -11,11 +11,9 @@ import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; -import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.List; -import java.util.Optional; import java.util.Objects; /** @@ -26,7 +24,7 @@ import java.util.Objects; @Beta public class UnpackBitsNode extends CompositeNode { - private static String operationName = "unpack_bits"; + private static final String operationName = "unpack_bits"; private enum EndianNess { BIG_ENDIAN("big"), LITTLE_ENDIAN("little"); @@ -121,9 +119,9 @@ public class UnpackBitsNode extends CompositeNode { var dim = inputType.dimensions().get(i); if (dim.name().equals(meta.unpackDimension())) { long newIdx = oldAddr.numericLabel(i) * 8 + bitIdx; - addrBuilder.add(dim.name(), String.valueOf(newIdx)); + addrBuilder.add(dim.name(), newIdx); } else { - addrBuilder.add(dim.name(), oldAddr.label(i)); + addrBuilder.add(dim.name(), (int) oldAddr.numericLabel(i)); } } var newAddr = addrBuilder.build(); @@ -152,7 +150,6 @@ public class UnpackBitsNode extends CompositeNode { if (lastDim.size().isEmpty()) { throw new IllegalArgumentException("bad " + operationName + "; last indexed dimension must be bound, but type was: " + inputType); } - List<TensorType.Dimension> outputDims = new ArrayList<>(); var ttBuilder = new TensorType.Builder(targetCellType); for (var dim : inputType.dimensions()) { if (dim.name().equals(lastDim.name())) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 5c2c4d77fad..4fa759668b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,6 +1,7 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.impl.Convert; import com.yahoo.tensor.impl.Label; import com.yahoo.tensor.impl.TensorAddressAny; @@ -102,7 +103,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return "'" + label + "'"; } - /** Returns an address with only some of the dimension */ + /** Returns an address with only some of the dimension. Ordering will also be according to indexMap */ public TensorAddress partialCopy(int[] indexMap) { int[] labels = new int[indexMap.length]; for (int i = 0; i < labels.length; ++i) { @@ -197,6 +198,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { labels[labelIndex] = Label.toNumber(label); return this; } + + public Builder add(String dimension, long label) { + return add(dimension, Convert.safe2Int(label)); + } public Builder add(String dimension, int label) { Objects.requireNonNull(dimension, "dimension cannot be null"); int labelIndex = type.indexOfDimensionAsInt(dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index dcfba5ecfad..9125b35ea5d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.TensorAddressAny; import java.util.Arrays; import java.util.HashMap; @@ -354,21 +355,21 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) { - String[] labels = new String[plan.resultType.rank()]; + int[] labels = new int[plan.resultType.rank()]; int out = 0; int m = 0; int a = 0; int b = 0; for (var how : plan.combineHow) { switch (how) { - case left -> labels[out++] = leftOnly.label(a++); - case right -> labels[out++] = rightOnly.label(b++); - case both -> labels[out++] = match.label(m++); - case concat -> labels[out++] = String.valueOf(concatDimIdx); + case left -> labels[out++] = (int) leftOnly.numericLabel(a++); + case right -> labels[out++] = (int) rightOnly.numericLabel(b++); + case both -> labels[out++] = (int) match.numericLabel(m++); + case concat -> labels[out++] = concatDimIdx; default -> throw new IllegalArgumentException("cannot handle: " + how); } } - return TensorAddress.of(labels); + return TensorAddressAny.ofUnsafe(labels); } Tensor merge(CellVectorMapMap a, CellVectorMapMap b) { @@ -398,8 +399,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET CellVectorMapMap decompose(Tensor input, SplitHow how) { var iter = input.cellIterator(); - String[] commonLabels = new String[(int)how.numCommon()]; - String[] separateLabels = new String[(int)how.numSeparate()]; + int[] commonLabels = new int[(int)how.numCommon()]; + int[] separateLabels = new int[(int)how.numSeparate()]; CellVectorMapMap result = new CellVectorMapMap(); while (iter.hasNext()) { var cell = iter.next(); @@ -409,14 +410,14 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int separateIdx = 0; for (int i = 0; i < how.handleDims.size(); i++) { switch (how.handleDims.get(i)) { - case common -> commonLabels[commonIdx++] = addr.label(i); - case separate -> separateLabels[separateIdx++] = addr.label(i); + case common -> commonLabels[commonIdx++] = (int) addr.numericLabel(i); + case separate -> separateLabels[separateIdx++] = (int) addr.numericLabel(i); case concat -> ccDimIndex = addr.numericLabel(i); default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i)); } } - TensorAddress commonAddr = TensorAddress.of(commonLabels); - TensorAddress separateAddr = TensorAddress.of(separateLabels); + TensorAddress commonAddr = TensorAddressAny.ofUnsafe(commonLabels); + TensorAddress separateAddr = TensorAddressAny.ofUnsafe(separateLabels); result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue()); } return result; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java index c87ef42976d..aa9602339e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java @@ -98,9 +98,9 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction for (int i = 0; i < inputType.dimensions().size(); i++) { var dim = inputType.dimensions().get(i); if (dim.isMapped()) { - mapAddrBuilder.add(dim.name(), fullAddr.label(i)); + mapAddrBuilder.add(dim.name(), fullAddr.numericLabel(i)); } else { - idxAddrBuilder.add(dim.name(), fullAddr.label(i)); + idxAddrBuilder.add(dim.name(), fullAddr.numericLabel(i)); } } var mapAddr = mapAddrBuilder.build(); @@ -123,11 +123,11 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction var addrBuilder = new TensorAddress.Builder(outputType); for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) { var dim = inputTypeMapped.dimensions().get(i); - addrBuilder.add(dim.name(), mappedAddr.label(i)); + addrBuilder.add(dim.name(), mappedAddr.numericLabel(i)); } for (int i = 0; i < denseOutputDims.size(); i++) { var dim = denseOutputDims.get(i); - addrBuilder.add(dim.name(), denseAddr.label(i)); + addrBuilder.add(dim.name(), denseAddr.numericLabel(i)); } builder.cell(addrBuilder.build(), cell.getValue()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 910c5900495..ed4154464fc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -118,11 +118,8 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET return true; } - private TensorAddress rename(TensorAddress address, int[] toIndexes) { - String[] reorderedLabels = new String[toIndexes.length]; - for (int i = 0; i < toIndexes.length; i++) - reorderedLabels[toIndexes[i]] = address.label(i); - return TensorAddress.of(reorderedLabels); + private static TensorAddress rename(TensorAddress address, int[] toIndexes) { + return address.partialCopy(toIndexes); } private String toVectorString(List<String> elements) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 807f56b1a49..38ac42a5f1f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -131,7 +131,7 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY for (int i = 0; i < address.size(); i++) { String dimension = type.dimensions().get(i).name(); if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent()) - b.add(dimension, address.label(i)); + b.add(dimension, (int)address.numericLabel(i)); } return b.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 771b74633d9..5598690e0bf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -16,13 +16,7 @@ import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.Name; -import com.yahoo.tensor.functions.ConstantTensor; -import com.yahoo.tensor.functions.Slice; - -import java.util.ArrayList; import java.util.Iterator; -import java.util.List; /** * Writes tensors on the JSON format used in Vespa tensor document fields: @@ -140,9 +134,9 @@ public class JsonFormat { } private static void encodeBlocks(MixedTensor tensor, Cursor cursor) { - var mappedDimensions = tensor.type().dimensions().stream().filter(d -> d.isMapped()) + var mappedDimensions = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped) .map(d -> TensorType.Dimension.mapped(d.name())).toList(); - if (mappedDimensions.size() < 1) { + if (mappedDimensions.isEmpty()) { throw new IllegalArgumentException("Should be ensured by caller"); } @@ -176,23 +170,6 @@ public class JsonFormat { cursor.setDouble(field, value); } - private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) { - TensorAddress.Builder builder = new TensorAddress.Builder(subType); - for (TensorType.Dimension dim : subType.dimensions()) { - builder.add(dim.name(), address.label(origType.indexOfDimension(dim.name()). - orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index")))); - } - return builder.build(); - } - - private static Tensor sliceSubAddress(Tensor tensor, TensorAddress subAddress, TensorType subType) { - List<Slice.DimensionValue<Name>> sliceDims = new ArrayList<>(subAddress.size()); - for (int i = 0; i < subAddress.size(); ++i) { - sliceDims.add(new Slice.DimensionValue<>(subType.dimensions().get(i).name(), subAddress.label(i))); - } - return new Slice<>(new ConstantTensor<>(tensor), sliceDims).evaluate(); - } - /** Deserializes the given tensor from JSON format */ // NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module public static Tensor decode(TensorType type, byte[] jsonTensorValue) { @@ -420,9 +397,7 @@ public class JsonFormat { if (decoded.length == 0) { throw new IllegalArgumentException("The block value string does not contain any values"); } - for (int i = 0; i < decoded.length; i++) { - values[i] = decoded[i]; - } + System.arraycopy(decoded, 0, values, 0, decoded.length); } else { throw new IllegalArgumentException("Expected a block to contain an array of values"); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java index a24475a6a24..dd40e3105bf 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java @@ -73,4 +73,11 @@ public class TensorAddressTestCase { } } + @Test + void testPartialCopy() { + var abcd = ofLabels("a", "b", "c", "d"); + int[] o_1_3_2 = {1,3,2}; + equal(ofLabels("b", "d", "c"), abcd.partialCopy(o_1_3_2)); + } + } |