diff options
-rw-r--r-- | lua/inet/init.lua | 59 | ||||
-rw-r--r-- | test/inet.lua | 19 |
2 files changed, 68 insertions, 10 deletions
diff --git a/lua/inet/init.lua b/lua/inet/init.lua index 72677f9..3186d30 100644 --- a/lua/inet/init.lua +++ b/lua/inet/init.lua @@ -5,9 +5,39 @@ local bit32 = require 'bit32' local inet, inet4, inet6 +local mt2fam = {} + inet = {} inet.__index = inet +inet4 = setmetatable({}, inet) +inet4.__index = inet4 +mt2fam[inet4] = 4 + +inet6 = setmetatable({}, inet) +inet6.__index = inet6 +mt2fam[inet6] = 6 + +local function get_mt(t) + if type(t) ~= 'table' then return nil end + return getmetatable(t) +end + +local function is_inet4(t) + local mt = get_mt(t) + return mt == inet4 +end + +local function is_inet6(t) + local mt = get_mt(t) + return mt == inet6 +end + +local function is_inet(t) + local mt = get_mt(t) + return mt == inet4 or mt == inet6 +end + function inet.new(ip, mask) local ipv6 = string.find(ip, ':', 1, true) if ipv6 then @@ -22,6 +52,13 @@ function inet:__len() if mask == nil then return 0 end -- make metatable inspectable return mask end +inet4.__len = inet.__len +inet6.__len = inet.__len + +function inet:family() + local mt = assert(getmetatable(self)) + return assert(mt2fam[mt]) +end local lshift = bit32.lshift local rshift = bit32.rshift @@ -29,10 +66,6 @@ local band = bit32.band local replace = bit32.replace local bxor = bit32.bxor -inet4 = {} -inet4.__index = inet4 -inet4.__len = inet.__len - local ipv4_parser local ipv6_parser do @@ -231,10 +264,6 @@ local function parse6(ipstr) return pcs, netmask end -inet6 = setmetatable({}, inet) -inet6.__index = inet6 -inet6.__len = inet.__len - function inet6.new(ip, netmask) local pcs, err if type(ip) == 'string' then @@ -488,8 +517,18 @@ function inet6:__mul(n) end pcs[p] = pcs[p] + low_shift new:balance() - -- print(mask % 8) return new end -return inet.new +local M = {} +local mt = {} + +function mt.__call(_, ...) + return new_inet(...) +end + +M.is4 = is_inet4 +M.is6 = is_inet6 +M.is = is_inet + +return setmetatable(M, mt) diff --git a/test/inet.lua b/test/inet.lua index 22f7185..705715a 100644 --- a/test/inet.lua +++ b/test/inet.lua @@ -42,6 +42,7 @@ return test.new(function() ip = inet('10.0.0.0/24') assert(type(ip) == 'table') assert(#ip == 24, 'incorrect netmask') + assert(ip:family() == 4, 'incorrect family') assert(tostring(ip) == '10.0.0.0/24', 'not human readable') assert(inet('10.0.0.0/32') == inet('10.0.0.0')) @@ -119,4 +120,22 @@ return test.new(function() -- TODO inet6.__le -- TODO inet6.__eq + + assert(not inet.is4(false)) + assert(not inet.is4('foo')) + assert(not inet.is4(42)) + assert(inet.is4(inet('0.0.0.0'))) + assert(not inet.is4(inet('::'))) + + assert(not inet.is6(false)) + assert(not inet.is6('foo')) + assert(not inet.is6(42)) + assert(not inet.is6(inet('0.0.0.0'))) + assert(inet.is6(inet('::'))) + + assert(not inet.is(false)) + assert(not inet.is('foo')) + assert(not inet.is(42)) + assert(inet.is(inet('0.0.0.0'))) + assert(inet.is(inet('::'))) end) |