aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lua/inet/init.lua59
-rw-r--r--test/inet.lua19
2 files changed, 68 insertions, 10 deletions
diff --git a/lua/inet/init.lua b/lua/inet/init.lua
index 72677f9..3186d30 100644
--- a/lua/inet/init.lua
+++ b/lua/inet/init.lua
@@ -5,9 +5,39 @@ local bit32 = require 'bit32'
local inet, inet4, inet6
+local mt2fam = {}
+
inet = {}
inet.__index = inet
+inet4 = setmetatable({}, inet)
+inet4.__index = inet4
+mt2fam[inet4] = 4
+
+inet6 = setmetatable({}, inet)
+inet6.__index = inet6
+mt2fam[inet6] = 6
+
+local function get_mt(t)
+ if type(t) ~= 'table' then return nil end
+ return getmetatable(t)
+end
+
+local function is_inet4(t)
+ local mt = get_mt(t)
+ return mt == inet4
+end
+
+local function is_inet6(t)
+ local mt = get_mt(t)
+ return mt == inet6
+end
+
+local function is_inet(t)
+ local mt = get_mt(t)
+ return mt == inet4 or mt == inet6
+end
+
function inet.new(ip, mask)
local ipv6 = string.find(ip, ':', 1, true)
if ipv6 then
@@ -22,6 +52,13 @@ function inet:__len()
if mask == nil then return 0 end -- make metatable inspectable
return mask
end
+inet4.__len = inet.__len
+inet6.__len = inet.__len
+
+function inet:family()
+ local mt = assert(getmetatable(self))
+ return assert(mt2fam[mt])
+end
local lshift = bit32.lshift
local rshift = bit32.rshift
@@ -29,10 +66,6 @@ local band = bit32.band
local replace = bit32.replace
local bxor = bit32.bxor
-inet4 = {}
-inet4.__index = inet4
-inet4.__len = inet.__len
-
local ipv4_parser
local ipv6_parser
do
@@ -231,10 +264,6 @@ local function parse6(ipstr)
return pcs, netmask
end
-inet6 = setmetatable({}, inet)
-inet6.__index = inet6
-inet6.__len = inet.__len
-
function inet6.new(ip, netmask)
local pcs, err
if type(ip) == 'string' then
@@ -488,8 +517,18 @@ function inet6:__mul(n)
end
pcs[p] = pcs[p] + low_shift
new:balance()
- -- print(mask % 8)
return new
end
-return inet.new
+local M = {}
+local mt = {}
+
+function mt.__call(_, ...)
+ return new_inet(...)
+end
+
+M.is4 = is_inet4
+M.is6 = is_inet6
+M.is = is_inet
+
+return setmetatable(M, mt)
diff --git a/test/inet.lua b/test/inet.lua
index 22f7185..705715a 100644
--- a/test/inet.lua
+++ b/test/inet.lua
@@ -42,6 +42,7 @@ return test.new(function()
ip = inet('10.0.0.0/24')
assert(type(ip) == 'table')
assert(#ip == 24, 'incorrect netmask')
+ assert(ip:family() == 4, 'incorrect family')
assert(tostring(ip) == '10.0.0.0/24', 'not human readable')
assert(inet('10.0.0.0/32') == inet('10.0.0.0'))
@@ -119,4 +120,22 @@ return test.new(function()
-- TODO inet6.__le
-- TODO inet6.__eq
+
+ assert(not inet.is4(false))
+ assert(not inet.is4('foo'))
+ assert(not inet.is4(42))
+ assert(inet.is4(inet('0.0.0.0')))
+ assert(not inet.is4(inet('::')))
+
+ assert(not inet.is6(false))
+ assert(not inet.is6('foo'))
+ assert(not inet.is6(42))
+ assert(not inet.is6(inet('0.0.0.0')))
+ assert(inet.is6(inet('::')))
+
+ assert(not inet.is(false))
+ assert(not inet.is('foo'))
+ assert(not inet.is(42))
+ assert(inet.is(inet('0.0.0.0')))
+ assert(inet.is(inet('::')))
end)