diff options
192 files changed, 4072 insertions, 2955 deletions
diff --git a/CMakeLists.txt b/CMakeLists.txt index 1fe91c233ab..22ebd6abb4c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,6 +99,7 @@ add_subdirectory(streamingvisitors) add_subdirectory(vbench) add_subdirectory(vdslib) add_subdirectory(vdstestlib) +add_subdirectory(vespa-athenz) add_subdirectory(vespa-http-client) add_subdirectory(vespa_jersey2) add_subdirectory(vespabase) diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/SecretStoreKeyProvider.java b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java index ac8c0eabf31..2f2cd5a8495 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/SecretStoreKeyProvider.java +++ b/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/instanceproviderservice/impl/CkmsKeyProvider.java @@ -1,10 +1,10 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl; import com.google.inject.Inject; import com.yahoo.athenz.auth.util.Crypto; import com.yahoo.config.provision.Zone; -import com.yahoo.jdisc.http.SecretStore; +import com.yahoo.container.jdisc.Ckms; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.KeyProvider; import com.yahoo.vespa.hosted.athenz.instanceproviderservice.config.AthenzProviderServiceConfig; @@ -18,19 +18,20 @@ import static com.yahoo.vespa.hosted.athenz.instanceproviderservice.impl.Utils.g /** * @author mortent + * @author bjorncs */ @SuppressWarnings("unused") // Injected component -public class SecretStoreKeyProvider implements KeyProvider { +public class CkmsKeyProvider implements KeyProvider { - private final SecretStore secretStore; + private final Ckms ckms; private final String secretName; private final Map<Integer, KeyPair> secrets; @Inject - public SecretStoreKeyProvider(SecretStore secretStore, - Zone zone, - AthenzProviderServiceConfig config) { - this.secretStore = secretStore; + public CkmsKeyProvider(Ckms ckms, + Zone zone, + AthenzProviderServiceConfig config) { + this.ckms = ckms; this.secretName = getZoneConfig(config, zone).secretName(); this.secrets = new HashMap<>(); } @@ -59,7 +60,7 @@ public class SecretStoreKeyProvider implements KeyProvider { // TODO: Consider moving to cryptoutils private KeyPair readKeyPair(int version) { - PrivateKey privateKey = Crypto.loadPrivateKey(secretStore.getSecret(secretName, version)); + PrivateKey privateKey = Crypto.loadPrivateKey(ckms.getSecret(secretName, version)); PublicKey publicKey = Crypto.extractPublicKey(privateKey); return new KeyPair(publicKey, privateKey); } diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java index 0ff59c26c13..74ddd941afb 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/ContentCluster.java @@ -165,19 +165,6 @@ public class ContentCluster { } } - public StorageNodeStats getStorageNodeStats(int storageNodeIndex) { - LatencyStats aggregatePutLatencyStats = new LatencyStats(); - StorageNodeStats aggregateStats = new StorageNodeStats(aggregatePutLatencyStats); - for (DistributorNodeInfo distributor : clusterInfo.getDistributorNodeInfo()) { - StorageNodeStats statsFromDistributor = distributor.getStorageNodeStatsOrNull(storageNodeIndex); - if (statsFromDistributor != null) { - aggregateStats.add(statsFromDistributor); - } - } - - return aggregateStats; - } - /** * Checks if a node can be upgraded * diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java index 575b965c0e5..a21fbd22213 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/DistributorNodeInfo.java @@ -15,28 +15,8 @@ import com.yahoo.vespa.clustercontroller.core.hostinfo.StorageNodeStatsBridge; */ public class DistributorNodeInfo extends NodeInfo { - private StorageNodeStatsContainer storageNodeStatsContainer = null; - public DistributorNodeInfo(ContentCluster cluster, int index, String rpcAddress, Distribution distribution) { super(cluster, new Node(NodeType.DISTRIBUTOR, index), false, rpcAddress, distribution); } - @Override - public void setHostInfo(HostInfo hostInfo) { - // This affects getHostInfo(), and makes the host info available through NodeInfo. - super.setHostInfo(hostInfo); - storageNodeStatsContainer = StorageNodeStatsBridge.traverseHostInfo(hostInfo); - } - - /** - * @return Stats this distributor has about a storage node, or null if unknown. - */ - public StorageNodeStats getStorageNodeStatsOrNull(int storageNodeIndex) { - if (storageNodeStatsContainer == null) { - return null; - } - - return storageNodeStatsContainer.get(storageNodeIndex); - } - } diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/LatencyStats.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/LatencyStats.java deleted file mode 100644 index 581cc244a20..00000000000 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/LatencyStats.java +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.clustercontroller.core; - -/** - * LatencyStats handles adding latencies and counts. - * - * @author hakonhall - */ -public class LatencyStats { - - private long latencyMsSum; - private long count; - - public LatencyStats() { this(0, 0); } - - /** - * @param latencyMsSum The sum of the latencies of all RPCs (or whatever) in milliseconds. - * @param count The number of RPC calls (or whatever). - */ - public LatencyStats(long latencyMsSum, long count) { - this.latencyMsSum = latencyMsSum; - this.count = count; - } - - void add(LatencyStats latencyToAdd) { - latencyMsSum += latencyToAdd.latencyMsSum; - count += latencyToAdd.count; - } - - public long getLatencyMsSum() { return latencyMsSum; } - public long getCount() { return count; } - -} diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodes.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodes.java deleted file mode 100644 index 8df5820bc49..00000000000 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodes.java +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.clustercontroller.core; - -import java.util.Map; - -/** - * Contains stats for a set of storage nodes. This is used to store the stats returned - * by Distributors from their getnodestate RPCs. The stats for a single storage node - * is represented by the StorageNodeStats class. - * - * @author hakonhall - */ -public class StatsForStorageNodes { - - final private Map<Integer, StorageNodeStats> storageNodesByIndex; - - StatsForStorageNodes(Map<Integer, StorageNodeStats> storageNodesByIndex) { - this.storageNodesByIndex = storageNodesByIndex; - } - - StorageNodeStats getStatsForStorageNode(int nodeIndex) { - return storageNodesByIndex.get(nodeIndex); - } - -} diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStats.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStats.java deleted file mode 100644 index d0afc1fa4b7..00000000000 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStats.java +++ /dev/null @@ -1,22 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.clustercontroller.core; - -/** - * Contains stats related to a single storage node. - * - * @author hakonhall - */ -public class StorageNodeStats { - - final private LatencyStats distributorPutLatency; - - /** - * @param distributorPutLatency the "put" latency from the point of view of the distributor. - */ - public StorageNodeStats(LatencyStats distributorPutLatency) { this.distributorPutLatency = distributorPutLatency; } - public LatencyStats getDistributorPutLatency() { return distributorPutLatency; } - public void add(StorageNodeStats statsToAdd) { - distributorPutLatency.add(statsToAdd.distributorPutLatency); - } - -} diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainer.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainer.java deleted file mode 100644 index 1fb24e72218..00000000000 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainer.java +++ /dev/null @@ -1,28 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.clustercontroller.core; - -import java.util.HashMap; -import java.util.Map; - -/** - * Contains stats for a set of storage nodes. This is used to store the stats returned - * by Distributors from their getnodestate RPCs. The stats for a single storage node - * is represented by the StorageNodeStats class. - * - * @author hakonhall - */ -public class StorageNodeStatsContainer { - - final private Map<Integer, StorageNodeStats> storageNodesByIndex = new HashMap<>(); - - public void put(int nodeIndex, StorageNodeStats nodeStats) { - storageNodesByIndex.put(nodeIndex, nodeStats); - } - - public StorageNodeStats get(int nodeIndex) { - return storageNodesByIndex.get(nodeIndex); - } - - public int size() { return storageNodesByIndex.size(); } - -} diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java index 55b7e4bb8c1..30ef0c69fe3 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridge.java @@ -16,31 +16,6 @@ public class StorageNodeStatsBridge { private StorageNodeStatsBridge() { } - public static StorageNodeStatsContainer traverseHostInfo(HostInfo hostInfo) { - StorageNodeStatsContainer container = new StorageNodeStatsContainer(); - List<StorageNode> storageNodes = hostInfo.getDistributor().getStorageNodes(); - for (StorageNode storageNode : storageNodes) { - Integer storageNodeIndex = storageNode.getIndex(); - if (storageNodeIndex == null) { - continue; - } - StorageNode.OpsLatency opsLatency = storageNode.getOpsLatenciesOrNull(); - if (opsLatency == null) { - continue; - } - StorageNode.Put putLatency = opsLatency.getPut(); - Long putLatencyMsSum = putLatency.getLatencyMsSum(); - Long putLatencyCount = putLatency.getCount(); - if (putLatencyMsSum == null || putLatencyCount == null) { - continue; - } - LatencyStats putLatencyStats = new LatencyStats(putLatencyMsSum, putLatencyCount); - StorageNodeStats nodeStats = new StorageNodeStats(putLatencyStats); - container.put(storageNodeIndex, nodeStats); - } - return container; - } - public static ContentClusterStats generate(Distributor distributor) { Map<Integer, ContentNodeStats> mapToNodeStats = new HashMap<>(); for (StorageNode storageNode : distributor.getStorageNodes()) { diff --git a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java index 9c7143aed4a..669042c2fd8 100644 --- a/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java +++ b/clustercontroller-core/src/main/java/com/yahoo/vespa/clustercontroller/core/restapiv2/requests/NodeStateRequest.java @@ -1,10 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.clustercontroller.core.restapiv2.requests; -import com.yahoo.vespa.clustercontroller.core.LatencyStats; import com.yahoo.vespa.clustercontroller.core.NodeInfo; import com.yahoo.vespa.clustercontroller.core.RemoteClusterControllerTask; -import com.yahoo.vespa.clustercontroller.core.StorageNodeStats; import com.yahoo.vespa.clustercontroller.core.restapiv2.Id; import com.yahoo.vespa.clustercontroller.core.restapiv2.Request; import com.yahoo.vespa.clustercontroller.core.restapiv2.Response; @@ -41,13 +39,6 @@ public class NodeStateRequest extends Request<Response.NodeResponse> { result.addState("unit", new Response.UnitStateImpl(info.getReportedState())); result.addState("user", new Response.UnitStateImpl(info.getWantedState())); - if (info.isStorage() && verboseReports.contains(VerboseReport.STATISTICS)) { - StorageNodeStats storageStats = context.cluster.getStorageNodeStats(info.getNodeIndex()); - LatencyStats latencyStats = storageStats.getDistributorPutLatency(); - result.addMetric("distributor-put-latency-ms-sum", latencyStats.getLatencyMsSum()); - result.addMetric("distributor-put-latency-count", latencyStats.getCount()); - } - for (int i=0; i<info.getReportedState().getDiskCount(); ++i) { Id.Partition partitionId = new Id.Partition(id, i); if (recursive > 0) { diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodeTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodeTest.java deleted file mode 100644 index 2e88c147095..00000000000 --- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StatsForStorageNodeTest.java +++ /dev/null @@ -1,44 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.clustercontroller.core; - -import org.junit.Test; - -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; - -/** - * @author hakonhall - */ -public class StatsForStorageNodeTest { - @Test - public void testStatsForStorage() { - Map<Integer, StorageNodeStats> statsMap = new HashMap<>(); - - LatencyStats putLatencyForA = new LatencyStats(1, 2); - StorageNodeStats nodeStatsForA = new StorageNodeStats(putLatencyForA); - statsMap.put(5, nodeStatsForA); - - LatencyStats putLatencyForB = new LatencyStats(3, 4); - StorageNodeStats nodeStatsForB = new StorageNodeStats(putLatencyForB); - statsMap.put(6, nodeStatsForB); - - StatsForStorageNodes stats = new StatsForStorageNodes(statsMap); - - StorageNodeStats nodeStats = stats.getStatsForStorageNode(5); - assertNotNull(nodeStats); - assertEquals(1, nodeStatsForA.getDistributorPutLatency().getLatencyMsSum()); - assertEquals(2, nodeStatsForA.getDistributorPutLatency().getCount()); - - nodeStats = stats.getStatsForStorageNode(6); - assertNotNull(nodeStats); - assertEquals(3, nodeStatsForB.getDistributorPutLatency().getLatencyMsSum()); - assertEquals(4, nodeStatsForB.getDistributorPutLatency().getCount()); - - nodeStats = stats.getStatsForStorageNode(7); - assertNull(nodeStats); - } -} diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainerTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainerTest.java deleted file mode 100644 index 5107792dbff..00000000000 --- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsContainerTest.java +++ /dev/null @@ -1,43 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.clustercontroller.core; - -import org.junit.Test; - -import java.util.HashMap; -import java.util.Map; - -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertNotNull; -import static org.junit.Assert.assertNull; - -/** - * @author hakonhall - */ -public class StorageNodeStatsContainerTest { - @Test - public void testStatsForStorage() { - StorageNodeStatsContainer statsContainer = new StorageNodeStatsContainer(); - Map<Integer, StorageNodeStats> statsMap = new HashMap<>(); - - LatencyStats putLatencyForA = new LatencyStats(1, 2); - StorageNodeStats nodeStatsForA = new StorageNodeStats(putLatencyForA); - statsContainer.put(5, nodeStatsForA); - - LatencyStats putLatencyForB = new LatencyStats(3, 4); - StorageNodeStats nodeStatsForB = new StorageNodeStats(putLatencyForB); - statsContainer.put(6, nodeStatsForB); - - StorageNodeStats nodeStats = statsContainer.get(5); - assertNotNull(nodeStats); - assertEquals(1, nodeStatsForA.getDistributorPutLatency().getLatencyMsSum()); - assertEquals(2, nodeStatsForA.getDistributorPutLatency().getCount()); - - nodeStats = statsContainer.get(6); - assertNotNull(nodeStats); - assertEquals(3, nodeStatsForB.getDistributorPutLatency().getLatencyMsSum()); - assertEquals(4, nodeStatsForB.getDistributorPutLatency().getCount()); - - nodeStats = statsContainer.get(7); - assertNull(nodeStats); - } -} diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsTest.java deleted file mode 100644 index 4defb015e76..00000000000 --- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/StorageNodeStatsTest.java +++ /dev/null @@ -1,25 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.clustercontroller.core; - -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -/** - * @author hakonhall - */ -public class StorageNodeStatsTest { - @Test - public void testStorageNodeStats() { - LatencyStats putLatency = new LatencyStats(1, 2); - StorageNodeStats stats = new StorageNodeStats(putLatency); - assertEquals(1, stats.getDistributorPutLatency().getLatencyMsSum()); - assertEquals(2, stats.getDistributorPutLatency().getCount()); - - LatencyStats putLatencyToAdd = new LatencyStats(3, 4); - StorageNodeStats statsToAdd = new StorageNodeStats(putLatencyToAdd); - stats.add(statsToAdd); - assertEquals(1 + 3, stats.getDistributorPutLatency().getLatencyMsSum()); - assertEquals(2 + 4, stats.getDistributorPutLatency().getCount()); - } -} diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java index 5319d741503..51e73b333c5 100644 --- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java +++ b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/hostinfo/StorageNodeStatsBridgeTest.java @@ -3,8 +3,6 @@ package com.yahoo.vespa.clustercontroller.core.hostinfo; import com.yahoo.vespa.clustercontroller.core.ContentNodeStats; import com.yahoo.vespa.clustercontroller.core.ContentClusterStats; -import com.yahoo.vespa.clustercontroller.core.StorageNodeStats; -import com.yahoo.vespa.clustercontroller.core.StorageNodeStatsContainer; import org.junit.Test; import java.io.IOException; @@ -31,24 +29,6 @@ public class StorageNodeStatsBridgeTest { } @Test - public void testStorageNodeStatsContainer() throws IOException { - String data = getJsonString(); - HostInfo hostInfo = HostInfo.createHostInfo(data); - StorageNodeStatsContainer container = StorageNodeStatsBridge.traverseHostInfo(hostInfo); - assertEquals(2, container.size()); - - StorageNodeStats node0 = container.get(0); - assertNotNull(node0); - assertEquals(15, node0.getDistributorPutLatency().getLatencyMsSum()); - assertEquals(16, node0.getDistributorPutLatency().getCount()); - - StorageNodeStats node1 = container.get(1); - assertNotNull(node1); - assertEquals(17, node1.getDistributorPutLatency().getLatencyMsSum()); - assertEquals(18, node1.getDistributorPutLatency().getCount()); - } - - @Test public void testContentNodeStats() throws IOException { String data = getJsonString(); HostInfo hostInfo = HostInfo.createHostInfo(data); diff --git a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java index 1421e901048..de28867520b 100644 --- a/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java +++ b/clustercontroller-core/src/test/java/com/yahoo/vespa/clustercontroller/core/restapiv2/NodeTest.java @@ -60,14 +60,6 @@ public class NodeTest extends StateRestApiTest { " \"reason\": \"\"\n" + " }\n" + " },\n" + - " \"metrics\": {\n" + - // Why 24 and 28? There are 4 distributor nodes seen in slobrok (see StateRestApiTest). - // Each gets a host info with distributor-put-latency-ms-sum 6 and - // distributor-put-latency-count 7 (see StateRestApiTest.getHostInfo()). - // Therefore, in aggregate, 4*6 is 24, and 4*7 is 28. - " \"distributor-put-latency-ms-sum\": 24,\n" + - " \"distributor-put-latency-count\": 28\n" + - " },\n" + " \"partition\": {\n" + " \"0\": {\"link\": \"\\/cluster\\/v2\\/music\\/storage\\/1\\/0\"},\n" + " \"1\": {\"link\": \"\\/cluster\\/v2\\/music\\/storage\\/1\\/1\"}\n" + @@ -97,14 +89,6 @@ public class NodeTest extends StateRestApiTest { " \"reason\": \"\"\n" + " }\n" + " },\n" + - " \"metrics\": {\n" + - // Why 24 and 28? There are 4 distributor nodes seen in slobrok (see StateRestApiTest). - // Each gets a host info with distributor-put-latency-ms-sum 6 and - // distributor-put-latency-count 7 (see StateRestApiTest.getHostInfo()). - // Therefore, in aggregate, 4*6 is 24, and 4*7 is 28. - " \"distributor-put-latency-ms-sum\": 24,\n" + - " \"distributor-put-latency-count\": 28\n" + - " },\n" + " \"partition\": {\n" + " \"0\": {\n" + " \"state\": {\"generated\": {\n" + @@ -158,10 +142,6 @@ public class NodeTest extends StateRestApiTest { " \"state\": \"up\",\n" + " \"reason\": \"\"\n" + " }\n" + - " },\n" + - " \"metrics\": {\n" + - " \"distributor-put-latency-ms-sum\": 0,\n" + - " \"distributor-put-latency-count\": 0\n" + " }\n" + "}"; assertEquals(expected, jsonWriter.createJson(response).toString(2)); diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index d6b916680d8..bd94f67e4a7 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -323,7 +323,7 @@ public class DeployState implements ConfigDefinitionStore { closeIgnoreException(reader.getReader()); } } - builder.build(logger, queryProfiles); + builder.build(logger); return SearchDocumentModel.fromBuilderAndNames(builder, names); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java index dd03cb8b2a7..dc59d9cb3e5 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/FeatureNames.java @@ -5,11 +5,10 @@ */ package com.yahoo.searchdefinition; -import java.util.Arrays; -import java.util.List; +import com.yahoo.searchlib.rankingexpression.Reference; + import java.util.Optional; import java.util.regex.Pattern; -import java.util.stream.Collectors; /** * Utility methods for query, document and constant rank feature names @@ -20,85 +19,16 @@ public class FeatureNames { private static final Pattern identifierRegexp = Pattern.compile("[A-Za-z0-9_][A-Za-z0-9_-]*"); - /** - * <p>Returns the given query, document or constant feature in canonical form. - * A feature name consists of a feature type name (query, attribute or constant), - * followed by one argument enclosed in quotes. - * The argument may be an identifier or any string single or double quoted.</p> - * - * <p>Argument string values may not contain comma, single quote nor double quote characters.</p> - * - * <p><i>The canonical form use no quotes for arguments which are identifiers, and double quotes otherwise.</i></p> - * - * <p>Note that the above definition is not true for features in general, which accept any ranking expression - * as argument.</p> - * - * @throws IllegalArgumentException if the feature name is not valid - */ - // Note that this implementation is more general than what is described above: - // It accepts any number of arguments and an optional output - public static String canonicalize(String feature) { - return canonicalizeIfValid(feature).orElseThrow(() -> - new IllegalArgumentException("A feature name must be on the form query(name), attribute(name) or " + - "constant(name), but was '" + feature + "'" - )); - } - - /** - * Canonicalizes the given argument as in canonicalize, but returns empty instead of throwing an exception if - * the argument is not a valid feature - */ - public static Optional<String> canonicalizeIfValid(String feature) { - int startParenthesis = feature.indexOf('('); - if (startParenthesis < 0) - return Optional.empty(); - int endParenthesis = feature.lastIndexOf(')'); - String featureType = feature.substring(0, startParenthesis); - if ( ! ( featureType.equals("query") || featureType.equals("attribute") || featureType.equals("constant"))) - return Optional.empty(); - if (startParenthesis < 1) return Optional.of(feature); // No arguments - if (endParenthesis < startParenthesis) - return Optional.empty(); - String argumentString = feature.substring(startParenthesis + 1, endParenthesis); - List<String> canonicalizedArguments = - Arrays.stream(argumentString.split(",")) - .map(FeatureNames::canonicalizeArgument) - .collect(Collectors.toList()); - return Optional.of(featureType + "(" + - canonicalizedArguments.stream().collect(Collectors.joining(",")) + - feature.substring(endParenthesis)); - } - - /** Canomicalizes a single argument */ - private static String canonicalizeArgument(String argument) { - if (argument.startsWith("'")) { - if ( ! argument.endsWith("'")) - throw new IllegalArgumentException("Feature arguments starting by a single quote " + - "must end by a single quote, but was \"" + argument + "\""); - argument = argument.substring(1, argument.length() - 1); - } - if (argument.startsWith("\"")) { - if ( ! argument.endsWith("\"")) - throw new IllegalArgumentException("Feature arguments starting by a double quote " + - "must end by a double quote, but was '" + argument + "'"); - argument = argument.substring(1, argument.length() - 1); - } - if (identifierRegexp.matcher(argument).matches()) - return argument; - else - return "\"" + argument + "\""; - } - - public static String asConstantFeature(String constantName) { - return canonicalize("constant(\"" + constantName + "\")"); + public static Reference asConstantFeature(String constantName) { + return Reference.simple("constant", quoteIfNecessary(constantName)); } - public static String asAttributeFeature(String attributeName) { - return canonicalize("attribute(\"" + attributeName + "\")"); + public static Reference asAttributeFeature(String attributeName) { + return Reference.simple("attribute", quoteIfNecessary(attributeName)); } - public static String asQueryFeature(String propertyName) { - return canonicalize("query(\"" + propertyName + "\")"); + public static Reference asQueryFeature(String propertyName) { + return Reference.simple("query", quoteIfNecessary(propertyName)); } /** @@ -106,15 +36,21 @@ public class FeatureNames { * or empty if it is not a valid query, attribute or constant feature name */ public static Optional<String> argumentOf(String feature) { - return canonicalizeIfValid(feature).map(f -> { - int startParenthesis = f.indexOf("("); - int endParenthesis = f.indexOf(")"); - String possiblyQuotedArgument = f.substring(startParenthesis + 1, endParenthesis); - if (possiblyQuotedArgument.startsWith("\"")) - return possiblyQuotedArgument.substring(1, possiblyQuotedArgument.length() - 1); - else - return possiblyQuotedArgument; - }); + Optional<Reference> reference = Reference.simple(feature); + if ( ! reference.isPresent()) return Optional.empty(); + if ( ! ( reference.get().name().equals("attribute") || + reference.get().name().equals("constant") || + reference.get().name().equals("query"))) + return Optional.empty(); + + return Optional.of(reference.get().arguments().expressions().get(0).toString()); + } + + private static String quoteIfNecessary(String s) { + if (identifierRegexp.matcher(s).matches()) + return s; + else + return "\"" + s + "\""; } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java new file mode 100644 index 00000000000..cf6d90db7fa --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -0,0 +1,167 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition; + +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext; +import com.yahoo.searchlib.rankingexpression.rule.NameNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +/** + * A context which only contains type information. + * This returns empty tensor types (double) for unknown features which are not + * query, attribute or constant features, as we do not have information about which such + * features exist (but we know those that exist are doubles). + * + * @author bratseth + */ +public class MapEvaluationTypeContext extends FunctionReferenceContext implements TypeContext<Reference> { + + private final Map<Reference, TensorType> featureTypes = new HashMap<>(); + + public MapEvaluationTypeContext(Collection<ExpressionFunction> functions) { + super(functions); + } + + public MapEvaluationTypeContext(Map<String, ExpressionFunction> functions, + Map<String, String> bindings, + Map<Reference, TensorType> featureTypes) { + super(functions, bindings); + this.featureTypes.putAll(featureTypes); + } + + public void setType(Reference reference, TensorType type) { + featureTypes.put(reference, type); + } + + @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + + @Override + public TensorType getType(Reference reference) { + Optional<String> binding = boundIdentifier(reference); + if (binding.isPresent()) { + try { + // This is not pretty, but changing to bind expressions rather + // than their string values requires deeper changes + return new RankingExpression(binding.get()).type(this); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + + if (isSimpleFeature(reference)) { + // The argument may be a local identifier bound to the actual value + String argument = simpleArgument(reference.arguments()).get(); + reference = Reference.simple(reference.name(), bindings.getOrDefault(argument, argument)); + return featureTypes.getOrDefault(reference, defaultTypeOf(reference)); + } + + Optional<ExpressionFunction> function = functionInvocation(reference); + if (function.isPresent()) { + return function.get().getBody().type(this.withBindings(bind(function.get().arguments(), reference.arguments()))); + } + + // We do not know what this is - since we do not have complete knowledge abut the match features + // in Java we must assume this is a match feature and return the double type - which is the type of all + // all match features + return TensorType.empty; + } + + /** + * Returns the default type for this simple feature, or nullif it does not have a default + */ + public TensorType defaultTypeOf(Reference reference) { + if ( ! isSimpleFeature(reference)) + throw new IllegalArgumentException("This can only be called for simple references, not " + reference); + if (reference.name().equals("query")) // we do not require all query features to be declared, only non-doubles + return TensorType.empty; + return null; + } + + /** + * Returns the binding if this reference is a simple identifier which is bound in this context. + * Returns empty otherwise. + */ + private Optional<String> boundIdentifier(Reference reference) { + if ( ! reference.arguments().isEmpty()) return Optional.empty(); + if ( reference.output() != null) return Optional.empty(); + return Optional.ofNullable(bindings.get(reference.name())); + } + + /** + * Return whether the reference (discarding the output) is a simple feature + * ("attribute(name)", "constant(name)" or "query(name)"). + * We disregard the output because all outputs under a simple feature have the same type. + */ + private boolean isSimpleFeature(Reference reference) { + Optional<String> argument = simpleArgument(reference.arguments()); + if ( ! argument.isPresent()) return false; + return reference.name().equals("attribute") || + reference.name().equals("constant") || + reference.name().equals("query"); + } + + /** + * If these arguments contains one simple argument string, it is returned. + * Otherwise null is returned. + */ + private Optional<String> simpleArgument(Arguments arguments) { + if (arguments.expressions().size() != 1) return Optional.empty(); + ExpressionNode argument = arguments.expressions().get(0); + + if ( ! (argument instanceof ReferenceNode)) return Optional.empty(); + ReferenceNode refArgument = (ReferenceNode)argument; + + if ( ! refArgument.reference().isIdentifier()) return Optional.empty(); + + return Optional.of(refArgument.getName()); + } + + private Optional<ExpressionFunction> functionInvocation(Reference reference) { + if (reference.output() != null) return Optional.empty(); + ExpressionFunction function = functions().get(reference.name()); + if (function == null) return Optional.empty(); + if (function.arguments().size() != reference.arguments().size()) return Optional.empty(); + return Optional.of(function); + } + + /** Binds the given list of formal arguments to their actual values */ + private Map<String, String> bind(List<String> formalArguments, + Arguments invocationArguments) { + Map<String, String> bindings = new HashMap<>(formalArguments.size()); + for (int i = 0; i < formalArguments.size(); i++) { + String identifier = invocationArguments.expressions().get(i).toString(); + identifier = super.bindings.getOrDefault(identifier, identifier); + bindings.put(formalArguments.get(i), identifier); + } + return bindings; + } + + public Map<Reference, TensorType> featureTypes() { + return Collections.unmodifiableMap(featureTypes); + } + + @Override + public MapEvaluationTypeContext withBindings(Map<String, String> bindings) { + if (bindings.isEmpty() && this.bindings.isEmpty()) return this; + return new MapEvaluationTypeContext(functions(), bindings, featureTypes); + } + +} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index bcbc7cc99e2..064897de8dc 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -2,24 +2,18 @@ package com.yahoo.searchdefinition; import com.yahoo.config.application.api.ApplicationPackage; -import com.yahoo.config.model.deploy.DeployState; -import com.yahoo.io.reader.NamedReader; -import com.yahoo.processing.request.CompoundName; -import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.search.query.profile.config.QueryProfileXMLReader; import com.yahoo.search.query.profile.types.FieldDescription; import com.yahoo.search.query.profile.types.QueryProfileType; -import com.yahoo.search.query.profile.types.TensorFieldType; import com.yahoo.search.query.ranking.Diversity; -import com.yahoo.searchdefinition.document.SDField; +import com.yahoo.searchdefinition.document.ImmutableSDField; import com.yahoo.searchdefinition.expressiontransforms.RankProfileTransformContext; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.FeatureList; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TypeMapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.TensorType; @@ -39,7 +33,10 @@ import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * Represents a rank profile - a named set of ranking settings @@ -363,14 +360,14 @@ public class RankProfile implements Serializable, Cloneable { /** Returns a read-only view of the summary features to use in this profile. This is never null */ public Set<ReferenceNode> getSummaryFeatures() { - if (summaryFeatures!=null) return Collections.unmodifiableSet(summaryFeatures); - if (getInherited()!=null) return getInherited().getSummaryFeatures(); + if (summaryFeatures != null) return Collections.unmodifiableSet(summaryFeatures); + if (getInherited() != null) return getInherited().getSummaryFeatures(); return Collections.emptySet(); } public void addSummaryFeature(ReferenceNode feature) { - if (summaryFeatures==null) - summaryFeatures=new LinkedHashSet<>(); + if (summaryFeatures == null) + summaryFeatures = new LinkedHashSet<>(); summaryFeatures.add(feature); } @@ -585,8 +582,11 @@ public class RankProfile implements Serializable, Cloneable { } /** - * Will take the parser-set textual ranking expressions and turn into objects + * Will take the parser-set textual ranking expressions and turn into ranking expression objects, + * if not already done */ + // TODO: There doesn't appear to be any good reason to defer parsing of ranking expressions + // until this is called. Simplify by parsing them right away. public void parseExpressions() { try { parseRankingExpressions(); @@ -604,20 +604,23 @@ public class RankProfile implements Serializable, Cloneable { for (Map.Entry<String, Macro> e : getMacros().entrySet()) { String macroName = e.getKey(); Macro macro = e.getValue(); - RankingExpression expr = parseRankingExpression(macroName, macro.getTextualExpression()); - macro.setRankingExpression(expr); - macro.setTextualExpression(expr.getRoot().toString()); + if (macro.getRankingExpression() == null) { + RankingExpression expr = parseRankingExpression(macroName, macro.getTextualExpression()); + macro.setRankingExpression(expr); + macro.setTextualExpression(expr.getRoot().toString()); + } } } /** * Passes ranking expressions on to parser + * * @throws ParseException if either of the ranking expressions could not be parsed */ private void parseRankingExpressions() throws ParseException { - if (getFirstPhaseRankingString() != null) + if (getFirstPhaseRankingString() != null && firstPhaseRanking == null) setFirstPhaseRanking(parseRankingExpression("firstphase", getFirstPhaseRankingString())); - if (getSecondPhaseRankingString() != null) + if (getSecondPhaseRankingString() != null && secondPhaseRanking == null) setSecondPhaseRanking(parseRankingExpression("secondphase", getSecondPhaseRankingString())); } @@ -748,37 +751,50 @@ public class RankProfile implements Serializable, Cloneable { * referable from this rank profile. */ public TypeContext typeContext(QueryProfileRegistry queryProfiles) { - TypeMapContext context = new TypeMapContext(); + MapEvaluationTypeContext context = new MapEvaluationTypeContext(getMacros().values().stream() + .map(Macro::asExpressionFunction) + .collect(Collectors.toList())); - // Add small constants + // Add small and large constants, respectively getConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.type())); - // Add large constants getSearch().getRankingConstants().forEach((k, v) -> context.setType(FeatureNames.asConstantFeature(k), v.getTensorType())); // Add attributes - for (SDField field : getSearch().allConcreteFields()) { - field.getAttributes().forEach((k, a) -> context.setType(FeatureNames.asAttributeFeature(k), a.tensorType().orElse(TensorType.empty))); - } + getSearch().allFields().forEach(field -> addAttributeFeatureTypes(field, context)); + getSearch().allImportedFields().forEach(field -> addAttributeFeatureTypes(field, context)); // Add query features from rank profile types reached from the "default" profile for (QueryProfileType queryProfileType : queryProfiles.getTypeRegistry().allComponents()) { for (FieldDescription field : queryProfileType.declaredFields().values()) { TensorType type = field.getType().asTensorType(); - String feature = FeatureNames.asQueryFeature(field.getName()); - TensorType existingType = context.getType(feature); - if (existingType != null) + Optional<Reference> feature = Reference.simple(field.getName()); + if ( ! feature.isPresent() || ! feature.get().name().equals("query")) continue; + + TensorType existingType = context.getType(feature.get()); + if ( ! Objects.equals(existingType, context.defaultTypeOf(feature.get()))) type = existingType.dimensionwiseGeneralizationWith(type).orElseThrow( () -> - new IllegalArgumentException(queryProfileType + " contains query feature " + feature + + new IllegalArgumentException(queryProfileType + " contains query feature " + feature.get() + " with type " + field.getType().asTensorType() + ", but this is already defined " + - "in another query profile with type " + context.getType(feature))); - context.setType(feature, type); + "in another query profile with type " + + context.getType(feature.get()))); + context.setType(feature.get(), type); } } return context; } + private void addAttributeFeatureTypes(ImmutableSDField field, MapEvaluationTypeContext context) { + field.getAttributes().forEach((k, a) -> { + String name = k; + if (k.equals(field.getBackingField().getName())) // this attribute should take the fields name + name = field.getName(); // switch to that - it is separate for imported fields + context.setType(FeatureNames.asAttributeFeature(name), + a.tensorType().orElse(TensorType.empty)); + }); + } + /** * A rank setting. The identity of a rank setting is its field name and type (not value). * A rank setting is immutable. @@ -910,7 +926,7 @@ public class RankProfile implements Serializable, Cloneable { */ public static class Macro implements Serializable, Cloneable { - private String name=null; + private final String name; private String textualExpression=null; private RankingExpression expression=null; private List<String> formalParams = new ArrayList<>(); @@ -955,7 +971,7 @@ public class RankProfile implements Serializable, Cloneable { return inline && formalParams.size() == 0; // only inline no-arg macros; } - public ExpressionFunction toExpressionMacro() { + public ExpressionFunction asExpressionFunction() { return new ExpressionFunction(getName(), getFormalParams(), getRankingExpression()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java index a075b9d00fa..7b4d70d85b1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfileRegistry.java @@ -16,8 +16,7 @@ import java.util.Set; * Having both of these mappings consolidated here make it easier to remove dependencies on these mappings at * run time, since it is essentially only used when building rank profile config at deployment time. * - * TODO: Reconsider the difference between local and global maps. Right now, the local maps might better be - * served from a different class owned by SearchBuilder. + * TODO: Rank profiles should be stored under its owning Search instance. * * @author Ulf Lilleengen */ @@ -31,9 +30,6 @@ public class RankProfileRegistry { /* These rank profiles can be overridden: 'default' rank profile, as that is documented to work. And 'unranked'. */ static final Set<String> overridableRankProfileNames = new HashSet<>(Arrays.asList("default", "unranked")); - public RankProfileRegistry() { - } - public static RankProfileRegistry createRankProfileRegistryWithBuiltinRankProfiles(Search search) { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); rankProfileRegistry.addRankProfile(new DefaultRankProfile(search, rankProfileRegistry)); @@ -47,7 +43,7 @@ public class RankProfileRegistry { * @param rankProfile the rank profile to add */ public void addRankProfile(RankProfile rankProfile) { - if (!rankProfiles.containsKey(rankProfile.getSearch())) { + if ( ! rankProfiles.containsKey(rankProfile.getSearch())) { rankProfiles.put(rankProfile.getSearch(), new LinkedHashMap<>()); } checkForDuplicateRankProfile(rankProfile); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java index f4a0365e36e..1ab76afc9c0 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/Search.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/Search.java @@ -199,9 +199,7 @@ public class Search implements Serializable, ImmutableSearch { @Override public ImmutableSDField getField(String name) { ImmutableSDField field = getConcreteField(name); - if (field != null) { - return field; - } + if (field != null) return field; return allImportedFields() .filter(f -> f.getName().equals(name)) .findFirst() @@ -248,8 +246,6 @@ public class Search implements Serializable, ImmutableSearch { * Returns a list of all the fields of this search definition, that is all fields in all documents, in the documents * they inherit, and all extra fields. The caller receives ownership to the list - subsequent changes to it will not * impact this - * - * @return the list of fields in this searchdefinition */ public List<SDField> allConcreteFields() { List<SDField> allFields = new ArrayList<>(); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java index 762c0fec838..e7cd21ac834 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java @@ -18,6 +18,7 @@ import com.yahoo.searchdefinition.parser.TokenMgrError; import com.yahoo.searchdefinition.processing.Processing; import com.yahoo.vespa.documentmodel.DocumentModel; import com.yahoo.vespa.model.container.search.QueryProfiles; +import com.yahoo.yolean.Exceptions; import java.io.File; import java.io.IOException; @@ -34,7 +35,6 @@ import java.util.List; * expressions, using the setRankXXX() methods, 3) invoke the {@link #build()} method, and 4) retrieve the built * search objects using the {@link #getSearch(String)} method. */ -// TODO: This should be cleaned up and more or maybe completely taken over by MockApplicationPackage public class SearchBuilder { private final DocumentTypeManager docTypeMgr = new DocumentTypeManager(); @@ -154,7 +154,7 @@ public class SearchBuilder { } catch (TokenMgrError e) { throw new ParseException("Unknown symbol: " + e.getMessage()); } catch (ParseException pe) { - throw new ParseException(stream.formatException(pe.getMessage())); + throw new ParseException(stream.formatException(Exceptions.toMessageString(pe))); } return importRawSearch(search); } @@ -196,11 +196,7 @@ public class SearchBuilder { * @throws IllegalStateException Thrown if this method has already been called. */ public void build() { - build(new BaseDeployLogger(), new QueryProfiles()); - } - - public void build(DeployLogger logger) { - build(logger, new QueryProfiles()); + build(new BaseDeployLogger()); } /** @@ -209,12 +205,10 @@ public class SearchBuilder { * * @throws IllegalStateException Thrown if this method has already been called. * @param deployLogger The logger to use during build - * @param queryProfiles The query profiles contained in the application this search is part of. */ - public void build(DeployLogger deployLogger, QueryProfiles queryProfiles) { - if (isBuilt) { - throw new IllegalStateException("Searches already built."); - } + public void build(DeployLogger deployLogger) { + if (isBuilt) throw new IllegalStateException("Model already built"); + List<Search> built = new ArrayList<>(); List<SDDocumentType> sdocs = new ArrayList<>(); sdocs.add(SDDocumentType.VESPA_DOCUMENT); @@ -240,7 +234,7 @@ public class SearchBuilder { for (Search search : new SearchOrderer().order(searchList)) { new FieldOperationApplierForSearch().process(search); // These two needed for a couple of old unit tests, ideally these are just read from app - process(search, deployLogger, queryProfiles); + process(search, deployLogger, new QueryProfiles(queryProfileRegistry)); built.add(search); } builder.addToModel(searchList); @@ -254,8 +248,6 @@ public class SearchBuilder { /** * Processes and returns the given {@link Search} object. This method has been factored out of the {@link * #build()} method so that subclasses can choose not to build anything. - * - * @param search The object to build. */ protected void process(Search search, DeployLogger deployLogger, QueryProfiles queryProfiles) { Processing.process(search, deployLogger, rankProfileRegistry, queryProfiles); @@ -352,7 +344,7 @@ public class SearchBuilder { rankProfileRegistry, queryprofileRegistry); builder.importFile(fileName); - builder.build(deployLogger, new QueryProfiles()); + builder.build(deployLogger); return builder; } @@ -368,7 +360,7 @@ public class SearchBuilder { for (Iterator<Path> i = Files.list(new File(dir).toPath()).filter(p -> p.getFileName().toString().endsWith(".sd")).iterator(); i.hasNext(); ) { builder.importFile(i.next()); } - builder.build(new BaseDeployLogger(), new QueryProfiles()); + builder.build(new BaseDeployLogger()); return builder; } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java deleted file mode 100644 index 40e9db1413f..00000000000 --- a/config-model/src/main/java/com/yahoo/searchdefinition/TypeMapContext.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -/** - * A context which only contains type information. - * - * @author bratseth - */ -public class TypeMapContext implements TypeContext { - - private final Map<String, TensorType> featureTypes = new HashMap<>(); - - public void setType(String name, TensorType type) { - featureTypes.put(FeatureNames.canonicalize(name), type); - } - - @Override - public TensorType getType(String name) { - return featureTypes.get(FeatureNames.canonicalize(name)); - } - - /** Returns an unmodifiable map of the bindings in this */ - public Map<String, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); } - -} diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index ea02f960800..b02362154d9 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -188,7 +188,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { if (macros.isEmpty()) return; Map<String, ExpressionFunction> expressionMacros = new LinkedHashMap<>(); for (Map.Entry<String, RankProfile.Macro> macro : macros.entrySet()) { - expressionMacros.put(macro.getKey(), macro.getValue().toExpressionMacro()); + expressionMacros.put(macro.getKey(), macro.getValue().asExpressionFunction()); } Map<String, String> macroProperties = new LinkedHashMap<>(); @@ -223,7 +223,7 @@ public class RawRankProfile implements RankProfilesConfig.Producer { // Is the feature a macro? if (context.getFunction(referenceNode.getName()) != null) { context.addFunctionSerialization(RankingExpression.propertyName(referenceNode.getName()), - referenceNode.toString(context, null, null)); + referenceNode.toString(context, null, null)); ReferenceNode newReferenceNode = new ReferenceNode("rankingExpression(" + referenceNode.getName() + ")", referenceNode.getArguments().expressions(), referenceNode.getOutput()); macroSummaryFeatures.put(referenceNode.getName(), newReferenceNode); i.remove(); // Will add the expanded one in next block diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java index 8b6df1a87db..4502468379f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableImportedSDField.java @@ -63,6 +63,9 @@ public class ImmutableImportedSDField implements ImmutableSDField { } @Override + public ImmutableSDField getBackingField() { return importedField.targetField(); } + + @Override public boolean isIndexStructureField() { return importedField.targetField().isIndexStructureField(); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java index 152690a6f56..70553d4b57c 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/ImmutableSDField.java @@ -19,6 +19,7 @@ import java.util.Map; * @author bjorncs */ public interface ImmutableSDField { + <T extends Expression> boolean containsExpression(Class<T> searchFor); boolean doesAttributing(); @@ -33,6 +34,12 @@ public interface ImmutableSDField { boolean isImportedField(); + /** + * Returns the field backing this - the field itself if this is a regular field, + * and the target field if this is imported. + */ + ImmutableSDField getBackingField(); + default boolean isConcreteField() { return !isImportedField(); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java index 593edc33370..6e7582a98c8 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/document/SDField.java @@ -209,6 +209,9 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, } @Override + public ImmutableSDField getBackingField() { return this; } + + @Override public boolean doesAttributing() { return containsExpression(AttributeExpression.class); } @@ -623,8 +626,7 @@ public class SDField extends Field implements TypedKey, FieldOperationContainer, public RankType getRankType() { return this.rankType; } /** - * <p>Returns the search-time attribute settings of this field - * or null if none is set.</p> + * Returns the search-time attribute settings of this field or null if none is set. * * <p>TODO: Make unmodifiable.</p> */ diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 2b997aa25f2..f16697b5ba6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -208,6 +208,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil throw new IllegalArgumentException("Model refers Placeholder '" + macroName + "' of type " + requiredType + " but this macro is not present in " + profile); + // TODO: We should verify this in the (function reference(s) this is invoked (starting from first/second + // phase and summary features), as it may only resolve correctly given those bindings + // Or, probably better, annotate the macros with type constraints here and verify during general + // type verification TensorType actualType = macro.getRankingExpression().getRoot().type(profile.typeContext(queryProfiles)); if ( actualType == null) throw new IllegalArgumentException("Model refers Placeholder '" + macroName + diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java index ee65c9bec02..cc634abef01 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/IndexingValues.java @@ -13,7 +13,7 @@ import com.yahoo.vespa.indexinglanguage.expressions.OutputExpression; import com.yahoo.vespa.model.container.search.QueryProfiles; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen Hult</a> + * @author Simon Thoresen Hult */ public class IndexingValues extends Processor { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java index 90183848094..061a803cb48 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/Processing.java @@ -76,8 +76,9 @@ public class Processing { ImportedFieldsInSummayValidator::new, FastAccessValidator::new, ReservedMacroNames::new, + RankingExpressionTypeValidator::new, - // These two should be last. + // These should be last. IndexingValidation::new, IndexingValues::new); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java new file mode 100644 index 00000000000..baacceea667 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidator.java @@ -0,0 +1,82 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.config.application.api.DeployLogger; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.vespa.model.container.search.QueryProfiles; + +/** + * Validates the types of all ranking expressions under a search instance: + * Some operators constrain the types of inputs, and first-and second-phase expressions + * must return scalar values. In addition, the existence of all referred attribute, query and constant + * features is ensured. + * + * @author bratseth + */ +public class RankingExpressionTypeValidator extends Processor { + + private final QueryProfileRegistry queryProfiles; + + public RankingExpressionTypeValidator(Search search, + DeployLogger deployLogger, + RankProfileRegistry rankProfileRegistry, + QueryProfiles queryProfiles) { + super(search, deployLogger, rankProfileRegistry, queryProfiles); + this.queryProfiles = queryProfiles.getRegistry(); + } + + @Override + public void process() { + for (RankProfile profile : rankProfileRegistry.localRankProfiles(search)) { + try { + validate(profile); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("In " + search + ", " + profile, e); + } + } + } + + /** Throws an IllegalArgumentException if the given rank profile does not produce valid type */ + private void validate(RankProfile profile) { + profile.parseExpressions(); + TypeContext context = profile.typeContext(queryProfiles); + profile.getSummaryFeatures().forEach(f -> ensureValid(f, "summary feature " + f, context)); + ensureValidDouble(profile.getFirstPhaseRanking(), "first-phase expression", context); + ensureValidDouble(profile.getSecondPhaseRanking(), "second-phase expression", context); + } + + private TensorType ensureValid(RankingExpression expression, String expressionDescription, TypeContext context) { + if (expression == null) return null; + return ensureValid(expression.getRoot(), expressionDescription, context); + } + + private TensorType ensureValid(ExpressionNode expression, String expressionDescription, TypeContext context) { + TensorType type; + try { + type = expression.type(context); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("The " + expressionDescription + " is invalid", e); + } + if (type == null) // Not expected to happen + throw new IllegalStateException("Could not determine the type produced by " + expressionDescription); + return type; + } + + private void ensureValidDouble(RankingExpression expression, String expressionDescription, TypeContext context) { + if (expression == null) return; + TensorType type = ensureValid(expression, expressionDescription, context); + if ( ! type.equals(TensorType.empty)) + throw new IllegalArgumentException("The " + expressionDescription + " must produce a double " + + "(a tensor with no dimensions), but produces " + type); + } + +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java b/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java index 640a85d9b50..15b482ee60c 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/IdentityProvider.java @@ -4,21 +4,24 @@ package com.yahoo.vespa.model.container; import com.yahoo.config.provision.AthenzDomain; import com.yahoo.config.provision.AthenzService; import com.yahoo.config.provision.HostName; +import com.yahoo.container.bundle.BundleInstantiationSpecification; import com.yahoo.container.core.identity.IdentityConfig; +import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.vespa.model.container.component.SimpleComponent; /** * @author mortent */ public class IdentityProvider extends SimpleComponent implements IdentityConfig.Producer { - public static final String CLASS = "com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl"; + public static final String CLASS = "com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl"; + public static final String BUNDLE = "vespa-athenz"; private final AthenzDomain domain; private final AthenzService service; private final HostName loadBalancerName; public IdentityProvider(AthenzDomain domain, AthenzService service, HostName loadBalancerName) { - super(CLASS); + super(new ComponentModel(BundleInstantiationSpecification.getFromStrings(CLASS, CLASS, BUNDLE))); this.domain = domain; this.service = service; this.loadBalancerName = loadBalancerName; diff --git a/config-model/src/test/derived/rankexpression/rank-profiles.cfg b/config-model/src/test/derived/rankexpression/rank-profiles.cfg index e890b75770b..f5652c31d2a 100644 --- a/config-model/src/test/derived/rankexpression/rank-profiles.cfg +++ b/config-model/src/test/derived/rankexpression/rank-profiles.cfg @@ -24,7 +24,7 @@ rankprofile[0].fef.property[10].value "4" rankprofile[0].fef.property[11].name "vespa.dump.feature" rankprofile[0].fef.property[11].value "attribute(foo1).out" rankprofile[0].fef.property[12].name "vespa.dump.feature" -rankprofile[0].fef.property[12].value "attribute(bar1.out)" +rankprofile[0].fef.property[12].value "attribute(bar1)" rankprofile[0].fef.property[13].name "vespa.dump.feature" rankprofile[0].fef.property[13].value "attribute(foo2).out" rankprofile[0].fef.property[14].name "vespa.dump.feature" @@ -64,7 +64,7 @@ rankprofile[2].fef.property[2].value "10 + feature(arg1).out.out" rankprofile[2].fef.property[3].name "vespa.summary.feature" rankprofile[2].fef.property[3].value "attribute(foo1).out" rankprofile[2].fef.property[4].name "vespa.summary.feature" -rankprofile[2].fef.property[4].value "attribute(bar1.out)" +rankprofile[2].fef.property[4].value "attribute(bar1)" rankprofile[2].fef.property[5].name "vespa.summary.feature" rankprofile[2].fef.property[5].value "attribute(foo2).out" rankprofile[2].fef.property[6].name "vespa.summary.feature" diff --git a/config-model/src/test/derived/rankexpression/rankexpression.sd b/config-model/src/test/derived/rankexpression/rankexpression.sd index 8ed1f2bab4c..d3e0057cfe1 100644 --- a/config-model/src/test/derived/rankexpression/rankexpression.sd +++ b/config-model/src/test/derived/rankexpression/rankexpression.sd @@ -5,12 +5,10 @@ search rankexpression { field artist type string { indexing: summary | index - # index-to: artist, default } field title type string { indexing: summary | index - # index-to: title, default } field surl type string { @@ -21,6 +19,38 @@ search rankexpression { indexing: summary | attribute } + field foo1 type int { + indexing: attribute + } + + field foo2 type int { + indexing: attribute + } + + field foo3 type int { + indexing: attribute + } + + field foo4 type int { + indexing: attribute + } + + field bar1 type int { + indexing: attribute + } + + field bar2 type int { + indexing: attribute + } + + field bar3 type int { + indexing: attribute + } + + field bar4 type int { + indexing: attribute + } + } rank-profile default { @@ -33,7 +63,7 @@ search rankexpression { expression: if(3>2,4,2) rerank-count: 10 } - rank-features: attribute(foo1).out attribute(bar1.out) + rank-features: attribute(foo1).out attribute(bar1) rank-features { attribute(foo2).out attribute(bar2).out } rank-features { attribute(foo3).out attribute(bar3).out } @@ -65,7 +95,7 @@ search rankexpression { file:rankexpression } } - summary-features: attribute(foo1).out attribute(bar1.out) + summary-features: attribute(foo1).out attribute(bar1) summary-features { attribute(foo2).out attribute(bar2).out } summary-features { attribute(foo3).out attribute(bar3).out } diff --git a/config-model/src/test/derived/rankexpression/summary.cfg b/config-model/src/test/derived/rankexpression/summary.cfg index 00df2e87144..9752a9f55e3 100644 --- a/config-model/src/test/derived/rankexpression/summary.cfg +++ b/config-model/src/test/derived/rankexpression/summary.cfg @@ -15,9 +15,25 @@ classes[0].fields[5].name "summaryfeatures" classes[0].fields[5].type "featuredata" classes[0].fields[6].name "documentid" classes[0].fields[6].type "longstring" -classes[1].id 1787488393 +classes[1].id 1736696699 classes[1].name "attributeprefetch" classes[1].fields[0].name "year" +classes[].fields[].type "integer" +classes[].fields[].name "foo1" +classes[].fields[].type "integer" +classes[].fields[].name "foo2" +classes[].fields[].type "integer" +classes[].fields[].name "foo3" +classes[].fields[].type "integer" +classes[].fields[].name "foo4" +classes[].fields[].type "integer" +classes[].fields[].name "bar1" +classes[].fields[].type "integer" +classes[].fields[].name "bar2" +classes[].fields[].type "integer" +classes[].fields[].name "bar3" +classes[].fields[].type "integer" +classes[].fields[].name "bar4" classes[1].fields[0].type "integer" classes[1].fields[1].name "rankfeatures" classes[1].fields[1].type "featuredata" diff --git a/config-model/src/test/derived/rankexpression/summarymap.cfg b/config-model/src/test/derived/rankexpression/summarymap.cfg index c810f7282ba..21e6cdf346f 100644 --- a/config-model/src/test/derived/rankexpression/summarymap.cfg +++ b/config-model/src/test/derived/rankexpression/summarymap.cfg @@ -7,4 +7,28 @@ override[1].command "rankfeatures" override[1].arguments "" override[2].field "summaryfeatures" override[2].command "summaryfeatures" -override[2].arguments ""
\ No newline at end of file +override[2].arguments "" +override[].field "foo1" +override[].command "attribute" +override[].arguments "foo1" +override[].field "foo2" +override[].command "attribute" +override[].arguments "foo2" +override[].field "foo3" +override[].command "attribute" +override[].arguments "foo3" +override[].field "foo4" +override[].command "attribute" +override[].arguments "foo4" +override[].field "bar1" +override[].command "attribute" +override[].arguments "bar1" +override[].field "bar2" +override[].command "attribute" +override[].arguments "bar2" +override[].field "bar3" +override[].command "attribute" +override[].arguments "bar3" +override[].field "bar4" +override[].command "attribute" +override[].arguments "bar4"
\ No newline at end of file diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 2b231e0cda2..b6ad5372c05 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -35,7 +35,7 @@ rankprofile[3].name "profile2" rankprofile[3].fef.property[0].name "vespa.rank.firstphase" rankprofile[3].fef.property[0].value "rankingExpression(firstphase)" rankprofile[3].fef.property[1].name "rankingExpression(firstphase).rankingScript" -rankprofile[3].fef.property[1].value "reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x)" +rankprofile[3].fef.property[1].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)" rankprofile[3].fef.property[2].name "vespa.type.attribute.f2" rankprofile[3].fef.property[2].value "tensor(x[2],y[])" rankprofile[3].fef.property[3].name "vespa.type.attribute.f3" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index a6a9a98db3a..3d64f6b807e 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -28,7 +28,7 @@ search tensor { rank-profile profile2 { first-phase { - expression: matmul(attribute(f4), diag(x[2],y[2],z[3]), x) + expression: sum(matmul(attribute(f4), diag(x[2],y[2],z[3]), x)) } } diff --git a/config-model/src/test/examples/rankpropvars.sd b/config-model/src/test/examples/rankpropvars.sd index 40f9e73f35a..28959edbc09 100644 --- a/config-model/src/test/examples/rankpropvars.sd +++ b/config-model/src/test/examples/rankpropvars.sd @@ -18,8 +18,8 @@ first-phase { second-phase { expression { if (attribute(artist) == query(testvar1), - 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist), - 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) + 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist), + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) } } @@ -42,8 +42,8 @@ first-phase { second-phase { expression { if (attribute(artist) == query(testvar1), - 0.0 * fieldMatch(title) + 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist), - 0.0 * attribute(popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) + 0.0 * fieldMatch(title) + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist), + 0.0 * attribute(Popularity) + 0.0 * fieldMatch(artist) + 0.0 * fieldMatch(title)) } } } diff --git a/config-model/src/test/examples/simple.sd b/config-model/src/test/examples/simple.sd index 4fda7f5039e..96b0fa98098 100644 --- a/config-model/src/test/examples/simple.sd +++ b/config-model/src/test/examples/simple.sd @@ -116,7 +116,7 @@ search simple { first-phase { keep-rank-count:200 rank-score-drop-limit: -13.0 - expression: attribute(year) + expression: attribute(popularity) } second-phase { rerank-count: 99 diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java index 1f60ad870ec..aa01070d296 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/FeatureNamesTestCase.java @@ -18,17 +18,6 @@ import static org.junit.Assert.assertFalse; public class FeatureNamesTestCase { @Test - public void testCanonicalization() { - assertFalse(FeatureNames.canonicalizeIfValid("foo").isPresent()); - assertEquals("query(bar)", FeatureNames.canonicalize("query(bar)")); - assertEquals("query(bar)", FeatureNames.canonicalize("query('bar')")); - assertEquals("constant(bar)", FeatureNames.canonicalize("constant(\"bar\")")); - assertEquals("query(\"ba.r\")", FeatureNames.canonicalize("query(ba.r)")); - assertEquals("query(\"ba.r\")", FeatureNames.canonicalize("query('ba.r')")); - assertEquals("attribute(\"ba.r\")", FeatureNames.canonicalize("attribute(\"ba.r\")")); - } - - @Test public void testArgument() { assertFalse(FeatureNames.argumentOf("foo(bar)").isPresent()); assertFalse(FeatureNames.argumentOf("foo(bar.baz)").isPresent()); @@ -42,17 +31,20 @@ public class FeatureNamesTestCase { @Test public void testConstantFeature() { - assertEquals("constant(\"foo/bar\")", FeatureNames.asConstantFeature("foo/bar")); + assertEquals("constant(\"foo/bar\")", + FeatureNames.asConstantFeature("foo/bar").toString()); } @Test public void testAttributeFeature() { - assertEquals("attribute(foo)", FeatureNames.asAttributeFeature("foo")); + assertEquals("attribute(foo)", + FeatureNames.asAttributeFeature("foo").toString()); } @Test public void testQueryFeature() { - assertEquals("query(\"foo.bar\")", FeatureNames.asQueryFeature("foo.bar")); + assertEquals("query(\"foo.bar\")", + FeatureNames.asQueryFeature("foo.bar").toString()); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java index 442c8bd41bd..11093d9f008 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankProfileTestCase.java @@ -135,13 +135,13 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { @Test public void requireThatConfigIsDerivedForQueryFeatureTypeSettings() throws ParseException { RankProfileRegistry registry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(registry); + SearchBuilder builder = new SearchBuilder(registry, setupQueryProfileTypes()); builder.importString("search test {\n" + " document test { } \n" + " rank-profile p1 {}\n" + " rank-profile p2 {}\n" + "}"); - builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + builder.build(new BaseDeployLogger()); Search search = builder.getSearch(); assertEquals(4, registry.allRankProfiles().size()); @@ -151,7 +151,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { assertQueryFeatureTypeSettings(registry.getRankProfile(search, "p2"), search); } - private static QueryProfiles setupQueryProfileTypes() { + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -164,7 +164,7 @@ public class RankProfileTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(numeric)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return new QueryProfiles(registry); + return registry; } private static void assertQueryFeatureTypeSettings(RankProfile profile, Search search) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index e94880e61c7..82b9f5ac043 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -207,6 +207,9 @@ public class RankingExpressionConstantsTestCase extends SearchDefinitionTestCase builder.importString( "search test {\n" + " document test { \n" + + " field rating_yelp type int {" + + " indexing: attribute" + + " }" + " }\n" + " \n" + " rank-profile test {\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 5100ac15c40..ed1b00e2875 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -2,7 +2,10 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; +import com.yahoo.search.query.profile.QueryProfile; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; @@ -149,11 +152,12 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase censorBindingHash(testRankProperties.get(4).toString())); } - @Test public void testNeuralNetworkSetup() throws ParseException { + // Note: the type assigned to query profile and constant tensors here is not the correct type RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + QueryProfileRegistry queryProfiles = queryProfileWith("query(q)", "tensor(x[])"); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + " document test { \n" + @@ -176,13 +180,28 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase " expression: sum(final_layer)\n" + " }\n" + " }\n" + - "\n" + + " constant W_hidden {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant b_input {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant W_final {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + + " constant b_final {\n" + + " type: tensor(x[])\n" + + " file: ignored.json\n" + + " }\n" + "}\n"); builder.build(); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), + queryProfiles, new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(relu).rankingScript,max(1.0,x))", testRankProperties.get(0).toString()); @@ -198,6 +217,17 @@ public class RankingExpressionShadowingTestCase extends SearchDefinitionTestCase testRankProperties.get(5).toString()); } + private QueryProfileRegistry queryProfileWith(String field, String type) { + QueryProfileType queryProfileType = new QueryProfileType("root"); + queryProfileType.addField(new FieldDescription(field, type)); + QueryProfileRegistry queryProfileRegistry = new QueryProfileRegistry(); + queryProfileRegistry.getTypeRegistry().register(queryProfileType); + QueryProfile profile = new QueryProfile("default"); + profile.setType(queryProfileType); + queryProfileRegistry.register(profile); + return queryProfileRegistry; + } + private String censorBindingHash(String s) { StringBuilder b = new StringBuilder(); boolean areInHash = false; diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 800697b3430..0ce6129ef7f 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -38,7 +38,8 @@ class RankProfileSearchFixture { RankProfileSearchFixture(ApplicationPackage applicationpackage, QueryProfileRegistry queryProfileRegistry, String rankProfiles, String constant, String field) throws ParseException { - SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, new QueryProfileRegistry()); + this.queryProfileRegistry = queryProfileRegistry; + SearchBuilder builder = new SearchBuilder(applicationpackage, rankProfileRegistry, queryProfileRegistry); String sdContent = "search test {\n" + " " + (constant != null ? constant : "") + "\n" + " document test {\n" + @@ -50,7 +51,6 @@ class RankProfileSearchFixture { builder.importString(sdContent); builder.build(); search = builder.getSearch(); - this.queryProfileRegistry = queryProfileRegistry; } public void assertFirstPhaseExpression(String expExpression, String rankProfile) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java new file mode 100644 index 00000000000..5f5b40e545f --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeValidatorTestCase.java @@ -0,0 +1,239 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.yolean.Exceptions; +import org.junit.Test; + +import java.util.Map; +import java.util.stream.Collectors; + +import static com.yahoo.config.model.test.TestUtil.joinLines; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class RankingExpressionTypeValidatorTestCase { + + @Test + public void tensorFirstPhaseMustProduceDouble() throws Exception { + try { + SearchBuilder builder = new SearchBuilder(); + builder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: attribute(a)", + " }", + " }", + "}" + )); + builder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void tensorSecondPhaseMustProduceDouble() throws Exception { + try { + SearchBuilder builder = new SearchBuilder(); + builder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(attribute(a))", + " }", + " second-phase {", + " expression: attribute(a)", + " }", + " }", + "}" + )); + builder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The second-phase expression must produce a double (a tensor with no dimensions), but produces tensor(x[],y[])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void tensorConditionsMustHaveTypeCompatibleBranches() throws Exception { + try { + SearchBuilder searchBuilder = new SearchBuilder(); + searchBuilder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " field b type tensor(z[10]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(if(1>0, attribute(a), attribute(b)))", + " }", + " }", + "}" + )); + searchBuilder.build(); + fail("Expected exception"); + } + catch (IllegalArgumentException expected) { + assertEquals("In search definition 'test', rank profile 'my_rank_profile': The first-phase expression is invalid: An if expression must produce compatible types in both alternatives, but the 'true' type is tensor(x[],y[]) while the 'false' type is tensor(z[10])", + Exceptions.toMessageString(expected)); + } + } + + @Test + public void testMacroInvocationTypes() throws Exception { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " field b type tensor(z[10]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " macro macro1(attribute_to_use) {", + " expression: attribute(attribute_to_use)", + " }", + " summary-features {", + " macro1(a)", + " macro1(b)", + " }", + " }", + "}" + )); + builder.build(); + RankProfile profile = + builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile"); + assertEquals(TensorType.fromSpec("tensor(x[],y[])"), + summaryFeatures(profile).get("macro1(a)").type(profile.typeContext(builder.getQueryProfileRegistry()))); + assertEquals(TensorType.fromSpec("tensor(z[10])"), + summaryFeatures(profile).get("macro1(b)").type(profile.typeContext(builder.getQueryProfileRegistry()))); + } + + @Test + public void testTensorMacroInvocationTypes_Nested() throws Exception { + SearchBuilder builder = new SearchBuilder(); + builder.importString(joinLines( + "search test {", + " document test { ", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " field b type tensor(z[10]) {", + " indexing: attribute", + " }", + " }", + " rank-profile my_rank_profile {", + " macro return_a() {", + " expression: return_first(attribute(a), attribute(b))", + " }", + " macro return_b() {", + " expression: return_second(attribute(a), attribute(b))", + " }", + " macro return_first(e1, e2) {", + " expression: e1", + " }", + " macro return_second(e1, e2) {", + " expression: return_first(e2, e1)", + " }", + " summary-features {", + " return_a", + " return_b", + " }", + " }", + "}" + )); + builder.build(); + RankProfile profile = + builder.getRankProfileRegistry().getRankProfile(builder.getSearch(), "my_rank_profile"); + assertEquals(TensorType.fromSpec("tensor(x[],y[])"), + summaryFeatures(profile).get("return_a").type(profile.typeContext(builder.getQueryProfileRegistry()))); + assertEquals(TensorType.fromSpec("tensor(z[10])"), + summaryFeatures(profile).get("return_b").type(profile.typeContext(builder.getQueryProfileRegistry()))); + } + + @Test + public void importedFieldsAreAvailable() throws Exception { + SearchBuilder builder = new SearchBuilder(); + builder.importString(joinLines( + "search parent {", + " document parent {", + " field a type tensor(x[],y[]) {", + " indexing: attribute", + " }", + " }", + "}" + )); + builder.importString(joinLines( + "search child {", + " document child { ", + " field ref type reference<parent> {", + "indexing: attribute | summary", + " }", + " }", + " import field ref.a as imported_a {}", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: sum(attribute(imported_a))", + " }", + " }", + "}" + )); + builder.build(); + } + + @Test + public void undeclaredQueryFeaturesAreAccepted() throws Exception { + SearchBuilder builder = new SearchBuilder(); + builder.importString(joinLines( + "search test {", + " document test { ", + " }", + " rank-profile my_rank_profile {", + " first-phase {", + " expression: query(foo)", + " }", + " }", + "}" + )); + builder.build(); + } + + private Map<String, ReferenceNode> summaryFeatures(RankProfile profile) { + return profile.getSummaryFeatures().stream().collect(Collectors.toMap(f -> f.toString(), f -> f)); + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 4693ac5cf4d..1e376824b7b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -42,7 +42,7 @@ import static org.junit.Assert.*; public class RankingExpressionWithTensorFlowTestCase { private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/"); - private final String vespaExpression = "join(rename(reduce(join(Placeholder, rename(constant(\"layer_Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"layer_Variable_1\"), d0, d1), f(a,b)(a + b))"; + private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b))"; @After public void removeGeneratedConstantTensorFiles() { @@ -54,8 +54,8 @@ public class RankingExpressionWithTensorFlowTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -65,15 +65,15 @@ public class RankingExpressionWithTensorFlowTestCase { "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test public void testTensorFlowReferenceWithQueryFeature() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='mytensor' type='tensor(d0[3],d1[784])'/>" + + " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -85,8 +85,8 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -99,15 +99,15 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test public void testTensorFlowReferenceWithFeatureCombination() { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='mytensor' type='tensor(d0[3],d1[784],d2[10])'/>" + + " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -119,8 +119,8 @@ public class RankingExpressionWithTensorFlowTestCase { "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -128,8 +128,8 @@ public class RankingExpressionWithTensorFlowTestCase { RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "5 + sum(tensorflow('mnist_softmax/saved'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); } @Test @@ -224,8 +224,8 @@ public class RankingExpressionWithTensorFlowTestCase { "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); - assertLargeConstant("layer_Variable_1", search, Optional.of(10L)); - assertLargeConstant("layer_Variable", search, Optional.of(7840L)); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); // At this point the expression is stored - copy application to another location which do not have a models dir Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); @@ -243,8 +243,8 @@ public class RankingExpressionWithTensorFlowTestCase { searchFromStored.assertFirstPhaseExpression(vespaExpression, "my_profile"); // Verify that the constants exists, but don't verify the content as we are not // simulating file distribution in this test - assertLargeConstant("layer_Variable_1", searchFromStored, Optional.empty()); - assertLargeConstant("layer_Variable", searchFromStored, Optional.empty()); + assertLargeConstant("layer_Variable_1_read", searchFromStored, Optional.empty()); + assertLargeConstant("layer_Variable_read", searchFromStored, Optional.empty()); } finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); @@ -253,7 +253,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(rename(reduce(join(join(join(rename(constant(\"dnn_hidden2_Const\"), d0, d1), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(join(join(0.009999999776482582, join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(a * b)), join(rename(reduce(join(input, rename(constant(\"dnn_hidden1_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden1_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_hidden2_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_hidden2_bias\"), d0, d1), f(a,b)(a + b)), f(a,b)(max(a,b))), rename(constant(\"dnn_outputs_weights\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"dnn_outputs_bias\"), d0, d1), f(a,b)(a + b))"; + final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", "tensorflow('mnist/saved')", diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java index b001db69768..054c9220225 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -17,98 +17,129 @@ import com.yahoo.searchdefinition.SearchDefinitionTestCase; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; import com.yahoo.searchdefinition.parser.ParseException; -import com.yahoo.vespa.model.container.search.QueryProfiles; import org.junit.Test; import java.util.List; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class TensorTransformTestCase extends SearchDefinitionTestCase { @Test public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException { - assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)"); - assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)"); - assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))"); - assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))"); - assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))"); - assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)"); - assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)"); - assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)"); - assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)"); - assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)"); + assertTransformedExpression("max(1.0,2.0)", + "max(1.0,2.0)"); + assertTransformedExpression("min(attribute(double_field),x)", + "min(attribute(double_field),x)"); + assertTransformedExpression("max(attribute(double_field),attribute(double_array_field))", + "max(attribute(double_field),attribute(double_array_field))"); + assertTransformedExpression("min(attribute(tensor_field_1),attribute(double_field))", + "min(attribute(tensor_field_1),attribute(double_field))"); + assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)", + "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),sum)"); + assertTransformedExpression("min(constant(test_constant_tensor),1.0)", + "min(test_constant_tensor,1.0)"); + assertTransformedExpression("max(constant(base_constant_tensor),1.0)", + "max(base_constant_tensor,1.0)"); + assertTransformedExpression("min(constant(file_constant_tensor),1.0)", + "min(constant(file_constant_tensor),1.0)"); + assertTransformedExpression("max(query(q),1.0)", + "max(query(q),1.0)"); + assertTransformedExpression("max(query(n),1.0)", + "max(query(n),1.0)"); } @Test public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { - assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)"); - assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)"); - assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)"); - assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); - assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); - assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error. - assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)"); + assertTransformedExpression("reduce(attribute(tensor_field_1),max,x)", + "max(attribute(tensor_field_1),x)"); + assertTransformedExpression("1+reduce(attribute(tensor_field_1),max,x)", + "1 + max(attribute(tensor_field_1),x)"); + assertTransformedExpression("if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)", + "if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)"); + assertTransformedExpression("reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)", + "max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)"); + assertTransformedExpression("reduce(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),max,x)", + "max(if(attribute(double_field),attribute(tensor_field_2),attribute(tensor_field_2)),x)"); + assertTransformedExpression("max(reduce(attribute(tensor_field_1),max,x),x)", + "max(max(attribute(tensor_field_1),x),x)"); // will result in deploy error. + assertTransformedExpression("reduce(reduce(attribute(tensor_field_2),max,x),max,y)", + "max(max(attribute(tensor_field_2),x),y)"); } @Test public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException { - assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)"); - assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)"); - assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)"); + assertTransformedExpression("reduce(constant(test_constant_tensor),max,x)", + "max(test_constant_tensor,x)"); + assertTransformedExpression("reduce(constant(base_constant_tensor),max,x)", + "max(base_constant_tensor,x)"); + assertTransformedExpression("reduce(constant(file_constant_tensor),min,x)", + "min(constant(file_constant_tensor),x)"); } @Test public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException { - assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)"); - assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)"); - assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"); - assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) - assertContainsExpression("min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)"); + assertTransformedExpression("reduce(attribute(double_field)+attribute(tensor_field_1),min,x)", + "min(attribute(double_field) + attribute(tensor_field_1),x)"); + assertTransformedExpression("reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)", + "min(attribute(tensor_field_1) * attribute(tensor_field_2),x)"); + assertTransformedExpression("reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)", + "min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)"); + assertTransformedExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", + "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); // because tensor fields are not in attribute(...) + assertTransformedExpression("reduce(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),min,x)", + "min(join(attribute(tensor_field_1),backend_rank_feature,f(x,y)(x*y)),x)"); } @Test public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException { - assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)"); - assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)"); - assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)"); - assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)"); + assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)", + "max(tensorFromLabels(attribute(double_array_field)),double_array_field)"); + assertTransformedExpression("reduce(tensorFromLabels(attribute(double_array_field),x),max,x)", + "max(tensorFromLabels(attribute(double_array_field),x),x)"); + assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)", + "max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)"); + assertTransformedExpression("reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)", + "max(tensorFromWeightedSet(attribute(weightedset_field),x),x)"); } @Test public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException { - assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)"); - assertContainsExpression("max(query(n),x)", "max(query(n),x)"); + assertTransformedExpression("reduce(query(q),max,x)", "max(query(q),x)"); + assertTransformedExpression("max(query(n),x)", "max(query(n),x)"); } @Test public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException { - assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)"); - assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)"); - assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)"); - assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)"); + assertTransformedExpression("reduce(rankingExpression(returns_tensor),max,x)", + "max(returns_tensor,x)"); + assertTransformedExpression("reduce(rankingExpression(wraps_returns_tensor),max,x)", + "max(wraps_returns_tensor,x)"); + assertTransformedExpression("reduce(rankingExpression(tensor_inheriting),max,x)", + "max(tensor_inheriting,x)"); + assertTransformedExpression("reduce(rankingExpression(returns_tensor_with_arg@),max,x)", + "max(returns_tensor_with_arg(attribute(tensor_field_1)),x)"); } - private void assertContainsExpression(String expr, String transformedExpression) throws ParseException { - assertTrue("Expected expression '" + transformedExpression + "' found", - containsExpression(expr, transformedExpression)); - } - - private boolean containsExpression(String expr, String transformedExpression) throws ParseException { - for (Pair<String, String> rankPropertyExpression : buildSearch(expr)) { + private void assertTransformedExpression(String expected, String original) throws ParseException { + for (Pair<String, String> rankPropertyExpression : buildSearch(original)) { String rankProperty = rankPropertyExpression.getFirst(); if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) { String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ","")); - return rankExpression.equals(transformedExpression); + assertEquals(expected, rankExpression); + return; } } - return false; + fail("No 'rankingExpression(firstphase).rankingScript' property produced"); } private List<Pair<String, String>> buildSearch(String expression) throws ParseException { RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + QueryProfileRegistry queryProfiles = setupQueryProfileTypes(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry, queryProfiles); builder.importString( "search test {\n" + " document test { \n" + @@ -167,16 +198,16 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { " }\n" + " }\n" + "}\n"); - builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + builder.build(new BaseDeployLogger()); Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(new QueryProfileRegistry()); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(queryProfiles); List<Pair<String, String>> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), + queryProfiles, new AttributeFields(s)).configProperties(); return testRankProperties; } - private static QueryProfiles setupQueryProfileTypes() { + private static QueryProfileRegistry setupQueryProfileTypes() { QueryProfileRegistry registry = new QueryProfileRegistry(); QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); @@ -185,7 +216,7 @@ public class TensorTransformTestCase extends SearchDefinitionTestCase { type.addField(new FieldDescription("ranking.features.query(n)", FieldType.fromString("integer", typeRegistry)), typeRegistry); typeRegistry.register(type); - return new QueryProfiles(registry); + return registry; } private String censorBindingHash(String s) { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java index a4ab5ebdb5e..aeddd05209f 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/change/search/DocumentTypeChangeValidatorTest.java @@ -25,7 +25,6 @@ import static org.junit.Assert.assertTrue; * Test validation of changes between a current and next document type used in a document database. * * @author toregge - * @since 2014-11-25 */ public class DocumentTypeChangeValidatorTest { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index 9bf2a858476..d3edd1c0ca5 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -737,7 +737,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { // TODO: Propagate all filters Optional<Hostname> hostname = Optional.ofNullable(request.getProperty("hostname")).map(Hostname::new); - + controller.applications().restart(deploymentId, hostname); // TODO: Change to return JSON return new StringResponse("Requested restart of " + path(TenantResource.API_PATH, tenantName, diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java index 0e703cf4cec..5be7fe03319 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/filter/ControllerAuthorizationFilter.java @@ -8,6 +8,7 @@ import com.yahoo.jdisc.handler.ResponseHandler; import com.yahoo.jdisc.http.HttpRequest.Method; import com.yahoo.jdisc.http.filter.DiscFilterRequest; import com.yahoo.jdisc.http.filter.SecurityRequestFilter; +import com.yahoo.log.LogLevel; import com.yahoo.vespa.athenz.api.AthenzDomain; import com.yahoo.vespa.athenz.api.AthenzIdentity; import com.yahoo.vespa.athenz.api.AthenzPrincipal; @@ -30,6 +31,7 @@ import javax.ws.rs.WebApplicationException; import java.util.Arrays; import java.util.List; import java.util.Optional; +import java.util.logging.Logger; import static com.yahoo.jdisc.http.HttpRequest.Method.GET; import static com.yahoo.jdisc.http.HttpRequest.Method.HEAD; @@ -49,6 +51,8 @@ public class ControllerAuthorizationFilter implements SecurityRequestFilter { private static final List<Method> WHITELISTED_METHODS = Arrays.asList(GET, OPTIONS, HEAD); + private static final Logger log = Logger.getLogger(ControllerAuthorizationFilter.class.getName()); + private final AthenzClientFactory clientFactory; private final Controller controller; private final EntityService entityService; @@ -261,7 +265,10 @@ public class ControllerAuthorizationFilter implements SecurityRequestFilter { public void handle(ResponseHandler responseHandler, DiscFilterRequest request, WebApplicationException exception) { - sendErrorResponse(responseHandler, exception.getResponse().getStatus(), exception.getMessage()); + int statusCode = exception.getResponse().getStatus(); + String errorMessage = exception.getMessage(); + log.log(LogLevel.WARNING, String.format("Access denied(%d): %s", statusCode, errorMessage), exception); + sendErrorResponse(responseHandler, statusCode, errorMessage); } } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java index f5852b9dfcf..d0154ace4e0 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResource.java @@ -4,7 +4,7 @@ package com.yahoo.vespa.hosted.restapi.impl; import com.fasterxml.jackson.databind.JsonNode; import com.google.inject.Inject; import com.yahoo.container.jaxrs.annotation.Component; -import com.yahoo.vespa.hosted.controller.api.integration.security.KeyService; +import com.yahoo.container.jdisc.Ckms; import javax.ws.rs.Path; import javax.ws.rs.Produces; @@ -24,20 +24,20 @@ import javax.ws.rs.core.UriBuilder; public class StatusPageResource implements com.yahoo.vespa.hosted.controller.api.statuspage.StatusPageResource { private final Client client; - private final KeyService keyService; + private final Ckms ckms; @Inject - public StatusPageResource(@Component KeyService keyService) { - this(keyService, ClientBuilder.newClient()); + public StatusPageResource(@Component Ckms ckms) { + this(ckms, ClientBuilder.newClient()); } - protected StatusPageResource(KeyService keyService, Client client) { - this.keyService = keyService; + protected StatusPageResource(Ckms ckms, Client client) { + this.ckms = ckms; this.client = client; } protected UriBuilder statusPageURL(String page, String since) { - String[] secrets = keyService.getSecret("vespa_hosted.controller.statuspage_api_key").split(":"); + String[] secrets = ckms.getSecret("vespa_hosted.controller.statuspage_api_key").split(":"); UriBuilder uriBuilder = UriBuilder.fromUri("https://" + secrets[0] + ".statuspage.io/api/v2/" + page + ".json?api_key=" + secrets[1]); if (since != null) { uriBuilder.queryParam("since", since); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java index 4e2e4bb15b4..b116ba3b5ee 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/restapi/impl/StatusPageResourceTest.java @@ -3,7 +3,7 @@ package com.yahoo.vespa.hosted.restapi.impl; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; -import com.yahoo.vespa.hosted.controller.api.integration.security.KeyService; +import com.yahoo.container.jdisc.Ckms; import org.junit.Before; import org.junit.Test; import org.mockito.Mockito; @@ -30,15 +30,15 @@ public class StatusPageResourceTest { Client mockClient = Mockito.mock(Client.class); WebTarget mockTarget = Mockito.mock(WebTarget.class); Invocation.Builder mockRequest = Mockito.mock(Invocation.Builder.class); - KeyService keyService = Mockito.mock(KeyService.class); + Ckms ckms = Mockito.mock(Ckms.class); Mockito.when(mockClient.target(Mockito.any(UriBuilder.class))).thenReturn(mockTarget); Mockito.when(mockTarget.request()).thenReturn(mockRequest); Mockito.when(mockRequest.get(JsonNode.class)).thenReturn( new ObjectMapper().readTree("{\"page\":{\"name\":\"Vespa\"}}")); - Mockito.when(keyService.getSecret(Mockito.any(String.class))).thenReturn("testpage:testkey"); + Mockito.when(ckms.getSecret(Mockito.any(String.class))).thenReturn("testpage:testkey"); - statusPage = new StatusPageResource(keyService, mockClient); + statusPage = new StatusPageResource(ckms, mockClient); } diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java index e81c6325922..2da18e12e40 100644 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java +++ b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerImpl.java @@ -15,15 +15,14 @@ import com.github.dockerjava.api.model.Image; import com.github.dockerjava.api.model.Network; import com.github.dockerjava.api.model.Statistics; import com.github.dockerjava.core.DefaultDockerClientConfig; +import com.github.dockerjava.core.DockerClientConfig; import com.github.dockerjava.core.DockerClientImpl; -import com.github.dockerjava.core.RemoteApiVersion; import com.github.dockerjava.core.async.ResultCallbackTemplate; import com.github.dockerjava.core.command.BuildImageResultCallback; import com.github.dockerjava.core.command.ExecStartResultCallback; import com.github.dockerjava.core.command.PullImageResultCallback; import com.github.dockerjava.jaxrs.JerseyDockerCmdExecFactory; import com.google.inject.Inject; -import com.yahoo.log.LogLevel; import com.yahoo.vespa.hosted.dockerapi.metrics.CounterWrapper; import com.yahoo.vespa.hosted.dockerapi.metrics.Dimensions; import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper; @@ -34,7 +33,6 @@ import java.io.File; import java.io.IOException; import java.net.Inet6Address; import java.net.InetAddress; -import java.net.URI; import java.time.Duration; import java.util.Arrays; import java.util.Collections; @@ -58,13 +56,11 @@ public class DockerImpl implements Docker { public static final String DOCKER_CUSTOM_MACVLAN_NETWORK_NAME = "vespa-macvlan"; static final String LABEL_NAME_MANAGEDBY = "com.yahoo.vespa.managedby"; - - private final int SECONDS_TO_WAIT_BEFORE_KILLING; - private final boolean fallbackTo123OnErrors; private static final String FRAMEWORK_CONTAINER_PREFIX = "/"; + private final DockerConfig config; - private final boolean inProduction; - private Optional<DockerImageGarbageCollector> dockerImageGC = Optional.empty(); + private final Optional<DockerImageGarbageCollector> dockerImageGC; + private final int secondsToWaitBeforeKilling; private CounterWrapper numberOfDockerDaemonFails; private boolean started = false; @@ -76,63 +72,40 @@ public class DockerImpl implements Docker { DockerClient dockerClient; @Inject - public DockerImpl(final DockerConfig config, MetricReceiverWrapper metricReceiver) { - this(config, - true, /* fallback to 1.23 on errors */ - metricReceiver, - !config.isRunningLocally()); - } - - private DockerImpl(final DockerConfig config, - boolean fallbackTo123OnErrors, - MetricReceiverWrapper metricReceiverWrapper, - boolean inProduction) { + public DockerImpl(DockerConfig config, MetricReceiverWrapper metricReceiverWrapper) { this.config = config; - this.fallbackTo123OnErrors = fallbackTo123OnErrors; - this.inProduction = inProduction; - if (config == null) { - this.SECONDS_TO_WAIT_BEFORE_KILLING = 10; - } else { - SECONDS_TO_WAIT_BEFORE_KILLING = config.secondsToWaitBeforeKillingContainer(); - } - if (metricReceiverWrapper != null) { - setMetrics(metricReceiverWrapper); - } + + secondsToWaitBeforeKilling = Optional.ofNullable(config) + .map(DockerConfig::secondsToWaitBeforeKillingContainer) + .orElse(10); + + dockerImageGC = Optional.ofNullable(config) + .map(DockerConfig::imageGCMinTimeToLiveMinutes) + .map(Duration::ofMinutes) + .map(DockerImageGarbageCollector::new); + + Optional.ofNullable(metricReceiverWrapper).ifPresent(this::setMetrics); } // For testing DockerImpl(final DockerClient dockerClient) { - this(null, false, null, false); + this(null, null); this.dockerClient = dockerClient; } - // For testing - DockerImpl(final DockerConfig config, - boolean fallbackTo123OnErrors, - MetricReceiverWrapper metricReceiverWrapper) { - this(config, fallbackTo123OnErrors, metricReceiverWrapper, false); - } - @Override public void start() { if (started) return; started = true; if (config != null) { - if (dockerClient == null) { - dockerClient = initDockerConnection(); - } - if (inProduction) { - Duration minAgeToDelete = Duration.ofMinutes(config.imageGCMinTimeToLiveMinutes()); - dockerImageGC = Optional.of(new DockerImageGarbageCollector(minAgeToDelete)); - + dockerClient = createDockerClient(config); - if (!config.networkNATed()) { - try { - setupDockerNetworkIfNeeded(); - } catch (Exception e) { - throw new DockerException("Could not setup docker network", e); - } + if (!config.networkNATed()) { + try { + setupDockerNetworkIfNeeded(); + } catch (Exception e) { + throw new DockerException("Could not setup docker network", e); } } } @@ -143,21 +116,6 @@ public class DockerImpl implements Docker { return config.networkNATed(); } - static DefaultDockerClientConfig.Builder buildDockerClientConfig(DockerConfig config) { - DefaultDockerClientConfig.Builder dockerConfigBuilder = new DefaultDockerClientConfig.Builder() - .withDockerHost(config.uri()); - - if (URI.create(config.uri()).getScheme().equals("tcp") && !config.caCertPath().isEmpty()) { - // In current version of docker-java (3.0.2), withDockerTlsVerify() only effect is when using it together - // with withDockerCertPath(), where setting withDockerTlsVerify() must be set to true, otherwise the - // cert path parameter will be ignored. - // withDockerTlsVerify() has no effect when used with withCustomSslConfig() - dockerConfigBuilder.withCustomSslConfig(new VespaSSLConfig(config)); - } - - return dockerConfigBuilder; - } - private void setupDockerNetworkIfNeeded() throws IOException { if (!dockerClient.listNetworksCmd().withNameFilter(DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).exec().isEmpty()) return; @@ -366,7 +324,7 @@ public class DockerImpl implements Docker { @Override public void stopContainer(final ContainerName containerName) { try { - dockerClient.stopContainerCmd(containerName.asString()).withTimeout(SECONDS_TO_WAIT_BEFORE_KILLING).exec(); + dockerClient.stopContainerCmd(containerName.asString()).withTimeout(secondsToWaitBeforeKilling).exec(); } catch (NotModifiedException ignored) { // If is already stopped, ignore } catch (RuntimeException e) { @@ -545,36 +503,18 @@ public class DockerImpl implements Docker { } } - private DockerClient initDockerConnection() { + private static DockerClient createDockerClient(DockerConfig config) { JerseyDockerCmdExecFactory dockerFactory = new JerseyDockerCmdExecFactory() .withMaxPerRouteConnections(config.maxPerRouteConnections()) .withMaxTotalConnections(config.maxTotalConnections()) .withConnectTimeout(config.connectTimeoutMillis()) .withReadTimeout(config.readTimeoutMillis()); - RemoteApiVersion remoteApiVersion; - try { - remoteApiVersion = RemoteApiVersion.parseConfig(DockerClientImpl.getInstance( - buildDockerClientConfig(config).build()) - .withDockerCmdExecFactory(dockerFactory).versionCmd().exec().getApiVersion()); - logger.info("Found version of remote docker API: " + remoteApiVersion); - // From version 1.24 a field was removed which causes trouble with the current docker java code. - // When this is fixed, we can remove this and do not specify version. - if (remoteApiVersion.isGreaterOrEqual(RemoteApiVersion.VERSION_1_24)) { - remoteApiVersion = RemoteApiVersion.VERSION_1_23; - logger.info("Found version 1.24 or newer of remote API, using 1.23."); - } - } catch (Exception e) { - if (!fallbackTo123OnErrors) { - throw e; - } - logger.log(LogLevel.ERROR, "Failed when trying to figure out remote API version of docker, using 1.23", e); - remoteApiVersion = RemoteApiVersion.VERSION_1_23; - } - return DockerClientImpl.getInstance( - buildDockerClientConfig(config) - .withApiVersion(remoteApiVersion) - .build()) + DockerClientConfig dockerClientConfig = new DefaultDockerClientConfig.Builder() + .withDockerHost(config.uri()) + .build(); + + return DockerClientImpl.getInstance(dockerClientConfig) .withDockerCmdExecFactory(dockerFactory); } diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerTestUtils.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerTestUtils.java deleted file mode 100644 index 549af0d85cb..00000000000 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/DockerTestUtils.java +++ /dev/null @@ -1,94 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.dockerapi; - -import com.github.dockerjava.api.model.Network; -import com.yahoo.metrics.simple.MetricReceiver; -import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper; - -import java.io.File; - -/** - * Helper class for testing full integration with docker daemon, requires running daemon. To run these tests: - * - * MAC: - * 1. Install Docker Toolbox, and start it (Docker Quickstart Terminal) (you can close terminal window afterwards) - * 2. For network test, we need to make docker containers visible for Mac: sudo route add 172.18.0.0/16 192.168.99.100 - * - * @author freva - */ -public class DockerTestUtils { - private static final OS operatingSystem = getSystemOS(); - private static final String prefix = "/Users/" + System.getProperty("user.name") + "/.docker/machine/machines/default/"; - private static final DockerConfig dockerConfig = new DockerConfig(new DockerConfig.Builder() - .caCertPath( operatingSystem == OS.Mac_OS_X ? prefix + "ca.pem" : "") - .clientCertPath(operatingSystem == OS.Mac_OS_X ? prefix + "cert.pem" : "") - .clientKeyPath( operatingSystem == OS.Mac_OS_X ? prefix + "key.pem" : "") - .uri( operatingSystem == OS.Mac_OS_X ? "tcp://192.168.99.100:2376" : "tcp://localhost:2376") - .secondsToWaitBeforeKillingContainer(0)); - private static DockerImpl docker; - - public static boolean dockerDaemonIsPresent() { - if (docker != null) return true; - if (operatingSystem == OS.Unsupported) { - System.err.println("This test does not support " + System.getProperty("os.name") + " yet, ignoring test."); - return false; - } - - try { - getDocker(); // Will throw an exception if docker is not installed/incorrectly configured - return true; - } catch (Exception e) { - System.err.println("Please install Docker Toolbox and start Docker Quick Start Terminal once, ignoring test."); - System.err.println(e.getMessage()); - return false; - } - } - - public static DockerImpl getDocker() { - if (docker == null) { - DockerImpl tmpDocker = new DockerImpl( - dockerConfig, - false, /* fallback to 1.23 on errors */ - new MetricReceiverWrapper(MetricReceiver.nullImplementation)); - tmpDocker.start(); - createDockerTestNetworkIfNeeded(tmpDocker); - docker = tmpDocker; - } - - return docker; - } - - public static void createDockerTestNetworkIfNeeded(DockerImpl docker) { - if (! docker.dockerClient.listNetworksCmd().withNameFilter(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).exec().isEmpty()) return; - - Network.Ipam ipam = new Network.Ipam().withConfig(new Network.Ipam.Config() - .withSubnet("172.18.0.0/16") - .withGateway("172.18.0.1")); - docker.dockerClient.createNetworkCmd() - .withName(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).withDriver("bridge").withIpam(ipam).exec(); - } - - public static void buildSimpleHttpServerDockerImage(DockerImpl docker, DockerImage dockerImage) { - try { - docker.deleteImage(dockerImage); - } catch (Exception e) { - if (! e.getMessage().equals("Failed to delete docker image " + dockerImage.asString())) { - throw e; - } - } - - // Build the image locally - File dockerFileStream = new File("src/test/resources/simple-ipv6-server"); - docker.buildImage(dockerFileStream, dockerImage); - } - - public enum OS { Linux, Mac_OS_X, Unsupported } - - public static OS getSystemOS() { - switch (System.getProperty("os.name").toLowerCase()) { - case "linux": return OS.Linux; - case "mac os x": return OS.Mac_OS_X; - default: return OS.Unsupported; - } - } -} diff --git a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/VespaSSLConfig.java b/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/VespaSSLConfig.java deleted file mode 100644 index e9bc0181dd7..00000000000 --- a/docker-api/src/main/java/com/yahoo/vespa/hosted/dockerapi/VespaSSLConfig.java +++ /dev/null @@ -1,228 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.dockerapi; - -import com.github.dockerjava.api.exception.DockerClientException; -import com.github.dockerjava.core.SSLConfig; -import org.bouncycastle.asn1.ASN1ObjectIdentifier; -import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; -import org.bouncycastle.cert.X509CertificateHolder; -import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; -import org.bouncycastle.jce.provider.BouncyCastleProvider; -import org.bouncycastle.openssl.PEMKeyPair; -import org.bouncycastle.openssl.PEMParser; -import org.glassfish.jersey.SslConfigurator; - -import javax.net.ssl.SSLContext; -import java.io.BufferedReader; -import java.io.IOException; -import java.io.Reader; -import java.io.StringReader; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.security.KeyFactory; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.PrivateKey; -import java.security.Security; -import java.security.cert.Certificate; -import java.security.cert.CertificateException; -import java.security.spec.InvalidKeySpecException; -import java.security.spec.PKCS8EncodedKeySpec; -import java.util.ArrayList; -import java.util.List; - -import static java.util.Objects.requireNonNull; - - -/** - * This class is based off {@link com.github.dockerjava.core.LocalDirectorySSLConfig}, but with the ability to - * specify path to each of the certificates instead of directory path. Additionally it includes - * {@link com.github.dockerjava.core.util.CertificateUtils} because of version conflict of with - * com.google.code.findbugs.annotations - */ -public class VespaSSLConfig implements SSLConfig { - private final DockerConfig config; - - public VespaSSLConfig(DockerConfig config) { - this.config = config; - } - - @Override - public SSLContext getSSLContext() { - try { - Security.addProvider(new BouncyCastleProvider()); - - // properties acrobatics not needed for java > 1.6 - String httpProtocols = System.getProperty("https.protocols"); - System.setProperty("https.protocols", "TLSv1"); - SslConfigurator sslConfig = SslConfigurator.newInstance(true); - if (httpProtocols != null) { - System.setProperty("https.protocols", httpProtocols); - } - - String keypem = new String(Files.readAllBytes(Paths.get(config.clientKeyPath()))); - String certpem = new String(Files.readAllBytes(Paths.get(config.clientCertPath()))); - String capem = new String(Files.readAllBytes(Paths.get(config.caCertPath()))); - - sslConfig.keyStore(createKeyStore(keypem, certpem)); - sslConfig.keyStorePassword("docker"); - sslConfig.trustStore(createTrustStore(capem)); - - return sslConfig.createSSLContext(); - } catch (Exception e) { - throw new DockerClientException(e.getMessage(), e); - } - } - - public static KeyStore createKeyStore(final String keypem, final String certpem) throws NoSuchAlgorithmException, - IOException, CertificateException, KeyStoreException { - PrivateKey privateKey = loadPrivateKey(keypem); - requireNonNull(privateKey); - List<Certificate> privateCertificates = loadCertificates(certpem); - - KeyStore keyStore = KeyStore.getInstance("JKS"); - keyStore.load(null); - - keyStore.setKeyEntry("docker", - privateKey, - "docker".toCharArray(), - privateCertificates.toArray(new Certificate[privateCertificates.size()]) - ); - - return keyStore; - } - - /** - * from "cert.pem" String - */ - private static List<Certificate> loadCertificates(final String certpem) throws IOException, - CertificateException { - final StringReader certReader = new StringReader(certpem); - try (BufferedReader reader = new BufferedReader(certReader)) { - return loadCertificates(reader); - } - } - - /** - * "cert.pem" from reader - */ - private static List<Certificate> loadCertificates(final Reader reader) throws IOException, - CertificateException { - try (PEMParser pemParser = new PEMParser(reader)) { - List<Certificate> certificates = new ArrayList<>(); - - JcaX509CertificateConverter certificateConverter = new JcaX509CertificateConverter().setProvider("BC"); - Object certObj = pemParser.readObject(); - - if (certObj instanceof X509CertificateHolder) { - X509CertificateHolder certificateHolder = (X509CertificateHolder) certObj; - certificates.add(certificateConverter.getCertificate(certificateHolder)); - } - - return certificates; - } - } - - - /** - * Return private key ("key.pem") from Reader - */ - private static PrivateKey loadPrivateKey(final Reader reader) throws IOException, NoSuchAlgorithmException { - try (PEMParser pemParser = new PEMParser(reader)) { - Object readObject = pemParser.readObject(); - while (readObject != null) { - if (readObject instanceof PEMKeyPair) { - PEMKeyPair pemKeyPair = (PEMKeyPair) readObject; - PrivateKey privateKey = guessKey(pemKeyPair.getPrivateKeyInfo().getEncoded()); - if (privateKey != null) { - return privateKey; - } - } else if (readObject instanceof PrivateKeyInfo) { - PrivateKeyInfo privateKeyInfo = (PrivateKeyInfo) readObject; - PrivateKey privateKey = guessKey(privateKeyInfo.getEncoded()); - if (privateKey != null) { - return privateKey; - } - } else if (readObject instanceof ASN1ObjectIdentifier) { - // no idea how it can be used - final ASN1ObjectIdentifier asn1ObjectIdentifier = (ASN1ObjectIdentifier) readObject; - } - - readObject = pemParser.readObject(); - } - } - - return null; - } - - private static PrivateKey guessKey(byte[] encodedKey) throws NoSuchAlgorithmException { - //no way to know, so iterate - for (String guessFactory : new String[]{"RSA", "ECDSA"}) { - try { - KeyFactory factory = KeyFactory.getInstance(guessFactory); - - PKCS8EncodedKeySpec privateKeySpec = new PKCS8EncodedKeySpec(encodedKey); - return factory.generatePrivate(privateKeySpec); - } catch (InvalidKeySpecException ignore) { - } - } - - return null; - } - - /** - * Return KeyPair from "key.pem" - */ - private static PrivateKey loadPrivateKey(final String keypem) throws IOException, NoSuchAlgorithmException { - try (StringReader certReader = new StringReader(keypem); - BufferedReader reader = new BufferedReader(certReader)) { - return loadPrivateKey(reader); - } - } - - /** - * "ca.pem" from String - */ - public static KeyStore createTrustStore(String capem) throws IOException, CertificateException, - KeyStoreException, NoSuchAlgorithmException { - try (Reader certReader = new StringReader(capem)) { - return createTrustStore(certReader); - } - } - - /** - * "ca.pem" from Reader - */ - public static KeyStore createTrustStore(final Reader certReader) throws IOException, CertificateException, - KeyStoreException, NoSuchAlgorithmException { - try (PEMParser pemParser = new PEMParser(certReader)) { - X509CertificateHolder certificateHolder = (X509CertificateHolder) pemParser.readObject(); - Certificate caCertificate = new JcaX509CertificateConverter() - .setProvider("BC") - .getCertificate(certificateHolder); - - KeyStore trustStore = KeyStore.getInstance("JKS"); - trustStore.load(null); - trustStore.setCertificateEntry("ca", caCertificate); - - return trustStore; - } - } - - @Override - public boolean equals(Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - VespaSSLConfig that = (VespaSSLConfig) o; - - return config.equals(that.config); - - } - - @Override - public int hashCode() { - return config.hashCode(); - } -} diff --git a/docker-api/src/main/resources/configdefinitions/docker.def b/docker-api/src/main/resources/configdefinitions/docker.def index b4585318cd8..83fee05dff6 100644 --- a/docker-api/src/main/resources/configdefinitions/docker.def +++ b/docker-api/src/main/resources/configdefinitions/docker.def @@ -1,9 +1,6 @@ # Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. namespace=vespa.hosted.dockerapi -caCertPath string default = "" -clientCertPath string default = "" -clientKeyPath string default = "" uri string default = "unix:///host/var/run/docker.sock" secondsToWaitBeforeKillingContainer int default = 10 diff --git a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java index 12e52dde494..654b5df3f3b 100644 --- a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java +++ b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerImplTest.java @@ -12,7 +12,6 @@ import com.github.dockerjava.api.command.InspectImageCmd; import com.github.dockerjava.api.command.InspectImageResponse; import com.github.dockerjava.api.command.PullImageCmd; import com.github.dockerjava.api.exception.NotFoundException; -import com.github.dockerjava.core.DefaultDockerClientConfig; import com.github.dockerjava.core.command.ExecStartResultCallback; import com.yahoo.metrics.simple.MetricReceiver; import com.yahoo.vespa.hosted.dockerapi.metrics.MetricReceiverWrapper; @@ -20,12 +19,6 @@ import org.junit.Test; import org.mockito.ArgumentCaptor; import org.mockito.Matchers; -import java.io.IOException; -import java.security.KeyManagementException; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.UnrecoverableKeyException; - import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertThat; @@ -39,60 +32,6 @@ import static org.mockito.Mockito.when; * @author tonytv */ public class DockerImplTest { - @Test - public void testDockerConfigWithUnixPath() throws UnrecoverableKeyException, NoSuchAlgorithmException, KeyStoreException, KeyManagementException { - String dockerUri = "unix:///var/run/docker.sock"; - DockerConfig config = createConfig(dockerUri, null, null, null); - DefaultDockerClientConfig clientConfig = DockerImpl.buildDockerClientConfig(config).build(); - - assertTrue("Docker uri incorrectly set", clientConfig.getDockerHost().toString().equals(dockerUri)); - assertTrue("SSL config was set when using socket", clientConfig.getSSLConfig() == null); - } - - @Test - public void testDockerConfigWithTcpPathWithoutSSL() { - String dockerUri = "tcp://127.0.0.1:2376"; - DockerConfig config = createConfig(dockerUri, null, null, null); - DefaultDockerClientConfig clientConfig = DockerImpl.buildDockerClientConfig(config).build(); - - assertTrue("Docker uri incorrectly set", clientConfig.getDockerHost().toString().equals(dockerUri)); - assertTrue("SSL config was set", clientConfig.getSSLConfig() == null); - } - - @Test - public void testDockerConfigWithTcpPathWithSslConfig() throws IOException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyStoreException, KeyManagementException { - String dockerUri = "tcp://127.0.0.1:2376"; - DockerConfig config = createConfig(dockerUri, "/some/path/ca", "/some/path/cert", "/some/path/key"); - DefaultDockerClientConfig clientConfig = DockerImpl.buildDockerClientConfig(config).build(); - - assertTrue("Docker uri incorrectly set", clientConfig.getDockerHost().toString().equals(dockerUri)); - assertTrue("SSL config was not set", clientConfig.getSSLConfig() != null); - } - - @Test(expected=RuntimeException.class) - public void testDockerConfigWithTcpPathWithInvalidSslConfig() throws IOException, UnrecoverableKeyException, NoSuchAlgorithmException, KeyStoreException, KeyManagementException { - String dockerUri = "tcp://127.0.0.1:2376"; - DockerConfig config = createConfig(dockerUri, "/some/path/ca", "/some/path/cert", "/some/path/key"); - DefaultDockerClientConfig clientConfig = DockerImpl.buildDockerClientConfig(config).build(); - - assertTrue("Docker uri incorrectly set", clientConfig.getDockerHost().toString().equals(dockerUri)); - assertTrue("SSL config was not set", clientConfig.getSSLConfig() != null); - - // SSL certificates are read during the getSSLContext(), the invalid paths should cause a RuntimeException - clientConfig.getSSLConfig().getSSLContext(); - } - - private static DockerConfig createConfig(String uri, String caCertPath, String clientCertPath, String clientKeyPath) { - DockerConfig.Builder configBuilder = new DockerConfig.Builder(); - - if (uri != null) configBuilder.uri(uri); - if (caCertPath != null) configBuilder.caCertPath(caCertPath); - if (clientCertPath != null) configBuilder.clientCertPath(clientCertPath); - if (clientKeyPath != null) configBuilder.clientKeyPath(clientKeyPath); - - return new DockerConfig(configBuilder); - } - @Test public void testExecuteCompletes() throws Exception { diff --git a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerTest.java b/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerTest.java deleted file mode 100644 index 18f87e5ae17..00000000000 --- a/docker-api/src/test/java/com/yahoo/vespa/hosted/dockerapi/DockerTest.java +++ /dev/null @@ -1,197 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.dockerapi; - -import org.apache.commons.io.IOUtils; -import org.junit.Before; -import org.junit.Ignore; -import org.junit.Test; - -import java.io.IOException; -import java.net.InetAddress; -import java.net.URL; -import java.util.Optional; -import java.util.concurrent.ExecutionException; - -import static org.hamcrest.CoreMatchers.is; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; -import static org.junit.Assume.assumeTrue; - -/** - * Requires docker daemon, see {@link com.yahoo.vespa.hosted.dockerapi.DockerTestUtils} for more details. - * - * @author freva - * @author dybdahl - */ -public class DockerTest { - private DockerImpl docker; - private static final DockerImage dockerImage = new DockerImage("simple-ipv6-server:Dockerfile"); - private static final String MANAGER_NAME = "docker-test"; - - // Ignored because the test is very slow (several minutes) when swap is enabled, to disable: (Linux) - // $ sudo swapoff -a - @Ignore - @Test - public void testOutOfMemoryDoesNotAffectOtherContainers() throws InterruptedException, ExecutionException, IOException { - String hostName1 = "docker10.test.yahoo.com"; - String hostName2 = "docker11.test.yahoo.com"; - ContainerName containerName1 = new ContainerName("docker-test-1"); - ContainerName containerName2 = new ContainerName("docker-test-2"); - InetAddress inetAddress1 = InetAddress.getByName("172.18.10.10"); - InetAddress inetAddress2 = InetAddress.getByName("172.18.10.11"); - - docker.createContainerCommand(dockerImage, ContainerResources.from(0, 0.1), containerName1, hostName1) - .withManagedBy(MANAGER_NAME) - .withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME) - .withIpAddress(inetAddress1) - .create(); - docker.startContainer(containerName1); - - docker.createContainerCommand(dockerImage, ContainerResources.from(0, 0.1), containerName2, hostName2) - .withManagedBy(MANAGER_NAME) - .withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME) - .withIpAddress(inetAddress2) - .create(); - docker.startContainer(containerName2); - - // 137 = 128 + 9 = kill -9 (SIGKILL), doesn't need to be run as "root", but "yahoo" does not exist in this basic image - assertThat(docker.executeInContainerAsRoot(containerName2, "python", "/pysrc/fillmem.py", "90").getExitStatus(), is(137)); - - // Verify that both HTTP servers are still up - testReachabilityFromHost("http://" + inetAddress1.getHostAddress() + "/ping"); - testReachabilityFromHost("http://" + inetAddress2.getHostAddress() + "/ping"); - - docker.stopContainer(containerName1); - docker.deleteContainer(containerName1); - - docker.stopContainer(containerName2); - docker.deleteContainer(containerName2); - } - - @Test - public void testContainerCycle() throws IOException, InterruptedException, ExecutionException { - final ContainerName containerName = new ContainerName("docker-test-foo"); - final String containerHostname = "hostName1"; - - docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName, containerHostname) - .withManagedBy(MANAGER_NAME).create(); - Optional<Container> container = docker.getContainer(containerName); - assertTrue(container.isPresent()); - assertEquals(container.get().state, Container.State.CREATED); - - docker.startContainer(containerName); - container = docker.getContainer(containerName); - assertTrue(container.isPresent()); - assertEquals(container.get().state, Container.State.RUNNING); - - docker.dockerClient.pauseContainerCmd(containerName.asString()).exec(); - container = docker.getContainer(containerName); - assertTrue(container.isPresent()); - assertEquals(container.get().state, Container.State.PAUSED); - - docker.dockerClient.unpauseContainerCmd(containerName.asString()).exec(); - docker.stopContainer(containerName); - container = docker.getContainer(containerName); - assertTrue(container.isPresent()); - assertEquals(container.get().state, Container.State.EXITED); - - docker.deleteContainer(containerName); - assertThat(docker.listAllContainersManagedBy(MANAGER_NAME).isEmpty(), is(true)); - } - - /** - * Test the expected behavior for exec when it times out - it should throw an exception when it times out, - * and before the process completes. - * - * The test timeout value is set quite high to avoid noise if screwdriver is slow but lower than the process time. - */ - @Test(expected = DockerExecTimeoutException.class, timeout = 2000) - public void testContainerExecHounorsTimeout() throws IOException, InterruptedException, ExecutionException { - final ContainerName containerName = new ContainerName("docker-test-foo"); - final String containerHostname = "hostName1"; - - docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName, containerHostname) - .withManagedBy(MANAGER_NAME).create(); - docker.startContainer(containerName); - docker.executeInContainerAsRoot(containerName, 1L, "sh", "-c", "sleep 5"); - } - - /** - * Test the expected behavior for exec that completes before specified timeout - it should return when the process finishes and not - * wait for the timeout. Some previous tests indicated that this was not behaving correctly. - * - * No timeout implies infinite timeout. - * - * The test timeout value is set quite high to avoid noise if screwdriver is slow - */ - @Test(timeout = 4000) - public void testContainerExecDoesNotBlockUntilTimeoutWhenCommandFinishesBeforeTimeout() throws IOException, InterruptedException, ExecutionException { - final ContainerName containerName = new ContainerName("docker-test-foo"); - final String containerHostname = "hostName1"; - - docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName, containerHostname) - .withManagedBy(MANAGER_NAME).create(); - docker.startContainer(containerName); - docker.executeInContainerAsRoot(containerName, 2L, "sh", "-c", "echo hei"); - - // Also test that this is the behavoir when not specifying timeout - docker.executeInContainerAsRoot(containerName,"sh", "-c", "echo hei"); - } - - @Test - public void testDockerNetworking() throws InterruptedException, ExecutionException, IOException { - String hostName1 = "docker10.test.yahoo.com"; - String hostName2 = "docker11.test.yahoo.com"; - ContainerName containerName1 = new ContainerName("docker-test-1"); - ContainerName containerName2 = new ContainerName("docker-test-2"); - InetAddress inetAddress1 = InetAddress.getByName("172.18.10.10"); - InetAddress inetAddress2 = InetAddress.getByName("172.18.10.11"); - - docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName1, hostName1) - .withManagedBy(MANAGER_NAME) - .withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).withIpAddress(inetAddress1).create(); - docker.startContainer(containerName1); - - docker.createContainerCommand(dockerImage, ContainerResources.UNLIMITED, containerName2, hostName2) - .withManagedBy(MANAGER_NAME) - .withNetworkMode(DockerImpl.DOCKER_CUSTOM_MACVLAN_NETWORK_NAME).withIpAddress(inetAddress2).create(); - docker.startContainer(containerName2); - - testReachabilityFromHost("http://" + inetAddress1.getHostAddress() + "/ping"); - testReachabilityFromHost("http://" + inetAddress2.getHostAddress() + "/ping"); - - String[] curlFromNodeToNode = new String[]{"curl", "-g", "http://" + inetAddress2.getHostAddress() + "/ping"}; - ProcessResult result = docker.executeInContainerAsRoot(containerName1, curlFromNodeToNode); - assertThat("Could not reach " + containerName2.asString() + " from " + containerName1.asString(), - result.getOutput(), is("pong\n")); - - docker.stopContainer(containerName1); - docker.deleteContainer(containerName1); - - docker.stopContainer(containerName2); - docker.deleteContainer(containerName2); - } - - @Before - public void setup() throws InterruptedException, ExecutionException, IOException { - if (docker == null) { - assumeTrue(DockerTestUtils.dockerDaemonIsPresent()); - - docker = DockerTestUtils.getDocker(); - DockerTestUtils.buildSimpleHttpServerDockerImage(docker, dockerImage); - } - - // Clean up any non deleted containers from previous tests - docker.getAllContainersManagedBy(MANAGER_NAME).forEach(container -> { - if (container.state.isRunning()) docker.stopContainer(container.name); - docker.deleteContainer(container.name); - }); - } - - private void testReachabilityFromHost(String target) throws IOException, InterruptedException { - URL url = new URL(target); - String containerServer = IOUtils.toString(url.openStream()); - assertThat(containerServer, is("pong\n")); - } -} diff --git a/docker-api/src/test/resources/simple-ipv6-server/Dockerfile b/docker-api/src/test/resources/simple-ipv6-server/Dockerfile deleted file mode 100644 index ee33894dbeb..00000000000 --- a/docker-api/src/test/resources/simple-ipv6-server/Dockerfile +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -FROM gliderlabs/alpine:3.4 - -# Install python and curl -RUN apk-install python curl - -# Copy source -ADD src/ pysrc - -# Run http server on port 80 -EXPOSE 80 -CMD ["python", "/pysrc/server.py"] diff --git a/docker-api/src/test/resources/simple-ipv6-server/README b/docker-api/src/test/resources/simple-ipv6-server/README deleted file mode 100644 index 0cb96035c42..00000000000 --- a/docker-api/src/test/resources/simple-ipv6-server/README +++ /dev/null @@ -1,10 +0,0 @@ -Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -This is the source for a basic docker image that runs a python HTTP server listening at IPv6 port 80. -The server serves two basic paths: - /ip - returns IP address of the requester - /ping - returns string "pong" - - -To build the image run: -$ sudo docker build -t "simple-ipv6-server:Dockerfile" <path to this directory> diff --git a/docker-api/src/test/resources/simple-ipv6-server/src/fillmem.py b/docker-api/src/test/resources/simple-ipv6-server/src/fillmem.py deleted file mode 100644 index b3990bea859..00000000000 --- a/docker-api/src/test/resources/simple-ipv6-server/src/fillmem.py +++ /dev/null @@ -1,11 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -import sys -import time - -megabyte = [0] * (1024 * 1024 / 8) -data = megabyte * int(sys.argv[1]) - -while True: - time.sleep(1) - data.extend(megabyte) diff --git a/docker-api/src/test/resources/simple-ipv6-server/src/server.py b/docker-api/src/test/resources/simple-ipv6-server/src/server.py deleted file mode 100644 index 9b4d543d4ed..00000000000 --- a/docker-api/src/test/resources/simple-ipv6-server/src/server.py +++ /dev/null @@ -1,43 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -import socket -from BaseHTTPServer import HTTPServer -from SimpleHTTPServer import SimpleHTTPRequestHandler - - -class MyHandler(SimpleHTTPRequestHandler): - def do_GET(self): - if self.path == '/ip': - self.send_response(200) - self.send_header('Content-type', 'text/html') - self.end_headers() - self.wfile.write('Your IP address is %s\n' % self.client_address[0]) - return - - elif self.path == '/ping': - self.send_response(200) - self.send_header('Content-type', 'text/html') - self.end_headers() - self.wfile.write('pong\n') - return - - else: - self.send_response(404) - self.send_header('Content-type', 'text/html') - self.end_headers() - self.wfile.write('Could not find ' + self.path + '! Try /ping or /ip.\n') - return - - -class DualHTTPServer(HTTPServer): - def __init__(self, address, handler): - self.address_family = socket.AF_INET6 if (':' in address[0]) else socket.AF_INET - HTTPServer.__init__(self, address, handler) - - -def main(ipv6): - server = DualHTTPServer(('::' if ipv6 else '', 80), MyHandler) - server.serve_forever() - - -if __name__ == '__main__': - main(False) diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/ConfigServerClientsImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/RealConfigServerClients.java index 43a2c66a9e5..b8e16ee5910 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/ConfigServerClientsImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/RealConfigServerClients.java @@ -3,7 +3,7 @@ package com.yahoo.vespa.hosted.node.admin.configserver; import com.yahoo.vespa.hosted.node.admin.component.Environment; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeRepository; -import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.NodeRepositoryImpl; +import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.RealNodeRepository; import com.yahoo.vespa.hosted.node.admin.configserver.orchestrator.Orchestrator; import com.yahoo.vespa.hosted.node.admin.configserver.orchestrator.OrchestratorImpl; @@ -12,25 +12,25 @@ import java.util.Optional; /** * @author freva */ -public class ConfigServerClientsImpl implements ConfigServerClients { +public class RealConfigServerClients implements ConfigServerClients { private final Optional<ConfigServerApi> configServerApi; private final NodeRepository nodeRepository; private final Orchestrator orchestrator; - public ConfigServerClientsImpl(Environment environment) { + public RealConfigServerClients(Environment environment) { this(new SslConfigServerApiImpl(environment)); } - public ConfigServerClientsImpl(NodeRepository nodeRepository, Orchestrator orchestrator) { + public RealConfigServerClients(NodeRepository nodeRepository, Orchestrator orchestrator) { this(nodeRepository, orchestrator, Optional.empty()); } - private ConfigServerClientsImpl(ConfigServerApi configServerApi) { - this(new NodeRepositoryImpl(configServerApi), new OrchestratorImpl(configServerApi), Optional.of(configServerApi)); + private RealConfigServerClients(ConfigServerApi configServerApi) { + this(new RealNodeRepository(configServerApi), new OrchestratorImpl(configServerApi), Optional.of(configServerApi)); } - private ConfigServerClientsImpl(NodeRepository nodeRepository, Orchestrator orchestrator, + private RealConfigServerClients(NodeRepository nodeRepository, Orchestrator orchestrator, Optional<ConfigServerApi> configServerApi) { this.nodeRepository = nodeRepository; this.orchestrator = orchestrator; diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java index f2152dffc0c..5b22866fa15 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepository.java @@ -1,18 +1,20 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.hosted.node.admin.configserver.noderepository; -import com.yahoo.vespa.hosted.node.admin.ContainerAclSpec; -import com.yahoo.vespa.hosted.node.admin.ContainerNodeSpec; import com.yahoo.vespa.hosted.dockerapi.ContainerName; import com.yahoo.vespa.hosted.dockerapi.DockerImage; +import com.yahoo.vespa.hosted.node.admin.ContainerAclSpec; +import com.yahoo.vespa.hosted.node.admin.ContainerNodeSpec; +import com.yahoo.vespa.hosted.node.admin.component.Environment; import com.yahoo.vespa.hosted.node.admin.configserver.ConfigServerApi; -import com.yahoo.vespa.hosted.node.admin.nodeagent.NodeAttributes; +import com.yahoo.vespa.hosted.node.admin.configserver.HttpException; +import com.yahoo.vespa.hosted.node.admin.configserver.SslConfigServerApiImpl; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.GetAclResponse; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.GetNodesResponse; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.NodeMessageResponse; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.UpdateNodeAttributesRequestBody; import com.yahoo.vespa.hosted.node.admin.configserver.noderepository.bindings.UpdateNodeAttributesResponse; -import com.yahoo.vespa.hosted.node.admin.configserver.HttpException; +import com.yahoo.vespa.hosted.node.admin.nodeagent.NodeAttributes; import com.yahoo.vespa.hosted.node.admin.util.PrefixLogger; import com.yahoo.vespa.hosted.provision.Node; @@ -26,15 +28,19 @@ import java.util.stream.Collectors; /** * @author stiankri, dybis */ -public class NodeRepositoryImpl implements NodeRepository { - private static final PrefixLogger NODE_ADMIN_LOGGER = PrefixLogger.getNodeAdminLogger(NodeRepositoryImpl.class); +public class RealNodeRepository implements NodeRepository { + private static final PrefixLogger NODE_ADMIN_LOGGER = PrefixLogger.getNodeAdminLogger(RealNodeRepository.class); private final ConfigServerApi configServerApi; - public NodeRepositoryImpl(ConfigServerApi configServerApi) { + public RealNodeRepository(ConfigServerApi configServerApi) { this.configServerApi = configServerApi; } + public RealNodeRepository(Environment environment) { + this(new SslConfigServerApiImpl(environment)); + } + @Override public List<ContainerNodeSpec> getContainersToRun(String baseHostName) { final GetNodesResponse nodesForHost = configServerApi.get( diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java index 2d261195213..bc8a45f2dfb 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/docker/DockerOperationsImpl.java @@ -33,8 +33,6 @@ import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.Stream; -import static com.yahoo.vespa.defaults.Defaults.getDefaults; - /** * Class that wraps the Docker class and have some tools related to running programs in docker. * diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java index d19f64a2bc3..4b806c905d9 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeadmin/NodeAdminMain.java @@ -11,7 +11,7 @@ import com.yahoo.vespa.hosted.node.admin.component.AdminComponent; import com.yahoo.vespa.hosted.node.admin.component.Environment; import com.yahoo.vespa.hosted.node.admin.config.ConfigServerConfig; import com.yahoo.vespa.hosted.node.admin.component.DockerAdminComponent; -import com.yahoo.vespa.hosted.node.admin.configserver.ConfigServerClientsImpl; +import com.yahoo.vespa.hosted.node.admin.configserver.RealConfigServerClients; import com.yahoo.vespa.hosted.node.admin.provider.NodeAdminStateUpdater; import java.io.File; @@ -66,7 +66,7 @@ public class NodeAdminMain implements AutoCloseable { docker, metricReceiver, classLocking, - new ConfigServerClientsImpl(new Environment(configServerConfig))); + new RealConfigServerClients(new Environment(configServerConfig))); } logger.log(LogLevel.INFO, () -> { diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java index 2d93dff80a4..4851ad71ebb 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAttributes.java @@ -6,7 +6,7 @@ import com.yahoo.vespa.hosted.dockerapi.DockerImage; import java.util.Objects; -// It somewhat sucks that this class almost duplicates a binding class used by NodeRepositoryImpl, +// It somewhat sucks that this class almost duplicates a binding class used by RealNodeRepository, // but using the binding class here would be a layer violation, and would also tie this logic to // serialization-related dependencies it needs not have. public class NodeAttributes { diff --git a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImplTest.java b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java index 85e101714e8..fb3416615da 100644 --- a/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/NodeRepositoryImplTest.java +++ b/node-admin/src/test/java/com/yahoo/vespa/hosted/node/admin/configserver/noderepository/RealNodeRepositoryTest.java @@ -34,7 +34,7 @@ import static org.junit.Assert.fail; * * @author dybdahl */ -public class NodeRepositoryImplTest { +public class RealNodeRepositoryTest { private JDisc container; private ConfigServerApiImpl configServerApi; @@ -74,7 +74,7 @@ public class NodeRepositoryImplTest { private void waitForJdiscContainerToServe() throws InterruptedException { Instant start = Instant.now(); - NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi); + NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi); while (Instant.now().minusSeconds(120).isBefore(start)) { try { nodeRepositoryApi.getContainersToRun("foobar"); @@ -96,7 +96,7 @@ public class NodeRepositoryImplTest { @Test public void testGetContainersToRunApi() throws InterruptedException { waitForJdiscContainerToServe(); - NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi); + NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi); String dockerHostHostname = "dockerhost1.yahoo.com"; final List<ContainerNodeSpec> containersToRun = nodeRepositoryApi.getContainersToRun(dockerHostHostname); @@ -115,7 +115,7 @@ public class NodeRepositoryImplTest { @Test public void testGetContainer() throws InterruptedException, IOException { waitForJdiscContainerToServe(); - NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi); + NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi); String hostname = "host4.yahoo.com"; Optional<ContainerNodeSpec> nodeSpec = nodeRepositoryApi.getContainerNodeSpec(hostname); assertThat(nodeSpec.isPresent(), is(true)); @@ -125,7 +125,7 @@ public class NodeRepositoryImplTest { @Test public void testGetContainerForNonExistingNode() throws InterruptedException, IOException { waitForJdiscContainerToServe(); - NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi); + NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi); String hostname = "host-that-does-not-exist"; Optional<ContainerNodeSpec> nodeSpec = nodeRepositoryApi.getContainerNodeSpec(hostname); assertFalse(nodeSpec.isPresent()); @@ -134,7 +134,7 @@ public class NodeRepositoryImplTest { @Test public void testUpdateNodeAttributes() throws InterruptedException, IOException { waitForJdiscContainerToServe(); - NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi); + NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi); String hostname = "host4.yahoo.com"; nodeRepositoryApi.updateNodeAttributes( hostname, @@ -147,7 +147,7 @@ public class NodeRepositoryImplTest { @Test(expected = RuntimeException.class) public void testUpdateNodeAttributesWithBadValue() throws InterruptedException, IOException { waitForJdiscContainerToServe(); - NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi); + NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi); String hostname = "host4.yahoo.com"; nodeRepositoryApi.updateNodeAttributes( hostname, @@ -159,7 +159,7 @@ public class NodeRepositoryImplTest { @Test public void testMarkAsReady() throws InterruptedException, IOException { - NodeRepository nodeRepositoryApi = new NodeRepositoryImpl(configServerApi); + NodeRepository nodeRepositoryApi = new RealNodeRepository(configServerApi); waitForJdiscContainerToServe(); nodeRepositoryApi.markAsDirty("host5.yahoo.com"); diff --git a/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp b/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp index 466df61f8d0..a0cc89d15c6 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.cpp @@ -5,6 +5,43 @@ #include <vespa/log/log.h> LOG_SETUP(".queryperf"); +namespace { + +struct MyLogTask : vespalib::Executor::Task { + uint32_t queueLen; + uint32_t activeCnt; + uint32_t queryCnt; + uint32_t dropCnt; + uint32_t timeoutCnt; + double avgQueryTime; + MyLogTask(uint32_t queueLen_in, + uint32_t activeCnt_in, + uint32_t queryCnt_in, + uint32_t dropCnt_in, + uint32_t timeoutCnt_in, + double avgQueryTime_in) + : queueLen(queueLen_in), + activeCnt(activeCnt_in), + queryCnt(queryCnt_in), + dropCnt(dropCnt_in), + timeoutCnt(timeoutCnt_in), + avgQueryTime(avgQueryTime_in) + { + } + void run() override { + EV_VALUE("queued_queries", queueLen); + EV_VALUE("active_queries", activeCnt); + EV_COUNT("queries", queryCnt); + EV_COUNT("dropped_queries", dropCnt); + EV_COUNT("timedout_queries", timeoutCnt); + if (avgQueryTime > 0.0) { + EV_VALUE("query_eval_time_avg_s", avgQueryTime); + } + } +}; + +} // namespace <unnamed> + FastS_QueryPerf::FastS_QueryPerf() : queueLen(0), activeCnt(0), @@ -28,19 +65,20 @@ FastS_QueryPerf::reset() timeoutCnt = 0; } -void -FastS_QueryPerf::log() +vespalib::Executor::Task::UP +FastS_QueryPerf::make_log_task() { - EV_VALUE("queued_queries", queueLen); - EV_VALUE("active_queries", activeCnt); - EV_COUNT("queries", queryCnt); - EV_COUNT("dropped_queries", dropCnt); - EV_COUNT("timedout_queries", timeoutCnt); + double avgQueryTime = 0.0; if (queryCnt > _lastQueryCnt) { - double avgQueryTime = (queryTime - _lastQueryTime) - / ((double)(queryCnt - _lastQueryCnt)); - EV_VALUE("query_eval_time_avg_s", avgQueryTime); + avgQueryTime = (queryTime - _lastQueryTime) + / ((double)(queryCnt - _lastQueryCnt)); } _lastQueryCnt = queryCnt; _lastQueryTime = queryTime; + return std::make_unique<MyLogTask>(queueLen, + activeCnt, + queryCnt, + dropCnt, + timeoutCnt, + avgQueryTime); } diff --git a/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h b/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h index c4f20bc3cef..ee31a8e58b2 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h +++ b/searchcore/src/vespa/searchcore/fdispatch/common/queryperf.h @@ -3,6 +3,7 @@ #pragma once #include <cstdint> +#include <vespa/vespalib/util/executor.h> struct FastS_QueryPerf { @@ -20,7 +21,7 @@ struct FastS_QueryPerf * prepare the object for reuse logging wise. **/ void reset(); - void log(); + vespalib::Executor::Task::UP make_log_task(); private: uint32_t _lastQueryCnt; diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp index d9f2b4ecd4f..b68566c3c9b 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.cpp @@ -92,6 +92,7 @@ Fdispatch::~Fdispatch() LOG(debug, "Will close threadpool"); _mypool->Close(); + _executor.shutdown().sync(); LOG(debug, "Has closed threadpool"); _transportServer.reset(); _engineAdapter.reset(); @@ -194,7 +195,8 @@ Fdispatch::CheckTempFail() * Set up stuff as specified in the fdispatch-rc-file. */ Fdispatch::Fdispatch(const config::ConfigUri &configUri) - : _mypool(), + : _executor(1, 128 * 1024), + _mypool(), _engineAdapter(), _transportServer(), _componentConfig(), @@ -391,7 +393,7 @@ Fdispatch::Init() void Fdispatch::logPerformance() { - _nodeManager->logPerformance(); + _nodeManager->logPerformance(_executor); } uint32_t diff --git a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h index a0294e22655..6cfb4bfb5a1 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h +++ b/searchcore/src/vespa/searchcore/fdispatch/program/fdispatch.h @@ -11,6 +11,7 @@ #include <vespa/config/helper/configfetcher.h> #include <vespa/vespalib/net/simple_component_config_producer.h> #include <vespa/vespalib/util/random.h> +#include <vespa/vespalib/util/threadstackexecutor.h> class FastS_NodeManager; class FastS_fdispatch_RPC; @@ -62,6 +63,7 @@ private: Fdispatch(const Fdispatch &); Fdispatch& operator=(const Fdispatch &); + vespalib::ThreadStackExecutor _executor; std::unique_ptr<FastOS_ThreadPool> _mypool; std::unique_ptr<EngineAdapter> _engineAdapter; std::unique_ptr<TransportServer> _transportServer; diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp index 4b272a615a6..302f92cef39 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp @@ -391,7 +391,7 @@ FastS_NodeManager::getChildInfo() void -FastS_NodeManager::logPerformance() +FastS_NodeManager::logPerformance(vespalib::Executor &executor) { _queryPerf.reset(); FastS_DataSetCollection *dsc = GetDataSetCollection(); @@ -403,7 +403,7 @@ FastS_NodeManager::logPerformance() } dsc->subRef(); - _queryPerf.log(); + executor.execute(_queryPerf.make_log_task()); } diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h index e0396b46748..77d4482fba7 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h +++ b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.h @@ -8,6 +8,7 @@ #include <vespa/searchcore/fdispatch/common/queryperf.h> #include <vespa/vespalib/net/simple_component_config_producer.h> #include <vespa/config/subscription/configuri.h> +#include <vespa/vespalib/util/executor.h> #include <mutex> using vespa::config::search::core::PartitionsConfig; @@ -92,7 +93,7 @@ public: * log query performance. This method should only be invoked from * the FNET thread. **/ - void logPerformance(); + void logPerformance(vespalib::Executor &executor); void CheckEvents(FastS_TimeKeeper *timeKeeper); // invoked by FNET thread }; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 2e2858da238..262aba89f27 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.text.Utf8; @@ -11,9 +12,9 @@ import java.security.NoSuchAlgorithmException; import java.util.*; /** - * <p>A function defined by a ranking expression</p> + * A function defined by a ranking expression * - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen * @author bratseth */ public class ExpressionFunction { @@ -23,7 +24,7 @@ public class ExpressionFunction { private final RankingExpression body; /** - * <p>Constructs a new function</p> + * Constructs a new function * * @param name the name of this function * @param arguments its argument names @@ -43,28 +44,27 @@ public class ExpressionFunction { public RankingExpression getBody() { return body; } /** - * <p>Create and return an instance of this function based on the given - * arguments. If function calls are nested, this call might produce - * additional scripts.</p> + * Creates and returns an instance of this function based on the given + * arguments. If function calls are nested, this call may produce + * additional functions. * * @param context the context used to expand this - * @param arguments the arguments to instantiate on. + * @param argumentValues the arguments to instantiate on. * @param path the expansion path leading to this. * @return the script function instance created. */ - public Instance expand(SerializationContext context, List<ExpressionNode> arguments, Deque<String> path) { + public Instance expand(SerializationContext context, List<ExpressionNode> argumentValues, Deque<String> path) { Map<String, String> argumentBindings = new HashMap<>(); - for (int i = 0; i < this.arguments.size() && i < arguments.size(); ++i) { - argumentBindings.put(this.arguments.get(i), arguments.get(i).toString(context, path, null)); + for (int i = 0; i < arguments.size() && i < arguments.size(); ++i) { + argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(context, path, null)); } - return new Instance(toSymbol(argumentBindings), body.getRoot().toString(context.createBinding(argumentBindings), path, null)); + return new Instance(toSymbol(argumentBindings), body.getRoot().toString(context.withBindings(argumentBindings), path, null)); } /** * Returns a symbolic string that represents this function with a given * list of arguments. The arguments are mangled by hashing the string - * representation of the argument expressions, so we might need to revisit - * this if we start seeing collisions. + * representation of the argument expressions. * * @param argumentBindings the bound arguments to include in the symbolic name. * @return the symbolic name for an instance of this function @@ -85,8 +85,8 @@ public class ExpressionFunction { /** - * <p>Returns a more unique hash code than what Java's own {@link - * String#hashCode()} method would produce.</p> + * Returns a more unique hash code than what Java's own {@link + * String#hashCode()} method would produce. * * @param str The string to hash. * @return A 64 bit long hash code. @@ -136,4 +136,5 @@ public class ExpressionFunction { } } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java index 49466f1974d..f0532d9d433 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/FeatureList.java @@ -91,8 +91,8 @@ public class FeatureList implements Iterable<ReferenceNode> { /** * Returns the feature at the given index. * - * @param i The index of the feature to return. - * @return The featuer at the given index. + * @param i the index of the feature to return. + * @return the feature at the given index. */ public ReferenceNode get(int i) { return features.get(i); @@ -137,4 +137,5 @@ public class FeatureList implements Iterable<ReferenceNode> { public Iterator<ReferenceNode> iterator() { return features.iterator(); } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index c8d90e8c4e8..6b2422d7cb2 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -244,10 +244,6 @@ public class RankingExpression implements Serializable { * @return a list of named rank properties required to implement this expression. */ public Map<String, String> getRankProperties(List<ExpressionFunction> macros) { - Map<String, ExpressionFunction> arg = new HashMap<>(); - for (ExpressionFunction function : macros) { - arg.put(function.getName(), function); - } Deque<String> path = new LinkedList<>(); SerializationContext context = new SerializationContext(macros); String serializedRoot = root.toString(context, path, null); @@ -272,7 +268,7 @@ public class RankingExpression implements Serializable { * * @throws IllegalArgumentException if this expression is not type correct in this context */ - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return root.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java new file mode 100644 index 00000000000..6277721e8f5 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java @@ -0,0 +1,121 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression; + +import com.yahoo.searchlib.rankingexpression.rule.Arguments; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Deque; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A reference to a feature, function, or value in ranking expressions + * + * @author bratseth + */ +public class Reference extends TypeContext.Name { + + private final String name; + private final Arguments arguments; + + /** + * The output, or null if none + */ + private final String output; + + public Reference(String name, Arguments arguments, String output) { + super(name); + Objects.requireNonNull(name, "name cannot be null"); + Objects.requireNonNull(arguments, "arguments cannot be null"); + this.name = name; + this.arguments = arguments; + this.output = output; + } + + public String name() { return name; } + + public Arguments arguments() { return arguments; } + + public String output() { return output; } + + /** + * Creates a reference to a simple feature consisting of a name and a single argument + */ + public static Reference simple(String name, String argumentValue) { + return new Reference(name, + new Arguments(new ReferenceNode(argumentValue)), + null); + } + + /** + * Returns the given simple feature as a reference, or empty if it is not a valid simple + * feature string on the form name(argument). + */ + public static Optional<Reference> simple(String feature) { + int startParenthesis = feature.indexOf('('); + if (startParenthesis < 0) + return Optional.empty(); + int endParenthesis = feature.lastIndexOf(')'); + String featureName = feature.substring(0, startParenthesis); + if (startParenthesis < 1 || endParenthesis < startParenthesis) return Optional.empty(); + String argument = feature.substring(startParenthesis + 1, endParenthesis); + if (argument.startsWith("'") || argument.startsWith("\"")) + argument = argument.substring(1); + if (argument.endsWith("'") || argument.endsWith("\"")) + argument = argument.substring(0, argument.length() - 1); + return Optional.of(simple(featureName, argument)); + } + + /** + * Returns whether this is a simple identifier - no arguments or output + */ + public boolean isIdentifier() { + return this.arguments.expressions().size() == 0 && output == null; + } + + public Reference withArguments(Arguments arguments) { + return new Reference(name, arguments, output); + } + + public Reference withOutput(String output) { + return new Reference(name, arguments, output); + } + + @Override + public boolean equals(Object o) { + if (o == this) return true; + if (!(o instanceof Reference)) return false; + Reference other = (Reference) o; + if (!Objects.equals(other.name, this.name)) return false; + if (!Objects.equals(other.arguments, this.arguments)) return false; + if (!Objects.equals(other.output, this.output)) return false; + return true; + } + + @Override + public int hashCode() { + return Objects.hash(name, arguments, output); + } + + @Override + public String toString() { + return toString(new SerializationContext(), null, null); + } + + public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { + StringBuilder b = new StringBuilder(name); + if (arguments != null && arguments.expressions().size() > 0) + b.append("(").append(arguments.expressions().stream() + .map(node -> node.toString(context, path, parent)) + .collect(Collectors.joining(","))).append(")"); + if (output != null) + b.append(".").append(output); + return b.toString(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java index 5f8daa69ecf..ee5952d9aea 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/ArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import java.util.Arrays; @@ -82,8 +83,8 @@ public class ArrayContext extends AbstractArrayContext implements Cloneable { } @Override - public TensorType getType(String name) { - Integer index = nameToIndex().get(name); + public TensorType getType(Reference reference) { + Integer index = nameToIndex().get(reference.toString()); if (index == null) return null; return values[index].type(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 861f9565d66..4e046df11ca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -1,9 +1,11 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -14,7 +16,7 @@ import java.util.stream.Collectors; * * @author bratseth */ -public abstract class Context implements EvaluationContext { +public abstract class Context implements EvaluationContext<Reference> { /** * Returns the value of a simple variable name. @@ -24,6 +26,11 @@ public abstract class Context implements EvaluationContext { */ public abstract Value get(String name); + @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + /** Returns a variable as a tensor */ @Override public Tensor getTensor(String name) { return get(name).asTensor(); } @@ -46,6 +53,7 @@ public abstract class Context implements EvaluationContext { * calculation to output several), or null to output the * "main" (or only) value. */ + // TODO: Remove/change to use reference? public Value get(String name, Arguments arguments, String output) { if (arguments != null && arguments.expressions().size() > 0) name = name + "(" + arguments.expressions().stream().map(ExpressionNode::toString).collect(Collectors.joining(",")) + ")"; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java index 0625e8506cc..0004036da4b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleOnlyArrayContext.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; /** @@ -68,7 +69,9 @@ public class DoubleOnlyArrayContext extends AbstractArrayContext { } @Override - public TensorType getType(String name) { return TensorType.empty; } + public TensorType getType(Reference reference) { + return TensorType.empty; // Double only + } /** Perform a slow lookup by name */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java index a81d0c89f8f..4ef24d60bba 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapContext.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; import java.util.Collections; @@ -15,7 +16,7 @@ import java.util.Set; */ public class MapContext extends Context { - private Map<String, Value> bindings = new HashMap<>(); + private Map<String, Value> bindings = new HashMap<>(); // TODO: Change String to Reference private boolean frozen = false; @@ -42,8 +43,8 @@ public class MapContext extends Context { /** Returns the type of the given value key, or null if it is not bound. */ @Override - public TensorType getType(String key) { - Value value = bindings.get(key); + public TensorType getType(Reference key) { + Value value = bindings.get(key.toString()); if (value == null) return null; return value.type(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java new file mode 100644 index 00000000000..2a42e2d92f7 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A context which only contains type information. + * + * @author bratseth + */ +public class MapTypeContext implements TypeContext<Reference> { + + private final Map<Reference, TensorType> featureTypes = new HashMap<>(); + + public void setType(Reference reference, TensorType type) { + featureTypes.put(reference, type); + } + + @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + + @Override + public TensorType getType(Reference reference) { + return featureTypes.get(reference); + } + + /** Returns an unmodifiable map of the bindings in this */ + public Map<Reference, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java deleted file mode 100644 index ff2088263d8..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeMapContext.java +++ /dev/null @@ -1,32 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.evaluation;// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.TypeContext; - -import java.util.Collections; -import java.util.HashMap; -import java.util.Map; - -/** - * A context which only contains type information. - * - * @author bratseth - */ -public class TypeMapContext implements TypeContext { - - private final Map<String, TensorType> featureTypes = new HashMap<>(); - - public void setType(String name, TensorType type) { - featureTypes.put(name, type); - } - - @Override - public TensorType getType(String name) { - return featureTypes.get(name); - } - - /** Returns an unmodifiable map of the bindings in this */ - public Map<String, TensorType> bindings() { return Collections.unmodifiableMap(featureTypes); } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java index 8ee4cdbf297..649c70122f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -26,7 +27,7 @@ public class GBDTForestNode extends ExpressionNode { } @Override - public final TensorType type(TypeContext context) { return TensorType.empty; } + public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index aac635b2545..53a286f09f6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -51,7 +52,7 @@ public final class GBDTNode extends ExpressionNode { public final double[] values() { return values; } @Override - public final TensorType type(TypeContext context) { return TensorType.empty; } + public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java deleted file mode 100644 index 5f0c016881a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.TensorProto; -import org.tensorflow.framework.TensorShapeProto; - -/** - * @author lesters - */ -public class AttrValueConverter { - - public static Tensor toVespaTensor(NodeDef tfNode, String attr) { - if (!tfNode.getAttrMap().containsKey(attr)) { - throw new IllegalArgumentException(tfNode.getName() + " has no attribute called " + attr); - } - AttrValue attrValue = tfNode.getAttrMap().get(attr); - switch (attrValue.getValueCase()) { - case TENSOR: - return buildFromTensor(attrValue); - case B: - return buildFromSingleValue(attrValue.getB() ? 1.0 : 0.0); - case F: - return buildFromSingleValue(attrValue.getF()); - case I: - return buildFromSingleValue(attrValue.getI()); - } - - throw new IllegalArgumentException(tfNode.getName() + - ": unsupported attribute type: '" + attrValue.getValueCase().toString() + "'"); - } - - private static Tensor buildFromSingleValue(double value) { - TensorType type = new TensorType.Builder().build(); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - builder.cellByDirectIndex(0, value); - return builder.build(); - } - - private static Tensor buildFromTensor(AttrValue attrValue) { - TensorProto tensorProto = attrValue.getTensor(); - TensorType type = toVespaTensorType(tensorProto.getTensorShape()); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - Values values = valuesOf(tensorProto); - for (int i = 0; i < values.size(); ++i) { - builder.cellByDirectIndex(i, values.get(i)); - } - Tensor tensor = builder.build(); - return tensor; - } - - private static Values valuesOf(TensorProto tensorProto) { - switch (tensorProto.getDtype()) { - case DT_BOOL: - return new BoolValues(tensorProto); - case DT_HALF: - return new HalfValues(tensorProto); - case DT_INT16: - case DT_INT32: - return new IntValues(tensorProto); - case DT_INT64: - return new Int64Values(tensorProto); - case DT_FLOAT: - return new FloatValues(tensorProto); - case DT_DOUBLE: - return new DoubleValues(tensorProto); - } - - throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); - } - - public static TensorType toVespaTensorType(TensorShapeProto shapeProto) { - TensorType.Builder b = new TensorType.Builder(); - for (TensorShapeProto.Dim dimension : shapeProto.getDimList()) { - int dimensionSize = (int)dimension.getSize(); - if (dimensionSize >= 0) - b.indexed("d" + b.rank(), dimensionSize); - else - b.indexed("d" + b.rank()); // unbound size - } - return b.build(); - } - - private static abstract class Values { - protected final TensorProto tensorProto; - protected Values(TensorProto tensorProto) { this.tensorProto = tensorProto; } - abstract double get(int i); - abstract int size(); - } - - private static class BoolValues extends Values { - BoolValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; } - @Override int size() { return tensorProto.getBoolValCount(); } - } - - private static class HalfValues extends Values { - HalfValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getHalfVal(i); } - @Override int size() { return tensorProto.getHalfValCount(); } - } - - private static class IntValues extends Values { - IntValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getIntVal(i); } - @Override int size() { return tensorProto.getIntValCount(); } - } - - private static class Int64Values extends Values { - Int64Values(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getInt64Val(i); } - @Override int size() { return tensorProto.getInt64ValCount(); } - } - - private static class FloatValues extends Values { - FloatValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getFloatVal(i); } - @Override int size() { return tensorProto.getFloatValCount(); } - } - - private static class DoubleValues extends Values { - DoubleValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getDoubleVal(i); } - @Override int size() { return tensorProto.getDoubleValCount(); } - } - - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java deleted file mode 100644 index ef82045e771..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ /dev/null @@ -1,715 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Matmul; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.Softmax; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.Session; -import org.tensorflow.framework.AttrValue; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Optional; -import java.util.function.DoubleBinaryOperator; -import java.util.function.DoubleUnaryOperator; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -/** - * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions. - * - * @author bratseth - * @author lesters - */ -class OperationMapper { - - // A note on conversion from implicitly numbered to explicitly named dimensions: - // - // Vespa tensor dimensions are explicitly named and thus have an explicit notion of being - // 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation - // comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation - // around dimension renaming operations which mirrors those built into the TF operation definitions. - // - // To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' - // dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation - // and the result is then renamed again (if necessary) to recover this convention across a full nested - // computation. - // - // This requires us to track tensor types throughout the conversion. - - - // Supported TensorFlow operations - enum Operation { - - // TODO: move the implementations to specific files as we support more operations - - /* - * array ops - */ - CONST (OperationMapper::constant), - EXPANDDIMS (OperationMapper::expandDims), - IDENTITY (OperationMapper::identity), - PLACEHOLDER (OperationMapper::placeholder), - PLACEHOLDERWITHDEFAULT (OperationMapper::placeholderWithDefault), - RESHAPE (OperationMapper::reshape), - SQUEEZE (OperationMapper::squeeze), - - /* - * control flow - */ - MERGE (OperationMapper::merge), - SWITCH (OperationMapper::switchOp), - - /* - * math ops - */ - ADD (OperationMapper::add), - ADD_N (OperationMapper::add), - ACOS (OperationMapper::acos), - DIV (OperationMapper::div), - REALDIV (OperationMapper::div), - FLOOR (OperationMapper::floor), - MATMUL (OperationMapper::matmul), - MAXIMUM (OperationMapper::maximum), - MEAN (OperationMapper::mean), - REDUCEMEAN (OperationMapper::mean), - MUL (OperationMapper::mul), - MULTIPLY (OperationMapper::mul), - RSQRT (OperationMapper::rsqrt), - SELECT (OperationMapper::select), - WHERE3 (OperationMapper::select), - SIGMOID (OperationMapper::sigmoid), - SQUAREDDIFFERENCE (OperationMapper::squaredDifference), - SUB (OperationMapper::sub), - SUBTRACT (OperationMapper::sub), - - /* - * nn ops - */ - BIASADD (OperationMapper::add), - ELU (OperationMapper::elu), - RELU (OperationMapper::relu), - SELU (OperationMapper::selu), - SOFTMAX (OperationMapper::softMax), - - /* - * state ops - */ - VARIABLE (OperationMapper::variable), - VARIABLEV2 (OperationMapper::variable), - - /* - * evaluation no-ops - */ - STOPGRADIENT (OperationMapper::identity), - NOOP (OperationMapper::noOp); - - - private final Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func; - - Operation(Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func) { - this.func = func; - } - - Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) { - return func.apply(params); - } - - } - - static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) { - Optional<Operation> operation = Stream.of(Operation.values()) - .filter(op -> op.name().equalsIgnoreCase(params.node().getOp())) - .findFirst(); - if (operation.isPresent()) { - return operation.get().map(params); - } - params.signature().importWarning("TensorFlow operation '" + params.node().getOp() + - "' in node '" + params.node().getName() + "' is not supported."); - return Optional.empty(); - } - - - // Operations --------------------------------- - - private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) { - Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value"); - if (value.type().rank() == 0) { - TypedTensorFunction output = new TypedTensorFunction(value.type(), - new TensorFunctionNode.TensorFunctionExpressionNode( - new ConstantNode(new DoubleValue(value.asDouble())))); - return Optional.of(output); - } - return createConstant(params, value); - } - - private static Optional<TypedTensorFunction> expandDims(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - - Tensor axis = getConstantTensor(params, params.node().getInput(1)); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar"); - } - - TensorFunction inputFunction = inputs.get(0).get().function(); - TensorType inputType = inputs.get(0).get().type(); - - int dimensionToInsert = (int)axis.asDouble(); - if (dimensionToInsert < 0) { - dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; - } - - TensorType.Builder outputTypeBuilder = new TensorType.Builder(); - int dimensionIndex = 0; - for (int i = 0; i < inputType.dimensions().size() + 1; ++i) { - String name = String.format("temp_%d", i); - Long size; - if (i == dimensionToInsert) { - size = 1L; - } else { - size = dimensionSize(inputType.dimensions().get(dimensionIndex)); - dimensionIndex++; - } - outputTypeBuilder.indexed(name, size); - } - - return reshape(inputFunction, inputType, outputTypeBuilder.build()); - } - - private static Optional<TypedTensorFunction> identity(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 1)) { - return Optional.empty(); - } - return params.inputs().get(0); - } - - private static Optional<TypedTensorFunction> placeholder(TensorFlowImporter.Parameters params) { - String name = params.node().getName(); - String vespaName = toVespaName(params.node().getName()); - TensorType type = params.result().arguments().get(name); - if (type == null) { - throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + - "', but there is no such placeholder"); - } - params.result().requiredMacro(vespaName, type); - // Included literally in the expression and so must be produced by a separate macro in the rank profile - TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(vespaName, type)); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) { - String name = toVespaName(params.node().getInput(0)); - Tensor defaultValue = getConstantTensor(params, params.node().getInput(0)); - params.result().largeConstant(name, defaultValue); - params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); - // The default value will be provided by the macro. Users can override macro to change value. - TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) { - if ( ! checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - Tensor shape = getConstantTensor(params, params.node().getInput(1)); - - TensorFunction inputFunction = inputs.get(0).get().function(); - TensorType inputType = inputs.get(0).get().type(); - - TensorType.Builder outputTypeBuilder = new TensorType.Builder(); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); - if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue(); - } - outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size); - dimensionIndex++; - } - return reshape(inputFunction, inputType, outputTypeBuilder.build()); - } - - private static Optional<TypedTensorFunction> squeeze(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 1)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - - TensorFunction inputFunction = inputs.get(0).get().function(); - TensorType inputType = inputs.get(0).get().type(); - List<String> squeezeDimensions; - - AttrValue squeezeDimsAttr = params.node().getAttrMap().get("squeeze_dims"); - if (squeezeDimsAttr == null) { - squeezeDimensions = inputType.dimensions().stream(). - filter(dim -> dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); - } else { - squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). - map(i -> i < 0 ? inputType.dimensions().size() - i : i). - map(i -> inputType.dimensions().get(i.intValue())). - filter(dim -> dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); - } - - if (squeezeDimensions.isEmpty()) { - return inputs.get(0); - } - - TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); - TensorType outputType = Reduce.outputType(inputType, squeezeDimensions); - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> merge(TensorFlowImporter.Parameters params) { - return params.inputs().stream() - .filter(Optional::isPresent) - .findFirst() - .orElse(Optional.empty()); - } - - private static Optional<TypedTensorFunction> switchOp(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - Tensor predicate = getConstantTensor(params, params.node().getInput(1)); - if (predicate.type().rank() != 0) { - throw new IllegalArgumentException("'switch': predicate must be a scalar"); - } - double pred = predicate.asDouble(); - int output = params.port().length() > 0 ? Integer.parseInt(params.port()) : 0; - if (output < 0 || output > 1) { - throw new IllegalArgumentException("'switch': predicate is not boolean"); - } - if (pred == output) { - return inputs.get(0); - } - return Optional.empty(); - } - - private static Optional<TypedTensorFunction> add(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.add()); - } - - private static Optional<TypedTensorFunction> acos(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.acos()); - } - - private static Optional<TypedTensorFunction> div(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.divide()); - } - - private static Optional<TypedTensorFunction> floor(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.floor()); - } - - private static Optional<TypedTensorFunction> matmul(TensorFlowImporter.Parameters params) { - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - - TypedTensorFunction a = inputs.get(0).get(); - TypedTensorFunction b = inputs.get(1).get(); - if (a.type().rank() < 2 || b.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (a.type().rank() != b.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - String afterLastDim = "d" + (a.type().rank() + 1); - // Let the first dimension of the second tensor be the same as the second dimension of the first - // and the second dimension of the second argument be not present in the first argument, while leaving the - // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. - - // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly - - Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"), - ImmutableList.of("d1", afterLastDim)); - Matmul matmul = new Matmul(a.function(), renamedB, "d1"); - TypedTensorFunction output = new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), - new Rename(matmul, afterLastDim, "d1")); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> maximum(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.max()); - } - - private static Optional<TypedTensorFunction> mean(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TensorFunction inputFunction = inputs.get(0).get().function(); - TensorType inputType = inputs.get(0).get().type(); - - Tensor reductionIndices = getConstantTensor(params, params.node().getInput(1)); - List<String> reduceDimensions = new ArrayList<>(); - for (Iterator<Tensor.Cell> cellIterator = reductionIndices.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int dimensionIndex = cell.getValue().intValue(); - if (dimensionIndex < 0) { - dimensionIndex = inputType.dimensions().size() - dimensionIndex; - } - reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); - } - - TensorType outputType = Reduce.outputType(inputType, reduceDimensions); - TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); - - if (shouldKeepDimensions(params)) { - return reshape(outputFunction, outputType, keepDimensionType(inputType, reduceDimensions)); - } - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> mul(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.multiply()); - } - - private static Optional<TypedTensorFunction> rsqrt(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.rsqrt()); - } - - private static Optional<TypedTensorFunction> select(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 3)) { - return Optional.empty(); - } - Tensor condition = getConstantTensor(params, params.node().getInput(0)); - - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TypedTensorFunction x = inputs.get(1).get(); - TypedTensorFunction y = inputs.get(2).get(); - if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) { - throw new IllegalArgumentException("'Select': input tensors must have the same shape"); - } - - if (condition.type().rank() == 0) { - return Optional.of((int)condition.asDouble() == 0 ? y : x); - } - if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { - return Optional.of(condition.cellIterator().next().getValue().intValue() == 0 ? y : x); - } - - // The task is to select cells from 'x' or 'y' based on 'condition'. - // If 'condition' is 0 (false), select from 'y', if 1 (true) select - // from 'x'. We do this by individually joining 'x' and 'y' with - // 'condition', and then joining the resulting two tensors. - - Optional<TypedTensorFunction> conditionFunction = importConstantTensor(params, params.node().getInput(0)); - if (!conditionFunction.isPresent()) { - return Optional.empty(); - } - TensorFunction xCond = new Join(x.function(), conditionFunction.get().function(), ScalarFunctions.multiply()); - TensorFunction yCond = new Join(y.function(), conditionFunction.get().function(), new DoubleBinaryOperator() { - @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } - @Override public String toString() { return "f(a,b)(a * (1-b))"; } - }); - TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add()); - TypedTensorFunction output = new TypedTensorFunction(x.type(), outputFunction); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> sigmoid(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.sigmoid()); - } - - private static Optional<TypedTensorFunction> squaredDifference(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.squareddifference()); - } - - private static Optional<TypedTensorFunction> sub(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.subtract()); - } - - private static Optional<TypedTensorFunction> elu(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.elu()); - } - - private static Optional<TypedTensorFunction> relu(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.relu()); - } - - private static Optional<TypedTensorFunction> selu(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.selu()); - } - - private static Optional<TypedTensorFunction> softMax(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 1)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TypedTensorFunction a = inputs.get(0).get(); - // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 - String dimension = "d" + (a.type().rank() - 1); - Softmax softmax = new Softmax(a.function(), dimension); - TypedTensorFunction output = new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> variable(TensorFlowImporter.Parameters params) { - return importConstantTensor(params, params.node().getName()); - } - - private static Optional<TypedTensorFunction> noOp(TensorFlowImporter.Parameters params) { - return Optional.empty(); - } - - /* - * Utility - */ - - private static Optional<TypedTensorFunction> join(TensorFlowImporter.Parameters params, DoubleBinaryOperator doubleFunction) { - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - - TypedTensorFunction a = inputs.get(0).get(); - TypedTensorFunction b = inputs.get(1).get(); - - if (a.type().rank() == 0 && b.type().rank() > 0) { - return Optional.of(new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction))); - } - if (b.type().rank() == 0 && a.type().rank() > 0) { - return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction))); - } - if (a.type().rank() == b.type().rank()) { - return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction))); - } - - // Well now we have entered the wonderful world of "broadcasting" - // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html - // I'm not able to extract from that any unambiguous specification of which dimensions - // should be "stretched" when the tensor do not have the same number of dimensions. - // From trying this with TensorFlow it appears that the second tensor is matched to the - // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. - // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). - - if (a.type().rank() > b.type().rank()) { - TensorFunction renameFunction = renameForBroadcast(a, b); - return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction))); - } - TensorFunction renameFunction = renameForBroadcast(b, a); - return Optional.of(new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction))); - } - - private static TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) { - List<String> renameFrom = new ArrayList<>(); - List<String> renameTo = new ArrayList<>(); - int sizeDifference = a.type().rank() - b.type().rank(); - for (int i = 0; i < b.type().rank(); i++) { - renameFrom.add(b.type().dimensions().get(i).name()); - renameTo.add("d" + (sizeDifference + i)); - } - return new Rename(b.function(), renameFrom, renameTo); - } - - private static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params, DoubleUnaryOperator doubleFunction) { - if (!checkInputs(params, 1)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TypedTensorFunction a = inputs.get(0).get(); - TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); - com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); - return Optional.of(new TypedTensorFunction(resultType, function)); - } - - private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) { - String name = toVespaName(params.node().getName()); - if (constant.type().rank() == 0 || constant.size() <= 1) { - params.result().smallConstant(name, constant); - } else { - params.result().largeConstant(name, constant); - } - TypedTensorFunction output = new TypedTensorFunction(constant.type(), - new TensorFunctionNode.TensorFunctionExpressionNode( - new ReferenceNode("constant(\"" + name + "\")"))); - return Optional.of(output); - } - - private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) { - String vespaName = toVespaName(name); - if (params.result().smallConstants().containsKey(vespaName)) { - return params.result().smallConstants().get(vespaName); - } - if (params.result().largeConstants().containsKey(vespaName)) { - return params.result().largeConstants().get(vespaName); - } - Session.Runner fetched = params.model().session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " + - importedTensors.size()); - return TensorConverter.toVespaTensor(importedTensors.get(0)); - } - - private static Optional<TypedTensorFunction> importConstantTensor(TensorFlowImporter.Parameters params, String name) { - AttrValue shapes = params.node().getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("'" + name + "' is missing a tensor shape"); - Tensor constant = getConstantTensor(params, name); - return createConstant(params, constant); - } - - private static Optional<TypedTensorFunction> reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!tensorSize(inputType).equals(tensorSize(outputType))) { - throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); - } - - // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, - // then use the dimension order of the new shape to roll back into a tensor. - // Here we create a transformation tensor that is multiplied with the from tensor to map into - // the new shape. We have to introduce temporary dimension names and rename back if dimension names - // in the new and old tensor type overlap. - - ExpressionNode unrollFrom = unrollTensorExpression(inputType); - ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo); - - TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); - Generate transformTensor = new Generate(transformationType, - new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - - TensorFunction outputFunction = new Reduce( - new Join(inputFunction, transformTensor, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return Optional.of(output); - } - - private static ExpressionNode unrollTensorExpression(TensorType type) { - if (type.rank() == 0) { - return new ConstantNode(DoubleValue.zero); - } - List<ExpressionNode> children = new ArrayList<>(); - List<ArithmeticOperator> operators = new ArrayList<>(); - int size = 1; - for (int i = type.dimensions().size() - 1; i >= 0; --i) { - TensorType.Dimension dimension = type.dimensions().get(i); - children.add(0, new ReferenceNode(dimension.name())); - if (size > 1) { - operators.add(0, ArithmeticOperator.MULTIPLY); - children.add(0, new ConstantNode(new DoubleValue(size))); - } - size *= dimensionSize(dimension); - if (i > 0) { - operators.add(0, ArithmeticOperator.PLUS); - } - } - return new ArithmeticNode(children, operators); - } - - private static boolean shouldKeepDimensions(TensorFlowImporter.Parameters params) { - AttrValue keepDimsAttr = params.node().getAttrMap().get("keep_dims"); - return keepDimsAttr != null && keepDimsAttr.getB(); - } - - private static TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) { - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension: inputType.dimensions()) { - String name = dimension.name(); - Long size = dimensionSize(dimension); - if (reduceDimensions.contains(name)) { - size = 1L; - } - builder.indexed(name, size); - } - return builder.build(); - } - - private static TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) { - for (int i = 0; i < type.dimensions().size(); ++i) { - String correct = String.format("d%d", i); - String current = type.dimensions().get(i).name(); - if (!current.equals(correct)) { - return fixNamingConvention(type, function); - } - } - return new TypedTensorFunction(type, function); - } - - private static TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) { - TensorType.Builder correctType = new TensorType.Builder(); - List<String> from = new ArrayList<>(); - List<String> to = new ArrayList<>(); - for (int i = 0; i < type.dimensions().size(); ++i) { - String correct = String.format("d%d", i); - String current = type.dimensions().get(i).name(); - if (!current.equals(correct)) { - from.add(current); - to.add(correct); - } - correctType.indexed(correct, dimensionSize(type.dimensions().get(i))); - } - if (from.size() > 0) { - function = new Rename(function, from, to); - type = correctType.build(); - } - return new TypedTensorFunction(type, function); - } - - private static Long tensorSize(TensorType type) { - Long size = 1L; - for (TensorType.Dimension dimension : type.dimensions()) { - size *= dimensionSize(dimension); - } - return size; - } - - private static Long dimensionSize(TensorType.Dimension dim) { - return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); - } - - private static boolean checkInputs(TensorFlowImporter.Parameters params, int expected) { - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - if (!inputs.stream().allMatch(Optional::isPresent)) { - return false; - } - if (inputs.size() != expected) { - params.signature().importWarning("Expected " + expected + - " arguments to " + params.node().getOp() + ", but got " + inputs.size()); - return false; - } - return true; - } - - public static String toVespaName(String name) { - return name != null ? name.replace('/', '_') : null; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java deleted file mode 100644 index b88ffce275a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; - - -/** - * Converts TensorFlow tensors into Vespa tensors. - * - * @author bratseth - */ -public class TensorConverter { - - public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { - TensorType type = toVespaTensorType(tfTensor.shape()); - Values values = readValuesOf(tfTensor); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - for (int i = 0; i < values.size(); i++) - builder.cellByDirectIndex(i, values.get(i)); - return builder.build(); - } - - private static TensorType toVespaTensorType(long[] shape) { - TensorType.Builder b = new TensorType.Builder(); - int dimensionIndex = 0; - for (long dimensionSize : shape) { - if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... - b.indexed("d" + (dimensionIndex++), dimensionSize); - } - return b.build(); - } - - private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { - switch (tfTensor.dataType()) { - case DOUBLE: return new DoubleValues(tfTensor); - case FLOAT: return new FloatValues(tfTensor); - case BOOL: return new BoolValues(tfTensor); - case UINT8: return new IntValues(tfTensor); - case INT32: return new IntValues(tfTensor); - case INT64: return new LongValues(tfTensor); - default: - throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); - } - } - - /** Allows reading values from buffers of various numeric types as bytes */ - private static abstract class Values { - - private final int size; - - protected Values(int size) { - this.size = size; - } - - abstract double get(int i); - - int size() { return size; } - - } - - private static class DoubleValues extends Values { - - private final DoubleBuffer values; - - DoubleValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = DoubleBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class FloatValues extends Values { - - private final FloatBuffer values; - - FloatValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = FloatBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class BoolValues extends Values { - - private final ByteBuffer values; - - BoolValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = ByteBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class IntValues extends Values { - - private final IntBuffer values; - - IntValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = IntBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class LongValues extends Values { - - private final LongBuffer values; - - LongValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = LongBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index c97ee2b1514..7116d430502 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -2,10 +2,20 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; @@ -24,6 +34,7 @@ import java.util.stream.Collectors; * Converts a saved TensorFlow model into a ranking expression and set of constants. * * @author bratseth + * @author lesters */ public class TensorFlowImporter { @@ -57,196 +68,303 @@ public class TensorFlowImporter { } } - private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) { - TensorFlowModel result = new TensorFlowModel(); + /** + * Imports the TensorFlow graph by first importing the tensor types, then + * finding a suitable set of dimensions names for each + * placeholder/constant/variable, then importing the expressions. + */ + private static TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle bundle) { + TensorFlowModel model = new TensorFlowModel(); + OperationIndex index = new OperationIndex(); + + importSignatures(graph, model); + importNodes(graph, model, index); + findDimensionNames(model, index); + importExpressions(model, index, bundle); + + // nodes with multiple outputs are calculated multiple times. consider adding macros for those. + + reportWarnings(model, index); + + return model; + } + + private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) { for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" + String signatureName = signatureEntry.getKey(); + TensorFlowModel.Signature signature = model.signature(signatureName); + + Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); + for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { + String inputName = input.getKey(); + signature.input(inputName, namePartOf(input.getValue().getName())); + } - importInputs(signatureEntry.getValue().getInputsMap(), signature); - for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { + Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); + for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { String outputName = output.getKey(); - try { - NodeDef node = getNode(namePartOf(output.getValue().getName()), graph.getGraphDef()); - Parameters params = createParameters(graph.getGraphDef(), model, result, signature, node, ""); - - // Commonly, there are multiple paths through a TensorFlow graph, for instance for - // training and testing/evaluation. Examples are dropout and batch norm. For Vespa - // we are not concerned with training paths, so we can ignore non-supported operations - // as long as they are on a path that will not be evaluated run time. Operations - // that fail import will not have a value present in the optionals. However, the - // final output node must have value present. It is an error if it does not. - - Optional<TypedTensorFunction> outputFunction = importNode(params); - if (!outputFunction.isPresent()) { - throw new IllegalArgumentException(signature.importWarnings().stream().collect(Collectors.joining("\n"))); - } - signature.output(outputName, namePartOf(output.getValue().getName())); - } - catch (IllegalArgumentException e) { - signature.skippedOutput(outputName, Exceptions.toMessageString(e)); - } + signature.output(outputName, namePartOf(output.getValue().getName())); } } - return result; } - private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) { - inputInfoMap.forEach((key, value) -> { - String argumentName = namePartOf(value.getName()); - TensorType argumentType = AttrValueConverter.toVespaTensorType(value.getTensorShape()); - // Arguments are (Placeholder) nodes, so not local to the signature: - signature.owner().argument(argumentName, argumentType); - signature.input(key, argumentName); - }); + private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String inputName : signature.inputs().values()) { + if (inputName.equals(operation.node().getName())) { + return true; + } + } + } + return false; } - /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private Optional<TypedTensorFunction> importNode(Parameters params) { - String nodeName = params.node().getName(); - if (params.imported().containsKey(nodeName)) { - return Optional.of(params.imported().get(nodeName)); + private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + if (outputName.equals(operation.node().getName())) { + return true; + } + } } + return false; + } - Optional<TypedTensorFunction> function = OperationMapper.map(params); - if ( ! function.isPresent()) { - return Optional.empty(); - } - if ( ! controlDependenciesArePresent(params)) { - return Optional.empty(); + private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + importNode(outputName, graph.getGraphDef(), index); + } } - params.imported().put(nodeName, function.get()); + } - try { - // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output - // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree - params.result().expression(nodeName, - new RankingExpression(nodeName, function.get().function().toString())); - return function; + private static TensorFlowOperation importNode(String name, GraphDef graph, OperationIndex index) { + if (index.alreadyImported(name)) { + return index.get(name); } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function.get().function() + - " cannot be parsed as a ranking expression", e); + NodeDef node = getTensorFlowNodeFromGraph(namePartOf(name), graph); + List<TensorFlowOperation> inputs = importNodeInputs(node, graph, index); + TensorFlowOperation operation = OperationMapper.get(node, inputs, portPartOf(name)); + index.put(name, operation); + + List<TensorFlowOperation> controlInputs = importControlInputs(node, graph, index); + if (controlInputs.size() > 0) { + operation.setControlInputs(controlInputs); } - } - private boolean controlDependenciesArePresent(Parameters params) { - return params.node().getInputList().stream() - .filter(TensorFlowImporter::isControlDependency) - .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName)))) - .allMatch(Optional::isPresent); + return operation; } - private static boolean isControlDependency(String nodeName) { - return nodeName.startsWith("^"); + private static List<TensorFlowOperation> importNodeInputs(NodeDef node, GraphDef graph, OperationIndex index) { + return node.getInputList().stream() + .filter(name -> ! isControlDependency(name)) + .map(name -> importNode(name, graph, index)) + .collect(Collectors.toList()); } - private List<Optional<TypedTensorFunction>> importArguments(Parameters params) { - return params.node().getInputList().stream() - .filter(nodeName -> !isControlDependency(nodeName)) - .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName)))) + private static List<TensorFlowOperation> importControlInputs(NodeDef node, GraphDef graph, OperationIndex index) { + return node.getInputList().stream() + .filter(name -> isControlDependency(name)) + .map(name -> importNode(name, graph, index)) .collect(Collectors.toList()); } - private NodeDef getNode(String name, GraphDef graph) { - return graph.getNodeList().stream() - .filter(node -> node.getName().equals(name)) - .findFirst() - .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'")); + private static boolean isControlDependency(String name) { + return name.startsWith("^"); } - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - private static String namePartOf(String name) { - name = name.startsWith("^") ? name.substring(1) : name; - return name.split(":")[0]; + /** Find dimension names to avoid excessive renaming while evaluating the model. */ + private static void findDimensionNames(TensorFlowModel model, OperationIndex index) { + DimensionRenamer renamer = new DimensionRenamer(); + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + addDimensionNameConstraints(index.get(output), renamer); + } + } + renamer.solve(); + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + renameDimensions(index.get(output), renamer); + } + } } - /** - * This return the index part. Indexes are used for nodes with - * multiple outputs. - */ - private static String indexPartOf(String name) { - int i = name.indexOf(":"); - return i < 0 ? "" : name.substring(i + 1); + private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.addDimensionNameConstraints(renamer); + } } + private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.renameDimensions(renamer); + } + } - private Parameters createParameters(GraphDef graph, - SavedModelBundle model, - TensorFlowModel result, - TensorFlowModel.Signature signature, - NodeDef node, - String port) { - return new Parameters(this, graph, model, result, signature, new HashMap<>(), node, port); + private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + try { + Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle); + if (!function.isPresent()) { + signature.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); + } + } + } } - /** Parameter object to hold important data while importing */ - static final class Parameters { + private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) { + if (!operation.type().isPresent()) { + return Optional.empty(); + } + if (operation.isConstant()) { + return importConstant(model, operation, bundle); + } - private final TensorFlowImporter owner; - private final GraphDef graph; - private final SavedModelBundle model; - private final TensorFlowModel result; - private final TensorFlowModel.Signature signature; - private final Map<String, TypedTensorFunction> imported; - private final NodeDef node; - private final String port; + importInputExpressions(operation, model, bundle); + importRankingExpression(model, operation); + importInputExpression(model, operation); + importMacroExpression(model, operation); - private Parameters(TensorFlowImporter owner, - GraphDef graph, - SavedModelBundle model, - TensorFlowModel result, - TensorFlowModel.Signature signature, - Map<String, TypedTensorFunction> imported, - NodeDef node, - String port) { - this.owner = owner; - this.graph = graph; - this.model = model; - this.result = result; - this.signature = signature; - this.imported = imported; - this.node = node; - this.port = port; - } + return operation.function(); + } - GraphDef graph() { - return this.graph; + private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) { + operation.inputs().forEach(input -> importExpression(input, model, bundle)); + } + + private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.macro().isPresent()) { + model.macro(operation.vespaName(), operation.macro().get()); } + } - SavedModelBundle model() { - return this.model; + private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) { + String name = operation.vespaName(); + if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + return operation.function(); } - TensorFlowModel result() { - return this.result; + Tensor tensor; + if (operation.getConstantValue().isPresent()) { + Value value = operation.getConstantValue().get(); + if ( ! (value instanceof TensorValue)) { + return operation.function(); // scalar values are inserted directly into the expression + } + tensor = value.asTensor(); + } else { + Session.Runner fetched = bundle.session().runner().fetch(operation.node().getName()); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if (importedTensors.size() != 1) { + throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " + + importedTensors.size()); + } + // Here we use the type from the operation, which will have correct dimension names after name resolving + tensor = TensorConverter.toVespaTensor(importedTensors.get(0), operation.type().get()); + operation.setConstantValue(new TensorValue(tensor)); } - TensorFlowModel.Signature signature() { - return this.signature; + if (tensor.type().rank() == 0 || tensor.size() <= 1) { + model.smallConstant(operation.vespaName(), tensor); + } else { + model.largeConstant(operation.vespaName(), tensor); } + return operation.function(); + } + + private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.function().isPresent()) { + String name = operation.node().getName(); + if (!model.expressions().containsKey(operation.node().getName())) { + TensorFunction function = operation.function().get(); + + // Make sure output adheres to standard naming convention + if (isSignatureOutput(model, operation)) { + OrderedTensorType operationType = operation.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node()); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + function = new Rename(function, renameFrom, renameTo); + } + } - Map<String, TypedTensorFunction> imported() { - return this.imported; + try { + // We add all intermediate nodes imported as separate expressions. Only + // those referenced in a signature output will be used. We parse the + // TensorFunction here to convert it to a RankingExpression tree. + model.expression(name, new RankingExpression(name, function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } + } } + } - NodeDef node() { - return node; + private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.isInput() && isSignatureInput(model, operation)) { + // All inputs must have dimensions with standard naming convention: d0, d1, ... + OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node()); + model.argument(operation.node().getName(), standardNamingConvention.type()); + model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); } + } - String port() { - return port; + private static void reportWarnings(TensorFlowModel model, OperationIndex index) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + reportWarnings(index.get(output), signature); + } } + } - Parameters copy(NodeDef node, String port) { - return new Parameters(this.owner, this.graph, this.model, this.result, this.signature, this.imported, node, port); + private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) { + for (String warning : operation.warnings()) { + signature.importWarning(warning); } + } - List<Optional<TypedTensorFunction>> inputs() { - return owner.importArguments(this); + private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) { + for (NodeDef node : graph.getNodeList()) { + if (node.getName().equals(name)) { + return node; + } } + throw new IllegalArgumentException("Could not find node '" + name + "'"); + } + + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + private static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output port part. Indexes are used for nodes with + * multiple outputs. + */ + private static int portPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + + + private static class OperationIndex { + private final Map<String, TensorFlowOperation> index = new HashMap<>(); + public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); } + public TensorFlowOperation get(String key) { return index.get(key); } + public boolean alreadyImported(String key) { return index.containsKey(key); } } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java deleted file mode 100644 index 600225bfe76..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; - -/** - * A tensor function returning a specific tensor type - * - * @author bratseth - */ -final class TypedTensorFunction { - - private final TensorType type; - private final TensorFunction function; - - public TypedTensorFunction(TensorType type, TensorFunction function) { - this.type = type; - this.function = function; - } - - public TensorType type() { - return type; - } - - public TensorFunction function() { - return function; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java new file mode 100644 index 00000000000..c1665d066a4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java @@ -0,0 +1,210 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * A constraint satisfier to find suitable dimension names to reduce the + * amount of necessary renaming during evaluation of an imported model. + * + * @author lesters + */ +public class DimensionRenamer { + + private final String dimensionPrefix; + private final Map<String, List<Integer>> variables = new HashMap<>(); + private final Map<Arc, Constraint> constraints = new HashMap<>(); + private final Map<String, Integer> renames = new HashMap<>(); + + private int iterations = 0; + + public DimensionRenamer() { + this("d"); + } + + public DimensionRenamer(String dimensionPrefix) { + this.dimensionPrefix = dimensionPrefix; + } + + /** + * Add a dimension name variable. + */ + public void addDimension(String name) { + variables.computeIfAbsent(name, d -> new ArrayList<>()); + } + + /** + * Add a constraint between dimension names. + */ + public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) { + Arc arc = new Arc(from, to, operation); + Arc opposite = arc.opposite(); + constraints.put(arc, pred); + constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + } + + /** + * Retrieve resulting name of dimension after solving for constraints. + */ + public Optional<String> dimensionNameOf(String name) { + if (!renames.containsKey(name)) { + return Optional.empty(); + } + return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); + } + + /** + * Perform iterative arc consistency until we have found a solution. After + * an initial iteration, the variables (dimensions) will have multiple + * valid values. Find a single valid assignment by iteratively locking one + * dimension after another, and running the arc consistency algorithm + * multiple times. + * + * This requires having constraints that result in an absolute ordering: + * equals, lesserThan and greaterThan do that, but adding notEquals does + * not typically result in a guaranteed ordering. If that is needed, the + * algorithm below needs to be adapted with a backtracking (tree) search + * to find solutions. + */ + public void solve(int maxIterations) { + initialize(); + + // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + + for (String dimension : variables.keySet()) { + List<Integer> values = variables.get(dimension); + if (values.size() > 1) { + if (!ac3()) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution."); + } + values.sort(Integer::compare); + variables.put(dimension, Collections.singletonList(values.get(0))); + } + renames.put(dimension, variables.get(dimension).get(0)); + if (iterations > maxIterations) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + + maxIterations + " iterations"); + } + } + + // Todo: handle failure more gracefully: + // If a solution can't be found, look at the operation node in the arc + // with the most remaining constraints, and inject a rename operation. + // Then run this algorithm again. + } + + public void solve() { + solve(100000); + } + + private void initialize() { + for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { + List<Integer> values = variable.getValue(); + for (int i = 0; i < variables.size(); ++i) { + values.add(i); // invariant: values are in increasing order + } + } + } + + private boolean ac3() { + Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); + while (!workList.isEmpty()) { + Arc arc = workList.pop(); + iterations += 1; + if (revise(arc)) { + if (variables.get(arc.from).size() == 0) { + return false; // no solution found + } + for (Arc constraint : constraints.keySet()) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { + workList.add(constraint); + } + } + } + } + return true; + } + + private boolean revise(Arc arc) { + boolean revised = false; + for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { + Integer from = fromIterator.next(); + boolean satisfied = false; + for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { + Integer to = toIterator.next(); + if (constraints.get(arc).test(from, to)) { + satisfied = true; + } + } + if (!satisfied) { + fromIterator.remove(); + revised = true; + } + } + return revised; + } + + public interface Constraint { + boolean test(Integer x, Integer y); + } + + public static boolean equals(Integer x, Integer y) { + return Objects.equals(x, y); + } + + public static boolean lesserThan(Integer x, Integer y) { + return x < y; + } + + public static boolean greaterThan(Integer x, Integer y) { + return x > y; + } + + private static class Arc { + + private final String from; + private final String to; + private final TensorFlowOperation operation; + + Arc(String from, String to, TensorFlowOperation operation) { + this.from = from; + this.to = to; + this.operation = operation; + } + + Arc opposite() { + return new Arc(to, from, operation); + } + + @Override + public int hashCode() { + return Objects.hash(from, to); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof Arc)) { + return false; + } + Arc other = (Arc) obj; + return Objects.equals(from, other.from) && Objects.equals(to, other.to); + } + + @Override + public String toString() { + return String.format("%s -> %s", from, to); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java new file mode 100644 index 00000000000..0fe73fad8ce --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java @@ -0,0 +1,108 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; +import com.yahoo.tensor.functions.ScalarFunctions; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +/** + * Maps from TensorFlow operations to Vespa operations. + * + * @author bratseth + * @author lesters + */ +public class OperationMapper { + + public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) { + switch (node.getOp().toLowerCase()) { + /* + * array ops + */ + case "const": return new Const(node, inputs, port); + case "expanddims": return new ExpandDims(node, inputs, port); + case "identity": return new Identity(node, inputs, port); + case "placeholder": return new Placeholder(node, inputs, port); + case "placeholderwithdefault": return new PlaceholderWithDefault(node, inputs, port); + case "reshape": return new Reshape(node, inputs, port); + case "shape": return new Shape(node, inputs, port); + case "squeeze": return new Squeeze(node, inputs, port); + + /* + * control flow + */ + case "merge": return new Merge(node, inputs, port); + case "switch": return new Switch(node, inputs, port); + + /* + * math ops + */ + case "add": return new Join(node, inputs, port, ScalarFunctions.add()); + case "add_n": return new Join(node, inputs, port, ScalarFunctions.add()); + case "acos": return new Map(node, inputs, port, ScalarFunctions.acos()); + case "div": return new Join(node, inputs, port, ScalarFunctions.divide()); + case "realdiv": return new Join(node, inputs, port, ScalarFunctions.divide()); + case "floor": return new Map(node, inputs, port, ScalarFunctions.floor()); + case "matmul": return new Matmul(node, inputs, port); + case "maximum": return new Join(node, inputs, port, ScalarFunctions.max()); + case "mean": return new Mean(node, inputs, port); + case "reducemean": return new Mean(node, inputs, port); + case "mul": return new Join(node, inputs, port, ScalarFunctions.multiply()); + case "multiply": return new Join(node, inputs, port, ScalarFunctions.multiply()); + case "rsqrt": return new Map(node, inputs, port, ScalarFunctions.rsqrt()); + case "select": return new Select(node, inputs, port); + case "where3": return new Select(node, inputs, port); + case "sigmoid": return new Map(node, inputs, port, ScalarFunctions.sigmoid()); + case "squareddifference": return new Join(node, inputs, port, ScalarFunctions.squareddifference()); + case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract()); + case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract()); + + /* + * nn ops + */ + case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add()); + case "elu": return new Map(node, inputs, port, ScalarFunctions.elu()); + case "relu": return new Map(node, inputs, port, ScalarFunctions.relu()); + case "selu": return new Map(node, inputs, port, ScalarFunctions.selu()); + + /* + * random ops + */ + + /* + * state ops + */ + case "variable": return new Variable(node, inputs, port); + case "variablev2": return new Variable(node, inputs, port); + + /* + * evaluation no-ops + */ + case "stopgradient":return new Identity(node, inputs, port); + case "noop": return new NoOp(node, inputs, port); + } + return new NoOp(node, inputs, port); + } + +} + + + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java new file mode 100644 index 00000000000..3742e443a06 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java @@ -0,0 +1,237 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.tensor.TensorType; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.TensorShapeProto; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A Vespa tensor type is ordered by the lexicographical ordering of dimension + * names. TensorFlow tensors have an explicit ordering of their dimensions. + * During import, we need to track the Vespa dimension that matches the + * corresponding TensorFlow dimension as the ordering can change after + * dimension renaming. That is the purpose of this class. + * + * @author lesters + */ +public class OrderedTensorType { + + private final TensorType type; + private final List<TensorType.Dimension> dimensions; + + private final long[] innerSizesTensorFlow; + private final long[] innerSizesVespa; + private final int[] dimensionMap; + + private OrderedTensorType(List<TensorType.Dimension> dimensions) { + this.dimensions = Collections.unmodifiableList(dimensions); + this.type = new TensorType.Builder(dimensions).build(); + this.innerSizesTensorFlow = new long[dimensions.size()]; + this.innerSizesVespa = new long[dimensions.size()]; + this.dimensionMap = createDimensionMap(); + } + + public TensorType type() { + return this.type; + } + + public List<TensorType.Dimension> dimensions() { + return dimensions; + } + + public List<String> dimensionNames() { + return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList()); + } + + private int[] createDimensionMap() { + int numDimensions = dimensions.size(); + if (numDimensions == 0) { + return null; + } + innerSizesTensorFlow[numDimensions - 1] = 1; + innerSizesVespa[numDimensions - 1] = 1; + for (int i = numDimensions - 1; --i >= 0; ) { + innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1]; + innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; + } + int[] mapping = new int[numDimensions]; + for (int i = 0; i < numDimensions; ++i) { + TensorType.Dimension dim1 = dimensions().get(i); + for (int j = 0; j < numDimensions; ++j) { + TensorType.Dimension dim2 = type.dimensions().get(j); + if (dim1.equals(dim2)) { + mapping[i] = j; + break; + } + } + } + return mapping; + } + + /** + * When dimension ordering between Vespa and TensorFlow differs, i.e. + * after dimension renaming, use the dimension map to read in values + * so that they are correctly laid out in memory for Vespa. + * Used when importing tensors from TensorFlow. + */ + public int toDirectIndex(int index) { + if (dimensions.size() == 0) { + return 0; + } + if (dimensionMap == null) { + throw new IllegalArgumentException("Dimension map is not available"); + } + int directIndex = 0; + long rest = index; + for (int i = 0; i < dimensions.size(); ++i) { + long address = rest / innerSizesTensorFlow[i]; + directIndex += innerSizesVespa[dimensionMap[i]] * address; + rest %= innerSizesTensorFlow[i]; + } + return directIndex; + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof OrderedTensorType)) { + return false; + } + OrderedTensorType other = (OrderedTensorType) obj; + if (dimensions.size() != dimensions.size()) { + return false; + } + List<TensorType.Dimension> thisDimensions = this.dimensions(); + List<TensorType.Dimension> otherDimensions = other.dimensions(); + for (int i = 0; i < thisDimensions.size(); ++i) { + if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { + return false; + } + } + return true; + } + + public static void verifyType(NodeDef node, OrderedTensorType type) { + if (type == null) { + return; + } + TensorShapeProto shape = tensorFlowShape(node); + if (shape != null && type.type != null) { + if (shape.getDimCount() != type.type.rank()) { + throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + + "does not match Vespa shape"); + } + for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions.size(); ++tensorFlowIndex) { + int vespaIndex = type.dimensionMap[tensorFlowIndex]; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); + TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); + if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + + "does not match Vespa dimensions"); + } + } + } + } + + private static TensorShapeProto tensorFlowShape(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); + if (attrValueList == null) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "does not exist"); + } + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "is not of expected type"); + } + List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); + return shapeList.get(0); // support multiple outputs? + } + + public static OrderedTensorType rename(OrderedTensorType type, DimensionRenamer renamer) { + List<TensorType.Dimension> renamedDimensions = new ArrayList<>(type.dimensions.size()); + for (TensorType.Dimension dimension : type.dimensions) { + String oldName = dimension.name(); + Optional<String> newName = renamer.dimensionNameOf(oldName); + if (!newName.isPresent()) + return type; // presumably, already renamed + TensorType.Dimension.Type dimensionType = dimension.type(); + if (dimensionType == TensorType.Dimension.Type.indexedBound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); + } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get())); + } else if (dimensionType == TensorType.Dimension.Type.mapped) { + renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); + } + } + return new OrderedTensorType(renamedDimensions); + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node) { + return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { + Builder builder = new Builder(node); + TensorShapeProto shape = tensorFlowShape(node); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); + if (tensorFlowDimension.getSize() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static class Builder { + + private final TensorShapeProto shape; + private final List<TensorType.Dimension> dimensions; + + public Builder(NodeDef node) { + this.shape = tensorFlowShape(node); + this.dimensions = new ArrayList<>(shape.getDimCount()); + } + + public Builder add(TensorType.Dimension vespaDimension) { + int index = dimensions.size(); + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index); + long size = tensorFlowDimension.getSize(); + if (size >= 0) { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension types"); + } + if (!vespaDimension.size().isPresent()) { + throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + + "not have a size"); + } + if (vespaDimension.size().get() != size) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get()); + } + } else { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension types"); + } + } + this.dimensions.add(vespaDimension); + return this; + } + + public OrderedTensorType build() { + return new OrderedTensorType(dimensions); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java new file mode 100644 index 00000000000..3f55e622fdf --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java @@ -0,0 +1,224 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.framework.TensorProto; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; + + +/** + * Converts TensorFlow tensors into Vespa tensors. + * + * @author bratseth + * @author lesters + */ +public class TensorConverter { + + public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { + return toVespaTensor(tfTensor, "d"); + } + + public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix); + Values values = readValuesOf(tfTensor); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + for (int i = 0; i < values.size(); i++) + builder.cellByDirectIndex(i, values.get(i)); + return builder.build(); + } + + public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) { + Values values = readValuesOf(tfTensor); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); + for (int i = 0; i < values.size(); i++) { + builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i)); + } + return builder.build(); + } + + public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) { + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + Values values = readValuesOf(tensorProto); + for (int i = 0; i < values.size(); ++i) { + builder.cellByDirectIndex(i, values.get(i)); + } + return builder.build(); + } + + private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) { + TensorType.Builder b = new TensorType.Builder(); + int dimensionIndex = 0; + for (long dimensionSize : shape) { + if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... + b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); + } + return b.build(); + } + + public static Long tensorSize(TensorType type) { + Long size = 1L; + for (TensorType.Dimension dimension : type.dimensions()) { + size *= dimensionSize(dimension); + } + return size; + } + + public static Long dimensionSize(TensorType.Dimension dim) { + return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); + } + + private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { + switch (tfTensor.dataType()) { + case DOUBLE: return new DoubleValues(tfTensor); + case FLOAT: return new FloatValues(tfTensor); + case BOOL: return new BoolValues(tfTensor); + case UINT8: return new IntValues(tfTensor); + case INT32: return new IntValues(tfTensor); + case INT64: return new LongValues(tfTensor); + } + throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + + tfTensor.dataType() + " to a Vespa tensor"); + } + + private static Values readValuesOf(TensorProto tensorProto) { + switch (tensorProto.getDtype()) { + case DT_BOOL: + return new ProtoBoolValues(tensorProto); + case DT_HALF: + return new ProtoHalfValues(tensorProto); + case DT_INT16: + case DT_INT32: + return new ProtoIntValues(tensorProto); + case DT_INT64: + return new ProtoInt64Values(tensorProto); + case DT_FLOAT: + return new ProtoFloatValues(tensorProto); + case DT_DOUBLE: + return new ProtoDoubleValues(tensorProto); + } + throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); + } + + /** Allows reading values from buffers of various numeric types as bytes */ + private static abstract class Values { + abstract double get(int i); + abstract int size(); + } + + private static abstract class TensorFlowValues extends Values { + private final int size; + TensorFlowValues(int size) { + this.size = size; + } + @Override int size() { return this.size; } + } + + private static class DoubleValues extends TensorFlowValues { + private final DoubleBuffer values; + DoubleValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = DoubleBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static class FloatValues extends TensorFlowValues { + private final FloatBuffer values; + FloatValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = FloatBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static class BoolValues extends TensorFlowValues { + private final ByteBuffer values; + BoolValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = ByteBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static class IntValues extends TensorFlowValues { + private final IntBuffer values; + IntValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = IntBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static class LongValues extends TensorFlowValues { + private final LongBuffer values; + LongValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = LongBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static abstract class ProtoValues extends Values { + protected final TensorProto tensorProto; + protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; } + } + + private static class ProtoBoolValues extends ProtoValues { + ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; } + @Override int size() { return tensorProto.getBoolValCount(); } + } + + private static class ProtoHalfValues extends ProtoValues { + ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getHalfVal(i); } + @Override int size() { return tensorProto.getHalfValCount(); } + } + + private static class ProtoIntValues extends ProtoValues { + ProtoIntValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getIntVal(i); } + @Override int size() { return tensorProto.getIntValCount(); } + } + + private static class ProtoInt64Values extends ProtoValues { + ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getInt64Val(i); } + @Override int size() { return tensorProto.getInt64ValCount(); } + } + + private static class ProtoFloatValues extends ProtoValues { + ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getFloatVal(i); } + @Override int size() { return tensorProto.getFloatValCount(); } + } + + private static class ProtoDoubleValues extends ProtoValues { + ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getDoubleVal(i); } + @Override int size() { return tensorProto.getDoubleValCount(); } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java new file mode 100644 index 00000000000..7decef51ab7 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java @@ -0,0 +1,93 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class Const extends TensorFlowOperation { + + public Const(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + setConstantValue(value()); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + } + + @Override + public Optional<TensorFunction> function() { + if (function == null) { + function = lazyGetFunction(); + } + return Optional.ofNullable(function); + } + + @Override + protected TensorFunction lazyGetFunction() { + ExpressionNode expressionNode; + if (type.type().rank() == 0 && getConstantValue().isPresent()) { + expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue()); + } else { + expressionNode = new ReferenceNode("constant(\"" + vespaName() + "\")"); + } + return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + setConstantValue(value()); + } + + @Override + public boolean isConstant() { + return true; + } + + private Value value() { + if (!node.getAttrMap().containsKey("value")) { + throw new IllegalArgumentException("Node '" + node.getName() + "' of type " + + "const has missing 'value' attribute"); + } + AttrValue attrValue = node.getAttrMap().get("value"); + if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { + return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type())); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.B) { + return new BooleanValue(attrValue.getB()); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.I) { + return new DoubleValue(attrValue.getI()); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.F) { + return new DoubleValue(attrValue.getF()); + } + throw new IllegalArgumentException("Requesting value of constant in " + + node.getName() + " but type is not recognized."); + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java new file mode 100644 index 00000000000..c1ad21f41d8 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java @@ -0,0 +1,107 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public class ExpandDims extends TensorFlowOperation { + + private List<String> expandDimensions; + + public ExpandDims(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + + TensorFlowOperation axisOperation = inputs().get(1); + if (!axisOperation.getConstantValue().isPresent()) { + throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + + "axis must be a constant."); + } + Tensor axis = axisOperation.getConstantValue().get().asTensor(); + if (axis.type().rank() != 0) { + throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + + "axis argument must be a scalar."); + } + + OrderedTensorType inputType = inputs.get(0).type().get(); + int dimensionToInsert = (int)axis.asDouble(); + if (dimensionToInsert < 0) { + dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + expandDimensions = new ArrayList<>(); + int dimensionIndex = 0; + for (TensorType.Dimension dimension : inputType.dimensions()) { + if (dimensionIndex == dimensionToInsert) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + typeBuilder.add(dimension); + dimensionIndex++; + } + + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(2)) { + return null; + } + + // multiply with a generated tensor created from the reduced dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (String name : expandDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); + for (String name : expandDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + expandDimensions = renamedDimensions; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java new file mode 100644 index 00000000000..d79707a42e6 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Identity extends TensorFlowOperation { + + public Identity(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) + return null; + return inputs.get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) + return null; + return inputs.get(0).function().orElse(null); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java new file mode 100644 index 00000000000..aa27ba2684d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java @@ -0,0 +1,79 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; +import java.util.function.DoubleBinaryOperator; + +public class Join extends TensorFlowOperation { + + private final DoubleBinaryOperator operator; + + public Join(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) { + super(node, inputs, port); + this.operator = operator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); + OrderedTensorType out = a.type().rank() >= b.type().rank() ? a : b; + return out; + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + Optional<TensorFunction> aFunction = inputs.get(0).function(); + Optional<TensorFunction> bFunction = inputs.get(1).function(); + if (!aFunction.isPresent() || !bFunction.isPresent()) { + return null; + } + + // The dimension renaming below takes care of broadcasting. + + return new com.yahoo.tensor.functions.Join(aFunction.get(), bFunction.get(), operator); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + + // Well now we have potentially entered the wonderful world of "broadcasting" + // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html + // I'm not able to extract from that any unambiguous specification of which dimensions + // should be "stretched" when the tensor do not have the same number of dimensions. + // From trying this with TensorFlow it appears that the second tensor is matched to the + // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. + // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + + TensorType a = inputs.get(0).type().get().type(); + TensorType b = inputs.get(1).type().get().type(); + if (a.rank() < b.rank()) { + TensorType temp = a; + a = b; + b = temp; + } + int sizeDifference = a.rank() - b.rank(); + for (int i = 0; i < b.rank(); ++i) { + String bDim = b.dimensions().get(i).name(); + String aDim = a.dimensions().get(i + sizeDifference).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java new file mode 100644 index 00000000000..105d65b3d69 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; +import java.util.function.DoubleUnaryOperator; + +public class Map extends TensorFlowOperation { + + private final DoubleUnaryOperator operator; + + public Map(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) { + super(node, inputs, port); + this.operator = operator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + return inputs.get(0).type().get(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + Optional<TensorFunction> input = inputs.get(0).function(); + return new com.yahoo.tensor.functions.Map(input.get(), operator); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java new file mode 100644 index 00000000000..ac4f78653d6 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java @@ -0,0 +1,74 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class Matmul extends TensorFlowOperation { + + public Matmul(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); + typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + if (aType.type().rank() < 2 || bType.type().rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (aType.type().rank() != bType.type().rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + Optional<TensorFunction> aFunction = inputs.get(0).function(); + Optional<TensorFunction> bFunction = inputs.get(1).function(); + if (!aFunction.isPresent() || !bFunction.isPresent()) { + return null; + } + return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); + List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + + String aDim0 = aDimensions.get(0).name(); + String aDim1 = aDimensions.get(1).name(); + String bDim0 = bDimensions.get(0).name(); + String bDim1 = bDimensions.get(1).name(); + + // The second dimension of a should have the same name as the first dimension of b + renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); + + // The first dimension of a should have a different name than the second dimension of b + renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); + + // For efficiency, the dimensions to join over should be innermost - soft constraint + renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java new file mode 100644 index 00000000000..dfe0796d9b8 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java @@ -0,0 +1,112 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +public class Mean extends TensorFlowOperation { + + private List<String> reduceDimensions; + + public Mean(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + TensorFlowOperation reductionIndices = inputs.get(1); + if (!reductionIndices.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Mean in " + node.getName() + ": " + + "reduction indices must be a constant."); + } + Tensor indices = reductionIndices.getConstantValue().get().asTensor(); + reduceDimensions = new ArrayList<>(); + + OrderedTensorType inputType = inputs.get(0).type().get(); + for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int dimensionIndex = cell.getValue().intValue(); + if (dimensionIndex < 0) { + dimensionIndex = inputType.dimensions().size() - dimensionIndex; + } + reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); + } + return reducedType(inputType, shouldKeepDimensions()); + } + + // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity. + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); + if (shouldKeepDimensions()) { + // multiply with a generated tensor created from the reduced dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (String name : reduceDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + } + return output; + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size()); + for (String name : reduceDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + reduceDimensions = renamedDimensions; + } + + private boolean shouldKeepDimensions() { + AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims"); + return keepDimsAttr != null && keepDimsAttr.getB(); + } + + private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + for (TensorType.Dimension dimension: inputType.type().dimensions()) { + if (!reduceDimensions.contains(dimension.name())) { + builder.add(dimension); + } else if (keepDimensions) { + builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); + } + } + return builder.build(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java new file mode 100644 index 00000000000..d3561716725 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java @@ -0,0 +1,36 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Merge extends TensorFlowOperation { + + public Merge(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + for (TensorFlowOperation operation : inputs) { + if (operation.type().isPresent()) { + return operation.type().get(); + } + } + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + for (TensorFlowOperation operation : inputs) { + if (operation.function().isPresent()) { + return operation.function().get(); + } + } + return null; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java new file mode 100644 index 00000000000..acf5d13b057 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java @@ -0,0 +1,32 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; + +public class NoOp extends TensorFlowOperation { + + public NoOp(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, Collections.emptyList(), port); // don't propagate inputs + } + + @Override + protected OrderedTensorType lazyGetType() { + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java new file mode 100644 index 00000000000..dadce395faf --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java @@ -0,0 +1,57 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Placeholder extends TensorFlowOperation { + + private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... + + public Placeholder(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + standardNamingType = OrderedTensorType.fromTensorFlowType(node); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + } + + @Override + protected TensorFunction lazyGetFunction() { + TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); + if (!standardNamingType.equals(type)) { + List<String> renameFrom = standardNamingType.dimensionNames(); + List<String> renameTo = type.dimensionNames(); + output = new Rename(output, renameFrom, renameTo); + } + return output; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isInput() { + return true; + } + + @Override + public boolean isConstant() { + return false; + } + + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java new file mode 100644 index 00000000000..ab091b77a65 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java @@ -0,0 +1,50 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class PlaceholderWithDefault extends TensorFlowOperation { + + public PlaceholderWithDefault(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + return inputs().get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + // This should be a call to the macro we add below, but for now + // we treat this as as identity function and just pass the constant. + return inputs.get(0).function().orElse(null); + } + + @Override + public Optional<RankingExpression> macro() { + // For now, it is much more efficient to assume we always will return + // the default value, as we can prune away large parts of the expression + // tree by having it calculated as a constant. If a case arises where + // it is important to support this, implement this. + return Optional.empty(); + } + + @Override + public boolean isConstant() { + return true; // not true if we add to macro + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java new file mode 100644 index 00000000000..9b3e28ce56b --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java @@ -0,0 +1,135 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; + +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; + +public class Reshape extends TensorFlowOperation { + + public Reshape(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + TensorFlowOperation newShape = inputs.get(1); + if (!newShape.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Reshape in " + node.getName() + ": " + + "shape input must be a constant."); + } + Tensor shape = newShape.getConstantValue().get().asTensor(); + + OrderedTensorType inputType = inputs.get(0).type().get(); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node); + int dimensionIndex = 0; + for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int size = cell.getValue().intValue(); + if (size < 0) { + size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / + tensorSize(inputType.type()).intValue(); + } + outputTypeBuilder.add(TensorType.Dimension.indexed( + String.format("%s_%d", vespaName(), dimensionIndex), size)); + dimensionIndex++; + } + return outputTypeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + if (!allInputFunctionsPresent(2)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + TensorFunction inputFunction = inputs.get(0).function().get(); + return reshape(inputFunction, inputType.type(), type.type()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { + if (!tensorSize(inputType).equals(tensorSize(outputType))) { + throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); + } + + // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, + // then use the dimension order of the new shape to roll back into a tensor. + // Here we create a transformation tensor that is multiplied with the from tensor to map into + // the new shape. We have to introduce temporary dimension names and rename back if dimension names + // in the new and old tensor type overlap. + + ExpressionNode unrollFrom = unrollTensorExpression(inputType); + ExpressionNode unrollTo = unrollTensorExpression(outputType); + ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo); + + TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); + Generate transformTensor = new Generate(transformationType, + new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); + + TensorFunction outputFunction = new Reduce( + new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); + + return outputFunction; + } + + private static ExpressionNode unrollTensorExpression(TensorType type) { + if (type.rank() == 0) { + return new ConstantNode(DoubleValue.zero); + } + List<ExpressionNode> children = new ArrayList<>(); + List<ArithmeticOperator> operators = new ArrayList<>(); + int size = 1; + for (int i = type.dimensions().size() - 1; i >= 0; --i) { + TensorType.Dimension dimension = type.dimensions().get(i); + children.add(0, new ReferenceNode(dimension.name())); + if (size > 1) { + operators.add(0, ArithmeticOperator.MULTIPLY); + children.add(0, new ConstantNode(new DoubleValue(size))); + } + size *= TensorConverter.dimensionSize(dimension); + if (i > 0) { + operators.add(0, ArithmeticOperator.PLUS); + } + } + return new ArithmeticNode(children, operators); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java new file mode 100644 index 00000000000..6a29d428cf3 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java @@ -0,0 +1,89 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.function.DoubleBinaryOperator; + +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize; +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; + +public class Select extends TensorFlowOperation { + + public Select(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(3)) { + return null; + } + OrderedTensorType a = inputs.get(1).type().get(); + OrderedTensorType b = inputs.get(2).type().get(); + if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) { + throw new IllegalArgumentException("'Select': input tensors must have the same shape"); + } + return a; + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(3)) { + return null; + } + TensorFlowOperation conditionOperation = inputs().get(0); + TensorFunction a = inputs().get(1).function().get(); + TensorFunction b = inputs().get(2).function().get(); + + // Shortcut: if we know during import which tensor to select, do that directly here. + if (conditionOperation.getConstantValue().isPresent()) { + Tensor condition = conditionOperation.getConstantValue().get().asTensor(); + if (condition.type().rank() == 0) { + return ((int) condition.asDouble() == 0) ? b : a; + } + if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { + return condition.cellIterator().next().getValue().intValue() == 0 ? b : a; + } + } + + // The task is to select cells from 'x' or 'y' based on 'condition'. + // If 'condition' is 0 (false), select from 'y', if 1 (true) select + // from 'x'. We do this by individually joining 'x' and 'y' with + // 'condition', and then joining the resulting two tensors. + + TensorFunction conditionFunction = conditionOperation.function().get(); + TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply()); + TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() { + @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } + @Override public String toString() { return "f(a,b)(a * (1-b))"; } + }); + return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(3)) { + return; + } + List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions(); + List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions(); + + String aDim0 = aDimensions.get(0).name(); + String aDim1 = aDimensions.get(1).name(); + String bDim0 = bDimensions.get(0).name(); + String bDim1 = bDimensions.get(1).name(); + + // These tensors should have the same dimension names + renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this); + renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java new file mode 100644 index 00000000000..8f4313022e0 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java @@ -0,0 +1,55 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Shape extends TensorFlowOperation { + + public Shape(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + createConstantValue(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + return new OrderedTensorType.Builder(node) + .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) + .build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public boolean isConstant() { + return true; + } + + private void createConstantValue() { + if (!allInputTypesPresent(1)) { + return; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type()); + List<TensorType.Dimension> inputDimensions = inputType.dimensions(); + for (int i = 0; i < inputDimensions.size(); i++) { + builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L)); + } + this.setConstantValue(new TensorValue(builder.build())); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java new file mode 100644 index 00000000000..d7750b52fc3 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java @@ -0,0 +1,84 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +public class Squeeze extends TensorFlowOperation { + + private List<String> squeezeDimensions; + + public Squeeze(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + squeezeDimensions = new ArrayList<>(); + + AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims"); + if (squeezeDimsAttr == null) { + squeezeDimensions = inputType.type().dimensions().stream(). + filter(dim -> TensorConverter.dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } else { + squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). + map(i -> i < 0 ? inputType.type().dimensions().size() - i : i). + map(i -> inputType.type().dimensions().get(i.intValue())). + filter(dim -> TensorConverter.dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } + return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + TensorFunction inputFunction = inputs.get(0).function().get(); + return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size()); + for (String name : squeezeDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + squeezeDimensions = renamedDimensions; + } + + private OrderedTensorType reducedType(OrderedTensorType inputType) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + for (TensorType.Dimension dimension: inputType.type().dimensions()) { + if ( ! squeezeDimensions.contains(dimension.name())) { + builder.add(dimension); + } + } + return builder.build(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java new file mode 100644 index 00000000000..1cc0e1936eb --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java @@ -0,0 +1,48 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class Switch extends TensorFlowOperation { + + public Switch(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + Optional<OrderedTensorType> predicate = inputs.get(1).type(); + if (predicate.get().type().rank() != 0) { + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + "predicate must be a scalar"); + } + return inputs.get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + TensorFlowOperation predicateOperation = inputs().get(1); + if (!predicateOperation.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + "predicate must be a constant"); + } + if (port < 0 || port > 1) { + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + "choice should be boolean"); + } + + double predicate = predicateOperation.getConstantValue().get().asDouble(); + return predicate == port ? inputs().get(0).function().get() : null; + } + +} + + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java new file mode 100644 index 00000000000..fd9dfd167fb --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -0,0 +1,136 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +/** + * Wraps a TensorFlow node and produces the respective Vespa tensor operation. + * During import, a graph of these operations are constructed. Then, the + * types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper + * Vespa expressions can be extracted. + * + * @author lesters + */ +public abstract class TensorFlowOperation { + + protected final NodeDef node; + protected final int port; + protected final List<TensorFlowOperation> inputs; + protected final List<TensorFlowOperation> outputs = new ArrayList<>(); + protected final List<String> importWarnings = new ArrayList<>(); + + protected OrderedTensorType type; + protected TensorFunction function; + + private Value constantValue = null; + private List<TensorFlowOperation> controlInputs = Collections.emptyList(); + + TensorFlowOperation(NodeDef node, List<TensorFlowOperation> inputs, int port) { + this.node = node; + this.port = port; + this.inputs = Collections.unmodifiableList(inputs); + this.inputs.forEach(i -> i.outputs.add(this)); + } + + protected abstract OrderedTensorType lazyGetType(); + protected abstract TensorFunction lazyGetFunction(); + + /** Returns the Vespa tensor type of this operation if it exists */ + public Optional<OrderedTensorType> type() { + if (type == null) { + type = lazyGetType(); + } + OrderedTensorType.verifyType(node, type); + return Optional.ofNullable(type); + } + + /** Returns the Vespa tensor function implementing all operations from this node with inputs */ + public Optional<TensorFunction> function() { + if (function == null) { + if (isConstant()) { + ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")"); + function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); + } else { + function = lazyGetFunction(); + } + } + return Optional.ofNullable(function); + } + + /** Return TensorFlow node */ + public NodeDef node() { return node; } + + /** Return unmodifiable list of inputs */ + public List<TensorFlowOperation> inputs() { return inputs; } + + /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ + public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); } + + /** Returns a Vespa ranking expression that should be added as a macro */ + public Optional<RankingExpression> macro() { return Optional.empty(); } + + /** Add dimension name constraints for this operation */ + public void addDimensionNameConstraints(DimensionRenamer renamer) { } + + /** Performs dimension rename for this operation */ + public void renameDimensions(DimensionRenamer renamer) { type = OrderedTensorType.rename(type, renamer); } + + /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ + public boolean isInput() { return false; } + + /** Return true if this node is constant */ + public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); } + + /** Sets the constant value */ + public void setConstantValue(Value value) { constantValue = value; } + + /** Gets the constant value if it exists */ + public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } + + /** Sets the external control inputs */ + public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; } + + /** Retrieve the control inputs for this operation */ + public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } + + /** Retrieve the valid Vespa name of this node */ + public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; } + + /** Retrieve the list of warnings produced during its lifetime */ + public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } + + boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) { + if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) { + return false; + } + if (inputs.size() != expected) { + throw new IllegalArgumentException("Expected " + expected + " inputs " + + "for '" + node.getName() + "', got " + inputs.size()); + } + return inputs.stream().map(func).allMatch(Optional::isPresent); + } + + boolean allInputTypesPresent(int expected) { + return verifyInputs(expected, TensorFlowOperation::type); + } + + boolean allInputFunctionsPresent(int expected) { + return verifyInputs(expected, TensorFlowOperation::function); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java new file mode 100644 index 00000000000..6f377c4bda2 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java @@ -0,0 +1,40 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Variable extends TensorFlowOperation { + + public Variable(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java index fb9a7cb9ad7..d3a12d0f312 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java @@ -13,7 +13,7 @@ import java.util.List; /** * A set of argument expressions to a function or feature. - * This is immutable. + * This is a value object. * * @author bratseth */ @@ -22,7 +22,11 @@ public final class Arguments implements Serializable { private final ImmutableList<ExpressionNode> expressions; public Arguments() { - this(null); + this(ImmutableList.of()); + } + + public Arguments(ExpressionNode singleArgument) { + this(ImmutableList.of(singleArgument)); } public Arguments(List<? extends ExpressionNode> expressions) { @@ -38,9 +42,12 @@ public final class Arguments implements Serializable { this.expressions = b.build(); } - /** Returns an unmodifiable list of the expressions in this */ + /** Returns an unmodifiable list of the expressions in this, never null */ public List<ExpressionNode> expressions() { return expressions; } + /** Returns the number of arguments in this */ + public int size() { return expressions.size(); } + /** Evaluate all arguments in this */ public Value[] evaluate(Context context) { Value[] values=new Value[expressions.size()]; @@ -62,8 +69,9 @@ public final class Arguments implements Serializable { } @Override - public boolean equals(Object rhs) { - return rhs instanceof Arguments && expressions.equals(((Arguments)rhs).expressions); + public boolean equals(Object other) { + if (other == this) return true; + return other instanceof Arguments && expressions.equals(((Arguments)other).expressions); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index fc6428a4c33..49c49bed9bd 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -80,7 +81,7 @@ public final class ArithmeticNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { // Compute type using tensor types as arithmetic operators are supported on tensors // and is correct also in the special case of doubles. // As all our functions are type-commutative, we don't need to take operator precedence into account diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java index 1d7d9b1ecda..cd4ddbcae55 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java @@ -5,7 +5,6 @@ package com.yahoo.searchlib.rankingexpression.rule; * A node which produces a boolean value when evaluated. * * @author bratseth - * @since 5.1.21 */ public abstract class BooleanNode extends CompositeNode { } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 7601c0e6180..eb328486045 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -49,7 +50,7 @@ public class ComparisonNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; // by definition } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java index 1ea8d03f0eb..3ddd7223349 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -49,7 +50,7 @@ public final class ConstantNode extends ExpressionNode { } @Override - public TensorType type(TypeContext context) { return value.type(); } + public TensorType type(TypeContext<Reference> context) { return value.type(); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java index fd9fab99db8..47c2897e4a4 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -50,7 +51,7 @@ public final class EmbracedNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index 477f4db4981..6bb163590de 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -48,7 +49,7 @@ public abstract class ExpressionNode implements Serializable { * @param context the variable type bindings to use for this evaluation * @throws IllegalArgumentException if there are variables which are not bound in the given map */ - public abstract TensorType type(TypeContext context); + public abstract TensorType type(TypeContext<Reference> context); /** * Returns the value of evaluating this expression over the given context. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index 79515229019..1da2210a39c 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -67,7 +68,7 @@ public final class FunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { if (arguments.expressions().size() == 0) return TensorType.empty; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java new file mode 100644 index 00000000000..ed1e2838717 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java @@ -0,0 +1,74 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.google.common.collect.ImmutableMap; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; + +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * The context of a function invocation. + * + * @author bratseth + */ +public class FunctionReferenceContext { + + /** Expression functions indexed by name */ + private final ImmutableMap<String, ExpressionFunction> functions; + + /** Mapping from argument names to the expressions they resolve to */ + // TODO: Make private + public final Map<String, String> bindings = new HashMap<>(); + + /** Create a context for a single serialization task */ + public FunctionReferenceContext() { + this(Collections.emptyList()); + } + + /** Create a context for a single serialization task */ + public FunctionReferenceContext(Collection<ExpressionFunction> functions) { + this(toMap(functions), Collections.emptyMap()); + } + + public FunctionReferenceContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) { + this(toMap(functions), bindings); + } + + /** Create a context for a single serialization task */ + public FunctionReferenceContext(Map<String, ExpressionFunction> functions) { + this(functions.values()); + } + + /** Create a context for a single serialization task */ + public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings) { + this.functions = ImmutableMap.copyOf(functions); + if (bindings != null) + this.bindings.putAll(bindings); + } + + private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) { + ImmutableMap.Builder<String,ExpressionFunction> mapBuilder = new ImmutableMap.Builder<>(); + for (ExpressionFunction function : list) + mapBuilder.put(function.getName(), function); + return mapBuilder.build(); + } + + /** + * Returns a function or null if it isn't defined in this context + */ + public ExpressionFunction getFunction(String name) { return functions.get(name); } + + protected Map<String, ExpressionFunction> functions() { return functions; } + + /** Returns the resolution of an argument, or null if it isn't defined in this context */ + public String getBinding(String name) { return bindings.get(name); } + + /** Returns a new context with the bindings replaced by the given bindings */ + public FunctionReferenceContext withBindings(Map<String, String> bindings) { + return new FunctionReferenceContext(this.functions, bindings); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index e42884ecc05..c87eb0ace39 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -48,7 +49,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { return type; } + public TensorType type(TypeContext<Reference> context) { return type; } /** Evaluate this in a context which must have the arguments bound */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 66b250736e8..ee4edac4941 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -75,7 +76,7 @@ public final class IfNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { TensorType trueType = trueExpression.type(context); TensorType falseType = falseExpression.type(context); return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() -> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index da946228291..61086f8182a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -57,7 +58,7 @@ public class LambdaFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; // by definition - no nested lambdas } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java index f55ed59b65c..f1adf331630 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -14,6 +15,7 @@ import java.util.Deque; * * @author Simon Thoresen */ +// TODO: This is achieved by ReferenceNode in almost all cases - remove this public final class NameNode extends ExpressionNode { private final String name; @@ -32,7 +34,7 @@ public final class NameNode extends ExpressionNode { } @Override - public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); } + public TensorType type(TypeContext<Reference> context) { throw new RuntimeException("Named nodes can not have a type"); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java index 9cbe5f98c72..fcc03dc4862 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -38,7 +39,7 @@ public class NegativeNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java index e7041600635..a539f496ff5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -38,7 +39,7 @@ public class NotNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 05a6773c5cb..78f53b1593d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -13,114 +14,102 @@ import java.util.Deque; import java.util.List; /** - * A node referring either to a value in the context or to another named ranking expression. + * A node referring either to a value in the context or to a named ranking expression (function aka macro). * * @author simon * @author bratseth */ public final class ReferenceNode extends CompositeNode { - private final String name, output; - - private final Arguments arguments; + private final Reference reference; + /* Creates a node with a simple identifier reference */ public ReferenceNode(String name) { this(name, null, null); } public ReferenceNode(String name, List<? extends ExpressionNode> arguments, String output) { - this.name = name; - this.arguments = arguments != null ? new Arguments(arguments) : new Arguments(); - this.output = output; + this.reference = new Reference(name, + arguments != null ? new Arguments(arguments) : new Arguments(), + output); + } + + public ReferenceNode(Reference reference) { + this.reference = reference; } public String getName() { - return name; + return reference.name(); } /** Returns the arguments, never null */ - public Arguments getArguments() { return arguments; } + public Arguments getArguments() { return reference.arguments(); } /** Returns a copy of this where the arguments are replaced by the given arguments */ public ReferenceNode setArguments(List<ExpressionNode> arguments) { - return new ReferenceNode(name, arguments, output); + return new ReferenceNode(reference.withArguments(new Arguments(arguments))); } /** Returns the specific output this references, or null if none specified */ - public String getOutput() { return output; } + public String getOutput() { return reference.output(); } /** Returns a copy of this node with a modified output */ public ReferenceNode setOutput(String output) { - return new ReferenceNode(name, arguments.expressions(), output); + return new ReferenceNode(reference.withOutput(output)); } /** Returns an empty list as this has no children */ @Override - public List<ExpressionNode> children() { return arguments.expressions(); } + public List<ExpressionNode> children() { return reference.arguments().expressions(); } @Override public String toString(SerializationContext context, Deque<String> path, CompositeNode parent) { - if (path == null) - path = new ArrayDeque<>(); - String myName = this.name; - String myOutput = this.output; - List<ExpressionNode> myArguments = this.arguments.expressions(); - - String resolvedArgument = context.getBinding(myName); - if (resolvedArgument != null && this.arguments.expressions().size() == 0 && myOutput == null) { - // Replace this whole node with the value of the argument value that it maps to - myName = resolvedArgument; - myArguments = null; - myOutput = null; - } else if (context.getFunction(myName) != null) { - // Replace by the referenced expression - ExpressionFunction function = context.getFunction(myName); - if (function != null && myArguments != null && function.arguments().size() == myArguments.size() && myOutput == null) { - String myPath = name + this.arguments.expressions(); - if (path.contains(myPath)) { - throw new IllegalStateException("Cycle in ranking expression function: " + path); - } - path.addLast(myPath); - ExpressionFunction.Instance instance = function.expand(context, myArguments, path); - path.removeLast(); - context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); - myName = "rankingExpression(" + instance.getName() + ")"; - myArguments = null; - myOutput = null; - } + if (reference.isIdentifier() && context.getBinding(getName()) != null) { + // a bound identifier: replace by the value it is bound to + return context.getBinding(getName()); } - // Always print the same way, the magic is already done. - StringBuilder ret = new StringBuilder(myName); - if (myArguments != null && myArguments.size() > 0) { - ret.append("("); - for (int i = 0; i < myArguments.size(); ++i) { - ret.append(myArguments.get(i).toString(context, path, this)); - if (i < myArguments.size() - 1) { - ret.append(","); - } - } - ret.append(")"); + + ExpressionFunction function = context.getFunction(getName()); + if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) { + // a function reference: replace by the referenced function wrapped in rankingExpression + if (path == null) + path = new ArrayDeque<>(); + String myPath = getName() + getArguments().expressions(); + if (path.contains(myPath)) + throw new IllegalStateException("Cycle in ranking expression function: " + path); + path.addLast(myPath); + ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); + path.removeLast(); + context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); + return "rankingExpression(" + instance.getName() + ")"; } - ret.append(myOutput != null ? "." + myOutput : ""); - return ret.toString(); + + // not resolved in this context: output as-is + return reference.toString(context, path, parent); } + /** Returns the reference of this node */ + public Reference reference() { return reference; } + @Override - public TensorType type(TypeContext context) { - // Don't support outputs of different type, for simplicity - return context.getType(toString()); + public TensorType type(TypeContext<Reference> context) { + TensorType type = context.getType(reference); + if (type == null) + throw new IllegalArgumentException("Unknown feature '" + toString() + "'"); + return type; } @Override public Value evaluate(Context context) { - if (arguments.expressions().isEmpty() && output == null) - return context.get(name); - return context.get(name, arguments, output); + // TODO: Context should accept a Reference instead. + if (reference.isIdentifier()) + return context.get(reference.name()); + return context.get(getName(), getArguments(), getOutput()); } @Override public CompositeNode setChildren(List<ExpressionNode> newChildren) { - return new ReferenceNode(name, newChildren, output); + return setArguments(newChildren); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index ba765d07094..796c13a8669 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -16,17 +16,11 @@ import java.util.Map; * * @author bratseth */ -public class SerializationContext { +public class SerializationContext extends FunctionReferenceContext { - /** Expression functions indexed by name */ - private final ImmutableMap<String, ExpressionFunction> functions; - - /** A cache of already serialized expressions indexed by name */ + /** Serialized form of functions indexed by name */ private final Map<String, String> serializedFunctions; - /** Mapping from argument names to the expressions they resolve to */ - public final Map<String, String> bindings = new HashMap<>(); - /** Create a context for a single serialization task */ public SerializationContext() { this(Collections.emptyList()); @@ -77,17 +71,10 @@ public class SerializationContext { */ public SerializationContext(ImmutableMap<String,ExpressionFunction> functions, Map<String, String> bindings, Map<String, String> serializedFunctions) { - this.functions = functions; + super(functions, bindings); this.serializedFunctions = serializedFunctions; - if (bindings != null) - this.bindings.putAll(bindings); } - /** - * Returns a function or null if it isn't defined in this context - */ - public ExpressionFunction getFunction(String name) { return functions.get(name); } - /** Adds the serialization of a function */ public void addFunctionSerialization(String name, String expressionString) { serializedFunctions.put(name, expressionString); @@ -98,17 +85,9 @@ public class SerializationContext { return serializedFunctions.get(name); } - /** - * Returns the resolution of an argument, or null if it isn't defined in this context - */ - public String getBinding(String name) { return bindings.get(name); } - - /** - * Returns a new context which shares the functions and serialized function map with this but has different - * arguments. - */ - public SerializationContext createBinding(Map<String, String> arguments) { - return new SerializationContext(this.functions, arguments, this.serializedFunctions); + @Override + public SerializationContext withBindings(Map<String, String> bindings) { + return new SerializationContext(functions().values(), bindings, this.serializedFunctions); } public Map<String, String> serializedFunctions() { return serializedFunctions; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index a7b82f4753f..cb31219579a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -60,7 +61,7 @@ public class SetMembershipNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ec6af4bb413..6c9b6bb4a98 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -64,7 +65,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { return function.type(context); } + public TensorType type(TypeContext<Reference> context) { return function.type(context); } @Override public Value evaluate(Context context) { @@ -111,12 +112,13 @@ public class TensorFunctionNode extends CompositeNode { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { - return expression.type(context); + @SuppressWarnings("unchecked") // Generics awkwardness + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + return expression.type((TypeContext<Reference>)context); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return expression.evaluate((Context)context).asTensor(); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index e9030cf5852..f2122bb5da9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -378,8 +378,13 @@ public class EvaluationTestCase { private static class StructuredTestContext extends MapContext { @Override + public Value get(String feature) { + throw new RuntimeException("Called simple get for feature " + feature); + } + + @Override public Value get(String name, Arguments arguments, String output) { - if (!name.equals("average")) { + if ( ! name.equals("average")) { throw new IllegalArgumentException("Unknown operation '" + name + "'"); } if (arguments.expressions().size() != 2) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java index c882c887c8d..a08d510eec4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.TypeContext; @@ -18,12 +19,17 @@ public class TypeResolutionTestCase { @Test public void testTypeResolution() { - TypeMapContext context = new TypeMapContext(); - context.setType("query(x1)", TensorType.fromSpec("tensor(x[])")); - context.setType("query(x2)", TensorType.fromSpec("tensor(x[10])")); - context.setType("query(y1)", TensorType.fromSpec("tensor(y[])")); - context.setType("query(xy1)", TensorType.fromSpec("tensor(x[10],y[])")); - context.setType("query(xy2)", TensorType.fromSpec("tensor(x[],y[10])")); + MapTypeContext context = new MapTypeContext(); + context.setType(Reference.simple("query", "x1"), + TensorType.fromSpec("tensor(x[])")); + context.setType(Reference.simple("query", "x2"), + TensorType.fromSpec("tensor(x[10])")); + context.setType(Reference.simple("query", "y1"), + TensorType.fromSpec("tensor(y[])")); + context.setType(Reference.simple("query", "xy1"), + TensorType.fromSpec("tensor(x[10],y[])")); + context.setType(Reference.simple("query", "xy2"), + TensorType.fromSpec("tensor(x[],y[10])")); assertType("tensor(x[])", "query(x1)", context); assertType("tensor(x[])", "if (1>0, query(x1), query(x2))", context); @@ -31,7 +37,7 @@ public class TypeResolutionTestCase { assertIncompatibleType("if (1>0, query(x1), query(y1))", context); } - private void assertType(String type, String expression, TypeContext context) { + private void assertType(String type, String expression, TypeContext<Reference> context) { try { assertEquals(TensorType.fromSpec(type), new RankingExpression(expression).type(context)); } @@ -40,7 +46,7 @@ public class TypeResolutionTestCase { } } - private void assertIncompatibleType(String expression, TypeContext context) { + private void assertIncompatibleType(String expression, TypeContext<Reference> context) { try { new RankingExpression(expression).type(context); fail("Expected type incompatibility exception"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java new file mode 100644 index 00000000000..ebcfde54c70 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java @@ -0,0 +1,49 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import org.junit.Test; + +import static junit.framework.TestCase.assertTrue; + +public class DimensionRenamerTest { + + @Test + public void testMnistRenaming() { + DimensionRenamer renamer = new DimensionRenamer(); + + renamer.addDimension("first_dimension_of_x"); + renamer.addDimension("second_dimension_of_x"); + renamer.addDimension("first_dimension_of_w"); + renamer.addDimension("second_dimension_of_w"); + renamer.addDimension("first_dimension_of_b"); + + // which dimension to join on matmul + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); + + // other dimensions in matmul can't be equal + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); + + // for efficiency, put dimension to join on innermost + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); + + // bias + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); + + renamer.solve(); + + String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get(); + String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get(); + String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get(); + String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get(); + String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get(); + + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0); + assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0); + assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0); + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0); + assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0); + + + } +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index 3b25bfe1b1e..f64d697d9b9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -18,11 +18,6 @@ public class DropoutImportTestCase { public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); - // Check (provided) macros - assertEquals(1, model.get().macros().size()); - assertTrue(model.get().macros().containsKey("training_input")); - assertEquals("constant(\"training_input\")", model.get().macros().get("training_input").getRoot().toString()); - // Check required macros assertEquals(1, model.get().requiredMacros().size()); assertTrue(model.get().requiredMacros().containsKey("X")); @@ -37,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/BiasAdd", output.getName()); - assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs_kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs_bias\"), d0, d1), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index ad5abd4c03d..60dd3865aa1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -22,15 +22,15 @@ public class MnistSoftmaxImportTestCase { // Check constants assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("Variable"); + Tensor constant0 = model.get().largeConstants().get("Variable_read"); assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("Variable_1"); + Tensor constant1 = model.get().largeConstants().get("Variable_1_read"); assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); @@ -59,12 +59,10 @@ public class MnistSoftmaxImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))", output.getRoot().toString()); // Test execution - model.assertEqualResult("Placeholder", "Variable/read"); - model.assertEqualResult("Placeholder", "Variable_1/read"); model.assertEqualResult("Placeholder", "MatMul"); model.assertEqualResult("Placeholder", "add"); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index ae7714b271a..1691756a64d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.tensorflow.SavedModelBundle; @@ -47,8 +48,11 @@ public class TestableTensorFlowModel { private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, - FloatBuffer.allocate(d0Size * d1Size)); + FloatBuffer fb = FloatBuffer.allocate(d0Size * d1Size); + for (int i = 0; i < d1Size; ++i) { + fb.put(i, (float)(i * 1.0 / d1Size)); + } + org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb); runner.feed(inputName, placeholder); List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); @@ -66,7 +70,7 @@ public class TestableTensorFlowModel { Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); for (int d0 = 0; d0 < d0Size; d0++) for (int d1 = 0; d1 < d1Size; d1++) - b.cell(0, d0, d1); + b.cell(d1 * 1.0 / d1Size, d0, d1); return b.build(); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java index 867331e99ce..303135888d8 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/rule/ArgumentsTestCase.java @@ -9,13 +9,13 @@ import java.util.Collections; import static org.junit.Assert.*; /** - * @author <a href="mailto:simon@yahoo-inc.com">Simon Thoresen</a> + * @author Simon Thoresen */ public class ArgumentsTestCase { @Test public void requireThatAccessorsWork() { - Arguments args = new Arguments(null); + Arguments args = new Arguments(); assertTrue(args.expressions().isEmpty()); args = new Arguments(Collections.<ExpressionNode>emptyList()); diff --git a/storage/src/tests/common/testnodestateupdater.cpp b/storage/src/tests/common/testnodestateupdater.cpp index 18f296e5583..c7fd47e37c7 100644 --- a/storage/src/tests/common/testnodestateupdater.cpp +++ b/storage/src/tests/common/testnodestateupdater.cpp @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "testnodestateupdater.h" -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> namespace storage { @@ -14,7 +14,7 @@ TestNodeStateUpdater::TestNodeStateUpdater(const lib::NodeType& type) TestNodeStateUpdater::~TestNodeStateUpdater() = default; -std::shared_ptr<const ClusterStateBundle> +std::shared_ptr<const lib::ClusterStateBundle> TestNodeStateUpdater::getClusterStateBundle() const { return _clusterStateBundle; @@ -23,7 +23,7 @@ TestNodeStateUpdater::getClusterStateBundle() const void TestNodeStateUpdater::setClusterState(lib::ClusterState::CSP c) { - _clusterStateBundle = std::make_shared<const ClusterStateBundle>(*c); + _clusterStateBundle = std::make_shared<const lib::ClusterStateBundle>(*c); for (uint32_t i = 0; i < _listeners.size(); ++i) { _listeners[i]->handleNewState(); } diff --git a/storage/src/tests/common/testnodestateupdater.h b/storage/src/tests/common/testnodestateupdater.h index daecb45ece4..1e898e84b18 100644 --- a/storage/src/tests/common/testnodestateupdater.h +++ b/storage/src/tests/common/testnodestateupdater.h @@ -16,7 +16,7 @@ struct TestNodeStateUpdater : public NodeStateUpdater { lib::NodeState::CSP _reported; lib::NodeState::CSP _current; - std::shared_ptr<const ClusterStateBundle> _clusterStateBundle; + std::shared_ptr<const lib::ClusterStateBundle> _clusterStateBundle; std::vector<StateListener*> _listeners; public: @@ -25,7 +25,7 @@ public: lib::NodeState::CSP getReportedNodeState() const override { return _reported; } lib::NodeState::CSP getCurrentNodeState() const override { return _current; } - std::shared_ptr<const ClusterStateBundle> getClusterStateBundle() const override; + std::shared_ptr<const lib::ClusterStateBundle> getClusterStateBundle() const override; void addStateListener(StateListener& s) override { _listeners.push_back(&s); } void removeStateListener(StateListener&) override {} Lock::SP grabStateChangeLock() override { return Lock::SP(new Lock); } diff --git a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp index 2192ae4d634..248fb1e5203 100644 --- a/storage/src/tests/persistence/filestorage/filestormanagertest.cpp +++ b/storage/src/tests/persistence/filestorage/filestormanagertest.cpp @@ -7,7 +7,7 @@ #include <vespa/document/test/make_document_bucket.h> #include <vespa/storage/storageserver/statemanager.h> #include <vespa/storage/bucketdb/bucketmanager.h> -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storage/persistence/persistencethread.h> #include <vespa/storage/persistence/filestorage/filestormanager.h> #include <vespa/storage/persistence/filestorage/modifiedbucketchecker.h> diff --git a/storage/src/tests/storageserver/statemanagertest.cpp b/storage/src/tests/storageserver/statemanagertest.cpp index 0676d3684ff..7c5303f74fe 100644 --- a/storage/src/tests/storageserver/statemanagertest.cpp +++ b/storage/src/tests/storageserver/statemanagertest.cpp @@ -4,7 +4,7 @@ #include <vespa/metrics/metricmanager.h> #include <vespa/storageapi/message/bucket.h> #include <vespa/storageapi/message/state.h> -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storage/frameworkimpl/component/storagecomponentregisterimpl.h> #include <vespa/storage/storageserver/statemanager.h> #include <tests/common/teststorageapp.h> diff --git a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp index 142003735b8..5078d35956a 100644 --- a/storage/src/vespa/storage/bucketdb/bucketmanager.cpp +++ b/storage/src/vespa/storage/bucketdb/bucketmanager.cpp @@ -7,7 +7,7 @@ #include <iomanip> #include <vespa/storage/common/content_bucket_space_repo.h> #include <vespa/storage/common/nodestateupdater.h> -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storage/storageutil/distributorstatecache.h> #include <vespa/storageframework/generic/status/htmlstatusreporter.h> #include <vespa/storageframework/generic/status/xmlstatusreporter.h> diff --git a/storage/src/vespa/storage/common/CMakeLists.txt b/storage/src/vespa/storage/common/CMakeLists.txt index d1e819523d7..c53aead2ba2 100644 --- a/storage/src/vespa/storage/common/CMakeLists.txt +++ b/storage/src/vespa/storage/common/CMakeLists.txt @@ -3,7 +3,6 @@ vespa_add_library(storage_common OBJECT SOURCES bucketmessages.cpp bucketoperationlogger.cpp - cluster_state_bundle.cpp content_bucket_space.cpp content_bucket_space_repo.cpp distributorcomponent.cpp diff --git a/storage/src/vespa/storage/common/nodestateupdater.h b/storage/src/vespa/storage/common/nodestateupdater.h index 7fd3dedbcab..c2887a971f3 100644 --- a/storage/src/vespa/storage/common/nodestateupdater.h +++ b/storage/src/vespa/storage/common/nodestateupdater.h @@ -29,7 +29,7 @@ namespace storage { -class ClusterStateBundle; +namespace lib { class ClusterStateBundle; } struct StateListener { virtual ~StateListener() {} @@ -43,7 +43,7 @@ struct NodeStateUpdater { virtual lib::NodeState::CSP getReportedNodeState() const = 0; virtual lib::NodeState::CSP getCurrentNodeState() const = 0; - virtual std::shared_ptr<const ClusterStateBundle> getClusterStateBundle() const = 0; + virtual std::shared_ptr<const lib::ClusterStateBundle> getClusterStateBundle() const = 0; virtual void addStateListener(StateListener&) = 0; virtual void removeStateListener(StateListener&) = 0; diff --git a/storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp b/storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp index cf290c78acf..439bc9e078c 100644 --- a/storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp +++ b/storage/src/vespa/storage/frameworkimpl/component/distributorcomponentregisterimpl.cpp @@ -2,7 +2,7 @@ #include "distributorcomponentregisterimpl.h" #include <vespa/vdslib/distribution/idealnodecalculatorimpl.h> #include <vespa/vespalib/util/exceptions.h> -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> namespace storage { diff --git a/storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp b/storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp index 5e561951260..a0f05a70f4e 100644 --- a/storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp +++ b/storage/src/vespa/storage/persistence/bucketownershipnotifier.cpp @@ -4,7 +4,7 @@ #include <vespa/storage/common/nodestateupdater.h> #include <vespa/storage/common/bucketoperationlogger.h> #include <vespa/storage/common/content_bucket_space_repo.h> -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storageapi/message/bucket.h> #include <vespa/vdslib/distribution/distribution.h> #include <vespa/vespalib/util/backtrace.h> diff --git a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp index 2773c19eaa1..311dc52767d 100644 --- a/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp +++ b/storage/src/vespa/storage/persistence/filestorage/filestormanager.cpp @@ -6,7 +6,7 @@ #include <vespa/storage/common/bucketmessages.h> #include <vespa/storage/common/bucketoperationlogger.h> #include <vespa/storage/common/content_bucket_space_repo.h> -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storage/common/messagebucket.h> #include <vespa/storage/config/config-stor-server.h> #include <vespa/storage/persistence/bucketownershipnotifier.h> diff --git a/storage/src/vespa/storage/storageserver/bouncer.cpp b/storage/src/vespa/storage/storageserver/bouncer.cpp index af274c9b3e6..72edbfd095e 100644 --- a/storage/src/vespa/storage/storageserver/bouncer.cpp +++ b/storage/src/vespa/storage/storageserver/bouncer.cpp @@ -2,7 +2,7 @@ #include "bouncer.h" #include "bouncer_metrics.h" -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storageapi/message/state.h> #include <vespa/storageapi/message/persistence.h> #include <vespa/config/subscription/configuri.h> diff --git a/storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp b/storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp index cd7be21a369..7cf42af841d 100644 --- a/storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp +++ b/storage/src/vespa/storage/storageserver/changedbucketownershiphandler.cpp @@ -3,7 +3,7 @@ #include "changedbucketownershiphandler.h" #include <vespa/storageapi/message/state.h> #include <vespa/storage/bucketdb/storbucketdb.h> -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storage/common/messagebucket.h> #include <vespa/storage/common/nodestateupdater.h> #include <vespa/storage/common/content_bucket_space_repo.h> diff --git a/storage/src/vespa/storage/storageserver/fnetlistener.cpp b/storage/src/vespa/storage/storageserver/fnetlistener.cpp index f53af6dc225..7a0711c9f7c 100644 --- a/storage/src/vespa/storage/storageserver/fnetlistener.cpp +++ b/storage/src/vespa/storage/storageserver/fnetlistener.cpp @@ -152,7 +152,7 @@ FNetListener::RPC_setSystemState2(FRT_RPCRequest *req) req->GetParams()->GetValue(0)._string._len); lib::ClusterState systemState(systemStateStr); - auto cmd(std::make_shared<api::SetSystemStateCommand>(systemState)); + auto cmd(std::make_shared<api::SetSystemStateCommand>(lib::ClusterStateBundle(systemState))); cmd->setPriority(api::StorageMessage::VERYHIGH); // Create a request object to avoid needing a separate transport type diff --git a/storage/src/vespa/storage/storageserver/mergethrottler.cpp b/storage/src/vespa/storage/storageserver/mergethrottler.cpp index 73fa61e9fb7..a15b1b98d63 100644 --- a/storage/src/vespa/storage/storageserver/mergethrottler.cpp +++ b/storage/src/vespa/storage/storageserver/mergethrottler.cpp @@ -1,7 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "mergethrottler.h" -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <vespa/storage/common/nodestateupdater.h> #include <vespa/storage/persistence/messages.h> #include <vespa/messagebus/message.h> diff --git a/storage/src/vespa/storage/storageserver/statemanager.cpp b/storage/src/vespa/storage/storageserver/statemanager.cpp index 1908eab96ec..11ca0bcc9ae 100644 --- a/storage/src/vespa/storage/storageserver/statemanager.cpp +++ b/storage/src/vespa/storage/storageserver/statemanager.cpp @@ -9,7 +9,7 @@ #include <vespa/storageapi/messageapi/storagemessage.h> #include <vespa/storage/storageserver/storagemetricsset.h> #include <vespa/storage/common/bucketoperationlogger.h> -#include <vespa/storage/common/cluster_state_bundle.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> #include <sys/types.h> #include <unistd.h> #include <vespa/vespalib/util/stringfmt.h> @@ -192,7 +192,7 @@ StateManager::getCurrentNodeState() const (_systemState->getBaselineClusterState()->getNodeState(thisNode())); } -std::shared_ptr<const ClusterStateBundle> +std::shared_ptr<const lib::ClusterStateBundle> StateManager::getClusterStateBundle() const { vespalib::LockGuard lock(_stateLock); diff --git a/storage/src/vespa/storage/storageserver/statemanager.h b/storage/src/vespa/storage/storageserver/statemanager.h index 8d3e4d75a88..9f5c60b42aa 100644 --- a/storage/src/vespa/storage/storageserver/statemanager.h +++ b/storage/src/vespa/storage/storageserver/statemanager.h @@ -33,7 +33,7 @@ namespace metrics { namespace storage { -class ClusterStateBundle; +namespace lib { class ClusterStateBundle; } class StateManager : public NodeStateUpdater, public StorageLink, @@ -50,6 +50,7 @@ class StateManager : public NodeStateUpdater, std::atomic<bool> _notifyingListeners; std::shared_ptr<lib::NodeState> _nodeState; std::shared_ptr<lib::NodeState> _nextNodeState; + using ClusterStateBundle = lib::ClusterStateBundle; std::shared_ptr<const ClusterStateBundle> _systemState; std::shared_ptr<const ClusterStateBundle> _nextSystemState; std::list<StateListener*> _stateListeners; diff --git a/storageapi/src/vespa/storageapi/message/state.cpp b/storageapi/src/vespa/storageapi/message/state.cpp index b128e8f6485..efa9a45764f 100644 --- a/storageapi/src/vespa/storageapi/message/state.cpp +++ b/storageapi/src/vespa/storageapi/message/state.cpp @@ -2,6 +2,7 @@ #include "state.h" #include <vespa/storageapi/messageapi/storagemessage.h> +#include <vespa/vdslib/state/clusterstate.h> #include <ostream> namespace storage { @@ -61,6 +62,12 @@ GetNodeStateReply::print(std::ostream& out, bool verbose, } } +SetSystemStateCommand::SetSystemStateCommand(const lib::ClusterStateBundle& state) + : StorageCommand(MessageType::SETSYSTEMSTATE), + _state(state) +{ +} + SetSystemStateCommand::SetSystemStateCommand(const lib::ClusterState& state) : StorageCommand(MessageType::SETSYSTEMSTATE), _state(state) @@ -71,7 +78,7 @@ void SetSystemStateCommand::print(std::ostream& out, bool verbose, const std::string& indent) const { - out << "SetSystemStateCommand(" << _state << ")"; + out << "SetSystemStateCommand(" << *_state.getBaselineClusterState() << ")"; if (verbose) { out << " : "; StorageCommand::print(out, verbose, indent); @@ -80,7 +87,7 @@ SetSystemStateCommand::print(std::ostream& out, bool verbose, SetSystemStateReply::SetSystemStateReply(const SetSystemStateCommand& cmd) : StorageReply(cmd), - _state(cmd.getSystemState()) + _state(cmd.getClusterStateBundle()) { } diff --git a/storageapi/src/vespa/storageapi/message/state.h b/storageapi/src/vespa/storageapi/message/state.h index 746d92fce6b..4e5ad92b259 100644 --- a/storageapi/src/vespa/storageapi/message/state.h +++ b/storageapi/src/vespa/storageapi/message/state.h @@ -4,7 +4,8 @@ #include <vespa/storageapi/messageapi/storagecommand.h> #include <vespa/storageapi/messageapi/storagereply.h> -#include <vespa/vdslib/state/clusterstate.h> +#include <vespa/vdslib/state/nodestate.h> +#include <vespa/vdslib/state/cluster_state_bundle.h> namespace storage::api { @@ -60,11 +61,13 @@ public: * put/get/remove etx) */ class SetSystemStateCommand : public StorageCommand { - lib::ClusterState _state; + lib::ClusterStateBundle _state; public: - explicit SetSystemStateCommand(const lib::ClusterState&); - const lib::ClusterState& getSystemState() const { return _state; } + explicit SetSystemStateCommand(const lib::ClusterStateBundle &state); + explicit SetSystemStateCommand(const lib::ClusterState &state); + const lib::ClusterState& getSystemState() const { return *_state.getBaselineClusterState(); } + const lib::ClusterStateBundle& getClusterStateBundle() const { return _state; } void print(std::ostream& out, bool verbose, const std::string& indent) const override; DECLARE_STORAGECOMMAND(SetSystemStateCommand, onSetSystemState) @@ -77,13 +80,14 @@ public: * @brief Reply received after a SetSystemStateCommand. */ class SetSystemStateReply : public StorageReply { - lib::ClusterState _state; + lib::ClusterStateBundle _state; public: explicit SetSystemStateReply(const SetSystemStateCommand& cmd); // Not serialized. Available locally - const lib::ClusterState& getSystemState() const { return _state; } + const lib::ClusterState& getSystemState() const { return *_state.getBaselineClusterState(); } + const lib::ClusterStateBundle& getClusterStateBundle() const { return _state; } void print(std::ostream& out, bool verbose, const std::string& indent) const override; DECLARE_STORAGEREPLY(SetSystemStateReply, onSetSystemStateReply) diff --git a/vdslib/src/vespa/vdslib/state/CMakeLists.txt b/vdslib/src/vespa/vdslib/state/CMakeLists.txt index 24402526c85..620e86c2677 100644 --- a/vdslib/src/vespa/vdslib/state/CMakeLists.txt +++ b/vdslib/src/vespa/vdslib/state/CMakeLists.txt @@ -7,5 +7,6 @@ vespa_add_library(vdslib_state OBJECT diskstate.cpp nodestate.cpp clusterstate.cpp + cluster_state_bundle.cpp DEPENDS ) diff --git a/storage/src/vespa/storage/common/cluster_state_bundle.cpp b/vdslib/src/vespa/vdslib/state/cluster_state_bundle.cpp index 1793c74d378..c55f1aadd06 100644 --- a/storage/src/vespa/storage/common/cluster_state_bundle.cpp +++ b/vdslib/src/vespa/vdslib/state/cluster_state_bundle.cpp @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "cluster_state_bundle.h" -#include <vespa/vdslib/state/clusterstate.h> +#include "clusterstate.h" -namespace storage { +namespace storage::lib { ClusterStateBundle::ClusterStateBundle(const ClusterState &baselineClusterState) : _baselineClusterState(std::make_shared<const ClusterState>(baselineClusterState)) diff --git a/storage/src/vespa/storage/common/cluster_state_bundle.h b/vdslib/src/vespa/vdslib/state/cluster_state_bundle.h index af4a12a8b3c..c54df1d1952 100644 --- a/storage/src/vespa/storage/common/cluster_state_bundle.h +++ b/vdslib/src/vespa/vdslib/state/cluster_state_bundle.h @@ -4,9 +4,9 @@ #include <vespa/document/bucket/bucketspace.h> -namespace storage { +namespace storage::lib { -namespace lib { class ClusterState; } +class ClusterState; /** * Class representing the baseline cluster state and the derived cluster @@ -14,10 +14,9 @@ namespace lib { class ClusterState; } */ class ClusterStateBundle { - using ClusterState = lib::ClusterState; std::shared_ptr<const ClusterState> _baselineClusterState; public: - ClusterStateBundle(const ClusterState &baselineClusterState); + explicit ClusterStateBundle(const ClusterState &baselineClusterState); ~ClusterStateBundle(); const std::shared_ptr<const ClusterState> &getBaselineClusterState() const; const std::shared_ptr<const ClusterState> &getDerivedClusterState(document::BucketSpace bucketSpace) const; diff --git a/vespa-athenz/CMakeLists.txt b/vespa-athenz/CMakeLists.txt new file mode 100644 index 00000000000..bb5a1f5b6de --- /dev/null +++ b/vespa-athenz/CMakeLists.txt @@ -0,0 +1,2 @@ +# Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +install_fat_java_artifact(vespa-athenz) diff --git a/vespa-athenz/pom.xml b/vespa-athenz/pom.xml index 5312594472f..31e56f76dd2 100644 --- a/vespa-athenz/pom.xml +++ b/vespa-athenz/pom.xml @@ -41,7 +41,12 @@ <artifactId>mockito-core</artifactId> <scope>test</scope> </dependency> - + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>testutil</artifactId> + <version>${project.version}</version> + <scope>test</scope> + </dependency> <!-- compile --> <dependency> @@ -110,31 +115,6 @@ <groupId>com.yahoo.vespa</groupId> <artifactId>bundle-plugin</artifactId> <extensions>true</extensions> - <configuration> - <useCommonAssemblyIds>false</useCommonAssemblyIds> - </configuration> - </plugin> - <plugin> - <groupId>org.codehaus.mojo</groupId> - <artifactId>build-helper-maven-plugin</artifactId> - <executions> - <execution> - <id>attach-artifacts</id> - <phase>package</phase> - <goals> - <goal>attach-artifact</goal> - </goals> - <configuration> - <artifacts> - <artifact> - <file>target/${project.artifactId}-deploy.jar</file> - <type>jar</type> - <classifier>deploy</classifier> - </artifact> - </artifacts> - </configuration> - </execution> - </executions> </plugin> <plugin> <groupId>org.apache.maven.plugins</groupId> diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentials.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java index 36c1aee49e0..c5dce1c5b1d 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentials.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentials.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import java.security.KeyPair; import java.security.cert.X509Certificate; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentialsService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java index 4072568d9d2..dd816929bfb 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzCredentialsService.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzCredentialsService.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.fasterxml.jackson.databind.ObjectMapper; import com.yahoo.container.core.identity.IdentityConfig; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImpl.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java index 18f90ce545f..95113e1b0b1 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImpl.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImpl.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.google.inject.Inject; import com.yahoo.component.AbstractComponent; @@ -8,16 +8,15 @@ import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider; import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException; import com.yahoo.jdisc.Metric; import com.yahoo.log.LogLevel; -import com.yahoo.vespa.defaults.Defaults; import com.yahoo.vespa.athenz.api.AthenzIdentityCertificate; import com.yahoo.vespa.athenz.tls.AthenzSslContextBuilder; +import com.yahoo.vespa.defaults.Defaults; import javax.net.ssl.SSLContext; import java.io.File; import java.time.Clock; import java.time.Duration; import java.time.Instant; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.Executors; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzService.java index c9e3809ea96..18576ab9bab 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzService.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/AthenzService.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.ObjectMapper; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtils.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtils.java index 6a766e7c49d..6e74d3bc8b1 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtils.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtils.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import org.bouncycastle.asn1.pkcs.PKCSObjectIdentifiers; import org.bouncycastle.asn1.x509.Extension; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/IdentityDocumentService.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/IdentityDocumentService.java index 8a9137a491d..4e88234d5de 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/IdentityDocumentService.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/IdentityDocumentService.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.yahoo.vespa.defaults.Defaults; import org.apache.http.client.methods.CloseableHttpResponse; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceIdentity.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceIdentity.java index d6e986959cb..b90ce56ca7e 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceIdentity.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceIdentity.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRefreshInformation.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRefreshInformation.java index d0c22d1d0d2..c627363c0f5 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRefreshInformation.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRefreshInformation.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRegisterInformation.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRegisterInformation.java index dd9f164fef1..69ddb72b8b8 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/InstanceRegisterInformation.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/InstanceRegisterInformation.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/SignedIdentityDocument.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/SignedIdentityDocument.java index 7bbd49c953f..c3b073765ac 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/SignedIdentityDocument.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/SignedIdentityDocument.java @@ -1,5 +1,5 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.fasterxml.jackson.annotation.JsonInclude; diff --git a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/package-info.java b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/package-info.java index 1b4842327dd..f23ea9406b3 100644 --- a/athenz-identity-provider-service/src/main/java/com/yahoo/vespa/hosted/athenz/identityprovider/package-info.java +++ b/vespa-athenz/src/main/java/com/yahoo/vespa/athenz/identityprovider/package-info.java @@ -3,6 +3,6 @@ * @author mortent */ @ExportPackage -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImplTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java index 3a506a39c43..d9dbd73a94e 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/AthenzIdentityProviderImplTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/AthenzIdentityProviderImplTest.java @@ -1,13 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import com.yahoo.container.core.identity.IdentityConfig; import com.yahoo.container.jdisc.athenz.AthenzIdentityProvider; import com.yahoo.container.jdisc.athenz.AthenzIdentityProviderException; -import com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.RunnableWithTag; -import com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.Scheduler; import com.yahoo.jdisc.Metric; import com.yahoo.test.ManualClock; +import com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.RunnableWithTag; +import com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.Scheduler; import org.junit.Test; import java.security.cert.X509Certificate; @@ -19,21 +19,13 @@ import java.util.List; import java.util.Map; import java.util.Objects; import java.util.PriorityQueue; -import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Predicate; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.INITIAL_BACKOFF_DELAY; -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.INITIAL_WAIT_NTOKEN; -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.MAX_REGISTER_BACKOFF_DELAY; -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.METRICS_UPDATER_TAG; -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.REDUCED_UPDATE_PERIOD; -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.REGISTER_INSTANCE_TAG; -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.TIMEOUT_INITIAL_WAIT_TAG; -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_CREDENTIALS_TAG; -import static com.yahoo.vespa.hosted.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_PERIOD; + +import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.METRICS_UPDATER_TAG; +import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.REDUCED_UPDATE_PERIOD; +import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_CREDENTIALS_TAG; +import static com.yahoo.vespa.athenz.identityprovider.AthenzIdentityProviderImpl.UPDATE_PERIOD; import static org.junit.Assert.assertEquals; import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; diff --git a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtilsTest.java b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtilsTest.java index 0412b9071dd..353c5d3c504 100644 --- a/athenz-identity-provider-service/src/test/java/com/yahoo/vespa/hosted/athenz/identityprovider/CryptoUtilsTest.java +++ b/vespa-athenz/src/test/java/com/yahoo/vespa/athenz/identityprovider/CryptoUtilsTest.java @@ -1,5 +1,5 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.hosted.athenz.identityprovider; +package com.yahoo.vespa.athenz.identityprovider; import org.bouncycastle.pkcs.PKCS10CertificationRequest; import org.junit.Test; diff --git a/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java new file mode 100644 index 00000000000..e0e4a0828a9 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/lang/MutableLong.java @@ -0,0 +1,33 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.lang; + +/** + * A mutable long + * + * @author bratseth + */ +public class MutableLong { + + private long value; + + public MutableLong(long value) { + this.value = value; + } + + public long get() { return value; } + + public void set(long value) { this.value = value; } + + /** Adds the increment to the current value and returns the resulting value */ + public long add(long increment) { + value += increment; + return value; + } + + /** Adds the increment to the current value and returns the resulting value */ + public long subtract(long increment) { + value -= increment; + return value; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 14cd3e70866..0176dac6821 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -77,6 +77,13 @@ public class TensorType { return Optional.empty(); } + /* Returns the bound of this dimension if it is present and bound in this, empty otherwise */ + public Optional<Long> sizeOfDimension(String dimension) { + Optional<Dimension> d = dimension(dimension); + if ( ! d.isPresent()) return Optional.empty(); + return d.get().size(); + } + /** * Returns whether this type can be assigned to the given type, * i.e if the given type is a generalization of this type. @@ -207,7 +214,7 @@ public class TensorType { /** Returns a copy of this with the name set to the given name */ public abstract Dimension withName(String name); - /** Returns true if this is an indexed bound or unboun type */ + /** Returns true if this is an indexed bound or unbound type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } /** @@ -254,6 +261,14 @@ public class TensorType { return new IndexedBoundDimension(name, size); } + public static Dimension indexed(String name) { + return new IndexedUnboundDimension(name); + } + + public static Dimension mapped(String name) { + return new MappedDimension(name); + } + } public static class IndexedBoundDimension extends TensorType.Dimension { @@ -367,6 +382,15 @@ public class TensorType { addDimensionsOf(type); } + /** + * Creates a builder from the given dimensions. + */ + public Builder(Iterable<Dimension> dimensions) { + for (TensorType.Dimension dimension : dimensions) { + dimension(dimension); + } + } + private static final boolean supportsMixedTypes = false; private void addDimensionsOf(TensorType type) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 3fb94f1251b..8a969180113 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -10,7 +10,7 @@ import com.yahoo.tensor.Tensor; * @author bratseth */ @Beta -public interface EvaluationContext extends TypeContext { +public interface EvaluationContext<NAMETYPE extends TypeContext.Name> extends TypeContext<NAMETYPE> { /** Returns the tensor bound to this name, or null if none */ Tensor getTensor(String name); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index 9fe6b7d053f..b9394da31e3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -11,17 +11,20 @@ import java.util.HashMap; * @author bratseth */ @Beta -public class MapEvaluationContext implements EvaluationContext { +public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> { private final java.util.Map<String, Tensor> bindings = new HashMap<>(); - static MapEvaluationContext empty() { return new MapEvaluationContext(); } - public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override public TensorType getType(String name) { - Tensor tensor = bindings.get(name); + return getType(new Name(name)); + } + + @Override + public TensorType getType(Name name) { + Tensor tensor = bindings.get(name.toString()); if (tensor == null) return null; return tensor.type(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index 760a225efdf..ff2e6318b37 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.TensorType; * * @author bratseth */ -public interface TypeContext { +public interface TypeContext<NAMETYPE extends TypeContext.Name> { /** * Returns the type of the tensor with this name. @@ -16,6 +16,39 @@ public interface TypeContext { * @return returns the type of the tensor which will be returned by calling getTensor(name) * or null if getTensor will return null. */ + TensorType getType(NAMETYPE name); + + /** + * Returns the type of the tensor with this name by converting from a string name. + * + * @return returns the type of the tensor which will be returned by calling getTensor(name) + * or null if getTensor will return null. + */ TensorType getType(String name); + /** A name which is just a string. Names are value objects. */ + class Name { + + private final String name; + + public Name(String name) { + this.name = name; + } + + @Override + public String toString() { return name; } + + @Override + public int hashCode() { return name.hashCode(); } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + if ( ! (other instanceof Name)) return false; + return ((Name)other).name.equals(this.name); + } + + } + + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 34beb465d4c..acb2363cba4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -44,7 +44,7 @@ public class VariableTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { TensorType givenType = context.getType(name); if (givenType == null) return null; verifyType(givenType); @@ -52,7 +52,7 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = context.getTensor(name); if (tensor == null) return null; verifyType(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 2109b730e1a..bfc0938abcc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -18,10 +18,14 @@ public abstract class CompositeTensorFunction extends TensorFunction { /** Finds the type this produces by first converting it to a primitive function */ @Override - public final TensorType type(TypeContext context) { return toPrimitive().type(context); } + public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + return toPrimitive().type(context); + } /** Evaluates this by first converting it to a primitive function */ @Override - public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } + public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index c77ed1c0526..13e7c136feb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -3,6 +3,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.lang.MutableInteger; +import com.yahoo.lang.MutableLong; import com.yahoo.tensor.DimensionSizes; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; @@ -60,21 +62,35 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argumentA.type(context), argumentB.type(context)); } /** Returns the type resulting from concatenating a and b */ private TensorType type(TensorType a, TensorType b) { + // TODO: Fail if concat dimension is present but not indexed in a or b TensorType.Builder builder = new TensorType.Builder(a, b); - if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size - builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + - b.dimension(dimension).get().size().get())); + if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) { + builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) + + b.sizeOfDimension(dimension).orElse(1L))); + /* + MutableLong concatSize = new MutableLong(0); + a.sizeOfDimension(dimension).ifPresent(concatSize::add); + b.sizeOfDimension(dimension).ifPresent(concatSize::add); + builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); + */ + } return builder.build(); } + /** Returns true if this dimension is present and unbound */ + private boolean unboundIn(TensorType type, String dimensionName) { + Optional<TensorType.Dimension> dimension = type.dimension(dimensionName); + return dimension.isPresent() && ! dimension.get().size().isPresent(); + } + @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); a = ensureIndexedDimension(dimension, a); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 50b479da168..a43de297b9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -42,10 +42,10 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { return constant.type(); } + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } @Override - public Tensor evaluate(EvaluationContext context) { return constant; } + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } @Override public String toString(ToStringContext context) { return constant.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e70d1de3db7..edfa8253eb9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -61,10 +61,10 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { return type; } + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); for (int i = 0; i < indexes.size(); i++) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 7812c985091..50b0e706a43 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -95,12 +95,12 @@ public class Join extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); @@ -251,7 +251,7 @@ public class Join extends PrimitiveTensorFunction { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder); - joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); +// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); return builder.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 53504868ff2..4a338e5501e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -53,12 +53,12 @@ public class Map extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return argument.type(context); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 76a938b9fe2..e045effbe7e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -101,11 +101,12 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } private TensorType type(TensorType argumentType) { + if (dimensions.isEmpty()) return TensorType.empty; // means reduce all TensorType.Builder builder = new TensorType.Builder(); for (TensorType.Dimension dimension : argumentType.dimensions()) if ( ! dimensions.contains(dimension.name())) // keep @@ -114,7 +115,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index de3d2be265a..af4492ca1e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -72,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } @@ -84,7 +84,7 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = argument.evaluate(context); TensorType renamedType = type(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 78ab09c7820..e805e9d87bb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -43,14 +43,14 @@ public abstract class TensorFunction { * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract Tensor evaluate(EvaluationContext context); + public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context); /** * Returns the type of the tensor this produces given the input types in the context * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract TensorType type(TypeContext context); + public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context); /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } @@ -58,7 +58,7 @@ public abstract class TensorFunction { /** * Return a string representation of this context. * - * @param context a context which must be passed to all nexted functions when requesting the string value + * @param context a context which must be passed to all nested functions when requesting the string value */ public abstract String toString(ToStringContext context); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index 7e1f292eb7b..eafa5c4addf 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -2,6 +2,9 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -16,51 +19,98 @@ public class ConcatTestCase { public void testConcatNumbers() { Tensor a = Tensor.from("{1}"); Tensor b = Tensor.from("{2}"); - assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2]):{ {x:0}:2, {x:1}:1 }"), b.concat(a, "x")); + assertConcat("tensor(x[2]):{ {x:0}:1, {x:1}:2 }", a, b, "x"); + assertConcat("tensor(x[2]):{ {x:0}:2, {x:1}:1 }", b, a , "x"); } @Test public void testConcatEqualShapes() { - Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); - Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); - assertEquals(Tensor.from("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + - "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }"), a.concat(b, "y")); + Tensor a = Tensor.from("tensor(x[3]):{ {x:0}:1, {x:1}:2, {x:2}:3 }"); + Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); + assertConcat("tensor(x[6]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4, {x:4}:5, {x:5}:6 }", a, b, "x"); + assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:2,y:0}:3, " + + "{x:0,y:1}:4, {x:1,y:1}:5, {x:2,y:1}:6 }", + a, b, "y"); } @Test public void testConcatNumberAndVector() { Tensor a = Tensor.from("{1}"); + Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); + assertConcat("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x"); + assertConcat("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + + "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }", + a, b, "y"); + } + + @Test + public void testConcatNumberAndVectorUnbound() { + Tensor a = Tensor.from("{1}"); Tensor b = Tensor.from("tensor(x[]):{ {x:0}:2, {x:1}:3, {x:2}:4 }"); - assertEquals(Tensor.from("tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + - "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }"), a.concat(b, "y")); + assertConcat("tensor(x[])","tensor(x[4]):{ {x:0}:1, {x:1}:2, {x:2}:3, {x:3}:4 }", a, b, "x"); + assertConcat("tensor(x[],y[2])", "tensor(x[3],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:1, {x:2,y:0}:1, " + + "{x:0,y:1}:2, {x:1,y:1}:3, {x:2,y:1}:4 }", + a, b, "y"); } @Test public void testUnequalSizesSameDimension() { + Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"); + Tensor b = Tensor.from("tensor(x[3]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); + assertConcat("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x"); + assertConcat("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y"); + } + + @Test + public void testUnequalSizesSameDimensionUnbound() { Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); Tensor b = Tensor.from("tensor(x[]):{ {x:0}:4, {x:1}:5, {x:2}:6 }"); - assertEquals(Tensor.from("tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }"), a.concat(b, "y")); + assertConcat("tensor(x[])", "tensor(x[5]):{ {x:0}:1, {x:1}:2, {x:2}:4, {x:3}:5, {x:4}:6 }", a, b, "x"); + assertConcat("tensor(x[],y[2])", "tensor(x[2],y[2]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:4, {x:1,y:1}:5 }", a, b,"y"); } @Test public void testUnequalEqualSizesDifferentDimension() { + Tensor a = Tensor.from("tensor(x[2]):{ {x:0}:1, {x:1}:2 }"); + Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); + assertConcat("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x"); + assertConcat("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); + assertConcat("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z"); + } + + @Test + public void testUnequalEqualSizesDifferentDimensionOneUnbound() { Tensor a = Tensor.from("tensor(x[]):{ {x:0}:1, {x:1}:2 }"); - Tensor b = Tensor.from("tensor(y[]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); - assertEquals(Tensor.from("tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y")); - assertEquals(Tensor.from("tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}"), a.concat(b, "z")); + Tensor b = Tensor.from("tensor(y[3]):{ {y:0}:4, {y:1}:5, {y:2}:6 }"); + assertConcat("tensor(x[],y[3])", "tensor(x[3],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:1.0,{x:0,y:2}:1.0,{x:1,y:0}:2.0,{x:1,y:1}:2.0,{x:1,y:2}:2.0,{x:2,y:0}:4.0,{x:2,y:1}:5.0,{x:2,y:2}:6.0}", a, b, "x"); + assertConcat("tensor(x[],y[4])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:4.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); + assertConcat("tensor(x[],y[3],z[2])", "tensor(x[2],y[3],z[2]):{{x:0,y:0,z:0}:1.0,{x:0,y:0,z:1}:4.0,{x:0,y:1,z:0}:1.0,{x:0,y:1,z:1}:5.0,{x:0,y:2,z:0}:1.0,{x:0,y:2,z:1}:6.0,{x:1,y:0,z:0}:2.0,{x:1,y:0,z:1}:4.0,{x:1,y:1,z:0}:2.0,{x:1,y:1,z:1}:5.0,{x:1,y:2,z:0}:2.0,{x:1,y:2,z:1}:6.0}", a, b, "z"); } @Test public void testDimensionsubset() { Tensor a = Tensor.from("tensor(x[],y[]):{ {x:0,y:0}:1, {x:1,y:0}:2, {x:0,y:1}:3, {x:1,y:1}:4 }"); Tensor b = Tensor.from("tensor(y[2]):{ {y:0}:5, {y:1}:6 }"); - assertEquals(Tensor.from("tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}"), a.concat(b, "x")); - assertEquals(Tensor.from("tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}"), a.concat(b, "y")); + assertConcat("tensor(x[],y[])", "tensor(x[3],y[2]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:2,y:0}:5.0,{x:2,y:1}:6.0}", a, b, "x"); + assertConcat("tensor(x[],y[])", "tensor(x[2],y[4]):{{x:0,y:0}:1.0,{x:0,y:1}:3.0,{x:0,y:2}:5.0,{x:0,y:3}:6.0,{x:1,y:0}:2.0,{x:1,y:1}:4.0,{x:1,y:2}:5.0,{x:1,y:3}:6.0}", a, b, "y"); + } + + private void assertConcat(String expected, Tensor a, Tensor b, String dimension) { + assertConcat(null, expected, a, b, dimension); + } + + private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) { + Tensor expectedAsTensor = Tensor.from(expected); + TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension) + .type(new MapEvaluationContext()); + Tensor result = a.concat(b, dimension); + + if (expectedType != null) + assertEquals(TensorType.fromSpec(expectedType), inferredType); + else + assertEquals(expectedAsTensor.type(), inferredType); + + assertEquals(expectedAsTensor, result); } } |