summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--container-search/src/main/java/com/yahoo/fs4/QueryPacket.java36
-rw-r--r--container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java5
-rw-r--r--documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java8
-rw-r--r--messagebus/pom.xml6
-rw-r--r--messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java6
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java42
-rw-r--r--vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java16
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java4
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java17
-rw-r--r--vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java144
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java61
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java95
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java20
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java11
17 files changed, 396 insertions, 117 deletions
diff --git a/container-search/src/main/java/com/yahoo/fs4/QueryPacket.java b/container-search/src/main/java/com/yahoo/fs4/QueryPacket.java
index 16414e8eabd..3a695efa383 100644
--- a/container-search/src/main/java/com/yahoo/fs4/QueryPacket.java
+++ b/container-search/src/main/java/com/yahoo/fs4/QueryPacket.java
@@ -105,7 +105,8 @@ public class QueryPacket extends Packet {
int startOfFieldToSave;
boolean sendSessionKey = query.getGroupingSessionCache() || query.getRanking().getQueryCache();
- buffer.putInt(getFeatureInt(sendSessionKey));
+ int featureFlag = getFeatureInt(sendSessionKey);
+ buffer.putInt(featureFlag);
IntegerCompressor.putCompressedPositiveNumber(getOffset(), buffer);
IntegerCompressor.putCompressedPositiveNumber(getHits(), buffer);
@@ -117,7 +118,7 @@ public class QueryPacket extends Packet {
Item.putString(query.getRanking().getProfile(), buffer);
queryPacketData.setRankProfile(buffer, startOfFieldToSave);
- if ( query.hasEncodableProperties()) {
+ if ( (featureFlag & QF_PROPERTIES) != 0) {
startOfFieldToSave = buffer.position();
query.encodeAsProperties(buffer, true);
queryPacketData.setPropertyMaps(buffer, startOfFieldToSave);
@@ -125,17 +126,18 @@ public class QueryPacket extends Packet {
// Language not needed when sending query stacks
- if (query.getRanking().getSorting() != null) {
+ if ((featureFlag & QF_SORTSPEC) != 0) {
int sortSpecLengthPosition=buffer.position();
buffer.putInt(0);
int sortSpecLength = query.getRanking().getSorting().encode(buffer);
buffer.putInt(sortSpecLengthPosition, sortSpecLength);
}
- if (getGroupingList(query).size() > 0) {
+ if ( (featureFlag & QF_GROUPSPEC) != 0) {
+ List<Grouping> groupingList = GroupingExecutor.getGroupingList(query);
BufferSerializer gbuf = new BufferSerializer(new GrowableByteBuffer());
- gbuf.putInt(null, getGroupingList(query).size());
- for (Grouping g: getGroupingList(query)){
+ gbuf.putInt(null, groupingList.size());
+ for (Grouping g: groupingList){
g.serialize(gbuf);
}
gbuf.getBuf().flip();
@@ -150,7 +152,7 @@ public class QueryPacket extends Packet {
buffer.put(query.getSessionId(true).asUtf8String().getBytes());
}
- if (query.getRanking().getLocation() != null) {
+ if ((featureFlag & QF_LOCATION) != 0) {
startOfFieldToSave = buffer.position();
int locationLengthPosition=buffer.position();
buffer.putInt(0);
@@ -184,16 +186,14 @@ public class QueryPacket extends Packet {
static final int QF_SESSIONID = 0x00800000;
private int getFeatureInt(boolean sendSessionId) {
- int features = QF_PARSEDQUERY; // this bitmask means "parsed query" in query packet.
- // we always use a parsed query here
+ int features = QF_PARSEDQUERY | QF_RANKP; // this bitmask means "parsed query" in query packet.
+ // And rank properties. Both are always present
- features |= QF_RANKP; // hasRankProfile
-
- features |= (query.getRanking().getSorting() != null) ? QF_SORTSPEC : 0;
- features |= (query.getRanking().getLocation() != null) ? QF_LOCATION : 0;
- features |= (query.hasEncodableProperties()) ? QF_PROPERTIES : 0;
- features |= (getGroupingList(query).size() > 0) ? QF_GROUPSPEC : 0;
- features |= (sendSessionId) ? QF_SESSIONID : 0;
+ features |= (query.getRanking().getSorting() != null) ? QF_SORTSPEC : 0;
+ features |= (query.getRanking().getLocation() != null) ? QF_LOCATION : 0;
+ features |= (query.hasEncodableProperties()) ? QF_PROPERTIES : 0;
+ features |= GroupingExecutor.hasGroupingList(query) ? QF_GROUPSPEC : 0;
+ features |= (sendSessionId) ? QF_SESSIONID : 0;
return features;
}
@@ -233,10 +233,6 @@ public class QueryPacket extends Packet {
return "Query x packet [query: " + query + "]";
}
- private static List<Grouping> getGroupingList(Query query) {
- return Collections.unmodifiableList(GroupingExecutor.getGroupingList(query));
- }
-
static int getQueryFlags(Query query) {
int flags = 0;
diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java
index e5e91f21f5f..70942fa2553 100644
--- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java
+++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java
@@ -363,6 +363,11 @@ public class GroupingExecutor extends Searcher {
return (List<Grouping>)obj;
}
+ public static boolean hasGroupingList(Query query) {
+ Object obj = query.properties().get(PROP_GROUPINGLIST);
+ return (obj instanceof List);
+ }
+
/**
* Sets the list of {@link Grouping} objects assigned to the given query. This method overwrites any grouping
* objects already assigned to the query.
diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java
index fdb4b8f6339..7f8121c2138 100644
--- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java
+++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java
@@ -326,9 +326,13 @@ public class StoragePolicy extends ExternalSlobrokPolicy {
}
}
if (context.usedState != null && newState.getVersion() <= context.usedState.getVersion()) {
- reply.setRetryDelay(-1);
+ if (reply.getRetryDelay() <= 0.0) {
+ reply.setRetryDelay(-1);
+ }
} else {
- reply.setRetryDelay(0);
+ if (reply.getRetryDelay() <= 0.0) {
+ reply.setRetryDelay(0);
+ }
}
if (context.calculatedDistributor == null) {
if (cachedClusterState == null) {
diff --git a/messagebus/pom.xml b/messagebus/pom.xml
index a3ae2fb54c2..37a1c681497 100644
--- a/messagebus/pom.xml
+++ b/messagebus/pom.xml
@@ -19,6 +19,12 @@
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>org.hamcrest</groupId>
+ <artifactId>hamcrest-all</artifactId>
+ <version>1.3</version>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>com.yahoo.vespa</groupId>
<artifactId>vespajlib</artifactId>
<version>${project.version}</version>
diff --git a/messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java b/messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java
index 95cf7639d09..d8345180952 100644
--- a/messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java
+++ b/messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java
@@ -15,6 +15,8 @@ import java.util.concurrent.TimeUnit;
import static org.hamcrest.CoreMatchers.instanceOf;
import static org.hamcrest.CoreMatchers.is;
+import static org.hamcrest.number.OrderingComparison.lessThan;
+import static org.hamcrest.number.OrderingComparison.greaterThanOrEqualTo;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
@@ -109,8 +111,8 @@ public class LocalNetworkTest {
assertEquals(ErrorCode.TIMEOUT, res.getError().getCode());
assertTrue(res.getError().getMessage().endsWith("Timed out in sendQ"));
long end = System.currentTimeMillis();
- assertTrue(end - start >= (TIMEOUT*0.98)); // Different clocks are used....
- assertTrue(end - start < 2*TIMEOUT);
+ assertThat(end, greaterThanOrEqualTo(start+TIMEOUT));
+ assertThat(end, lessThan(start+2*TIMEOUT));
msg = serverB.messages.poll(60, TimeUnit.SECONDS);
assertThat(msg, instanceOf(SimpleMessage.class));
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
index a3896808177..a1863262ba0 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java
@@ -4,15 +4,14 @@ package com.yahoo.vespa.http.client.config;
import com.google.common.annotations.Beta;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.Multimap;
-import com.yahoo.vespa.http.client.Session;
-
import net.jcip.annotations.Immutable;
import javax.net.ssl.SSLContext;
-
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.Map;
+import java.util.Objects;
import java.util.concurrent.TimeUnit;
/**
@@ -36,6 +35,7 @@ public final class ConnectionParams {
private SSLContext sslContext = null;
private long connectionTimeout = TimeUnit.SECONDS.toMillis(60);
private final Multimap<String, String> headers = ArrayListMultimap.create();
+ private final Map<String, HeaderProvider> headerProviders = new HashMap<>();
private int numPersistentConnectionsPerEndpoint = 8;
private String proxyHost = null;
private int proxyPort = 8080;
@@ -73,6 +73,24 @@ public final class ConnectionParams {
}
/**
+ * Adds a header provider for dynamic headers; headers where the value may change during a feeding session
+ * (e.g. security tokens with limited life time). Only one {@link HeaderProvider} is allowed for a given header name.
+ *
+ * @param provider A provider for a dynamic header
+ * @return pointer to builder.
+ * @throws IllegalArgumentException if a provider is already registered for the given header name
+ */
+ public Builder addDynamicHeader(String headerName, HeaderProvider provider) {
+ Objects.requireNonNull(headerName, "Header name cannot be null");
+ Objects.requireNonNull(provider, "Header provider cannot be null");
+ if (headerProviders.containsKey(headerName)) {
+ throw new IllegalArgumentException("Provider already registered for name '" + headerName + "'");
+ }
+ headerProviders.put(headerName, provider);
+ return this;
+ }
+
+ /**
* The number of connections between the http client and the gateways. A very low number can result
* in the network not fully utilized and the round-trip time can be a limiting factor. A low number
* can cause skew in distribution of load between gateways. A too high number will cause
@@ -203,6 +221,7 @@ public final class ConnectionParams {
sslContext,
connectionTimeout,
headers,
+ headerProviders,
numPersistentConnectionsPerEndpoint,
proxyHost,
proxyPort,
@@ -254,6 +273,7 @@ public final class ConnectionParams {
private final SSLContext sslContext;
private final long connectionTimeout;
private final Multimap<String, String> headers = ArrayListMultimap.create();
+ private final Map<String, HeaderProvider> headerProviders = new HashMap<>();
private final int numPersistentConnectionsPerEndpoint;
private final String proxyHost;
private final int proxyPort;
@@ -270,6 +290,7 @@ public final class ConnectionParams {
SSLContext sslContext,
long connectionTimeout,
Multimap<String, String> headers,
+ Map<String, HeaderProvider> headerProviders,
int numPersistentConnectionsPerEndpoint,
String proxyHost,
int proxyPort,
@@ -284,6 +305,7 @@ public final class ConnectionParams {
this.sslContext = sslContext;
this.connectionTimeout = connectionTimeout;
this.headers.putAll(headers);
+ this.headerProviders.putAll(headerProviders);
this.numPersistentConnectionsPerEndpoint = numPersistentConnectionsPerEndpoint;
this.proxyHost = proxyHost;
this.proxyPort = proxyPort;
@@ -305,6 +327,10 @@ public final class ConnectionParams {
return Collections.unmodifiableCollection(headers.entries());
}
+ public Map<String, HeaderProvider> getDynamicHeaders() {
+ return Collections.unmodifiableMap(headerProviders);
+ }
+
public int getNumPersistentConnectionsPerEndpoint() {
return numPersistentConnectionsPerEndpoint;
}
@@ -347,4 +373,14 @@ public final class ConnectionParams {
return printTraceToStdErr;
}
+ /**
+ * A header provider that provides a header value. {@link #getHeaderValue()} is called each time a new HTTP request
+ * is constructed by {@link com.yahoo.vespa.http.client.FeedClient}.
+ *
+ * Important: The implementation of {@link #getHeaderValue()} must be thread-safe!
+ */
+ public interface HeaderProvider {
+ String getHeaderValue();
+ }
+
}
diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
index f3d751f67c7..6fae0765e14 100644
--- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
+++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java
@@ -16,6 +16,8 @@ import org.apache.http.StatusLine;
import org.apache.http.client.HttpClient;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.methods.HttpPost;
+import org.apache.http.config.Registry;
+import org.apache.http.config.RegistryBuilder;
import org.apache.http.conn.socket.ConnectionSocketFactory;
import org.apache.http.conn.socket.PlainConnectionSocketFactory;
import org.apache.http.conn.ssl.SSLConnectionSocketFactory;
@@ -23,9 +25,6 @@ import org.apache.http.entity.InputStreamEntity;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.conn.PoolingHttpClientConnectionManager;
-import org.apache.http.config.Registry;
-import org.apache.http.config.RegistryBuilder;
-
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
@@ -36,13 +35,13 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.zip.GZIPOutputStream;
-import java.util.UUID;
-
/**
* @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a>
* @since 5.1.20
@@ -228,6 +227,13 @@ class ApacheGatewayConnection implements GatewayConnection {
for (Map.Entry<String, String> extraHeader : connectionParams.getHeaders()) {
httpPost.addHeader(extraHeader.getKey(), extraHeader.getValue());
}
+ connectionParams.getDynamicHeaders().forEach((headerName, provider) -> {
+ String headerValue = Objects.requireNonNull(
+ provider.getHeaderValue(),
+ provider.getClass().getName() + ".getHeader() returned null as header value!");
+ httpPost.addHeader(headerName, headerValue);
+ });
+
if (useCompression) {
httpPost.setHeader("Content-Encoding", "gzip");
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java
index bd894e84ed4..a095cb184a0 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java
@@ -109,7 +109,9 @@ public class V3HttpAPITest extends TestOnCiBuildingSystemOnly {
@Test
public void requireThatBadResponseCodeFails() throws Exception {
- testServerWithMock(new V3MockParsingRequestHandler(407), true);
+ testServerWithMock(new V3MockParsingRequestHandler(401/*Unauthorized*/), true);
+ testServerWithMock(new V3MockParsingRequestHandler(403/*Forbidden*/), true);
+ testServerWithMock(new V3MockParsingRequestHandler(407/*Proxy Authentication Required*/), true);
}
@Test
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java
index 49ffabbf1d0..cb9270fdd3f 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java
@@ -7,12 +7,11 @@ import javax.net.ssl.SSLContext;
import java.security.NoSuchAlgorithmException;
import java.util.Iterator;
import java.util.Map;
-import java.util.concurrent.TimeUnit;
import static org.hamcrest.CoreMatchers.equalTo;
-import static org.hamcrest.CoreMatchers.not;
import static org.hamcrest.core.Is.is;
import static org.hamcrest.core.IsNull.nullValue;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;
/**
@@ -57,4 +56,18 @@ public class ConnectionParamsTest {
assertThat(header3.getValue(), equalTo("Apple"));
}
+ @Test
+ public void header_providers_are_registered() {
+ ConnectionParams.HeaderProvider dummyProvider1 = () -> "fooValue";
+ ConnectionParams.HeaderProvider dummyProvider2 = () -> "barValue";
+ ConnectionParams params = new ConnectionParams.Builder()
+ .addDynamicHeader("foo", dummyProvider1)
+ .addDynamicHeader("bar", dummyProvider2)
+ .build();
+ Map<String, ConnectionParams.HeaderProvider> headerProviders = params.getDynamicHeaders();
+ assertEquals(2, headerProviders.size());
+ assertEquals(dummyProvider1, headerProviders.get("foo"));
+ assertEquals(dummyProvider2, headerProviders.get("bar"));
+ }
+
}
diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
index fa4ad8fa175..456e5184f6e 100644
--- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
+++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java
@@ -1,19 +1,13 @@
// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.vespa.http.client.core.communication;
-import java.io.ByteArrayInputStream;
-import java.io.IOException;
-import java.io.InputStream;
-import java.nio.charset.StandardCharsets;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.concurrent.CountDownLatch;
-import java.util.concurrent.TimeUnit;
-import java.util.concurrent.atomic.AtomicInteger;
-
-import com.yahoo.vespa.http.client.core.Headers;
import com.yahoo.vespa.http.client.TestUtils;
+import com.yahoo.vespa.http.client.config.ConnectionParams;
+import com.yahoo.vespa.http.client.config.Endpoint;
+import com.yahoo.vespa.http.client.config.FeedParams;
import com.yahoo.vespa.http.client.core.Document;
+import com.yahoo.vespa.http.client.core.Headers;
+import com.yahoo.vespa.http.client.core.ServerResponseException;
import org.apache.http.Header;
import org.apache.http.HeaderElement;
import org.apache.http.HttpEntity;
@@ -24,16 +18,29 @@ import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.entity.InputStreamEntity;
import org.junit.Test;
-import com.yahoo.vespa.http.client.config.ConnectionParams;
-import com.yahoo.vespa.http.client.config.Endpoint;
-import com.yahoo.vespa.http.client.config.FeedParams;
-import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicInteger;
+
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.core.Is.is;
+import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
-import static org.mockito.Mockito.*;
+import static org.mockito.Mockito.any;
+import static org.mockito.Mockito.mock;
+import static org.mockito.Mockito.stub;
+import static org.mockito.Mockito.times;
+import static org.mockito.Mockito.verify;
+import static org.mockito.Mockito.when;
public class ApacheGatewayConnectionTest {
@@ -47,23 +54,14 @@ public class ApacheGatewayConnectionTest {
.setEnableV3Protocol(true)
.build();
final List<Document> documents = new ArrayList<>();
-
- final ApacheGatewayConnection.HttpClientFactory mockFactory =
- mock(ApacheGatewayConnection.HttpClientFactory.class);
- final HttpClient httpClientMock = mock(HttpClient.class);
- when(mockFactory.createClient()).thenReturn(httpClientMock);
-
final CountDownLatch verifyContentSentLatch = new CountDownLatch(1);
final String vespaDocContent ="Hello, I a JSON doc.";
final String docId = "42";
final AtomicInteger requestsReceived = new AtomicInteger(0);
-
// This is the fake server, takes header client ID and uses this as session Id.
- stub(httpClientMock.execute(any())).toAnswer((Answer) invocation -> {
- final Object[] args = invocation.getArguments();
- final HttpPost post = (HttpPost) args[0];
+ ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> {
final Header clientIdHeader = post.getFirstHeader(Headers.CLIENT_ID);
verifyContentSentLatch.countDown();
return httpResponse(clientIdHeader.getValue(), "3");
@@ -94,15 +92,8 @@ public class ApacheGatewayConnectionTest {
.setEnableV3Protocol(true)
.build();
- final ApacheGatewayConnection.HttpClientFactory mockFactory =
- mock(ApacheGatewayConnection.HttpClientFactory.class);
- final HttpClient httpClientMock = mock(HttpClient.class);
- when(mockFactory.createClient()).thenReturn(httpClientMock);
-
// This is the fake server, returns wrong session Id.
- stub(httpClientMock.execute(any())).toAnswer(invocation -> {
- return httpResponse("Wrong Id from server", "3");
- });
+ ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> httpResponse("Wrong Id from server", "3"));
ApacheGatewayConnection apacheGatewayConnection =
new ApacheGatewayConnection(
@@ -148,11 +139,6 @@ public class ApacheGatewayConnectionTest {
.build();
final List<Document> documents = new ArrayList<>();
- final ApacheGatewayConnection.HttpClientFactory mockFactory =
- mock(ApacheGatewayConnection.HttpClientFactory.class);
- final HttpClient httpClientMock = mock(HttpClient.class);
- when(mockFactory.createClient()).thenReturn(httpClientMock);
-
final CountDownLatch verifyContentSentLatch = new CountDownLatch(1);
final String vespaDocContent ="Hello, I a JSON doc.";
@@ -161,13 +147,11 @@ public class ApacheGatewayConnectionTest {
final AtomicInteger requestsReceived = new AtomicInteger(0);
// This is the fake server, checks that DATA_FORMAT header is set properly.
- stub(httpClientMock.execute(any())).toAnswer((Answer) invocation -> {
- final Object[] args = invocation.getArguments();
- final HttpPost post = (HttpPost) args[0];
+ ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> {
final Header header = post.getFirstHeader(Headers.DATA_FORMAT);
if (requestsReceived.incrementAndGet() == 1) {
// This is handshake, it is not json.
- assert(header == null);
+ assert (header == null);
return httpResponse("clientId", "3");
}
assertNotNull(header);
@@ -218,11 +202,6 @@ public class ApacheGatewayConnectionTest {
.build();
final List<Document> documents = new ArrayList<>();
- final ApacheGatewayConnection.HttpClientFactory mockFactory =
- mock(ApacheGatewayConnection.HttpClientFactory.class);
- final HttpClient httpClientMock = mock(HttpClient.class);
- when(mockFactory.createClient()).thenReturn(httpClientMock);
-
final CountDownLatch verifyContentSentLatch = new CountDownLatch(1);
final String vespaDocContent ="Hello, I am the document data.";
@@ -232,9 +211,7 @@ public class ApacheGatewayConnectionTest {
// When sending data on http client, check if it is compressed. If compressed, unzip, check result,
// and count down latch.
- stub(httpClientMock.execute(any())).toAnswer((Answer) invocation -> {
- final Object[] args = invocation.getArguments();
- final HttpPost post = (HttpPost) args[0];
+ ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> {
final Header header = post.getFirstHeader("Content-Encoding");
if (header != null && header.getValue().equals("gzip")) {
final String rawContent = TestUtils.zipStreamToString(post.getEntity().getContent());
@@ -249,6 +226,7 @@ public class ApacheGatewayConnectionTest {
}
return httpResponse("clientId", "3");
});
+
StatusLine statusLineMock = mock(StatusLine.class);
when(statusLineMock.getStatusCode()).thenReturn(200);
@@ -269,6 +247,68 @@ public class ApacheGatewayConnectionTest {
assertTrue(verifyContentSentLatch.await(10, TimeUnit.SECONDS));
}
+ @Test
+ public void dynamic_headers_are_added_to_the_response() throws IOException, ServerResponseException, InterruptedException {
+ ConnectionParams.HeaderProvider headerProvider = mock(ConnectionParams.HeaderProvider.class);
+ when(headerProvider.getHeaderValue())
+ .thenReturn("v1")
+ .thenReturn("v2")
+ .thenReturn("v3");
+
+ ConnectionParams connectionParams = new ConnectionParams.Builder()
+ .addDynamicHeader("foo", headerProvider)
+ .build();
+
+ CountDownLatch verifyContentSentLatch = new CountDownLatch(1);
+
+ AtomicInteger counter = new AtomicInteger(1);
+ ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> {
+ Header[] fooHeader = post.getHeaders("foo");
+ assertEquals(1, fooHeader.length);
+ assertEquals("foo", fooHeader[0].getName());
+ assertEquals("v" + counter.getAndIncrement(), fooHeader[0].getValue());
+ verifyContentSentLatch.countDown();
+ return httpResponse("clientId", "3");
+
+ });
+
+ ApacheGatewayConnection apacheGatewayConnection =
+ new ApacheGatewayConnection(
+ Endpoint.create("hostname", 666, false),
+ new FeedParams.Builder().build(),
+ "",
+ connectionParams,
+ mockFactory,
+ "clientId");
+ apacheGatewayConnection.connect();
+ apacheGatewayConnection.handshake();
+
+ List<Document> documents = new ArrayList<>();
+ documents.add(createDoc("42", "content", true));
+ apacheGatewayConnection.writeOperations(documents);
+ apacheGatewayConnection.writeOperations(documents);
+ assertTrue(verifyContentSentLatch.await(10, TimeUnit.SECONDS));
+
+ verify(headerProvider, times(3)).getHeaderValue(); // 1x connect(), 2x writeOperations()
+ }
+
+ private static ApacheGatewayConnection.HttpClientFactory mockHttpClientFactory(HttpExecuteMock httpExecuteMock) throws IOException {
+ ApacheGatewayConnection.HttpClientFactory mockFactory =
+ mock(ApacheGatewayConnection.HttpClientFactory.class);
+ HttpClient httpClientMock = mock(HttpClient.class);
+ when(mockFactory.createClient()).thenReturn(httpClientMock);
+ stub(httpClientMock.execute(any())).toAnswer((Answer) invocation -> {
+ Object[] args = invocation.getArguments();
+ HttpPost post = (HttpPost) args[0];
+ return httpExecuteMock.execute(post);
+ });
+ return mockFactory;
+ }
+
+ @FunctionalInterface private interface HttpExecuteMock {
+ HttpResponse execute(HttpPost httpPost) throws IOException;
+ }
+
private Document createDoc(final String docId, final String content, boolean useJson) throws IOException {
return new Document(docId, content.getBytes(), null /* context */);
}
@@ -315,4 +355,4 @@ public class ApacheGatewayConnectionTest {
when(httpEntityMock.getContent()).thenReturn(inputs);
return httpResponseMock;
}
-} \ No newline at end of file
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index 76340bb7d8f..daa85cc51e4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -1,5 +1,7 @@
package com.yahoo.tensor;
+import com.google.common.annotations.Beta;
+
import java.util.Arrays;
/**
@@ -7,6 +9,7 @@ import java.util.Arrays;
*
* @author bratseth
*/
+@Beta
public final class DimensionSizes {
private final int[] sizes;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index d69cf65ee8d..9315922f57a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -52,6 +52,21 @@ public class IndexedTensor implements Tensor {
return new CellIterator();
}
+ /** Returns an iterator over all the cells in this tensor which matches the given partial address */
+ // TODO: Move up to Tensor and create a mixed tensor which can implement it (and subspace iterators) efficiently
+ public SubspaceIterator cellIterator(PartialAddress partialAddress, DimensionSizes iterationSizes) {
+ int[] startAddress = new int[type().dimensions().size()];
+ List<Integer> iterateDimensions = new ArrayList<>();
+ for (int i = 0; i < type().dimensions().size(); i++) {
+ int partialAddressLabel = partialAddress.intLabel(type.dimensions().get(i).name());
+ if (partialAddressLabel >= 0) // iterate at this label
+ startAddress[i] = partialAddressLabel;
+ else // iterate over this dimension
+ iterateDimensions.add(i);
+ }
+ return new SubspaceIterator(iterateDimensions, startAddress, iterationSizes);
+ }
+
/**
* Returns an iterator over the values of this.
* Values are returned in order of increasing indexes in each dimension, increasing
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
new file mode 100644
index 00000000000..463b7f3e99f
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -0,0 +1,61 @@
+package com.yahoo.tensor;
+
+import com.google.common.annotations.Beta;
+
+/**
+ * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
+ * dimensions.
+ *
+ * @author bratseth
+ */
+// Implementation notes:
+// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation.
+// We also avoid non-essential error checking.
+// - We can add support for string labels later without breaking the API
+@Beta
+public class PartialAddress {
+
+ // Two arrays which contains corresponding dimension=label pairs.
+ // The sizes of these are always equal.
+ private final String[] dimensionNames;
+ private final int[] labels;
+
+ private PartialAddress(Builder builder) {
+ this.dimensionNames = builder.dimensionNames;
+ this.labels = builder.labels;
+ builder.dimensionNames = null; // invalidate builder to safely take over array ownership
+ builder.labels = null;
+ }
+
+ /** Returns the int label of this dimension, or -1 if no label is specified for it */
+ int intLabel(String dimensionName) {
+ for (int i = 0; i < dimensionNames.length; i++)
+ if (dimensionNames[i].equals(dimensionName))
+ return labels[i];
+ return -1;
+ }
+
+ public static class Builder {
+
+ private String[] dimensionNames;
+ private int[] labels;
+ private int index = 0;
+
+ public Builder(int size) {
+ dimensionNames = new String[size];
+ labels = new int[size];
+ }
+
+ public void add(String dimensionName, int label) {
+ dimensionNames[index] = dimensionName;
+ labels[index] = label;
+ index++;
+ }
+
+ public PartialAddress build() {
+ return new PartialAddress(this);
+ }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 05999ff1240..fda6e8ef86c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -83,7 +83,7 @@ public class Concat extends PrimitiveTensorFunction {
for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) {
IndexedTensor.SubspaceIterator ibSubspace = ib.next();
while (ibSubspace.hasNext()) {
- java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry
+ Tensor.Cell bCell = ibSubspace.next();
TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes,
concatType, offset, dimension);
if (combinedAddress == null) continue; // incompatible
@@ -125,21 +125,21 @@ public class Concat extends PrimitiveTensorFunction {
/** Returns the concrete (not type) dimension sizes resulting from combining a and b */
private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) {
- DimensionSizes.Builder joinedSizes = new DimensionSizes.Builder(concatType.dimensions().size());
- for (int i = 0; i < joinedSizes.dimensions(); i++) {
+ DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size());
+ for (int i = 0; i < concatSizes.dimensions(); i++) {
String currentDimension = concatType.dimensions().get(i).name();
int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0);
int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0);
if (currentDimension.equals(concatDimension))
- joinedSizes.set(i, aSize + bSize);
+ concatSizes.set(i, aSize + bSize);
else if (aSize != 0 && bSize != 0 && aSize!=bSize )
throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " +
"concatenating " + a.type() + " and " + b.type() + " along dimension " +
concatDimension + ", but was " + aSize + " and " + bSize);
else
- joinedSizes.set(i, Math.max(aSize, bSize));
+ concatSizes.set(i, Math.max(aSize, bSize));
}
- return joinedSizes.build();
+ return concatSizes.build();
}
/**
@@ -150,13 +150,13 @@ public class Concat extends PrimitiveTensorFunction {
*/
private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
TensorType concatType, int concatOffset, String concatDimension) {
- int[] joinedLabels = new int[concatType.dimensions().size()];
- Arrays.fill(joinedLabels, -1);
+ int[] combinedLabels = new int[concatType.dimensions().size()];
+ Arrays.fill(combinedLabels, -1);
int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get();
- mapContent(a, joinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
- boolean compatible = mapContent(b, joinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
+ mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension
+ boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here
if ( ! compatible) return null;
- return TensorAddress.of(joinedLabels);
+ return TensorAddress.of(combinedLabels);
}
/**
@@ -166,7 +166,7 @@ public class Concat extends PrimitiveTensorFunction {
* fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name())
* If some dimension in fromType is not present in toType, the corresponding index will be -1
*/
- // TODO: Stolen from join - put on TensorType?
+ // TODO: Stolen from join
private int[] mapIndexes(TensorType fromType, TensorType toType) {
int[] toIndexes = new int[fromType.dimensions().size()];
for (int i = 0; i < fromType.dimensions().size(); i++)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 23865e1cc1c..ceade39ce42 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -2,8 +2,10 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -156,16 +158,20 @@ public class Join extends PrimitiveTensorFunction {
}
}
- private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) {
- DimensionSizes.Builder b = new DimensionSizes.Builder(joinedType.dimensions().size());
- for (int i = 0; i < b.dimensions(); i++) {
- Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name());
- if (subspaceIndex.isPresent())
- b.set(i, Math.min(superspace.dimensionSizes().size(i), subspace.dimensionSizes().size(subspaceIndex.get())));
- else
- b.set(i, superspace.dimensionSizes().size(i));
+ private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) {
+ DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size());
+ for (int i = 0; i < builder.dimensions(); i++) {
+ String dimensionName = joinedType.dimensions().get(i).name();
+ Optional<Integer> aIndex = a.type().indexOfDimension(dimensionName);
+ Optional<Integer> bIndex = b.type().indexOfDimension(dimensionName);
+ if (aIndex.isPresent() && bIndex.isPresent())
+ builder.set(i, Math.min(b.dimensionSizes().size(bIndex.get()), a.dimensionSizes().size(aIndex.get())));
+ else if (aIndex.isPresent())
+ builder.set(i, a.dimensionSizes().size(aIndex.get()));
+ else if (bIndex.isPresent())
+ builder.set(i, b.dimensionSizes().size(bIndex.get()));
}
- return b.build();
+ return builder.build();
}
private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) {
@@ -200,6 +206,69 @@ public class Join extends PrimitiveTensorFunction {
/** Slow join which works for any two tensors */
private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) {
+ if (a instanceof IndexedTensor && b instanceof IndexedTensor)
+ return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType);
+ else
+ return mappedGeneralJoin(a, b, joinedType);
+ }
+
+ private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) {
+ DimensionSizes joinedSize = joinedSize(joinedType, a, b);
+ Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize);
+ int[] aToIndexes = mapIndexes(a.type(), joinedType);
+ int[] bToIndexes = mapIndexes(b.type(), joinedType);
+ joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder);
+ joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
+ return builder.build();
+ }
+
+ private void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize,
+ int[] aToIndexes, int[] bToIndexes, boolean reversedOrder, Tensor.Builder builder) {
+ Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames());
+ Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames());
+
+ DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize);
+ DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize);
+
+ // for each combination of dimensions only in a
+ for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) {
+ IndexedTensor.SubspaceIterator aSubspace = ia.next();
+ // for each combination of dimensions in a which is also in b
+ while (aSubspace.hasNext()) {
+ Tensor.Cell aCell = aSubspace.next();
+ PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions);
+ // for each matching combination of dimensions ony in b
+ for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) {
+ Tensor.Cell bCell = bSubspace.next();
+ TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType);
+ double joinedValue = reversedOrder ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue())
+ : combinator.applyAsDouble(aCell.getValue(), bCell.getValue());
+ builder.cell(joinedAddress, joinedValue);
+ }
+ }
+ }
+ }
+
+ private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) {
+ PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size());
+ for (int i = 0; i < addressType.dimensions().size(); i++)
+ if (retainDimensions.contains(addressType.dimensions().get(i).name()))
+ builder.add(addressType.dimensions().get(i).name(), address.intLabel(i));
+ return builder.build();
+ }
+
+ /** Returns the sizes from the joined sizes which are present in the type argument */
+ private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) {
+ DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
+ int dimensionIndex = 0;
+ for (int i = 0; i < joinedType.dimensions().size(); i++) {
+ if (type.dimensionNames().contains(joinedType.dimensions().get(i).name()))
+ builder.set(dimensionIndex++, joinedSizes.size(i));
+ }
+ return builder.build();
+ }
+
+ private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
Tensor.Builder builder = Tensor.Builder.of(joinedType);
@@ -207,8 +276,8 @@ public class Join extends PrimitiveTensorFunction {
Map.Entry<TensorAddress, Double> aCell = aIterator.next();
for (Iterator<Tensor.Cell> bIterator = b.cellIterator(); bIterator.hasNext(); ) {
Map.Entry<TensorAddress, Double> bCell = bIterator.next();
- TensorAddress combinedAddress = combineAddresses(aCell.getKey(), aToIndexes,
- bCell.getKey(), bToIndexes, joinedType);
+ TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes,
+ bCell.getKey(), bToIndexes, joinedType);
if (combinedAddress == null) continue; // not combinable
builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue()));
}
@@ -230,8 +299,8 @@ public class Join extends PrimitiveTensorFunction {
return toIndexes;
}
- private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
- TensorType joinedType) {
+ private TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes,
+ TensorType joinedType) {
String[] joinedLabels = new String[joinedType.dimensions().size()];
mapContent(a, joinedLabels, aToIndexes);
boolean compatible = mapContent(b, joinedLabels, bToIndexes);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index de7e49c46e4..2f060239eb1 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -106,41 +106,51 @@ public class TensorFunctionBenchmark {
// ---------------- Mapped with extra space (sidesteps current special-case optimizations):
// 410 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true);
System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time);
// 770 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true);
System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time);
// ---------------- Mapped:
// 2.6 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false);
System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time);
// 6.8 ms
- time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false);
+ System.gc();
+ time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false);
System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time);
// ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations):
- // 1600 ms
- time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
+ // 30 ms
+ System.gc();
+ time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time);
- // 1800 ms
- time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
+ // 27 ms
+ System.gc();
+ time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true);
System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time);
// ---------------- Indexed unbound:
// 0.14 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false);
System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time);
// 0.14 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false);
System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time);
// ---------------- Indexed bound:
// 0.14 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false);
System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time);
// 0.14 ms
+ System.gc();
time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false);
System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time);
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
index f2b55c74066..bc2d1f21717 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java
@@ -34,4 +34,15 @@ public class JoinTestCase {
t2.divide(t1));
}
+ @Test
+ public void testGeneralJoin() {
+ assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"),
+ Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:4, {x:2}:6 }")
+ .divide(Tensor.from("tensor(y[]):{{y:0}:2}")));
+
+ assertEquals(Tensor.from("tensor(x[],y[],z[]):{ {x:0,y:0,z:0}:3, {x:1,y:0,z:0}:4, {x:0,y:1,z:0}:5, {x:1,y:1,z:0}:6 }"),
+ Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:6, {x:1,y:0}:8, {x:0,y:1}:20, {x:1,y:1}:24 }")
+ .divide(Tensor.from("tensor(y[],z[]):{ {y:0,z:0}:2, {y:1,z:0}:4, {y:2,z:0}:6 }")));
+ }
+
}