diff --git a/zmq.c b/zmq.c index 73d453f..40136e1 100644 --- a/zmq.c +++ b/zmq.c @@ -33,7 +33,10 @@ #define MT_ZMQ_CONTEXT "MT_ZMQ_CONTEXT" #define MT_ZMQ_SOCKET "MT_ZMQ_SOCKET" -typedef struct { void *ptr; } zmq_ptr; +typedef struct { + void *ptr; + int should_free; +} zmq_ptr; static int Lzmq_version(lua_State *L) { @@ -66,16 +69,35 @@ static int Lzmq_version(lua_State *L) static int Lzmq_init(lua_State *L) { - int io_threads = luaL_checkint(L, 1); zmq_ptr *ctx = lua_newuserdata(L, sizeof(zmq_ptr)); + luaL_getmetatable(L, MT_ZMQ_CONTEXT); + lua_setmetatable(L, -2); + + if (lua_islightuserdata(L, 1)) { + // Treat a light userdata as a raw ZMQ context object, which + // we'll silently wrap. (And we won't automatically call term + // on it.) + + ctx->ptr = lua_touserdata(L, 1); + ctx->should_free = 0; + return 1; + } + + int io_threads = luaL_checkint(L, 1); + ctx->ptr = zmq_init(io_threads); if (!ctx->ptr) { zmq_return_error(); } - luaL_getmetatable(L, MT_ZMQ_CONTEXT); - lua_setmetatable(L, -2); + // toboolean defaults to false, but we want a missing param #2 + // to mean true + if (lua_isnil(L, 2)) { + ctx->should_free = 1; + } else { + ctx->should_free = lua_toboolean(L, 2); + } return 1; } @@ -97,6 +119,23 @@ static int Lzmq_term(lua_State *L) return 1; } +static int Lzmq_ctx_gc(lua_State *L) +{ + zmq_ptr *ctx = luaL_checkudata(L, 1, MT_ZMQ_CONTEXT); + if (ctx->should_free) { + return Lzmq_term(L); + } else { + return 0; + } +} + +static int Lzmq_ctx_lightuserdata(lua_State *L) +{ + zmq_ptr *ctx = luaL_checkudata(L, 1, MT_ZMQ_CONTEXT); + lua_pushlightuserdata(L, ctx->ptr); + return 1; +} + static int Lzmq_socket(lua_State *L) { zmq_ptr *ctx = luaL_checkudata(L, 1, MT_ZMQ_CONTEXT); @@ -338,7 +377,8 @@ static const luaL_reg zmqlib[] = { }; static const luaL_reg ctxmethods[] = { - {"__gc", Lzmq_term}, + {"__gc", Lzmq_ctx_gc}, + {"lightuserdata", Lzmq_ctx_lightuserdata}, {"term", Lzmq_term}, {"socket", Lzmq_socket}, {NULL, NULL}