diff options
17 files changed, 396 insertions, 117 deletions
diff --git a/container-search/src/main/java/com/yahoo/fs4/QueryPacket.java b/container-search/src/main/java/com/yahoo/fs4/QueryPacket.java index 16414e8eabd..3a695efa383 100644 --- a/container-search/src/main/java/com/yahoo/fs4/QueryPacket.java +++ b/container-search/src/main/java/com/yahoo/fs4/QueryPacket.java @@ -105,7 +105,8 @@ public class QueryPacket extends Packet { int startOfFieldToSave; boolean sendSessionKey = query.getGroupingSessionCache() || query.getRanking().getQueryCache(); - buffer.putInt(getFeatureInt(sendSessionKey)); + int featureFlag = getFeatureInt(sendSessionKey); + buffer.putInt(featureFlag); IntegerCompressor.putCompressedPositiveNumber(getOffset(), buffer); IntegerCompressor.putCompressedPositiveNumber(getHits(), buffer); @@ -117,7 +118,7 @@ public class QueryPacket extends Packet { Item.putString(query.getRanking().getProfile(), buffer); queryPacketData.setRankProfile(buffer, startOfFieldToSave); - if ( query.hasEncodableProperties()) { + if ( (featureFlag & QF_PROPERTIES) != 0) { startOfFieldToSave = buffer.position(); query.encodeAsProperties(buffer, true); queryPacketData.setPropertyMaps(buffer, startOfFieldToSave); @@ -125,17 +126,18 @@ public class QueryPacket extends Packet { // Language not needed when sending query stacks - if (query.getRanking().getSorting() != null) { + if ((featureFlag & QF_SORTSPEC) != 0) { int sortSpecLengthPosition=buffer.position(); buffer.putInt(0); int sortSpecLength = query.getRanking().getSorting().encode(buffer); buffer.putInt(sortSpecLengthPosition, sortSpecLength); } - if (getGroupingList(query).size() > 0) { + if ( (featureFlag & QF_GROUPSPEC) != 0) { + List<Grouping> groupingList = GroupingExecutor.getGroupingList(query); BufferSerializer gbuf = new BufferSerializer(new GrowableByteBuffer()); - gbuf.putInt(null, getGroupingList(query).size()); - for (Grouping g: getGroupingList(query)){ + gbuf.putInt(null, groupingList.size()); + for (Grouping g: groupingList){ g.serialize(gbuf); } gbuf.getBuf().flip(); @@ -150,7 +152,7 @@ public class QueryPacket extends Packet { buffer.put(query.getSessionId(true).asUtf8String().getBytes()); } - if (query.getRanking().getLocation() != null) { + if ((featureFlag & QF_LOCATION) != 0) { startOfFieldToSave = buffer.position(); int locationLengthPosition=buffer.position(); buffer.putInt(0); @@ -184,16 +186,14 @@ public class QueryPacket extends Packet { static final int QF_SESSIONID = 0x00800000; private int getFeatureInt(boolean sendSessionId) { - int features = QF_PARSEDQUERY; // this bitmask means "parsed query" in query packet. - // we always use a parsed query here + int features = QF_PARSEDQUERY | QF_RANKP; // this bitmask means "parsed query" in query packet. + // And rank properties. Both are always present - features |= QF_RANKP; // hasRankProfile - - features |= (query.getRanking().getSorting() != null) ? QF_SORTSPEC : 0; - features |= (query.getRanking().getLocation() != null) ? QF_LOCATION : 0; - features |= (query.hasEncodableProperties()) ? QF_PROPERTIES : 0; - features |= (getGroupingList(query).size() > 0) ? QF_GROUPSPEC : 0; - features |= (sendSessionId) ? QF_SESSIONID : 0; + features |= (query.getRanking().getSorting() != null) ? QF_SORTSPEC : 0; + features |= (query.getRanking().getLocation() != null) ? QF_LOCATION : 0; + features |= (query.hasEncodableProperties()) ? QF_PROPERTIES : 0; + features |= GroupingExecutor.hasGroupingList(query) ? QF_GROUPSPEC : 0; + features |= (sendSessionId) ? QF_SESSIONID : 0; return features; } @@ -233,10 +233,6 @@ public class QueryPacket extends Packet { return "Query x packet [query: " + query + "]"; } - private static List<Grouping> getGroupingList(Query query) { - return Collections.unmodifiableList(GroupingExecutor.getGroupingList(query)); - } - static int getQueryFlags(Query query) { int flags = 0; diff --git a/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java b/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java index e5e91f21f5f..70942fa2553 100644 --- a/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java +++ b/container-search/src/main/java/com/yahoo/search/grouping/vespa/GroupingExecutor.java @@ -363,6 +363,11 @@ public class GroupingExecutor extends Searcher { return (List<Grouping>)obj; } + public static boolean hasGroupingList(Query query) { + Object obj = query.properties().get(PROP_GROUPINGLIST); + return (obj instanceof List); + } + /** * Sets the list of {@link Grouping} objects assigned to the given query. This method overwrites any grouping * objects already assigned to the query. diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java index fdb4b8f6339..7f8121c2138 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/StoragePolicy.java @@ -326,9 +326,13 @@ public class StoragePolicy extends ExternalSlobrokPolicy { } } if (context.usedState != null && newState.getVersion() <= context.usedState.getVersion()) { - reply.setRetryDelay(-1); + if (reply.getRetryDelay() <= 0.0) { + reply.setRetryDelay(-1); + } } else { - reply.setRetryDelay(0); + if (reply.getRetryDelay() <= 0.0) { + reply.setRetryDelay(0); + } } if (context.calculatedDistributor == null) { if (cachedClusterState == null) { diff --git a/messagebus/pom.xml b/messagebus/pom.xml index a3ae2fb54c2..37a1c681497 100644 --- a/messagebus/pom.xml +++ b/messagebus/pom.xml @@ -19,6 +19,12 @@ <scope>test</scope> </dependency> <dependency> + <groupId>org.hamcrest</groupId> + <artifactId>hamcrest-all</artifactId> + <version>1.3</version> + <scope>test</scope> + </dependency> + <dependency> <groupId>com.yahoo.vespa</groupId> <artifactId>vespajlib</artifactId> <version>${project.version}</version> diff --git a/messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java b/messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java index 95cf7639d09..d8345180952 100644 --- a/messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java +++ b/messagebus/src/test/java/com/yahoo/messagebus/network/local/LocalNetworkTest.java @@ -15,6 +15,8 @@ import java.util.concurrent.TimeUnit; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.number.OrderingComparison.lessThan; +import static org.hamcrest.number.OrderingComparison.greaterThanOrEqualTo; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; @@ -109,8 +111,8 @@ public class LocalNetworkTest { assertEquals(ErrorCode.TIMEOUT, res.getError().getCode()); assertTrue(res.getError().getMessage().endsWith("Timed out in sendQ")); long end = System.currentTimeMillis(); - assertTrue(end - start >= (TIMEOUT*0.98)); // Different clocks are used.... - assertTrue(end - start < 2*TIMEOUT); + assertThat(end, greaterThanOrEqualTo(start+TIMEOUT)); + assertThat(end, lessThan(start+2*TIMEOUT)); msg = serverB.messages.poll(60, TimeUnit.SECONDS); assertThat(msg, instanceOf(SimpleMessage.class)); diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java index a3896808177..a1863262ba0 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/config/ConnectionParams.java @@ -4,15 +4,14 @@ package com.yahoo.vespa.http.client.config; import com.google.common.annotations.Beta; import com.google.common.collect.ArrayListMultimap; import com.google.common.collect.Multimap; -import com.yahoo.vespa.http.client.Session; - import net.jcip.annotations.Immutable; import javax.net.ssl.SSLContext; - import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.Map; +import java.util.Objects; import java.util.concurrent.TimeUnit; /** @@ -36,6 +35,7 @@ public final class ConnectionParams { private SSLContext sslContext = null; private long connectionTimeout = TimeUnit.SECONDS.toMillis(60); private final Multimap<String, String> headers = ArrayListMultimap.create(); + private final Map<String, HeaderProvider> headerProviders = new HashMap<>(); private int numPersistentConnectionsPerEndpoint = 8; private String proxyHost = null; private int proxyPort = 8080; @@ -73,6 +73,24 @@ public final class ConnectionParams { } /** + * Adds a header provider for dynamic headers; headers where the value may change during a feeding session + * (e.g. security tokens with limited life time). Only one {@link HeaderProvider} is allowed for a given header name. + * + * @param provider A provider for a dynamic header + * @return pointer to builder. + * @throws IllegalArgumentException if a provider is already registered for the given header name + */ + public Builder addDynamicHeader(String headerName, HeaderProvider provider) { + Objects.requireNonNull(headerName, "Header name cannot be null"); + Objects.requireNonNull(provider, "Header provider cannot be null"); + if (headerProviders.containsKey(headerName)) { + throw new IllegalArgumentException("Provider already registered for name '" + headerName + "'"); + } + headerProviders.put(headerName, provider); + return this; + } + + /** * The number of connections between the http client and the gateways. A very low number can result * in the network not fully utilized and the round-trip time can be a limiting factor. A low number * can cause skew in distribution of load between gateways. A too high number will cause @@ -203,6 +221,7 @@ public final class ConnectionParams { sslContext, connectionTimeout, headers, + headerProviders, numPersistentConnectionsPerEndpoint, proxyHost, proxyPort, @@ -254,6 +273,7 @@ public final class ConnectionParams { private final SSLContext sslContext; private final long connectionTimeout; private final Multimap<String, String> headers = ArrayListMultimap.create(); + private final Map<String, HeaderProvider> headerProviders = new HashMap<>(); private final int numPersistentConnectionsPerEndpoint; private final String proxyHost; private final int proxyPort; @@ -270,6 +290,7 @@ public final class ConnectionParams { SSLContext sslContext, long connectionTimeout, Multimap<String, String> headers, + Map<String, HeaderProvider> headerProviders, int numPersistentConnectionsPerEndpoint, String proxyHost, int proxyPort, @@ -284,6 +305,7 @@ public final class ConnectionParams { this.sslContext = sslContext; this.connectionTimeout = connectionTimeout; this.headers.putAll(headers); + this.headerProviders.putAll(headerProviders); this.numPersistentConnectionsPerEndpoint = numPersistentConnectionsPerEndpoint; this.proxyHost = proxyHost; this.proxyPort = proxyPort; @@ -305,6 +327,10 @@ public final class ConnectionParams { return Collections.unmodifiableCollection(headers.entries()); } + public Map<String, HeaderProvider> getDynamicHeaders() { + return Collections.unmodifiableMap(headerProviders); + } + public int getNumPersistentConnectionsPerEndpoint() { return numPersistentConnectionsPerEndpoint; } @@ -347,4 +373,14 @@ public final class ConnectionParams { return printTraceToStdErr; } + /** + * A header provider that provides a header value. {@link #getHeaderValue()} is called each time a new HTTP request + * is constructed by {@link com.yahoo.vespa.http.client.FeedClient}. + * + * Important: The implementation of {@link #getHeaderValue()} must be thread-safe! + */ + public interface HeaderProvider { + String getHeaderValue(); + } + } diff --git a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java index f3d751f67c7..6fae0765e14 100644 --- a/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java +++ b/vespa-http-client/src/main/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnection.java @@ -16,6 +16,8 @@ import org.apache.http.StatusLine; import org.apache.http.client.HttpClient; import org.apache.http.client.config.RequestConfig; import org.apache.http.client.methods.HttpPost; +import org.apache.http.config.Registry; +import org.apache.http.config.RegistryBuilder; import org.apache.http.conn.socket.ConnectionSocketFactory; import org.apache.http.conn.socket.PlainConnectionSocketFactory; import org.apache.http.conn.ssl.SSLConnectionSocketFactory; @@ -23,9 +25,6 @@ import org.apache.http.entity.InputStreamEntity; import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.impl.conn.PoolingHttpClientConnectionManager; -import org.apache.http.config.Registry; -import org.apache.http.config.RegistryBuilder; - import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; @@ -36,13 +35,13 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.UUID; import java.util.concurrent.TimeUnit; import java.util.logging.Level; import java.util.logging.Logger; import java.util.zip.GZIPOutputStream; -import java.util.UUID; - /** * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> * @since 5.1.20 @@ -228,6 +227,13 @@ class ApacheGatewayConnection implements GatewayConnection { for (Map.Entry<String, String> extraHeader : connectionParams.getHeaders()) { httpPost.addHeader(extraHeader.getKey(), extraHeader.getValue()); } + connectionParams.getDynamicHeaders().forEach((headerName, provider) -> { + String headerValue = Objects.requireNonNull( + provider.getHeaderValue(), + provider.getClass().getName() + ".getHeader() returned null as header value!"); + httpPost.addHeader(headerName, headerValue); + }); + if (useCompression) { httpPost.setHeader("Content-Encoding", "gzip"); } diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java index bd894e84ed4..a095cb184a0 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/V3HttpAPITest.java @@ -109,7 +109,9 @@ public class V3HttpAPITest extends TestOnCiBuildingSystemOnly { @Test public void requireThatBadResponseCodeFails() throws Exception { - testServerWithMock(new V3MockParsingRequestHandler(407), true); + testServerWithMock(new V3MockParsingRequestHandler(401/*Unauthorized*/), true); + testServerWithMock(new V3MockParsingRequestHandler(403/*Forbidden*/), true); + testServerWithMock(new V3MockParsingRequestHandler(407/*Proxy Authentication Required*/), true); } @Test diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java index 49ffabbf1d0..cb9270fdd3f 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/config/ConnectionParamsTest.java @@ -7,12 +7,11 @@ import javax.net.ssl.SSLContext; import java.security.NoSuchAlgorithmException; import java.util.Iterator; import java.util.Map; -import java.util.concurrent.TimeUnit; import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.not; import static org.hamcrest.core.Is.is; import static org.hamcrest.core.IsNull.nullValue; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertThat; /** @@ -57,4 +56,18 @@ public class ConnectionParamsTest { assertThat(header3.getValue(), equalTo("Apple")); } + @Test + public void header_providers_are_registered() { + ConnectionParams.HeaderProvider dummyProvider1 = () -> "fooValue"; + ConnectionParams.HeaderProvider dummyProvider2 = () -> "barValue"; + ConnectionParams params = new ConnectionParams.Builder() + .addDynamicHeader("foo", dummyProvider1) + .addDynamicHeader("bar", dummyProvider2) + .build(); + Map<String, ConnectionParams.HeaderProvider> headerProviders = params.getDynamicHeaders(); + assertEquals(2, headerProviders.size()); + assertEquals(dummyProvider1, headerProviders.get("foo")); + assertEquals(dummyProvider2, headerProviders.get("bar")); + } + } diff --git a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java index fa4ad8fa175..456e5184f6e 100644 --- a/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java +++ b/vespa-http-client/src/test/java/com/yahoo/vespa/http/client/core/communication/ApacheGatewayConnectionTest.java @@ -1,19 +1,13 @@ // Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.http.client.core.communication; -import java.io.ByteArrayInputStream; -import java.io.IOException; -import java.io.InputStream; -import java.nio.charset.StandardCharsets; -import java.util.ArrayList; -import java.util.List; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicInteger; - -import com.yahoo.vespa.http.client.core.Headers; import com.yahoo.vespa.http.client.TestUtils; +import com.yahoo.vespa.http.client.config.ConnectionParams; +import com.yahoo.vespa.http.client.config.Endpoint; +import com.yahoo.vespa.http.client.config.FeedParams; import com.yahoo.vespa.http.client.core.Document; +import com.yahoo.vespa.http.client.core.Headers; +import com.yahoo.vespa.http.client.core.ServerResponseException; import org.apache.http.Header; import org.apache.http.HeaderElement; import org.apache.http.HttpEntity; @@ -24,16 +18,29 @@ import org.apache.http.client.HttpClient; import org.apache.http.client.methods.HttpPost; import org.apache.http.entity.InputStreamEntity; import org.junit.Test; -import com.yahoo.vespa.http.client.config.ConnectionParams; -import com.yahoo.vespa.http.client.config.Endpoint; -import com.yahoo.vespa.http.client.config.FeedParams; -import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.core.Is.is; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.stub; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class ApacheGatewayConnectionTest { @@ -47,23 +54,14 @@ public class ApacheGatewayConnectionTest { .setEnableV3Protocol(true) .build(); final List<Document> documents = new ArrayList<>(); - - final ApacheGatewayConnection.HttpClientFactory mockFactory = - mock(ApacheGatewayConnection.HttpClientFactory.class); - final HttpClient httpClientMock = mock(HttpClient.class); - when(mockFactory.createClient()).thenReturn(httpClientMock); - final CountDownLatch verifyContentSentLatch = new CountDownLatch(1); final String vespaDocContent ="Hello, I a JSON doc."; final String docId = "42"; final AtomicInteger requestsReceived = new AtomicInteger(0); - // This is the fake server, takes header client ID and uses this as session Id. - stub(httpClientMock.execute(any())).toAnswer((Answer) invocation -> { - final Object[] args = invocation.getArguments(); - final HttpPost post = (HttpPost) args[0]; + ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> { final Header clientIdHeader = post.getFirstHeader(Headers.CLIENT_ID); verifyContentSentLatch.countDown(); return httpResponse(clientIdHeader.getValue(), "3"); @@ -94,15 +92,8 @@ public class ApacheGatewayConnectionTest { .setEnableV3Protocol(true) .build(); - final ApacheGatewayConnection.HttpClientFactory mockFactory = - mock(ApacheGatewayConnection.HttpClientFactory.class); - final HttpClient httpClientMock = mock(HttpClient.class); - when(mockFactory.createClient()).thenReturn(httpClientMock); - // This is the fake server, returns wrong session Id. - stub(httpClientMock.execute(any())).toAnswer(invocation -> { - return httpResponse("Wrong Id from server", "3"); - }); + ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> httpResponse("Wrong Id from server", "3")); ApacheGatewayConnection apacheGatewayConnection = new ApacheGatewayConnection( @@ -148,11 +139,6 @@ public class ApacheGatewayConnectionTest { .build(); final List<Document> documents = new ArrayList<>(); - final ApacheGatewayConnection.HttpClientFactory mockFactory = - mock(ApacheGatewayConnection.HttpClientFactory.class); - final HttpClient httpClientMock = mock(HttpClient.class); - when(mockFactory.createClient()).thenReturn(httpClientMock); - final CountDownLatch verifyContentSentLatch = new CountDownLatch(1); final String vespaDocContent ="Hello, I a JSON doc."; @@ -161,13 +147,11 @@ public class ApacheGatewayConnectionTest { final AtomicInteger requestsReceived = new AtomicInteger(0); // This is the fake server, checks that DATA_FORMAT header is set properly. - stub(httpClientMock.execute(any())).toAnswer((Answer) invocation -> { - final Object[] args = invocation.getArguments(); - final HttpPost post = (HttpPost) args[0]; + ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> { final Header header = post.getFirstHeader(Headers.DATA_FORMAT); if (requestsReceived.incrementAndGet() == 1) { // This is handshake, it is not json. - assert(header == null); + assert (header == null); return httpResponse("clientId", "3"); } assertNotNull(header); @@ -218,11 +202,6 @@ public class ApacheGatewayConnectionTest { .build(); final List<Document> documents = new ArrayList<>(); - final ApacheGatewayConnection.HttpClientFactory mockFactory = - mock(ApacheGatewayConnection.HttpClientFactory.class); - final HttpClient httpClientMock = mock(HttpClient.class); - when(mockFactory.createClient()).thenReturn(httpClientMock); - final CountDownLatch verifyContentSentLatch = new CountDownLatch(1); final String vespaDocContent ="Hello, I am the document data."; @@ -232,9 +211,7 @@ public class ApacheGatewayConnectionTest { // When sending data on http client, check if it is compressed. If compressed, unzip, check result, // and count down latch. - stub(httpClientMock.execute(any())).toAnswer((Answer) invocation -> { - final Object[] args = invocation.getArguments(); - final HttpPost post = (HttpPost) args[0]; + ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> { final Header header = post.getFirstHeader("Content-Encoding"); if (header != null && header.getValue().equals("gzip")) { final String rawContent = TestUtils.zipStreamToString(post.getEntity().getContent()); @@ -249,6 +226,7 @@ public class ApacheGatewayConnectionTest { } return httpResponse("clientId", "3"); }); + StatusLine statusLineMock = mock(StatusLine.class); when(statusLineMock.getStatusCode()).thenReturn(200); @@ -269,6 +247,68 @@ public class ApacheGatewayConnectionTest { assertTrue(verifyContentSentLatch.await(10, TimeUnit.SECONDS)); } + @Test + public void dynamic_headers_are_added_to_the_response() throws IOException, ServerResponseException, InterruptedException { + ConnectionParams.HeaderProvider headerProvider = mock(ConnectionParams.HeaderProvider.class); + when(headerProvider.getHeaderValue()) + .thenReturn("v1") + .thenReturn("v2") + .thenReturn("v3"); + + ConnectionParams connectionParams = new ConnectionParams.Builder() + .addDynamicHeader("foo", headerProvider) + .build(); + + CountDownLatch verifyContentSentLatch = new CountDownLatch(1); + + AtomicInteger counter = new AtomicInteger(1); + ApacheGatewayConnection.HttpClientFactory mockFactory = mockHttpClientFactory(post -> { + Header[] fooHeader = post.getHeaders("foo"); + assertEquals(1, fooHeader.length); + assertEquals("foo", fooHeader[0].getName()); + assertEquals("v" + counter.getAndIncrement(), fooHeader[0].getValue()); + verifyContentSentLatch.countDown(); + return httpResponse("clientId", "3"); + + }); + + ApacheGatewayConnection apacheGatewayConnection = + new ApacheGatewayConnection( + Endpoint.create("hostname", 666, false), + new FeedParams.Builder().build(), + "", + connectionParams, + mockFactory, + "clientId"); + apacheGatewayConnection.connect(); + apacheGatewayConnection.handshake(); + + List<Document> documents = new ArrayList<>(); + documents.add(createDoc("42", "content", true)); + apacheGatewayConnection.writeOperations(documents); + apacheGatewayConnection.writeOperations(documents); + assertTrue(verifyContentSentLatch.await(10, TimeUnit.SECONDS)); + + verify(headerProvider, times(3)).getHeaderValue(); // 1x connect(), 2x writeOperations() + } + + private static ApacheGatewayConnection.HttpClientFactory mockHttpClientFactory(HttpExecuteMock httpExecuteMock) throws IOException { + ApacheGatewayConnection.HttpClientFactory mockFactory = + mock(ApacheGatewayConnection.HttpClientFactory.class); + HttpClient httpClientMock = mock(HttpClient.class); + when(mockFactory.createClient()).thenReturn(httpClientMock); + stub(httpClientMock.execute(any())).toAnswer((Answer) invocation -> { + Object[] args = invocation.getArguments(); + HttpPost post = (HttpPost) args[0]; + return httpExecuteMock.execute(post); + }); + return mockFactory; + } + + @FunctionalInterface private interface HttpExecuteMock { + HttpResponse execute(HttpPost httpPost) throws IOException; + } + private Document createDoc(final String docId, final String content, boolean useJson) throws IOException { return new Document(docId, content.getBytes(), null /* context */); } @@ -315,4 +355,4 @@ public class ApacheGatewayConnectionTest { when(httpEntityMock.getContent()).thenReturn(inputs); return httpResponseMock; } -}
\ No newline at end of file +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index 76340bb7d8f..daa85cc51e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -1,5 +1,7 @@ package com.yahoo.tensor; +import com.google.common.annotations.Beta; + import java.util.Arrays; /** @@ -7,6 +9,7 @@ import java.util.Arrays; * * @author bratseth */ +@Beta public final class DimensionSizes { private final int[] sizes; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index d69cf65ee8d..9315922f57a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -52,6 +52,21 @@ public class IndexedTensor implements Tensor { return new CellIterator(); } + /** Returns an iterator over all the cells in this tensor which matches the given partial address */ + // TODO: Move up to Tensor and create a mixed tensor which can implement it (and subspace iterators) efficiently + public SubspaceIterator cellIterator(PartialAddress partialAddress, DimensionSizes iterationSizes) { + int[] startAddress = new int[type().dimensions().size()]; + List<Integer> iterateDimensions = new ArrayList<>(); + for (int i = 0; i < type().dimensions().size(); i++) { + int partialAddressLabel = partialAddress.intLabel(type.dimensions().get(i).name()); + if (partialAddressLabel >= 0) // iterate at this label + startAddress[i] = partialAddressLabel; + else // iterate over this dimension + iterateDimensions.add(i); + } + return new SubspaceIterator(iterateDimensions, startAddress, iterationSizes); + } + /** * Returns an iterator over the values of this. * Values are returned in order of increasing indexes in each dimension, increasing diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java new file mode 100644 index 00000000000..463b7f3e99f --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -0,0 +1,61 @@ +package com.yahoo.tensor; + +import com.google.common.annotations.Beta; + +/** + * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors + * dimensions. + * + * @author bratseth + */ +// Implementation notes: +// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation. +// We also avoid non-essential error checking. +// - We can add support for string labels later without breaking the API +@Beta +public class PartialAddress { + + // Two arrays which contains corresponding dimension=label pairs. + // The sizes of these are always equal. + private final String[] dimensionNames; + private final int[] labels; + + private PartialAddress(Builder builder) { + this.dimensionNames = builder.dimensionNames; + this.labels = builder.labels; + builder.dimensionNames = null; // invalidate builder to safely take over array ownership + builder.labels = null; + } + + /** Returns the int label of this dimension, or -1 if no label is specified for it */ + int intLabel(String dimensionName) { + for (int i = 0; i < dimensionNames.length; i++) + if (dimensionNames[i].equals(dimensionName)) + return labels[i]; + return -1; + } + + public static class Builder { + + private String[] dimensionNames; + private int[] labels; + private int index = 0; + + public Builder(int size) { + dimensionNames = new String[size]; + labels = new int[size]; + } + + public void add(String dimensionName, int label) { + dimensionNames[index] = dimensionName; + labels[index] = label; + index++; + } + + public PartialAddress build() { + return new PartialAddress(this); + } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 05999ff1240..fda6e8ef86c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -83,7 +83,7 @@ public class Concat extends PrimitiveTensorFunction { for (Iterator<IndexedTensor.SubspaceIterator> ib = b.subspaceIterator(otherADimensions); ib.hasNext();) { IndexedTensor.SubspaceIterator ibSubspace = ib.next(); while (ibSubspace.hasNext()) { - java.util.Map.Entry<TensorAddress, Double> bCell = ibSubspace.next(); // TODO: Create Cell convenience subclass for Map.Entry + Tensor.Cell bCell = ibSubspace.next(); TensorAddress combinedAddress = combineAddresses(aAddress, aToIndexes, bCell.getKey(), bToIndexes, concatType, offset, dimension); if (combinedAddress == null) continue; // incompatible @@ -125,21 +125,21 @@ public class Concat extends PrimitiveTensorFunction { /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { - DimensionSizes.Builder joinedSizes = new DimensionSizes.Builder(concatType.dimensions().size()); - for (int i = 0; i < joinedSizes.dimensions(); i++) { + DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); + for (int i = 0; i < concatSizes.dimensions(); i++) { String currentDimension = concatType.dimensions().get(i).name(); int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0); int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0); if (currentDimension.equals(concatDimension)) - joinedSizes.set(i, aSize + bSize); + concatSizes.set(i, aSize + bSize); else if (aSize != 0 && bSize != 0 && aSize!=bSize ) throw new IllegalArgumentException("Dimension " + currentDimension + " must be of the same size when " + "concatenating " + a.type() + " and " + b.type() + " along dimension " + concatDimension + ", but was " + aSize + " and " + bSize); else - joinedSizes.set(i, Math.max(aSize, bSize)); + concatSizes.set(i, Math.max(aSize, bSize)); } - return joinedSizes.build(); + return concatSizes.build(); } /** @@ -150,13 +150,13 @@ public class Concat extends PrimitiveTensorFunction { */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, TensorType concatType, int concatOffset, String concatDimension) { - int[] joinedLabels = new int[concatType.dimensions().size()]; - Arrays.fill(joinedLabels, -1); + int[] combinedLabels = new int[concatType.dimensions().size()]; + Arrays.fill(combinedLabels, -1); int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); - mapContent(a, joinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension - boolean compatible = mapContent(b, joinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here + mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension + boolean compatible = mapContent(b, combinedLabels, bToIndexes, concatDimensionIndex, concatOffset); // ... which is overwritten by the right value here if ( ! compatible) return null; - return TensorAddress.of(joinedLabels); + return TensorAddress.of(combinedLabels); } /** @@ -166,7 +166,7 @@ public class Concat extends PrimitiveTensorFunction { * fromType.dimensions().get(i).name.equals(toType.dimensions().get(n).name()) * If some dimension in fromType is not present in toType, the corresponding index will be -1 */ - // TODO: Stolen from join - put on TensorType? + // TODO: Stolen from join private int[] mapIndexes(TensorType fromType, TensorType toType) { int[] toIndexes = new int[fromType.dimensions().size()]; for (int i = 0; i < fromType.dimensions().size(); i++) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 23865e1cc1c..ceade39ce42 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -2,8 +2,10 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -156,16 +158,20 @@ public class Join extends PrimitiveTensorFunction { } } - private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor subspace, IndexedTensor superspace) { - DimensionSizes.Builder b = new DimensionSizes.Builder(joinedType.dimensions().size()); - for (int i = 0; i < b.dimensions(); i++) { - Optional<Integer> subspaceIndex = subspace.type().indexOfDimension(joinedType.dimensions().get(i).name()); - if (subspaceIndex.isPresent()) - b.set(i, Math.min(superspace.dimensionSizes().size(i), subspace.dimensionSizes().size(subspaceIndex.get()))); - else - b.set(i, superspace.dimensionSizes().size(i)); + private DimensionSizes joinedSize(TensorType joinedType, IndexedTensor a, IndexedTensor b) { + DimensionSizes.Builder builder = new DimensionSizes.Builder(joinedType.dimensions().size()); + for (int i = 0; i < builder.dimensions(); i++) { + String dimensionName = joinedType.dimensions().get(i).name(); + Optional<Integer> aIndex = a.type().indexOfDimension(dimensionName); + Optional<Integer> bIndex = b.type().indexOfDimension(dimensionName); + if (aIndex.isPresent() && bIndex.isPresent()) + builder.set(i, Math.min(b.dimensionSizes().size(bIndex.get()), a.dimensionSizes().size(aIndex.get()))); + else if (aIndex.isPresent()) + builder.set(i, a.dimensionSizes().size(aIndex.get())); + else if (bIndex.isPresent()) + builder.set(i, b.dimensionSizes().size(bIndex.get())); } - return b.build(); + return builder.build(); } private Tensor generalSubspaceJoin(Tensor subspace, Tensor superspace, TensorType joinedType, boolean reversedArgumentOrder) { @@ -200,6 +206,69 @@ public class Join extends PrimitiveTensorFunction { /** Slow join which works for any two tensors */ private Tensor generalJoin(Tensor a, Tensor b, TensorType joinedType) { + if (a instanceof IndexedTensor && b instanceof IndexedTensor) + return indexedGeneralJoin((IndexedTensor) a, (IndexedTensor) b, joinedType); + else + return mappedGeneralJoin(a, b, joinedType); + } + + private Tensor indexedGeneralJoin(IndexedTensor a, IndexedTensor b, TensorType joinedType) { + DimensionSizes joinedSize = joinedSize(joinedType, a, b); + Tensor.Builder builder = Tensor.Builder.of(joinedType, joinedSize); + int[] aToIndexes = mapIndexes(a.type(), joinedType); + int[] bToIndexes = mapIndexes(b.type(), joinedType); + joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder); + joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); + return builder.build(); + } + + private void joinTo(IndexedTensor a, IndexedTensor b, TensorType joinedType, DimensionSizes joinedSize, + int[] aToIndexes, int[] bToIndexes, boolean reversedOrder, Tensor.Builder builder) { + Set<String> sharedDimensions = Sets.intersection(a.type().dimensionNames(), b.type().dimensionNames()); + Set<String> dimensionsOnlyInA = Sets.difference(a.type().dimensionNames(), b.type().dimensionNames()); + + DimensionSizes aIterateSize = joinedSizeOf(a.type(), joinedType, joinedSize); + DimensionSizes bIterateSize = joinedSizeOf(b.type(), joinedType, joinedSize); + + // for each combination of dimensions only in a + for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(dimensionsOnlyInA, aIterateSize); ia.hasNext(); ) { + IndexedTensor.SubspaceIterator aSubspace = ia.next(); + // for each combination of dimensions in a which is also in b + while (aSubspace.hasNext()) { + Tensor.Cell aCell = aSubspace.next(); + PartialAddress matchingBCells = partialAddress(a.type(), aSubspace.address(), sharedDimensions); + // for each matching combination of dimensions ony in b + for (IndexedTensor.SubspaceIterator bSubspace = b.cellIterator(matchingBCells, bIterateSize); bSubspace.hasNext(); ) { + Tensor.Cell bCell = bSubspace.next(); + TensorAddress joinedAddress = joinAddresses(aCell.getKey(), aToIndexes, bCell.getKey(), bToIndexes, joinedType); + double joinedValue = reversedOrder ? combinator.applyAsDouble(bCell.getValue(), aCell.getValue()) + : combinator.applyAsDouble(aCell.getValue(), bCell.getValue()); + builder.cell(joinedAddress, joinedValue); + } + } + } + } + + private PartialAddress partialAddress(TensorType addressType, TensorAddress address, Set<String> retainDimensions) { + PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); + for (int i = 0; i < addressType.dimensions().size(); i++) + if (retainDimensions.contains(addressType.dimensions().get(i).name())) + builder.add(addressType.dimensions().get(i).name(), address.intLabel(i)); + return builder.build(); + } + + /** Returns the sizes from the joined sizes which are present in the type argument */ + private DimensionSizes joinedSizeOf(TensorType type, TensorType joinedType, DimensionSizes joinedSizes) { + DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); + int dimensionIndex = 0; + for (int i = 0; i < joinedType.dimensions().size(); i++) { + if (type.dimensionNames().contains(joinedType.dimensions().get(i).name())) + builder.set(dimensionIndex++, joinedSizes.size(i)); + } + return builder.build(); + } + + private Tensor mappedGeneralJoin(Tensor a, Tensor b, TensorType joinedType) { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); Tensor.Builder builder = Tensor.Builder.of(joinedType); @@ -207,8 +276,8 @@ public class Join extends PrimitiveTensorFunction { Map.Entry<TensorAddress, Double> aCell = aIterator.next(); for (Iterator<Tensor.Cell> bIterator = b.cellIterator(); bIterator.hasNext(); ) { Map.Entry<TensorAddress, Double> bCell = bIterator.next(); - TensorAddress combinedAddress = combineAddresses(aCell.getKey(), aToIndexes, - bCell.getKey(), bToIndexes, joinedType); + TensorAddress combinedAddress = joinAddresses(aCell.getKey(), aToIndexes, + bCell.getKey(), bToIndexes, joinedType); if (combinedAddress == null) continue; // not combinable builder.cell(combinedAddress, combinator.applyAsDouble(aCell.getValue(), bCell.getValue())); } @@ -230,8 +299,8 @@ public class Join extends PrimitiveTensorFunction { return toIndexes; } - private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType joinedType) { + private TensorAddress joinAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, + TensorType joinedType) { String[] joinedLabels = new String[joinedType.dimensions().size()]; mapContent(a, joinedLabels, aToIndexes); boolean compatible = mapContent(b, joinedLabels, bToIndexes); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index de7e49c46e4..2f060239eb1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -106,41 +106,51 @@ public class TensorFunctionBenchmark { // ---------------- Mapped with extra space (sidesteps current special-case optimizations): // 410 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); System.out.printf("Mapped vectors, x space time per join: %1$8.3f ms\n", time); // 770 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, true); System.out.printf("Mapped matrix, x space time per join: %1$8.3f ms\n", time); // ---------------- Mapped: // 2.6 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time); // 6.8 ms - time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); + System.gc(); + time = new TensorFunctionBenchmark().benchmark(1000, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped, false); System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed (unbound) with extra space (sidesteps current special-case optimizations): - // 1600 ms - time = new TensorFunctionBenchmark().benchmark(20, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); + // 30 ms + System.gc(); + time = new TensorFunctionBenchmark().benchmark(500, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed vectors, x space time per join: %1$8.3f ms\n", time); - // 1800 ms - time = new TensorFunctionBenchmark().benchmark(20, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); + // 27 ms + System.gc(); + time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, true); System.out.printf("Indexed matrix, x space time per join: %1$8.3f ms\n", time); // ---------------- Indexed unbound: // 0.14 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); System.out.printf("Indexed unbound vectors, time per join: %1$8.3f ms\n", time); // 0.14 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound, false); System.out.printf("Indexed unbound matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed bound: // 0.14 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, vectors(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); System.out.printf("Indexed bound vectors, time per join: %1$8.3f ms\n", time); // 0.14 ms + System.gc(); time = new TensorFunctionBenchmark().benchmark(50000, matrix(100, 300, TensorType.Dimension.Type.indexedBound), TensorType.Dimension.Type.indexedBound, false); System.out.printf("Indexed bound matrix, time per join: %1$8.3f ms\n", time); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java index f2b55c74066..bc2d1f21717 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/JoinTestCase.java @@ -34,4 +34,15 @@ public class JoinTestCase { t2.divide(t1)); } + @Test + public void testGeneralJoin() { + assertEquals(Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3 }"), + Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:4, {x:2}:6 }") + .divide(Tensor.from("tensor(y[]):{{y:0}:2}"))); + + assertEquals(Tensor.from("tensor(x[],y[],z[]):{ {x:0,y:0,z:0}:3, {x:1,y:0,z:0}:4, {x:0,y:1,z:0}:5, {x:1,y:1,z:0}:6 }"), + Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:6, {x:1,y:0}:8, {x:0,y:1}:20, {x:1,y:1}:24 }") + .divide(Tensor.from("tensor(y[],z[]):{ {y:0,z:0}:2, {y:1,z:0}:4, {y:2,z:0}:6 }"))); + } + } |