summaryrefslogtreecommitdiffstats
path: root/container-di/src/main/scala/com/yahoo/container/di/componentgraph/core/ComponentRegistryNode.scala
blob: a0e07e7433efbe713fc546d643240679f1fc5666 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.di.componentgraph.core

import com.yahoo.component.provider.ComponentRegistry
import com.yahoo.component.{ComponentId, Component}
import ComponentRegistryNode._
import com.google.inject.Key
import com.google.inject.util.Types
import Node.syntheticComponentId

/**
 * @author tonytv
 * @author gjoranv
 */
class ComponentRegistryNode(val componentClass : Class[AnyRef])
  extends Node(componentId(componentClass)) {

  def usedComponents = componentsToInject

  protected def newInstance() = {
    val registry = new ComponentRegistry[AnyRef]

    componentsToInject foreach { component =>
      registry.register(component.componentId, component.newOrCachedInstance())
    }

    registry
  }

  override val instanceKey =
    Key.get(Types.newParameterizedType(classOf[ComponentRegistry[_]], componentClass)).asInstanceOf[Key[AnyRef]]

  override val instanceType: Class[AnyRef] = instanceKey.getTypeLiteral.getRawType.asInstanceOf[Class[AnyRef]]
  override val componentType: Class[AnyRef] = instanceType

  override def configKeys = Set()

  override def equals(other: Any) = {
    other match {
      case that: ComponentRegistryNode =>
        componentId == that.componentId &&       // includes componentClass
          instanceType == that.instanceType &&
          equalEdges(usedComponents, that.usedComponents)
      case _ => false
    }
  }

  override def label =
    "{ComponentRegistry\\<%s\\>|%s}".format(componentClass.getSimpleName, Node.packageName(componentClass))
}

object ComponentRegistryNode {
  val componentRegistryNamespace = ComponentId.fromString("ComponentRegistry")

  def componentId(componentClass: Class[_]) = {
    syntheticComponentId(componentClass.getName, componentClass, componentRegistryNamespace)
  }

  def equalEdges(edges: List[Node], otherEdges: List[Node]): Boolean = {
    def compareEdges = {
      (sortByComponentId(edges) zip sortByComponentId(otherEdges)).
        forall(equalEdge)
    }

    def sortByComponentId(in: List[Node]) = in.sortBy(_.componentId)
    def equalEdge(edgePair: (Node, Node)): Boolean = edgePair._1.componentId == edgePair._2.componentId

    edges.size == otherEdges.size &&
      compareEdges
  }
}