aboutsummaryrefslogtreecommitdiffstats
path: root/container-di/src/main/scala/com/yahoo/container/di/osgi/OsgiUtil.scala
blob: 3769eed6d2d75707ecad655416d69606a17740ba (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.container.di.osgi

import java.nio.file.{Files, Path, Paths}
import java.util.function.Predicate
import java.util.jar.{JarEntry, JarFile}
import java.util.logging.{Level, Logger}
import java.util.stream.Collectors

import com.google.common.io.Files.fileTreeTraverser
import com.yahoo.component.ComponentSpecification
import com.yahoo.container.di.Osgi.RelativePath
import com.yahoo.osgi.maven.ProjectBundleClassPaths
import com.yahoo.osgi.maven.ProjectBundleClassPaths.BundleClasspathMapping
import org.osgi.framework.Bundle
import org.osgi.framework.wiring.BundleWiring

import scala.collection.JavaConverters._

/**
 * Tested by com.yahoo.application.container.jersey.JerseyTest
 * @author Tony Vaagenes
 */
object OsgiUtil {
  private val log = Logger.getLogger(getClass.getName)
  private  val classFileTypeSuffix = ".class"

  def getClassEntriesInBundleClassPath(bundle: Bundle, packagesToScan: Set[String]) = {
    val bundleWiring = bundle.adapt(classOf[BundleWiring])

    def listClasses(path: String, recurse: Boolean): Iterable[RelativePath] = {
      val options =
        if (recurse) BundleWiring.LISTRESOURCES_LOCAL | BundleWiring.LISTRESOURCES_RECURSE
        else         BundleWiring.LISTRESOURCES_LOCAL

      bundleWiring.listResources(path, "*" + classFileTypeSuffix, options).asScala
    }

    if (packagesToScan.isEmpty) listClasses("/", recurse = true)
    else packagesToScan flatMap { packageName => listClasses(packageToPath(packageName), recurse = false) }
  }

  def getClassEntriesForBundleUsingProjectClassPathMappings(classLoader: ClassLoader,
                                                            bundleSpec: ComponentSpecification,
                                                            packagesToScan: Set[String]) = {
    classEntriesFrom(
      bundleClassPathMapping(bundleSpec, classLoader).classPathElements.asScala.toList,
      packagesToScan)
  }

  private def bundleClassPathMapping(bundleSpec: ComponentSpecification,
                                     classLoader: ClassLoader): BundleClasspathMapping = {

    val projectBundleClassPaths = loadProjectBundleClassPaths(classLoader)

    if (projectBundleClassPaths.mainBundle.bundleSymbolicName == bundleSpec.getName) {
      projectBundleClassPaths.mainBundle
    } else {
      log.log(Level.WARNING, s"Dependencies of the bundle $bundleSpec will not be scanned. Please file a feature request if you need this" )
      matchingBundleClassPathMapping(bundleSpec, projectBundleClassPaths.providedDependencies.asScala.toList)
    }
  }

  def matchingBundleClassPathMapping(bundleSpec: ComponentSpecification,
                                     providedBundlesClassPathMappings: List[BundleClasspathMapping]): BundleClasspathMapping = {
    providedBundlesClassPathMappings.
      find(_.bundleSymbolicName == bundleSpec.getName).
      getOrElse(throw new RuntimeException("No such bundle: " + bundleSpec))
  }

  private def loadProjectBundleClassPaths(classLoader: ClassLoader): ProjectBundleClassPaths = {
    val classPathMappingsFileLocation = classLoader.getResource(ProjectBundleClassPaths.CLASSPATH_MAPPINGS_FILENAME)
    if (classPathMappingsFileLocation == null)
      throw new RuntimeException(s"Couldn't find ${ProjectBundleClassPaths.CLASSPATH_MAPPINGS_FILENAME}  in the class path.")

    ProjectBundleClassPaths.load(Paths.get(classPathMappingsFileLocation.toURI))
  }

  private def classEntriesFrom(classPathEntries: List[String], packagesToScan: Set[String]): Iterable[RelativePath] = {
    val packagePathsToScan = packagesToScan map packageToPath

    classPathEntries.flatMap { entry =>
      val path = Paths.get(entry)
      if (Files.isDirectory(path)) classEntriesInPath(path, packagePathsToScan)
      else if (Files.isRegularFile(path) && path.getFileName.toString.endsWith(".jar")) classEntriesInJar(path, packagePathsToScan)
      else throw new RuntimeException("Unsupported path " + path + " in the class path")
    }
  }

  private def classEntriesInPath(rootPath: Path, packagePathsToScan: Traversable[String]): Traversable[RelativePath] = {
    def relativePathToClass(pathToClass: Path): RelativePath = {
      val relativePath = rootPath.relativize(pathToClass)
      relativePath.toString
    }

    val fileIterator =
      if (packagePathsToScan.isEmpty) fileTreeTraverser().preOrderTraversal(rootPath.toFile).asScala
      else packagePathsToScan.view flatMap  { packagePath =>  fileTreeTraverser().children(rootPath.resolve(packagePath).toFile).asScala }

    for {
      file <- fileIterator
      if file.isFile
      if file.getName.endsWith(classFileTypeSuffix)
    } yield relativePathToClass(file.toPath)
  }


  private def classEntriesInJar(jarPath: Path, packagePathsToScan: Set[String]): Traversable[RelativePath] = {
    def packagePath(name: String) = {
      name.lastIndexOf('/') match {
        case -1 => name
        case n => name.substring(0, n)
      }
    }

    val acceptedPackage: Predicate[String] =
      if (packagePathsToScan.isEmpty) (name: String) => true
      else (name: String) => packagePathsToScan(packagePath(name))

    var jarFile: JarFile = null
    try {
      jarFile = new JarFile(jarPath.toFile)
      jarFile.stream().
        map[String] { entry: JarEntry => entry.getName}.
        filter { name: String => name.endsWith(classFileTypeSuffix)}.
        filter(acceptedPackage).
        collect(Collectors.toList()).
        asScala
    } finally {
      if (jarFile != null) jarFile.close()
    }
  }

  def packageToPath(packageName: String) = packageName.replaceAllLiterally(".", "/")

  implicit class JavaPredicate[T](f: T => Boolean) extends Predicate[T] {
    override def test(t: T): Boolean = f(t)
  }

  implicit class JavaFunction[T, R](f: T => R) extends java.util.function.Function[T, R] {
    override def apply(t: T): R = f(t)
  }
}