// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.di.componentgraph.core; import com.google.common.collect.Iterables; import com.google.common.collect.Sets; import com.google.inject.BindingAnnotation; import com.google.inject.ConfigurationException; import com.google.inject.Guice; import com.google.inject.Injector; import com.google.inject.Key; import com.yahoo.collections.Pair; import com.yahoo.component.ComponentId; import com.yahoo.component.provider.ComponentRegistry; import com.yahoo.config.ConfigInstance; import com.yahoo.container.di.componentgraph.Provider; import com.yahoo.container.di.componentgraph.cycle.CycleFinder; import com.yahoo.container.di.componentgraph.cycle.Graph; import com.yahoo.vespa.config.ConfigKey; import java.lang.annotation.Annotation; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; import java.lang.reflect.TypeVariable; import java.lang.reflect.WildcardType; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; import static com.yahoo.container.di.componentgraph.core.Exceptions.removeStackTrace; /** * @author Tony Vaagenes * @author gjoranv * @author ollivir * * Not thread safe. */ public class ComponentGraph { private static final Logger log = Logger.getLogger(ComponentGraph.class.getName()); private final long generation; private final Map nodesById = new HashMap<>(); public ComponentGraph(long generation) { this.generation = generation; } public ComponentGraph() { this(0L); } public long generation() { return generation; } public int size() { return nodesById.size(); } public Collection nodes() { return nodesById.values(); } public void add(Node component) { if (nodesById.containsKey(component.componentId())) { throw new IllegalStateException("Multiple components with the same id " + component.componentId()); } nodesById.put(component.componentId(), component); } private Optional lookupGlobalComponent(Key key) { if (!(key.getTypeLiteral().getType() instanceof Class)) { throw new RuntimeException("Type not supported " + key.getTypeLiteral()); } Class clazz = key.getTypeLiteral().getRawType(); Collection components = matchingComponentNodes(nodes(), key); if (components.isEmpty()) { return Optional.empty(); } else if (components.size() == 1) { return Optional.ofNullable(Iterables.get(components, 0)); } else { var nonProviderComponents = components.stream().filter(c -> !Provider.class.isAssignableFrom(c.instanceType())).toList(); if (nonProviderComponents.isEmpty()) { throw new IllegalStateException("Multiple global component providers for class '" + clazz.getName() + "' found :" + components); } else if (nonProviderComponents.size() == 1) { return Optional.of(nonProviderComponents.get(0)); } else { throw new IllegalStateException("Multiple global components with class '" + clazz.getName() + "' found : " + nonProviderComponents); } } } public T getInstance(Class clazz) { return getInstance(Key.get(clazz)); } @SuppressWarnings("unchecked") public T getInstance(Key key) { // TODO: Combine exception handling with lookupGlobalComponent. Object ob = lookupGlobalComponent(key).map(Node::component) .orElseThrow(() -> new IllegalStateException(String.format("No global component with key '%s' ", key))); return (T) ob; } private Collection componentNodes() { return nodesOfType(nodes(), ComponentNode.class); } private Collection componentRegistryNodes() { return nodesOfType(nodes(), ComponentRegistryNode.class); } private Collection osgiComponentsOfClass(Class clazz) { return componentNodes().stream().filter(node -> clazz.isAssignableFrom(node.componentType())).toList(); } public List complete(Injector fallbackInjector) { componentNodes().forEach(node -> completeNode(node, fallbackInjector)); componentRegistryNodes().forEach(this::completeComponentRegistryNode); return topologicalSort(nodes()); } public List complete() { return complete(Guice.createInjector()); } public Set> configKeys() { return nodes().stream().flatMap(node -> node.configKeys().stream()).collect(Collectors.toSet()); } public void setAvailableConfigs(Map, ConfigInstance> configs) { componentNodes().forEach(node -> node.setAvailableConfigs(Keys.invariantCopy(configs))); } public void reuseNodes(ComponentGraph old) { // copy instances if node equal Set commonComponentIds = Sets.intersection(nodesById.keySet(), old.nodesById.keySet()); for (ComponentId id : commonComponentIds) { if (nodesById.get(id).equals(old.nodesById.get(id))) { nodesById.get(id).instance = old.nodesById.get(id).instance; } } // reset instances with modified dependencies for (Node node : topologicalSort(nodes())) { for (Node usedComponent : node.usedComponents()) { if (usedComponent.instance.isEmpty()) { node.instance = Optional.empty(); } } } } /** All constructed components and providers of this, in reverse creation order, i.e., suited for ordered deconstruction. */ public List allConstructedComponentsAndProviders() { List orderedNodes = topologicalSort(nodes()); Collections.reverse(orderedNodes); return orderedNodes.stream() .filter(node -> node.constructedInstance().isPresent()) .map(node -> node.constructedInstance().orElseThrow()) .collect(Collectors.toList()); } private void completeComponentRegistryNode(ComponentRegistryNode registry) { registry.injectAll(osgiComponentsOfClass(registry.componentClass())); } private void completeNode(ComponentNode node, Injector fallbackInjector) { try { Object[] arguments = node.getAnnotatedConstructorParams().stream().map(param -> handleParameter(node, fallbackInjector, param)) .toArray(); node.setArguments(arguments); } catch (Exception e) { throw removeStackTrace(new RuntimeException("When resolving dependencies of " + node.idAndType(), e)); } } private Object handleParameter(Node node, Injector fallbackInjector, Pair> annotatedParameterType) { Type parameterType = annotatedParameterType.getFirst(); List annotations = annotatedParameterType.getSecond(); if (parameterType instanceof Class && parameterType.equals(ComponentId.class)) { return node.componentId(); } else if (parameterType instanceof Class && ConfigInstance.class.isAssignableFrom((Class) parameterType)) { return handleConfigParameter((ComponentNode) node, (Class) parameterType); } else if (parameterType instanceof ParameterizedType registry && ((ParameterizedType) parameterType).getRawType().equals(ComponentRegistry.class)) { return getComponentRegistry(registry.getActualTypeArguments()[0]); } else if (parameterType instanceof Class) { return handleComponentParameter(node, fallbackInjector, (Class) parameterType, annotations); } else if (parameterType instanceof ParameterizedType) { throw new RuntimeException("Injection of parameterized type " + parameterType + " is not supported."); } else { throw new RuntimeException("Injection of type " + parameterType + " is not supported"); } } private ComponentRegistryNode newComponentRegistryNode(Class componentClass) { ComponentRegistryNode registry = new ComponentRegistryNode(componentClass); add(registry); //TODO: don't mutate nodes here. return registry; } private ComponentRegistryNode getComponentRegistry(Type componentType) { Class componentClass; if (componentType instanceof WildcardType wildcardType) { if (wildcardType.getLowerBounds().length > 0 || wildcardType.getUpperBounds().length > 1) { throw new RuntimeException("Can't create ComponentRegistry of unknown wildcard type" + wildcardType); } componentClass = (Class) wildcardType.getUpperBounds()[0]; } else if (componentType instanceof Class) { componentClass = (Class) componentType; } else if (componentType instanceof TypeVariable) { throw new RuntimeException("Can't create ComponentRegistry of unknown type variable " + componentType); } else { throw new RuntimeException("Can't create ComponentRegistry of unknown type " + componentType); } for (ComponentRegistryNode node : componentRegistryNodes()) { if (node.componentClass().equals(componentType)) { return node; } } return newComponentRegistryNode(componentClass); } @SuppressWarnings("unchecked") private ConfigKey handleConfigParameter(ComponentNode node, Class clazz) { Class castClass = (Class) clazz; return new ConfigKey<>(castClass, node.configId()); } private Key getKey(Class clazz, Optional bindingAnnotation) { return bindingAnnotation.map(annotation -> Key.get(clazz, annotation)).orElseGet(() -> Key.get(clazz)); } private Optional matchingGuiceNode(Key key, Object instance) { return matchingNodes(nodes(), GuiceNode.class, key).stream().filter(node -> node.component() == instance). // TODO: assert that there is only one (after filter) findFirst(); } private Node lookupOrCreateGlobalComponent(Node node, Injector fallbackInjector, Class clazz, Key key) { Optional component = lookupGlobalComponent(key); if (component.isEmpty()) { Object instance; try { Level level = hasExplicitBinding(fallbackInjector, key) ? Level.FINE : Level.WARNING; log.log(level, () -> "Trying the fallback injector to create" + messageForNoGlobalComponent(clazz, node)); if (level.intValue() > Level.INFO.intValue()) { log.log(level, "A component of type " + key.getTypeLiteral() + " should probably be declared in services.xml. " + "Not doing so may cause resource leaks and unnecessary reconstruction of components."); } instance = fallbackInjector.getInstance(key); } catch (ConfigurationException e) { throw removeStackTrace(new IllegalStateException( (messageForMultipleClassLoaders(clazz).isEmpty()) ? "No global" + messageForNoGlobalComponent(clazz, node) : messageForMultipleClassLoaders(clazz))); } component = Optional.of(matchingGuiceNode(key, instance).orElseGet(() -> { GuiceNode guiceNode = new GuiceNode(instance, key.getAnnotation()); add(guiceNode); return guiceNode; })); } return component.get(); } private boolean hasExplicitBinding(Injector injector, Key key) { log.log(Level.FINE, () -> "Injector binding for " + key + ": " + injector.getExistingBinding(key)); return injector.getExistingBinding(key) != null; } private Node handleComponentParameter(Node node, Injector fallbackInjector, Class clazz, Collection annotations) { List bindingAnnotations = annotations.stream().filter(ComponentGraph::isBindingAnnotation).toList(); Key key = getKey(clazz, bindingAnnotations.stream().findFirst()); if (bindingAnnotations.size() > 1) { throw new RuntimeException(String.format("More than one binding annotation used in class '%s'", node.instanceType())); } Collection injectedNodesOfCorrectType = matchingComponentNodes(node.componentsToInject, key); if (injectedNodesOfCorrectType.size() == 0) { return lookupOrCreateGlobalComponent(node, fallbackInjector, clazz, key); } else if (injectedNodesOfCorrectType.size() == 1) { return Iterables.get(injectedNodesOfCorrectType, 0); } else { //TODO: !className for last parameter throw new RuntimeException( String.format("Multiple components of type '%s' injected into component '%s'", clazz.getName(), node.instanceType())); } } private static String messageForNoGlobalComponent(Class clazz, Node node) { return String.format(" component of class %s to inject into component %s.", clazz.getName(), node.idAndType()); } private String messageForMultipleClassLoaders(Class clazz) { String errMsg = "Class " + clazz.getName() + " is provided by the framework, and cannot be embedded in a user bundle. " + "To resolve this problem, please refer to osgi-classloading.html#multiple-implementations in the documentation"; try { Class resolvedClass = Class.forName(clazz.getName(), false, this.getClass().getClassLoader()); if (!resolvedClass.equals(clazz)) { return errMsg; } } catch (ClassNotFoundException ignored) { } return ""; } public static Node getNode(ComponentGraph graph, String componentId) { return graph.nodesById.get(new ComponentId(componentId)); } private static Collection nodesOfType(Collection nodes, Class clazz) { List ret = new ArrayList<>(); for (Node node : nodes) { if (clazz.isInstance(node)) { ret.add(clazz.cast(node)); } } return ret; } private static Collection matchingComponentNodes(Collection nodes, Key key) { return matchingNodes(nodes, ComponentNode.class, key); } // Finds all nodes with a given nodeType and instance with given key private static Collection matchingNodes(Collection nodes, Class nodeType, Key key) { Class clazz = key.getTypeLiteral().getRawType(); Annotation annotation = key.getAnnotation(); List filteredByClass = nodesOfType(nodes, nodeType).stream().filter(node -> clazz.isAssignableFrom(node.componentType())) .toList(); if (filteredByClass.size() == 1) { return filteredByClass; } else { List filteredByClassAndAnnotation = filteredByClass.stream() .filter(node -> (annotation == null && node.instanceKey().getAnnotation() == null) || annotation.equals(node.instanceKey().getAnnotation())) .toList(); if (filteredByClassAndAnnotation.size() > 0) { return filteredByClassAndAnnotation; } else { return filteredByClass; } } } // Returns true if annotation is a BindingAnnotation, e.g. com.google.inject.name.Named public static boolean isBindingAnnotation(Annotation annotation) { LinkedList> queue = new LinkedList<>(); queue.add(annotation.getClass()); queue.addAll(Arrays.asList(annotation.getClass().getInterfaces())); while (!queue.isEmpty()) { Class clazz = queue.removeFirst(); if (clazz.getAnnotation(BindingAnnotation.class) != null) { return true; } else { if (clazz.getSuperclass() != null) { queue.addFirst(clazz.getSuperclass()); } } } return false; } /** * The returned list is the nodes from the graph bottom-up. * * For each iteration, the algorithm finds the components that are not "wanted by" any other component, * and prepends those components into the resulting 'sorted' list. Hence, the first element in the returned * list is the component that is directly or indirectly wanted by "most" other components. * * @return A list where a earlier than b in the list implies that there is no path from a to b */ private static List topologicalSort(Collection nodes) { Map numIncoming = new HashMap<>(); nodes.forEach( node -> node.usedComponents().forEach( injectedNode -> numIncoming.merge(injectedNode.componentId(), 1, (a, b) -> a + b))); LinkedList sorted = new LinkedList<>(); List unsorted = new ArrayList<>(nodes); while (!unsorted.isEmpty()) { List ready = new ArrayList<>(); List notReady = new ArrayList<>(); unsorted.forEach(node -> { if (numIncoming.getOrDefault(node.componentId(), 0) == 0) { ready.add(node); } else { notReady.add(node); } }); if (ready.isEmpty()) { throw new IllegalStateException("There is a cycle in the component injection graph: " + findCycle(notReady)); } ready.forEach(node -> node.usedComponents() .forEach(injectedNode -> numIncoming.merge(injectedNode.componentId(), -1, (a, b) -> a + b))); sorted.addAll(0, ready); unsorted = notReady; } return sorted; } private static List findCycle(List nodes) { var cyclicGraph = new Graph(); for (var node : nodes) { for (var adjacent : node.usedComponents()) { cyclicGraph.edge(node.componentId().stringValue(), adjacent.componentId().stringValue()); } } return new CycleFinder<>(cyclicGraph).findCycle(); } }