0%

Lua二进制扩展多线程任务回调做法

摘要

我希望在lua脚本中执行http请求,然后在注册的回调函数中异步的处理response,http请求任务以队列的形式进行排队,每次只有一个任务在执行。

lua本身是没有提供多线程支持的,虽然它支持协程,但本质是在单个线程内进行切换,本质不是并行。

lua可在不同线程中创建独立的lua状态机,而lua_newthread可根据已有lua状态机创建一个有独立调用栈的独立状态机(它们共享全局环境、注册表和栈空间),可以利用这两个特性来实现异步回调。

实现方法

二进制模块中使用libcurl库来执行http请求

初始化

模块命名为async_request,在luaopen_async_request中完成创建模块、初始化libcurl、创建任务线程以及等待创建新线程使用的lua状态机。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
static void routine(std::promise<void*>* p)
{
RequestContext next_context;

// g_main_lua_state是主线程的lua状态机
// 根据它来创建子线程的lua状态机,它们共享全局环境、注册表和栈空间
// 但新的lua状态机有独立的调用栈,可独立运行
g_routine_lua_state = lua_newthread(g_main_lua_state);

// 把新状态机记录到全局变量中防止被回收
// 同时也是平衡主线程lua状态机的栈
lua_setglobal(g_main_lua_state, "async_request_thread");

// 初始化完毕,唤醒主线程
p->set_value(g_routine_lua_state);

for (;;)
{
{
std::unique_lock<std::mutex> lock(g_mutex);
g_cv.wait_for(lock, std::chrono::milliseconds(100), [] {
return !g_queue.empty() || g_quit_flag;
});

// 退出指令
if (g_quit_flag) {
break;
}

if (g_queue.empty()) {
continue;
}

next_context = g_queue.front();
g_queue.pop();
}

func_request(next_context);
}

// 等待子线程lua状态机运行完毕
while (lua_status(g_routine_lua_state) != LUA_OK && lua_status(g_routine_lua_state) != LUA_YIELD)
{

}

// 关闭子线程lua状态机
lua_close(g_routine_lua_state);
}

int LUA_API luaopen_async_request(lua_State* L)
{
luaL_newlib(L, AsyncRequestModuleMethods);

char full_path[MAX_PATH] = { 0 };
char abs_full_path[MAX_PATH] = { 0 };

// 这里是为了从async_request.dll模块的所在目录去加载libcurl.dll模块,libcurl.dll以延迟加载的方式导入
if (GetModuleFileNameA(GetModuleHandleA("async_request.dll"), full_path, MAX_PATH)) {
if (GetFullPathNameA(full_path, MAX_PATH, abs_full_path, nullptr)) {
std::string dll_search_path = std::filesystem::path::path(abs_full_path).parent_path().string();
SetDllDirectoryA(dll_search_path.c_str());
}
}

// 以g_routine_lua_state是否为空来判断是否已经创建了子线程的lua状态机
if (g_routine_lua_state) {
return 1;
}

// 初始化libcurl
curl_global_init(CURL_GLOBAL_ALL);

std::promise<void*> p;
std::future<void*> f = p.get_future();

// 创建子线程,并等待子线程初始化完毕
g_main_lua_state = L;
g_routine_thread = std::thread(routine, &p);

f.wait();

//luaL_newmetatable(L, "async_request");
//lua_pushcfunction(L, luaL_async_request_destroy);
//lua_setfield(L, -2, "__gc");
//lua_setmetatable(L, -2);

return 1;
}

注册回调函数

1
2
3
4
5
6
7
static int __ar_register_response_callback(lua_State* L)
{
luaL_checktype(L, 1, LUA_TFUNCTION);
lua_setglobal(L, "async_request_response_callback");
lua_pushboolean(L, true);
return 1;
}

添加请求任务

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
static void push_mission(RequestContext context)
{
{
std::lock_guard<std::mutex> lock(g_mutex);
g_queue.push(context);
}

g_cv.notify_one();
}

static int __ar_request(lua_State* L)
{
const char* url = luaL_checkstring(L, 1);
const char* context = luaL_checkstring(L, 2);

push_mission(RequestContext{ url, context });

lua_pushboolean(L, true);
return 1;
}

执行请求以及响应回调

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
size_t func_response_chunk(void* buffer, size_t size, size_t nmemb, void* data)
{
size_t bytesize = size * nmemb;
PRequestContext context = reinterpret_cast<PRequestContext>(data);

if (!buffer || !context) {
return 0;
}

if (!g_routine_lua_state) {
return bytesize;
}

lua_getglobal(g_routine_lua_state, "async_request_response_callback");
if (lua_isfunction(g_routine_lua_state, -1))
{
lua_pushlstring(g_routine_lua_state, (const char*)(buffer), bytesize);
lua_pushstring(g_routine_lua_state, context->context.c_str());

int ret = lua_pcall(g_routine_lua_state, 2, 0, 0);
if (ret != LUA_OK) {
std::cerr << "error calling callback : "
<< lua_tostring(g_routine_lua_state, -1)
<< std::endl;
}
}

std::cout << "bytesize : " << bytesize << std::endl;
std::cout << context->url << std::endl;
std::cout << (char*)buffer << std::endl;

return bytesize;
}

void func_request(RequestContext& context)
{
CURL* curl;
CURLcode res;

curl = curl_easy_init();
if (!curl) {
return;
}

curl_easy_setopt(curl, CURLOPT_URL, context.url.c_str());
curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1);
curl_easy_setopt(curl, CURLOPT_BUFFERSIZE, 4096);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, func_response_chunk);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context);

res = curl_easy_perform(curl);
if (res != CURLE_OK) {
std::cerr << "curl_easy_perform() failed : "
<< curl_easy_strerror(res)
<< std::endl;
}

curl_easy_cleanup(curl);
}

完整二进制模块代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
#ifndef __WXBOX_DYNAMIC_PLUGIN_ASYNC_REQUEST_H
#define __WXBOX_DYNAMIC_PLUGIN_ASYNC_REQUEST_H

#if defined _WIN32 || defined __CYGWIN__
#ifdef BUILDING_WXBOX_DYNAMIC_PLUGIN
#define WXBOX_DYNAMIC_PLUGIN_PUBLIC __declspec(dllexport)
#else
#define WXBOX_DYNAMIC_PLUGIN_PUBLIC __declspec(dllimport)
#endif
#else
#ifdef BUILDING_WXBOX_DYNAMIC_PLUGIN
#define WXBOX_DYNAMIC_PLUGIN_PUBLIC __attribute__((visibility("default")))
#else
#define WXBOX_DYNAMIC_PLUGIN_PUBLIC
#endif
#endif

#ifdef __cplusplus
#define WXBOX_DYNAMIC_PLUGIN_PUBLIC_API extern "C" WXBOX_DYNAMIC_PLUGIN_PUBLIC
#else
#define WXBOX_DYNAMIC_PLUGIN_PUBLIC_API WXBOX_DYNAMIC_PLUGIN_PUBLIC
#endif

#include <lua.hpp>
#include <string>

typedef struct _RequestContext
{
std::string url;
std::string context;
} RequestContext, * PRequestContext;

WXBOX_DYNAMIC_PLUGIN_PUBLIC_API int LUA_API luaopen_async_request(lua_State* L);


#endif // #ifndef __WXBOX_DYNAMIC_PLUGIN_ASYNC_REQUEST_H
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#include "async_request.hpp"

#include <iostream>
#include <sstream>
#include <filesystem>
#include <thread>
#include <queue>
#include <mutex>
#include <condition_variable>
#include <memory>
#include <atomic>
#include <chrono>
#include <future>

#include <curl/curl.h>

static lua_State* g_main_lua_state = nullptr;

static lua_State* g_routine_lua_state = nullptr;
static std::thread g_routine_thread;

static std::queue<RequestContext> g_queue;
static std::mutex g_mutex;
static std::condition_variable g_cv;
static std::atomic_bool g_quit_flag = false;


size_t func_response_chunk(void* buffer, size_t size, size_t nmemb, void* data)
{
size_t bytesize = size * nmemb;
PRequestContext context = reinterpret_cast<PRequestContext>(data);

if (!buffer || !context) {
return 0;
}

if (!g_routine_lua_state) {
return bytesize;
}

lua_getglobal(g_routine_lua_state, "async_request_response_callback");
if (lua_isfunction(g_routine_lua_state, -1))
{
lua_pushlstring(g_routine_lua_state, (const char*)(buffer), bytesize);
lua_pushstring(g_routine_lua_state, context->context.c_str());

int ret = lua_pcall(g_routine_lua_state, 2, 0, 0);
if (ret != LUA_OK) {
std::cerr << "error calling callback : "
<< lua_tostring(g_routine_lua_state, -1)
<< std::endl;
}
}

std::cout << "bytesize : " << bytesize << std::endl;
std::cout << context->url << std::endl;
std::cout << (char*)buffer << std::endl;

return bytesize;
}

void func_request(RequestContext& context)
{
CURL* curl;
CURLcode res;

curl = curl_easy_init();
if (!curl) {
return;
}

curl_easy_setopt(curl, CURLOPT_URL, context.url.c_str());
curl_easy_setopt(curl, CURLOPT_NOSIGNAL, 1);
curl_easy_setopt(curl, CURLOPT_BUFFERSIZE, 4096);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, func_response_chunk);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &context);

res = curl_easy_perform(curl);
if (res != CURLE_OK) {
std::cerr << "curl_easy_perform() failed : "
<< curl_easy_strerror(res)
<< std::endl;
}

curl_easy_cleanup(curl);
}

static void routine(std::promise<void*>* p)
{
RequestContext next_context;

g_routine_lua_state = lua_newthread(g_main_lua_state);
lua_setglobal(g_main_lua_state, "async_request_thread");
p->set_value(g_routine_lua_state);

for (;;)
{
{
std::unique_lock<std::mutex> lock(g_mutex);
g_cv.wait_for(lock, std::chrono::milliseconds(100), [] {
return !g_queue.empty() || g_quit_flag;
});

if (g_quit_flag) {
break;
}

if (g_queue.empty()) {
continue;
}

next_context = g_queue.front();
g_queue.pop();
}

func_request(next_context);
}

while (lua_status(g_routine_lua_state) != LUA_OK && lua_status(g_routine_lua_state) != LUA_YIELD)
{

}

lua_close(g_routine_lua_state);
}

static void push_mission(RequestContext context)
{
{
std::lock_guard<std::mutex> lock(g_mutex);
g_queue.push(context);
}

g_cv.notify_one();
}

static int __ar_register_response_callback(lua_State* L)
{
luaL_checktype(L, 1, LUA_TFUNCTION);
lua_setglobal(L, "async_request_response_callback");
lua_pushboolean(L, true);
return 1;
}

static int __ar_request(lua_State* L)
{
const char* url = luaL_checkstring(L, 1);
const char* context = luaL_checkstring(L, 2);

push_mission(RequestContext{ url, context });

lua_pushboolean(L, true);
return 1;
}

const struct luaL_Reg AsyncRequestModuleMethods[] = {
{"register_response_callback", __ar_register_response_callback},
{"request", __ar_request},
{NULL, NULL},
};

int LUA_API luaL_async_request_destroy(lua_State* L)
{
UNREFERENCED_PARAMETER(L);

g_quit_flag = true;
g_cv.notify_all();
g_routine_thread.join();

g_routine_lua_state = nullptr;
g_main_lua_state = nullptr;

curl_global_cleanup();

return 0;
}

int LUA_API luaopen_async_request(lua_State* L)
{
luaL_newlib(L, AsyncRequestModuleMethods);

char full_path[MAX_PATH] = { 0 };
char abs_full_path[MAX_PATH] = { 0 };

if (GetModuleFileNameA(GetModuleHandleA("async_request.dll"), full_path, MAX_PATH)) {
if (GetFullPathNameA(full_path, MAX_PATH, abs_full_path, nullptr)) {
std::string dll_search_path = std::filesystem::path::path(abs_full_path).parent_path().string();
SetDllDirectoryA(dll_search_path.c_str());
}
}

if (g_routine_lua_state) {
return 1;
}

curl_global_init(CURL_GLOBAL_ALL);

std::promise<void*> p;
std::future<void*> f = p.get_future();

g_main_lua_state = L;
g_routine_thread = std::thread(routine, &p);

f.wait();

//luaL_newmetatable(L, "async_request");
//lua_pushcfunction(L, luaL_async_request_destroy);
//lua_setfield(L, -2, "__gc");
//lua_setmetatable(L, -2);

return 1;
}

调用方法

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
local ar = require('async_request')

function async_request_callback(response, data)
local context = JSON:decode(data)
wxbox.send_text(context['wxid'], response)
wxbox.info('send to : ' .. context['wxid'])
end

function async_ask(wxid, prompt, max_tokens)
local encoded_prompt = urlencode.encode_url(prompt)
local url = 'http://127.0.0.1:5000/question?prompt=' .. encoded_prompt .. '&max_tokens=' .. max_tokens
local context = {['wxid']=wxid, ['url']=url}

ar.request(url, JSON:encode(context))
end

ar.register_response_callback(async_request_callback)

Notes

  • lua源码中的lua_lock和lua_unlock默认是空操作,为了确保线程安全,最好实现加锁
请我喝瓶肥仔快乐水?

欢迎关注我的其它发布渠道