diff --git a/mqtt.lua b/mqtt.lua index c9b2b3e..37943c0 100644 --- a/mqtt.lua +++ b/mqtt.lua @@ -1,26 +1,36 @@ local mqtt = {} +local safeRead = function (conn, n) + local success, result, err = pcall(conn.read, conn, n) + if success then + return result, err + end + + return nil, result +end + local readVarint = function (conn, first_byte) - local b + local b, err if first_byte == nil then - b = conn:read(1) + b, err = safeRead(conn, 1) else b = first_byte end local n, s = 0, 0 - while b ~= nil and b & 0x80 == 0x80 do + while err == nil and b & 0x80 == 0x80 do if s > 21 then return 0, "number too large" end n = n + ((b & 0x7F) << s) s = s + 7 - b = conn:read(1) + + b, err = safeRead(conn, 1) end - if b == nil then - return n, "eof" + if err ~= nil then + return n, err end return n + (b << s), nil @@ -43,6 +53,7 @@ function MqttClient:new (conn) self.__index = self conn.readVarint = readVarint + conn.safeRead = safeRead conn:setTimeout(1) c.conn = conn @@ -53,9 +64,9 @@ function MqttClient:new (conn) end function MqttClient:handle () - local data = self.conn:read(2) - if data == nil then - return "eof" + local data, err = self.conn:safeRead(2) + if err ~= nil then + return err end local ptype, length, _ = string.unpack("B B", s) @@ -66,9 +77,9 @@ function MqttClient:handle () end if length > 0 then - data = self.conn:read(length) - if data == nil then - return "eof" + data, err = self.conn:safeRead(length) + if err ~= nil then + return err end else data = ""