diff options
-rw-r--r-- | Makefile | 9 | ||||
-rw-r--r-- | lua/inet/init.lua | 490 | ||||
-rw-r--r-- | lua/inet/set.lua | 49 | ||||
-rwxr-xr-x | test.lua | 8 | ||||
-rw-r--r-- | test/all.lua | 4 | ||||
-rw-r--r-- | test/inet.lua | 118 | ||||
-rw-r--r-- | test/inet_set.lua | 73 | ||||
-rw-r--r-- | test/init.lua | 163 |
8 files changed, 914 insertions, 0 deletions
diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..6efd638 --- /dev/null +++ b/Makefile @@ -0,0 +1,9 @@ +DEST_LUA = $(DESTDIR)/usr/share/lua/5.3/inet + +.PHONY: install build test +build: +install: + install -d $(DEST_LUA) + install -m644 lua/inet/*.lua $(DEST_LUA) +test: + ./test.lua diff --git a/lua/inet/init.lua b/lua/inet/init.lua new file mode 100644 index 0000000..62fa4d7 --- /dev/null +++ b/lua/inet/init.lua @@ -0,0 +1,490 @@ +-- ipv4 / 24 = network +-- ipv6/56 * 5 = 5 /56 further down + +local bit32 = require 'bit32' + +local inet, inet4, inet6 + +inet = {} +inet.__index = inet + +function inet.new(ip, mask) + local ipv6 = string.find(ip, ':', 1, true) + if ipv6 then + return inet6.new(ip, mask) + else + return inet4.new(ip, mask) + end +end + +function inet:__len() + return self.mask +end + +local lshift = bit32.lshift +local rshift = bit32.rshift +local band = bit32.band +local replace = bit32.replace +local bxor = bit32.bxor + +inet4 = {} +inet4.__index = inet4 +inet4.__len = inet.__len + +local ipv4_parser +local ipv6_parser +do + local lpeg = require 'lpeg' + local C, Ct = lpeg.C, lpeg.Ct + local S, R = lpeg.S, lpeg.R + local B, Cc = lpeg.B, lpeg.Cc + + local digit = R('09') + + do + local dot = S('.') + local zero = S('0') + local octet0 = B(zero) * Cc('0') + local octet1 = R('19') + local octet2 = R('19') * digit + local octet31 = S('1') * digit * digit + local octet32 = R('04') * digit + local octet325 = S('5') * R('05') + local octet3 = octet31 + (S('2') * (octet32 + octet325)) + local octet = zero^0 * (C(octet3 + octet2 + octet1) + octet0) + local ipv4 = octet * dot * octet * dot * octet * dot * octet + local mask12 = R('12') * digit + local mask3 = S('3') * R('02') + local netmask = S('/') * C(mask12 + mask3 + digit) + ipv4_parser = ipv4 * (netmask + C('')) * -1 + end + + do + local function hextonumber(hex) return tonumber(hex, 16) end + local hexdigit = R('09') + R('af') + R('AF') + local piece = C(hexdigit * (hexdigit^-3)) / hextonumber + local col = S(':') + local colcol = C(col * col) + local picol = piece * col + local colpi = col * piece + local full = picol * picol * picol * picol * picol * picol * picol * piece + local partial = (piece * (colpi^-6))^-1 * colcol * ((picol^-6)*piece)^-1 + local netmask = S('/') * C((digit^-3)) / tonumber + ipv6_parser = Ct(full + partial) * ((netmask + C(''))^-1) * -1 + end +end + +local function parse4(ipstr) + local o1, o2, o3, o4, mask = ipv4_parser:match(ipstr) + if o1 == nil then return nil end + + local bip = lshift(o1, 24) + lshift(o2, 16) + lshift(o3, 8) + o4 + return bip, tonumber(mask) +end + +function inet4.new(ip, mask) + local bip, ourmask + if type(ip) == 'string' then + bip, ourmask = parse4(ip) + if bip == nil then + return nil + end + elseif type(ip) == 'number' then + bip = ip + end + if mask then + if type(mask) == 'number' and mask >= 0 and mask <= 32 then + ourmask = mask + else + error('invalid mask') + end + else + if not ourmask then + ourmask = 32 + end + end + return setmetatable({ + bip = bip, + mask = ourmask, + }, inet4) +end + +local function tostr4(self, withmask) + -- return human readable + local bip, mask = self.bip, self.mask + local o1, o2, o3, o4 + o1 = band(rshift(bip, 24), 0xff) + o2 = band(rshift(bip, 16), 0xff) + o3 = band(rshift(bip, 8), 0xff) + o4 = band(bip, 0xff) + if (mask == nil or mask == 32 or withmask == false) and withmask ~= true then + return string.format('%d.%d.%d.%d', o1, o2, o3, o4) + else + return string.format('%d.%d.%d.%d/%d', o1, o2, o3, o4, mask) + end +end + +function inet4:__tostring() + return tostr4(self) +end + +function inet4:ipstring() + return tostr4(self, false) +end + +function inet4:cidrstring() + return tostr4(self, true) +end + +function inet4:__add(n) + return inet4.new(self.bip + n, self.mask) +end + +function inet4:__sub(n) + return inet4.new(self.bip - n, self.mask) +end + +function inet4:__mul(n) + local new = self.bip + (n * math.pow(2, 32 - self.mask)) + return inet4.new(new, self.mask) +end + +function inet4:__div(n) + return inet4.new(self.bip, n) +end + +function inet4:__pow(n) + return inet4.new(self.bip, self.mask + n) +end + +function inet4:__lt(other) + if self.mask <= other.mask then + return false + end + local mask = other.mask + local selfnet = replace(self.bip, 0, 0, 32-mask) + local othernet = replace(other.bip, 0, 0, 32-mask) + return selfnet == othernet +end + +function inet4:__le(other) + if self.mask < other.mask then + return false + end + local mask = other.mask + local selfnet = replace(self.bip, 0, 0, 32-mask) + local othernet = replace(other.bip, 0, 0, 32-mask) + return selfnet == othernet +end + +function inet4:__eq(other) + return self.bip == other.bip and self.mask == other.mask +end + +function inet4:network() + local hostbits = 32 - self.mask + return inet4.new(lshift(rshift(self.bip, hostbits), hostbits), self.mask) +end + +function inet4:netmask() + local hostbits = 32 - self.mask + return inet4.new(replace(0xffffffff, 0, 0, hostbits), 32) +end + +function inet4:flip() + -- find twin by flipping the last network bit + local mask = self.mask + if mask == 0 then return nil end + local hostbits = 32 - mask + local flipbit = 1 << hostbits + return inet4.new(self.bip ~ flipbit, mask) +end + +local function parse6(ipstr) + local pcs, netmask = ipv6_parser:match(ipstr) + if not pcs then return nil end + if #pcs > 8 then return nil, 'too many pieces' end + local zero_pieces = 8 - #pcs + for i=1,#pcs do + if pcs[i] == '::' then + pcs[i] = 0 + for j=1,#pcs-i do + pcs[i+j+zero_pieces] = pcs[i+j] + end + for j=1,zero_pieces do + pcs[i+j] = 0 + end + end + end + if #pcs > 8 then return nil, 'too many pieces' end + if netmask == '' then + netmask = 128 + elseif netmask > 128 then + return nil, 'invalid netmask' + end + return pcs, netmask +end + +inet6 = setmetatable({}, inet) +inet6.__index = inet6 +inet6.__len = inet.__len + +function inet6.new(ip, netmask) + local pcs, err + if type(ip) == 'string' then + pcs, err = parse6(ip) + if pcs == nil then + return nil, err + end + if not netmask then + netmask = err + end + elseif type(ip) == 'table' then + pcs = { ip[1], ip[2], ip[3], ip[4], + ip[5], ip[6], ip[7], ip[8] } + if not netmask then + netmask = 128 + end + else + return nil + end + + local r = setmetatable({ + pcs = pcs, + mask = netmask, + }, inet6) + + -- ensure that the result is balanced + if not r:is_balanced() then + r:balance() + return nil, tostring(r)..' unbalanced' + end + + return r +end + +-- each ipv6 address is stored as eight pieces +-- 1111:2222:3333:4444:5555:6666:7777:8888 +-- in the table pcs. + +function inet6:is_balanced() + local pcs = self.pcs + local i = 8 + for i=i,8 do + local piece = pcs[i] + if piece < 0 or piece > 0xffff then + return false + end + end + return true +end + +function inet6:balance(quick) + local pcs = self.pcs + local i = 8 + while i > 1 do + if quick and pcs[i] > 0 then + break + end + while pcs[i] < 0 do + pcs[i] = pcs[i] + 0x10000 + pcs[i-1] = pcs[i-1] - 1 + end + i = i - 1 + end + i = 8 + while i > 1 do + local extra = rshift(pcs[i], 16) + if quick and extra == 0 then + break + end + pcs[i] = band(pcs[i], 0xffff) + pcs[i-1] = pcs[i-1] + extra + i = i - 1 + end + pcs[1] = band(pcs[1], 0xffff) + return self +end + +local function tostr6(self, withmask) + -- return human readable + local pcs = self.pcs + local zeros = {} + + -- count zero clusters + local first_zero = 0 + local prev_was_zero = false + for i=1,#pcs do + if pcs[i] == 0 then + if prev_was_zero then + zeros[first_zero] = zeros[first_zero] + 1 + else + first_zero = i + zeros[first_zero] = 1 + end + prev_was_zero = true + else + prev_was_zero = false + end + end + + -- find the largest zero cluster + local zeros_begin = nil + local zeros_cnt = 0 + for begin,cnt in pairs(zeros) do + if cnt > zeros_cnt then + zeros_begin = begin + zeros_cnt = cnt + end + end + + -- format ipv6 address + local out = '' + local i = 1 + while i <= 8 do + if i == zeros_begin then + if i > 1 then + out = out .. ':' + else + out = out .. '::' + end + i = i + zeros_cnt + else + local p = pcs[i] + local hexdigits = string.format('%x', p) + out = out .. hexdigits + if i ~= 8 then + out = out .. ':' + end + i = i + 1 + end + end + + local mask = self.mask + if (mask == nil or mask == 128 or withmask == false) and withmask ~= true then + return out + else + return string.format('%s/%d', out, mask) + end +end + +function inet6:__tostring() + return tostr6(self) +end + +function inet6:ipstring() + return tostr6(self, false) +end + +function inet6:cidrstring() + return tostr6(self, true) +end + +function inet6:clone() + return inet6.new(self.pcs, self.mask) +end + +function inet6:__eq(other) + if self.mask ~= other.mask then + return false + end + local spcs = self.pcs + local opcs = other.pcs + for i=1,8 do + if spcs[i] ~= opcs[i] then + return false + end + end + return true +end + +function inet6:__div(n) + return inet6.new(self.pcs, n) +end + +function inet6:__pow(n) + return inet6.new(self.pcs, self.mask + n) +end + +function inet6:__add(n) + local new = self:clone() + local pcs = new.pcs + pcs[8] = pcs[8] + n + new:balance(true) + return new +end + +function inet6:__sub(n) + return self + (n*-1) +end + +function inet6:network() + local netbits = self.mask + local pcs = self.pcs + local newpcs = { 0, 0, 0, 0, 0, 0, 0, 0 } + for i=1,8 do + if netbits >= i*16 then + newpcs[i] = pcs[i] + elseif netbits <= (i-1)*16 then + break -- the rest is already zero + else + local netbitsleft = 16-(netbits-((i-1)*16)) + newpcs[i] = pcs[i] >> netbitsleft << netbitsleft + end + end + return inet6.new(newpcs, netbits) +end + +function inet6:flip() + -- find twin by flipping the last network bit + local mask = self.mask + if mask == 0 then return nil end + local block = (mask >> 4)+1 + local maskbits = mask & 0xf + local bitno = 16 - maskbits + if bitno == 16 then + block = block - 1 + bitno = 0 + end + local flipbit = 1 << bitno + local r = self:clone() + local val = r.pcs[block] + r.pcs[block] = r.pcs[block] ~ flipbit + --print(mask, block, maskbits, bitno, flipbit, self, r:balance()) + return r +end + + +function inet6:__mul(n) + local new = self:clone() + local mask = new.mask + local pcs = new.pcs + local netbitoverflow = mask % 16 + local netbitremainder = (128-mask) % 16 + local p = (mask - netbitoverflow) / 16 + if netbitremainder ~= 0 then + p = p + 1 + end + local was_negative = false + if n < 0 then + n = n * -1 + was_negative = true + end + local shiftet = lshift(n, netbitremainder) + local high_shift = rshift(shiftet, 16) + local low_shift = band(shiftet, 0xffff) + --print(p, netbitoverflow, hex(shiftet), hex(high_shift), hex(low_shift)) + if was_negative then + high_shift = -high_shift + low_shift = -low_shift + end + if p > 2 then + pcs[p-1] = pcs[p-1] + high_shift + end + pcs[p] = pcs[p] + low_shift + new:balance() + -- print(mask % 8) + return new +end + +return inet.new diff --git a/lua/inet/set.lua b/lua/inet/set.lua new file mode 100644 index 0000000..effb0f2 --- /dev/null +++ b/lua/inet/set.lua @@ -0,0 +1,49 @@ +local M = {} + +local function table_compact(t, n) + -- remove nil entries, and reorder + local i = 0 + for j=1,n do + if t[j] then + if i > 0 then + t[i] = t[j] + i = i + 1 + end + elseif i == 0 then + i = j + end + end + if i > 0 then + for j=i,n do + t[j] = nil + end + end +end + +function M.aggregate(t) + local flag = true + local n = #t + for i=1,n do + t[i] = t[i]:network() + end + while flag do -- loop until no aggregatable addresses are found + flag = false + for i=1,n do + local ia = t[i] + if ia then + local ib = ia:flip() -- counterpart + for j=1,n do + if j ~= i and t[j] == ib then + -- counterpart found, aggregating + t[i] = (ia ^ -1):network() + t[j] = nil + flag = true + end + end + end + end + end + table_compact(t, n) +end + +return M diff --git a/test.lua b/test.lua new file mode 100755 index 0000000..e02345f --- /dev/null +++ b/test.lua @@ -0,0 +1,8 @@ +#!/usr/bin/env lem + +package.path = './?/init.lua;'..package.path..';./lua/?.lua;./lua/?/init.lua;./?.lua' +package.cpath = package.cpath..';./lua/?.so' + +local test = require 'test' + +test.test() diff --git a/test/all.lua b/test/all.lua new file mode 100644 index 0000000..4ee056a --- /dev/null +++ b/test/all.lua @@ -0,0 +1,4 @@ +local all = require('test').new() +all:depend('inet') +all:depend('inet_set') +return all diff --git a/test/inet.lua b/test/inet.lua new file mode 100644 index 0000000..43432cc --- /dev/null +++ b/test/inet.lua @@ -0,0 +1,118 @@ +local inet = require 'inet' +local test = require 'test' + +local function parse(addr) + local ret, err = inet(addr) + assert(ret, (err or '')..' '..addr) + return ret, err +end + +local function dontparse(...) + test.fail(parse, ...) +end + +return test.new(function() + -- parsing + parse('1:2:3:4:5:6:7:8') + parse('::1/33') + parse('1::/33') + parse('1:2:3:4:5:6:7::/33') + parse('::2:3:4:5:6:7:8/33') + parse('2a03:5440:1010::80/64') + dontparse('::1::/33') + dontparse('::1/33a') + dontparse('::1/150') + dontparse('1:2:3:4::2:3:4:5:6:7:8/33') + assert(tostring(parse('1:0:0:1::/64') * 1) == '1:0:0:2::/64') + assert(tostring(parse('1::/64') * 5 / 32 * 3) == '1:3:0:5::/32') + assert(tostring(parse('5::64') / 32 * -3) == '4:fffd::64/32') + assert(tostring(parse('2::/32') ^ 1) == '2::/33') + assert(tostring(parse('2::/32') ^ -1) == '2::/31') + assert(tostring(parse('2::/128') * 5) == '2::5') + assert(tostring(parse('2::/49') - 1) + == '1:ffff:ffff:ffff:ffff:ffff:ffff:ffff/49') + assert(tostring(parse('2::/49') - 1 + 2) == '2::1/49') + assert(tostring(parse('1:ffff:ffff:fe00::/56') * 2) == '2::/56') + assert(tostring(parse('1:ffff:ffff:fe00::/56') * 2 * -2) + == '1:ffff:ffff:fe00::/56') + local ip = inet('10.0.0.0/33') + assert(ip == nil) + + local ip = inet('10.0.0.0/24') + assert(type(ip) == 'table') + assert(#ip == 24, 'incorrect netmask') + assert(tostring(ip) == '10.0.0.0/24', 'not human readable') + + assert(inet('10.0.0.0/32') == inet('10.0.0.0')) + assert(inet('10.0.0.0/31') ~= inet('10.0.0.0')) + + assert(tostring(ip+1) == '10.0.0.1/24', 'ip adding is broken') + assert(tostring(ip+9-1) == '10.0.0.8/24', 'ip subtract is broken') + assert(tostring(ip*1) == '10.0.1.0/24', 'ip multiplification is broken') + assert(tostring(ip/8) == '10.0.0.0/8', 'ip division is broken') + assert(tostring(ip^1) == '10.0.0.0/25', 'ip power is broken') + + -- test inet4.__lt + assert(inet('10.0.0.0/24') > inet('10.0.0.0/30'), 'inet less than is broken') + assert(not (inet('10.0.0.0/30') > inet('10.0.0.0/30')), 'inet less than is broken') + assert(inet('10.0.0.0/30') >= inet('10.0.0.0/30'), 'inet less than is broken') + assert(inet('10.0.0.0/30') <= inet('10.0.0.0/30'), 'inet less than is broken') + assert(inet('10.0.0.0/30') < inet('10.0.0.0/24'), 'inet less than is broken') + assert(not (inet('10.0.0.0/24') < inet('10.0.0.0/30')), 'inet less than is broken') + assert(not (inet('10.0.0.0/30') < inet('10.0.0.0/30')), 'inet less than is broken') + assert(not (inet('20.0.0.0/30') < inet('10.0.0.0/24')), 'inet less than is broken') + + -- test inet4.__le + assert(inet('10.0.1.2/24') <= inet('10.0.0.0/16')) + assert(not (inet('10.0.1.0/24') <= inet('10.0.0.0/24'))) + + assert(inet('127.0.0.1/8'):netmask() == inet('255.0.0.0')) + + + -- test inet*.__eq + assert(inet('10.0.0.0/30') == inet('10.0.0.0/30'), 'inet4 eq is broken') + assert(inet('10.0.1.0/30') ~= inet('10.0.0.0/30'), 'inet4 eq is broken') + assert(inet('10.0.0.0/31') ~= inet('10.0.0.0/30'), 'inet4 eq is broken') + assert(inet('::1') == inet('::1'), 'inet6 eq is broken') + assert(inet('::1') ~= inet('::2'), 'inet6 eq is broken') + assert(inet('::1/64') ~= inet('::1/56'), 'inet6 eq is broken') + + -- test inet*.ipstring + assert((ip+1):ipstring() == '10.0.0.1', 'ip4 string is broken') + assert(inet('::1/64'):ipstring() == '::1', 'ip6 string is broken') + + -- test inet*.network + assert(inet('10.0.0.1/30'):network() == inet('10.0.0.0/30'), 'inet4.network() is broken') + assert(inet('1::2/64'):network() == inet('1::/64'), 'inet6.network() is broken') + local ip = inet('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff/62') + assert((ip/22):network() == inet('ffff:fc00::/22'), 'inet6.network() is broken') + assert((ip/27):network() == inet('ffff:ffe0::/27'), 'inet6.network() is broken') + + --- test inet4:flip + assert(inet('10.0.0.1/24'):flip() == inet('10.0.1.1/24'), 'inet.flip() is broken') + assert(inet('10.0.0.0/24'):flip() == inet('10.0.1.0/24'), 'inet.flip() is broken') + assert(inet('10.0.0.0/24'):flip():flip() == inet('10.0.0.0/24'), 'inet.flip() is broken') + assert(inet('10.20.30.0/24'):flip() == inet('10.20.31.0/24')) + assert(inet('10.20.30.5/24'):flip() == inet('10.20.31.5/24')) + assert(inet('10.20.30.5/32'):flip() == inet('10.20.30.4/32')) + assert(inet('0.0.0.0/0'):flip() == nil) + local ips = { + inet('::'), + inet('ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff'), + } + assert(inet('::/0'):flip() == nil) + assert(inet('::1/32'):flip() == inet('0:1::1/32')) + assert(inet('::1/48'):flip() == inet('0:0:1::1/48')) + for i=1,#ips do + local ip = ips[i] + for j=1,128 do + local foo = ip / j + local bar = foo:flip() + assert(foo ~= bar) + assert(foo == bar:flip()) + end + end + + -- TODO inet6.__le + -- TODO inet6.__eq +end) diff --git a/test/inet_set.lua b/test/inet_set.lua new file mode 100644 index 0000000..044de09 --- /dev/null +++ b/test/inet_set.lua @@ -0,0 +1,73 @@ +local inet = require 'inet' +local inet_set = require 'inet.set' +local test = require 'test' +local inspect = require 'inspect' + +function agg_set(a, b) + inet_set.aggregate(a) + assert(#a == #b, 'wrong set size') + for i=1,#a do + --print(a[i], b[i]) + assert(a[i] == b[i], 'unexpected network') + end +end + + +return test.new(function() + local ip = inet('10.0.0.0/24') + + agg_set({ + inet('10.0.0.0/24'), + inet('10.0.1.0/24'), + }, { + inet('10.0.0.0/23'), + }) + + agg_set({ + inet('10.0.1.0/24'), + inet('10.0.2.0/24'), + }, { + inet('10.0.1.0/24'), + inet('10.0.2.0/24'), + }) + + agg_set({ + inet('10.0.1.0/24'), + inet('10.0.2.0/24'), + inet('10.0.3.0/24'), + inet('10.0.4.0/24'), + }, { + inet('10.0.1.0/24'), + inet('10.0.2.0/23'), + inet('10.0.4.0/24'), + }) + + agg_set({ + inet('10.0.2.1/24'), + inet('10.0.4.0/24'), + inet('10.0.1.0/24'), + inet('10.0.3.0/24'), + }, { + inet('10.0.2.0/23'), + inet('10.0.4.0/24'), + inet('10.0.1.0/24'), + }) + + agg_set({ + inet('10.0.1.1/24'), + inet('10.0.3.2/24'), + inet('10.0.2.3/24'), + inet('10.0.4.4/24'), + }, { + inet('10.0.1.0/24'), + inet('10.0.2.0/23'), + inet('10.0.4.0/24'), + }) + + agg_set({ + inet('::/32'), + inet('0:1::/32'), + }, { + inet('::/31'), + }) +end) diff --git a/test/init.lua b/test/init.lua new file mode 100644 index 0000000..0a4e05b --- /dev/null +++ b/test/init.lua @@ -0,0 +1,163 @@ +local utils = require 'lem.utils' + +local updatenow = utils.updatenow +local pack = table.pack +local unpack = table.unpack +local format = string.format + +local master + +local function test_assert(v, msg, ...) + if not v then + error(msg or 'assertion failed!', 2) + end + master.assert_cnt = master.assert_cnt + 1 + return v, msg, ... +end + +local function msg_handler(msg) + local trace = debug.traceback(msg, 4) + return trace + --return string.match(trace, '^(.-)\n[^\n]-tester_cut_traceback_here') +end + +local test = {} +test.__index = test + +function test:enter() + self.prev_test = master.activetest + master.activetest = self + self.real_assert = assert + _ENV.assert = test_assert +end + +function test:leave() + master.activetest = self.prev_test + self.prev_test = nil + _ENV.assert = self.real_assert + self.real_assert = nil +end + +function test:depend(t) + if type(t) == 'string' then + t = require('test.'..t) + end + table.insert(self.dependencies, t) +end + +local function tester_cut_traceback_here(self) + local deps = self.dependencies + for i=1,#deps do + local dep = deps[i] + dep:run() + end + if self.test then + self:test() + end +end + +local function tester(...) + -- hack due to the first function called by xpcall + -- not being refered to by name in tracebacks + local ret, msg = tester_cut_traceback_here(...) + return ret, msg +end + +function test:run() + if master == nil then + self:setmaster() + end + if master.have_run[self] then + return self.passed -- TODO return previous result + end + self:enter() + local t1, t2 + t1 = updatenow() + local ret, msg = xpcall(tester, msg_handler, self) + t2 = updatenow() + self.runtime = t2 - t1 + self:leave() + if not ret then + self.error_msg = msg + end + master.have_run[self] = true + self.passed = ret + master.run = master.run + 1 + if ret then + master.passed = master.passed + 1 + else + table.insert(master.failed, self) + end + return ret +end + +function test:setmaster() + assert(master == nil, 'master already set') + master = { + test = self, + assert_cnt = 0, + have_run = {}, + passed = 0, + run = 0, + failed = {}, + } +end + +function test:reset() + if self ~= master.test then + error('you can only run reset on current master task') + end + master = nil +end + +local function new_test(func) + local src = debug.getinfo(func or 2, 'S').short_src + return setmetatable({ + test = func, + dependencies = {}, + source = src, + }, test) +end + +function test:stats() + if self ~= master.test then + error('you can only run stats on current master task') + end + local run, passed, asserts = master.run, master.passed, master.assert_cnt + if run == passed then + print(format('%d/%d All tests passed, %d assertions in %.3fs', + run, passed, asserts, master.test.runtime)) + else + print(format('%d/%d %d failed', run, passed, run-passed)) + end +end + +function test:failed() + if self ~= master.test then + error('you can only run stats on current master task') + end + for i=1,#master.failed do + local t = master.failed[i] + print(t.error_msg) + end +end + +local function assert_fail(cb, ...) + local ret = pack(pcall(cb, ...)) + if ret[1] then assert(false, 'doomed function succeeded') end + return unpack(ret) +end + +local M = {} + +function M.test() + local index = require 'test.all' + index:run() + index:stats() + index:failed() +end + +M.new = new_test +M.fail = assert_fail + +return M |