aboutsummaryrefslogtreecommitdiffstats
path: root/jdisc_core/src/main/java/com/yahoo/jdisc/application/GuiceRepository.java
blob: 290d10e5610f27947a3bed838a9e17dd63b4b709 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.jdisc.application;

import com.google.common.collect.ImmutableSet;
import com.google.inject.Binding;
import com.google.inject.Guice;
import com.google.inject.Injector;
import com.google.inject.Key;
import com.google.inject.Module;
import com.google.inject.spi.DefaultElementVisitor;
import com.google.inject.spi.Element;
import com.google.inject.spi.Elements;
import com.yahoo.jdisc.Container;
import org.osgi.framework.Bundle;

import java.util.Arrays;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.ListIterator;
import java.util.Map;
import java.util.Set;
import java.util.logging.Logger;

/**
 * This is a repository of {@link Module}s. An instance of this class is owned by the {@link ContainerBuilder}, and is
 * used to configure the set of Modules that eventually form the {@link Injector} of the active {@link Container}.
 *
 * @author Simon Thoresen Hult
 */
public class GuiceRepository implements Iterable<Module> {

    private static final Logger log = Logger.getLogger(GuiceRepository.class.getName());
    private final Map<Module, List<Element>> modules = new LinkedHashMap<>();
    private Injector injector;

    public GuiceRepository(Module... modules) {
        installAll(Arrays.asList(modules));
    }

    public Injector activate() {
        return getInjector();
    }

    public List<Module> installAll(Bundle bundle, Iterable<String> moduleNames) throws ClassNotFoundException {
        List<Module> lst = new LinkedList<>();
        for (String moduleName : moduleNames) {
            lst.add(install(bundle, moduleName));
        }
        return lst;
    }

    public Module install(Bundle bundle, String moduleName) throws ClassNotFoundException {
        log.finer("Installing Guice module '" + moduleName + "'.");
        Class<?> namedClass = bundle.loadClass(moduleName);
        Class<Module> moduleClass = ContainerBuilder.safeClassCast(Module.class, namedClass);
        Module module = getInstance(moduleClass);
        install(module);
        return module;
    }

    public final void installAll(Iterable<? extends Module> modules) {
        for (Module module : modules) {
            install(module);
        }
    }

    public final void install(Module module) {
        modules.put(module, Elements.getElements(module));
        injector = null;
    }

    public void uninstallAll(Iterable<? extends Module> modules) {
        for (Module module : modules) {
            uninstall(module);
        }
    }

    public void uninstall(Module module) {
        modules.remove(module);
        injector = null;
    }

    public Injector getInjector() {
        if (injector == null) {
            injector = Guice.createInjector(createModule());
        }
        return injector;
    }

    public <T> T getInstance(Key<T> key) {
        return getInjector().getInstance(key);
    }

    public <T> T getInstance(Class<T> type) {
        return getInjector().getInstance(type);
    }

    public Collection<Module> collection() { return ImmutableSet.copyOf(modules.keySet()); }

    @Override
    public Iterator<Module> iterator() {
        return collection().iterator();
    }

    private Module createModule() {
        List<Element> allElements = new LinkedList<>();
        for (List<Element> moduleElements : modules.values()) {
            allElements.addAll(moduleElements);
        }
        ElementCollector collector = new ElementCollector();
        for (ListIterator<Element> it = allElements.listIterator(allElements.size()); it.hasPrevious(); ) {
            it.previous().acceptVisitor(collector);
        }
        return Elements.getModule(collector.elements);
    }

    private static class ElementCollector extends DefaultElementVisitor<Boolean> {

        final Set<Key<?>> seenKeys = new HashSet<>();
        final List<Element> elements = new LinkedList<>();

        @Override
        public <T> Boolean visit(Binding<T> binding) {
            if (seenKeys.add(binding.getKey())) {
                elements.add(binding);
            }
            return Boolean.TRUE;
        }

        @Override
        public Boolean visitOther(Element element) {
            elements.add(element);
            return Boolean.TRUE;
        }
    }
}