summaryrefslogtreecommitdiffstats
path: root/container-di/src/main/scala/com/yahoo/container/di/componentgraph/core/ComponentRegistryNode.scala
diff options
context:
space:
mode:
Diffstat (limited to 'container-di/src/main/scala/com/yahoo/container/di/componentgraph/core/ComponentRegistryNode.scala')
-rw-r--r--container-di/src/main/scala/com/yahoo/container/di/componentgraph/core/ComponentRegistryNode.scala71
1 files changed, 71 insertions, 0 deletions
diff --git a/container-di/src/main/scala/com/yahoo/container/di/componentgraph/core/ComponentRegistryNode.scala b/container-di/src/main/scala/com/yahoo/container/di/componentgraph/core/ComponentRegistryNode.scala
new file mode 100644
index 00000000000..864eb17ddfb
--- /dev/null
+++ b/container-di/src/main/scala/com/yahoo/container/di/componentgraph/core/ComponentRegistryNode.scala
@@ -0,0 +1,71 @@
+// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.container.di.componentgraph.core
+
+import com.yahoo.component.provider.ComponentRegistry
+import com.yahoo.component.{ComponentId, Component}
+import ComponentRegistryNode._
+import com.google.inject.Key
+import com.google.inject.util.Types
+import Node.syntheticComponentId
+
+/**
+ * @author tonytv
+ * @author gjoranv
+ */
+class ComponentRegistryNode(val componentClass : Class[AnyRef])
+ extends Node(componentId(componentClass)) {
+
+ def usedComponents = componentsToInject
+
+ protected def newInstance() = {
+ val registry = new ComponentRegistry[AnyRef]
+
+ componentsToInject foreach { component =>
+ registry.register(component.componentId, component.newOrCachedInstance())
+ }
+
+ registry
+ }
+
+ override val instanceKey =
+ Key.get(Types.newParameterizedType(classOf[ComponentRegistry[_]], componentClass)).asInstanceOf[Key[AnyRef]]
+
+ override val instanceType: Class[AnyRef] = instanceKey.getTypeLiteral.getRawType.asInstanceOf[Class[AnyRef]]
+ override val componentType: Class[AnyRef] = instanceType
+
+ override def configKeys = Set()
+
+ override def equals(other: Any) = {
+ other match {
+ case that: ComponentRegistryNode =>
+ componentId == that.componentId && // includes componentClass
+ instanceType == that.instanceType &&
+ equalEdges(usedComponents, that.usedComponents)
+ case _ => false
+ }
+ }
+
+ override def label =
+ "{ComponentRegistry\\<%s\\>|%s}".format(componentClass.getSimpleName, Node.packageName(componentClass))
+}
+
+object ComponentRegistryNode {
+ val componentRegistryNamespace = ComponentId.fromString("ComponentRegistry")
+
+ def componentId(componentClass: Class[_]) = {
+ syntheticComponentId(componentClass.getName, componentClass, componentRegistryNamespace)
+ }
+
+ def equalEdges(edges: List[Node], otherEdges: List[Node]): Boolean = {
+ def compareEdges = {
+ (sortByComponentId(edges) zip sortByComponentId(otherEdges)).
+ forall(equalEdge)
+ }
+
+ def sortByComponentId(in: List[Node]) = in.sortBy(_.componentId)
+ def equalEdge(edgePair: (Node, Node)): Boolean = edgePair._1.componentId == edgePair._2.componentId
+
+ edges.size == otherEdges.size &&
+ compareEdges
+ }
+}