aboutsummaryrefslogtreecommitdiffstats
path: root/protoc-gen-csi/protoc_gen_csi.cpp
blob: fa9c3e089ca8423b2cc21579eddc98adc4fd95de (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <google/protobuf/compiler/plugin.h>
#include <google/protobuf/compiler/code_generator.h>
#include <google/protobuf/io/zero_copy_stream.h>
#include <google/protobuf/descriptor.h>

using namespace google::protobuf;
using namespace google::protobuf::compiler;

class MyGen : public CodeGenerator {
public:
    bool Generate(const FileDescriptor * file, const std::string & parameter,
                  GeneratorContext * generator_context, std::string * error) const override;
    ~MyGen();
};

MyGen::~MyGen() = default;

int main(int argc, char* argv[]) {
    MyGen generator;
    return PluginMain(argc, argv, &generator);
}

void write_line(io::ZeroCopyOutputStream *target, const std::string &line) {
    void *data = nullptr;
    const char *src = line.c_str();
    int left = line.size() + 1;
    while (left > 0) {
        int size = left;
        if (target->Next(&data, &size)) {
            if (size == 0) continue;
            if (size >= left) {
                memcpy(data, src, left);
                char * buf = static_cast<char *>(data);
                buf[left - 1] = '\n';
                if (size > left) {
                    target->BackUp(size - left);
                }
                return;
            } else {
                memcpy(data, src, size);
                left -= size;
                src += size;
            }
        } else {
            perror("target->Next() returned false");
            std::string message = "Error writing output: ";
            message += strerror(errno);
            throw message;
        }
    }
}

void my_generate(const std::string &name,
                 const FileDescriptor &file,
                 GeneratorContext &context)
{
    if (file.dependency_count() > 0
        || file.public_dependency_count() > 0
        || file.weak_dependency_count() > 0)
    {
        std::string message = "Importing dependencies not supported";
        throw message;
    }
    if (file.extension_count() > 0) {
        std::string message = "Extensions not supported";
        throw message;
    }
    if (file.is_placeholder()) {
        std::string message = "Unexpected placeholder file";
        throw message;
    }
    auto filename_csi_h = name + ".csi.h";
    auto csi_h = context.Open(filename_csi_h);
    write_line(csi_h, "// DO NOT EDIT (generated by protoc-gen-csi)");
    write_line(csi_h, "// Coroutine Service Interface: " + name);
    write_line(csi_h, "#pragma once");

    auto filename_csi_cpp = name + ".csi.cpp";
    auto csi_cpp = context.Open(filename_csi_cpp);
    write_line(csi_cpp, "// DO NOT EDIT (generated by protoc-gen-csi)");
    write_line(csi_cpp, "// Coroutine Service Interface: " + name);
    write_line(csi_cpp, "#include \"" + filename_csi_h + "\"");
}

bool MyGen::Generate(const FileDescriptor * file, const std::string & parameter, 
                     GeneratorContext * generator_context, std::string * error) const
{
    std::string name = "[unknown]";
    try {
        if (file == nullptr) {
            std::string m = "No FileDescriptor";
            throw m;
        }
        name = file->name();
        if (name.ends_with(".proto")) {
            name = name.substr(0, name.size() - 6);
        }
        if (generator_context == nullptr) {
            std::string m = "No GeneratorContext";
            throw m;
        }
        if (! parameter.empty()) {
            std::string m = "unknown command line parameter " + parameter;
            throw m;
        }
        my_generate(name, *file, *generator_context);
    } catch (const std::string &message) {
        *error = name + ": " + message;
        return false;
    }
    return true;
}