summaryrefslogtreecommitdiffstats
path: root/container-di/src/main/scala/com/yahoo/container/di/componentgraph/core/ComponentNode.scala
blob: 4261b874330267dd85eaf638846e7f9d311d3b26 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.di.componentgraph.core

import java.lang.reflect.{Constructor, InvocationTargetException, Modifier, ParameterizedType, Type}
import java.util.logging.Logger

import com.google.inject.Inject
import com.yahoo.component.{AbstractComponent, ComponentId}
import com.yahoo.config.ConfigInstance
import com.yahoo.container.di.componentgraph.Provider
import com.yahoo.container.di.componentgraph.core.ComponentNode._
import com.yahoo.container.di.componentgraph.core.Node.equalEdges
import com.yahoo.container.di.{ConfigKeyT, JavaAnnotation, createKey, makeClassCovariant, preserveStackTrace, removeStackTrace}
import com.yahoo.vespa.config.ConfigKey

import scala.language.postfixOps

/**
 * @author Tony Vaagenes
 * @author gjoranv
 */
class ComponentNode(componentId: ComponentId,
                    val configId: String,
                    clazz: Class[_ <: AnyRef],
                    val XXX_key: JavaAnnotation = null) // TODO expose key, not javaAnnotation
  extends Node(componentId)
{
  require(!isAbstract(clazz), "Can't instantiate abstract class " + clazz.getName)

  var arguments : Array[AnyRef] = _

  val constructor: Constructor[AnyRef] = bestConstructor(clazz)

  var availableConfigs: Map[ConfigKeyT, ConfigInstance] = null

  override val instanceKey = createKey(clazz, XXX_key)

  override val instanceType = clazz

  override def usedComponents: List[Node] = {
    require(arguments != null, "Arguments must be set first.")
    arguments.collect{case node: Node => node}.toList
  }

  override val componentType: Class[AnyRef] = {
    def allSuperClasses(clazz: Class[_], coll : List[Class[_]]) : List[Class[_]] = {
      if (clazz == null) coll
      else allSuperClasses(clazz.getSuperclass, clazz :: coll)
    }

    def allGenericInterfaces(clazz : Class[_]) = allSuperClasses(clazz, List()) flatMap (_.getGenericInterfaces)

    def isProvider = classOf[Provider[_]].isAssignableFrom(clazz)
    def providerComponentType = (allGenericInterfaces(clazz).collect {
      case t: ParameterizedType if t.getRawType == classOf[Provider[_]] => t.getActualTypeArguments.head
    }).head

    if (isProvider) providerComponentType.asInstanceOf[Class[AnyRef]] //TODO: Test what happens if you ask for something that isn't a class, e.g. a parametrized type.
    else            clazz.asInstanceOf[Class[AnyRef]]
  }

  def setArguments(arguments: Array[AnyRef]) {
    this.arguments = arguments
  }

  def cutStackTraceAtConstructor(throwable: Throwable): Throwable = {
    def takeUntilComponentNode(elements: Array[StackTraceElement]) =
      elements.takeWhile(_.getClassName != classOf[ComponentNode].getName)

    def dropToInitAtEnd(elements: Array[StackTraceElement]) =
      elements.reverse.dropWhile(_.getMethodName != "<init>").reverse

    val modifyStackTrace = takeUntilComponentNode _ andThen dropToInitAtEnd

    val dependencyInjectorStackTraceMarker = new StackTraceElement("============= Dependency Injection =============", "newInstance", null, -1)

    if (throwable != null && !preserveStackTrace) {
      throwable.setStackTrace(modifyStackTrace(throwable.getStackTrace) :+
        dependencyInjectorStackTraceMarker)

      cutStackTraceAtConstructor(throwable.getCause)
    }
    throwable
  }

  override protected def newInstance() : AnyRef = {
    assert (arguments != null, "graph.complete must be called before retrieving instances.")

    val actualArguments = arguments.map {
      case node: Node => node.newOrCachedInstance()
      case config: ConfigKeyT => availableConfigs(config.asInstanceOf[ConfigKeyT])
      case other => other
    }

    val instance =
      try {
        constructor.newInstance(actualArguments: _*)
      } catch {
        case e: InvocationTargetException =>
          throw removeStackTrace(ErrorOrComponentConstructorException(cutStackTraceAtConstructor(e.getCause), s"Error constructing $idAndType"))
      }

    initId(instance)
  }

  private def ErrorOrComponentConstructorException(cause: Throwable, message: String) : Throwable = {
    if (cause != null && cause.isInstanceOf[Error]) // don't convert Errors to RuntimeExceptions
      new Error(message, cause)
    else
      new ComponentConstructorException(message, cause)
  }

  private def initId(component: AnyRef) = {
    def checkAndSetId(c: AbstractComponent) {
      if (c.hasInitializedId && c.getId != componentId )
        throw new IllegalStateException("Component with id '" + componentId + "' has set a conflicting component id: '" + c.getId + "'")

      c.initId(componentId)
    }

    component match {
      case component: AbstractComponent => checkAndSetId(component)
      case other => ()
    }
    component
  }

  override def equals(other: Any) = {
    other match {
      case that: ComponentNode =>
        super.equals(that) &&
          equalEdges(arguments.toList, that.arguments.toList) &&
          usedConfigs == that.usedConfigs
    }
  }

  private def usedConfigs = {
    require(availableConfigs != null, "setAvailableConfigs must be called!")
    ( arguments collect {case c: ConfigKeyT => c} map (availableConfigs) ).toList
  }

  def getAnnotatedConstructorParams: Array[(Type, Array[JavaAnnotation])] = {
    constructor.getGenericParameterTypes zip constructor.getParameterAnnotations
  }

  def setAvailableConfigs(configs: Map[ConfigKeyT, ConfigInstance]) {
    require (arguments != null, "graph.complete must be called before graph.setAvailableConfigs.")
    availableConfigs = configs
  }

  override def configKeys = {
    configParameterClasses.map(new ConfigKey(_, configId)).toSet
  }


  private def configParameterClasses: Array[Class[ConfigInstance]] = {
    constructor.getGenericParameterTypes.collect {
      case clazz: Class[_] if classOf[ConfigInstance].isAssignableFrom(clazz) => clazz.asInstanceOf[Class[ConfigInstance]]
    }
  }

  override def label = {
    val configNames = configKeys.map(_.getName + ".def").toList

    (List(instanceType.getSimpleName, Node.packageName(instanceType)) ::: configNames).
      mkString("{", "|", "}")
  }

}

object ComponentNode {
  val log = Logger.getLogger(classOf[ComponentNode].getName)

  private def bestConstructor(clazz: Class[AnyRef]) = {
    val publicConstructors = clazz.getConstructors.asInstanceOf[Array[Constructor[AnyRef]]]

    def constructorAnnotatedWithInject = {
      publicConstructors filter {_.getAnnotation(classOf[Inject]) != null} match {
        case Array() => None
        case Array(single) => Some(single)
        case _ => throwComponentConstructorException("Multiple constructors annotated with inject in class " + clazz.getName)
      }
    }

    def constructorWithMostConfigParameters = {
      def isConfigInstance(clazz: Class[_]) = classOf[ConfigInstance].isAssignableFrom(clazz)

      publicConstructors match {
        case Array() => throwComponentConstructorException("No public constructors in class " + clazz.getName)
        case Array(single) => single
        case _ =>
          log.warning("Multiple public constructors found in class %s, there should only be one. ".format(clazz.getName) +
            "If more than one public constructor is needed, the primary one must be annotated with @Inject.")
          publicConstructors.
            sortBy(_.getParameterTypes.filter(isConfigInstance).size).
            last
      }
    }

    constructorAnnotatedWithInject getOrElse constructorWithMostConfigParameters
  }

  private def throwComponentConstructorException(message: String) =
    throw removeStackTrace(new ComponentConstructorException(message))

  class ComponentConstructorException(message: String, cause: Throwable) extends RuntimeException(message, cause) {
    def this(message: String) = this(message, null)
  }

  def isAbstract(clazz: Class[_ <: AnyRef]) = Modifier.isAbstract(clazz.getModifiers)
}