summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-02-01 11:03:17 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-02-01 11:03:17 +0100
commitd2579b9fce1e2d5dd5f509ff767b986129e0973a (patch)
tree1206c16a24478ed1cc4d5eecb17966acf6f56267 /vespajlib
parentc4dcb35fb2979ea07b4ac576a95089a5c53e68dc (diff)
- Use numericLabel over label for address manipulation.
- Only use label when actual string representation is needed.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java31
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java7
7 files changed, 36 insertions, 51 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 5c2c4d77fad..4fa759668b6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -1,6 +1,7 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.yahoo.tensor.impl.Convert;
import com.yahoo.tensor.impl.Label;
import com.yahoo.tensor.impl.TensorAddressAny;
@@ -102,7 +103,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return "'" + label + "'";
}
- /** Returns an address with only some of the dimension */
+ /** Returns an address with only some of the dimension. Ordering will also be according to indexMap */
public TensorAddress partialCopy(int[] indexMap) {
int[] labels = new int[indexMap.length];
for (int i = 0; i < labels.length; ++i) {
@@ -197,6 +198,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
labels[labelIndex] = Label.toNumber(label);
return this;
}
+
+ public Builder add(String dimension, long label) {
+ return add(dimension, Convert.safe2Int(label));
+ }
public Builder add(String dimension, int label) {
Objects.requireNonNull(dimension, "dimension cannot be null");
int labelIndex = type.indexOfDimensionAsInt(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index dcfba5ecfad..9125b35ea5d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.impl.TensorAddressAny;
import java.util.Arrays;
import java.util.HashMap;
@@ -354,21 +355,21 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
TensorAddress combine(TensorAddress match, TensorAddress leftOnly, TensorAddress rightOnly, int concatDimIdx) {
- String[] labels = new String[plan.resultType.rank()];
+ int[] labels = new int[plan.resultType.rank()];
int out = 0;
int m = 0;
int a = 0;
int b = 0;
for (var how : plan.combineHow) {
switch (how) {
- case left -> labels[out++] = leftOnly.label(a++);
- case right -> labels[out++] = rightOnly.label(b++);
- case both -> labels[out++] = match.label(m++);
- case concat -> labels[out++] = String.valueOf(concatDimIdx);
+ case left -> labels[out++] = (int) leftOnly.numericLabel(a++);
+ case right -> labels[out++] = (int) rightOnly.numericLabel(b++);
+ case both -> labels[out++] = (int) match.numericLabel(m++);
+ case concat -> labels[out++] = concatDimIdx;
default -> throw new IllegalArgumentException("cannot handle: " + how);
}
}
- return TensorAddress.of(labels);
+ return TensorAddressAny.ofUnsafe(labels);
}
Tensor merge(CellVectorMapMap a, CellVectorMapMap b) {
@@ -398,8 +399,8 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
CellVectorMapMap decompose(Tensor input, SplitHow how) {
var iter = input.cellIterator();
- String[] commonLabels = new String[(int)how.numCommon()];
- String[] separateLabels = new String[(int)how.numSeparate()];
+ int[] commonLabels = new int[(int)how.numCommon()];
+ int[] separateLabels = new int[(int)how.numSeparate()];
CellVectorMapMap result = new CellVectorMapMap();
while (iter.hasNext()) {
var cell = iter.next();
@@ -409,14 +410,14 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int separateIdx = 0;
for (int i = 0; i < how.handleDims.size(); i++) {
switch (how.handleDims.get(i)) {
- case common -> commonLabels[commonIdx++] = addr.label(i);
- case separate -> separateLabels[separateIdx++] = addr.label(i);
+ case common -> commonLabels[commonIdx++] = (int) addr.numericLabel(i);
+ case separate -> separateLabels[separateIdx++] = (int) addr.numericLabel(i);
case concat -> ccDimIndex = addr.numericLabel(i);
default -> throw new IllegalArgumentException("cannot handle: " + how.handleDims.get(i));
}
}
- TensorAddress commonAddr = TensorAddress.of(commonLabels);
- TensorAddress separateAddr = TensorAddress.of(separateLabels);
+ TensorAddress commonAddr = TensorAddressAny.ofUnsafe(commonLabels);
+ TensorAddress separateAddr = TensorAddressAny.ofUnsafe(separateLabels);
result.lookupCreate(commonAddr).lookupCreate(separateAddr).setValue((int)ccDimIndex, cell.getValue());
}
return result;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
index c87ef42976d..aa9602339e9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
@@ -98,9 +98,9 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction
for (int i = 0; i < inputType.dimensions().size(); i++) {
var dim = inputType.dimensions().get(i);
if (dim.isMapped()) {
- mapAddrBuilder.add(dim.name(), fullAddr.label(i));
+ mapAddrBuilder.add(dim.name(), fullAddr.numericLabel(i));
} else {
- idxAddrBuilder.add(dim.name(), fullAddr.label(i));
+ idxAddrBuilder.add(dim.name(), fullAddr.numericLabel(i));
}
}
var mapAddr = mapAddrBuilder.build();
@@ -123,11 +123,11 @@ public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction
var addrBuilder = new TensorAddress.Builder(outputType);
for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) {
var dim = inputTypeMapped.dimensions().get(i);
- addrBuilder.add(dim.name(), mappedAddr.label(i));
+ addrBuilder.add(dim.name(), mappedAddr.numericLabel(i));
}
for (int i = 0; i < denseOutputDims.size(); i++) {
var dim = denseOutputDims.get(i);
- addrBuilder.add(dim.name(), denseAddr.label(i));
+ addrBuilder.add(dim.name(), denseAddr.numericLabel(i));
}
builder.cell(addrBuilder.build(), cell.getValue());
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 910c5900495..ed4154464fc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -118,11 +118,8 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
return true;
}
- private TensorAddress rename(TensorAddress address, int[] toIndexes) {
- String[] reorderedLabels = new String[toIndexes.length];
- for (int i = 0; i < toIndexes.length; i++)
- reorderedLabels[toIndexes[i]] = address.label(i);
- return TensorAddress.of(reorderedLabels);
+ private static TensorAddress rename(TensorAddress address, int[] toIndexes) {
+ return address.partialCopy(toIndexes);
}
private String toVectorString(List<String> elements) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 807f56b1a49..38ac42a5f1f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -131,7 +131,7 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
for (int i = 0; i < address.size(); i++) {
String dimension = type.dimensions().get(i).name();
if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent())
- b.add(dimension, address.label(i));
+ b.add(dimension, (int)address.numericLabel(i));
}
return b.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 771b74633d9..5598690e0bf 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -16,13 +16,7 @@ import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.Name;
-import com.yahoo.tensor.functions.ConstantTensor;
-import com.yahoo.tensor.functions.Slice;
-
-import java.util.ArrayList;
import java.util.Iterator;
-import java.util.List;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
@@ -140,9 +134,9 @@ public class JsonFormat {
}
private static void encodeBlocks(MixedTensor tensor, Cursor cursor) {
- var mappedDimensions = tensor.type().dimensions().stream().filter(d -> d.isMapped())
+ var mappedDimensions = tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped)
.map(d -> TensorType.Dimension.mapped(d.name())).toList();
- if (mappedDimensions.size() < 1) {
+ if (mappedDimensions.isEmpty()) {
throw new IllegalArgumentException("Should be ensured by caller");
}
@@ -176,23 +170,6 @@ public class JsonFormat {
cursor.setDouble(field, value);
}
- private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) {
- TensorAddress.Builder builder = new TensorAddress.Builder(subType);
- for (TensorType.Dimension dim : subType.dimensions()) {
- builder.add(dim.name(), address.label(origType.indexOfDimension(dim.name()).
- orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index"))));
- }
- return builder.build();
- }
-
- private static Tensor sliceSubAddress(Tensor tensor, TensorAddress subAddress, TensorType subType) {
- List<Slice.DimensionValue<Name>> sliceDims = new ArrayList<>(subAddress.size());
- for (int i = 0; i < subAddress.size(); ++i) {
- sliceDims.add(new Slice.DimensionValue<>(subType.dimensions().get(i).name(), subAddress.label(i)));
- }
- return new Slice<>(new ConstantTensor<>(tensor), sliceDims).evaluate();
- }
-
/** Deserializes the given tensor from JSON format */
// NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
@@ -420,9 +397,7 @@ public class JsonFormat {
if (decoded.length == 0) {
throw new IllegalArgumentException("The block value string does not contain any values");
}
- for (int i = 0; i < decoded.length; i++) {
- values[i] = decoded[i];
- }
+ System.arraycopy(decoded, 0, values, 0, decoded.length);
} else {
throw new IllegalArgumentException("Expected a block to contain an array of values");
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java
index a24475a6a24..dd40e3105bf 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorAddressTestCase.java
@@ -73,4 +73,11 @@ public class TensorAddressTestCase {
}
}
+ @Test
+ void testPartialCopy() {
+ var abcd = ofLabels("a", "b", "c", "d");
+ int[] o_1_3_2 = {1,3,2};
+ equal(ofLabels("b", "d", "c"), abcd.partialCopy(o_1_3_2));
+ }
+
}