diff --git a/zmq.c b/zmq.c index 14a4e29..ea09c61 100644 --- a/zmq.c +++ b/zmq.c @@ -42,7 +42,10 @@ #define MT_ZMQ_CONTEXT "MT_ZMQ_CONTEXT" #define MT_ZMQ_SOCKET "MT_ZMQ_SOCKET" -typedef struct { void *ptr; } zmq_ptr; +typedef struct { + void *ptr; + int should_free; +} zmq_ptr; static int Lzmq_version(lua_State *L) { @@ -88,16 +91,35 @@ static int Lzmq_push_error(lua_State *L) static int Lzmq_init(lua_State *L) { - int io_threads = luaL_checkint(L, 1); zmq_ptr *ctx = lua_newuserdata(L, sizeof(zmq_ptr)); + luaL_getmetatable(L, MT_ZMQ_CONTEXT); + lua_setmetatable(L, -2); + + if (lua_islightuserdata(L, 1)) { + // Treat a light userdata as a raw ZMQ context object, which + // we'll silently wrap. (And we won't automatically call term + // on it.) + + ctx->ptr = lua_touserdata(L, 1); + ctx->should_free = 0; + return 1; + } + + int io_threads = luaL_checkint(L, 1); + ctx->ptr = zmq_init(io_threads); if (!ctx->ptr) { return Lzmq_push_error(L); } - luaL_getmetatable(L, MT_ZMQ_CONTEXT); - lua_setmetatable(L, -2); + // toboolean defaults to false, but we want a missing param #2 + // to mean true + if (lua_isnil(L, 2)) { + ctx->should_free = 1; + } else { + ctx->should_free = lua_toboolean(L, 2); + } return 1; } @@ -119,6 +141,23 @@ static int Lzmq_term(lua_State *L) return 1; } +static int Lzmq_ctx_gc(lua_State *L) +{ + zmq_ptr *ctx = luaL_checkudata(L, 1, MT_ZMQ_CONTEXT); + if (ctx->should_free) { + return Lzmq_term(L); + } else { + return 0; + } +} + +static int Lzmq_ctx_lightuserdata(lua_State *L) +{ + zmq_ptr *ctx = luaL_checkudata(L, 1, MT_ZMQ_CONTEXT); + lua_pushlightuserdata(L, ctx->ptr); + return 1; +} + static int Lzmq_socket(lua_State *L) { zmq_ptr *ctx = luaL_checkudata(L, 1, MT_ZMQ_CONTEXT); @@ -430,7 +469,8 @@ static const luaL_reg zmqlib[] = { }; static const luaL_reg ctxmethods[] = { - {"__gc", Lzmq_term}, + {"__gc", Lzmq_ctx_gc}, + {"lightuserdata", Lzmq_ctx_lightuserdata}, {"term", Lzmq_term}, {"socket", Lzmq_socket}, {NULL, NULL}