#!/usr/bin/env lua5.1

require"net"
require"netutil"

hex_dump = h

local function check_ipv4()
 	local n = net.init()
 	n:ipv4{src="1.2.3.1", dst="1.2.3.2", protocol=2, len=20+4, options="AAAA"}
	n:eth{src="01:02:03:04:05:01", dst="01:02:03:04:05:02"}

	--ok=pcall(n.ipv4, n, {src="1.2.3.1", dst="1.2.3.2", protocol=2, len=20+4, options="AAAA"})
	--assert(not ok, "net:ipv4 fails to detect incorrect usage of IPv4 options")
end

local function check_ptag()
	local n = net.init()
	eth = n:eth{src="01:02:03:04:05:01", dst="01:02:03:04:05:02"}
	ok=pcall(n.ipv4, n, {src="1.2.3.1", dst="1.2.3.2", protocol=2, len=20+4, options="AAAA", ptag = eth})
	assert(not ok, "net:ipv4 fails to detect specifying a ptag when creating IPv4 options")
end

local function check_first()
	local n = net.init()
 	ok=pcall(n.ipv4, n, {src="1.2.3.1", dst="1.2.3.2", protocol=2, len=20+4, options="AAAA"})
	eth = n:eth{src="01:02:03:04:05:01", dst="01:02:03:04:05:02"}
 	assert(ok, "net:ipv4 fails to construct ipv4 options correctly")
end

n = net.init()

print(q(net.pton("01:02:03:04:05:06")))
print(q(net.pton("1.2.3.4")))
print(q(net.pton("::1")))

n:udp{src=3, dst=5, payload="XXX"}
ipv4 = n:ipv4{src="1.2.3.1", dst="1.2.3.2", protocol=17, len=20+8+3}
n:eth{src="01:02:03:04:05:01", dst="01:02:03:04:05:02"}

print(n:dump())

b1 = n:block()
print(#b1, q(b1))

n:ipv4{src="1.2.3.1", dst="1.2.3.2", protocol=17, len=20+8+3, ptag=ipv4}

--print(q({src="1.2.3.1", dst="1.2.3.2", protocol=17, len=20+8+3, ptag=ipv4}))

print(n:dump())

b2 = n:block()

print("b1=", q(b1))
print("b2=", q(b2))

assert(b1 == b2)

check_ipv4()
check_ptag()
check_first()

do
    print"\n\n# decode ipv4"

    local n = net.init()

    n:udp{src=1, dst=2, payload=" "}
    n:ipv4{src="1.2.3.1", dst="1.2.3.2", protocol=17, options="AAAA"}

    local b0 = n:block()

    print"= constructed:"
    print(n:dump())
    print("b0", h(b0))

    n:clear()

    print(n:dump())

    print"= decoded:"

    assert(n:decode_ipv4(b0))

    local ip1 = n:block(n:tag_below())
    local b1 = n:block()

    if b0 ~= b1 then
        print(n:dump())
        print("b1", h(b1))
    end

    assert(b0 == b1)

    local bot = assert(n:tag_below())
    local top = assert(n:tag_above())

    print("bot", bot, "top", top)
    assert(bot == 3)
    assert(n:tag_below(bot) == nil)
    assert(n:tag_above(bot) == 2)

    assert(top == 1)
    assert(n:tag_above(top) == nil)
    assert(n:tag_below(top) == 2)

    assert(n:tag_type(bot) == "ipv4", n:tag_type(bot))
    assert(n:tag_type(top) == "udp", n:tag_type(top))

    udpt = n:get_udp()
    udpt.payload = "\0"
    assert(udpt.ptag == top)
    assert(n:udp(udpt))

    local b2 = n:block()
    print(n:dump())
    hex_dump(b2)

    -- everything up to the checksum should be the same
    assert(b1:sub(1, 20+4+6) == b2:sub(1, 20+4+6))
    assert(b1:sub(20+4+7, 20+4+8) ~= b2:sub(20+4+7, 20+4+8))

    assert(#n:block(n:tag_above()) == (8+1))
    assert(n:block(n:tag_below()) == ip1)

    print"+pass"
end

do
    print"\n\n# pblock dump"

    local n = net.init()

    n:udp{src=1, dst=2, payload=" "}
    n:ipv4{src="1.2.3.1", dst="1.2.3.2", protocol=17, len=20+4, options="AAAA"}

    local ptag = n:tag_above()
    while ptag do
        --print("ptag", ptag)
        local pblock = n:pblock(ptag)
        local buf = pblock.buf
        pblock.buf = nil
        print(pblock, h(buf))
        ptag = pblock.next
    end
end

local function hex_dump(s)
    print(h(s))
end

net.IP = {
    UDP = 17,
    TCP = 6,
} -- TODO: add to net

function build_udp()
    local n = net.init()
    n:udp{src=0x88, dst=0x77, payload="xxxxx"}
    assert(not pcall(n.block, n))
    print(n:get_udp())
    n:ipv4{
        protocol=net.IP.UDP, -- TODO: protocol should default to that of the next pblock
        len=20+8+5,
        src="10.1.0.1",
        dst="10.1.0.2",
    }
    local b = n:block()
    print("udp=", h(b))
    print(n:get_udp())
    return b
end

function build_tcp()
    local n = net.init()
    n:tcp{src=0x88, dst=0x77, seq=9, ack=10, flags=0xff, win=11, urg=13, payload="xxxxx"}
    --print(n:get_tcp())
    print("n=") print(n:dump())
    n:ipv4{
        protocol=net.IP.TCP, -- TODO: protocol should default to that of the next pblock
        src="10.1.0.1",
        dst="10.1.0.2",
        options="oooooooo",
    }
    print("n=") print(n:dump())
    print("tcp=", h(n:block()))
    print(n:get_tcp())
    return n:block()
end

-- Build packets:

udp = build_udp()

tcp = build_tcp()

igmp = ""

do
    print("test: trailing garbage")

    local n = net.init()
    assert(n:decode_ip(tcp.."X"))
    local t = n:get_tcp()
    assert(t.payload == "xxxxx")
    print("+ok")
end

do
  print("test: decode truncated data")
  local ethip = "000000111111\8\0"

  local function roundtrip(input)
    local n = net.init()

    print("input", h(input))

    assert(n:decode_eth(input))

    print(n:dump())

    local output = n:block()

    print("output", h(output))
    print(n:dump())

    assert(#output == #input)
    assertmostlyeql(4, output, input) -- strings are allowed to differ in the 4 checksum bytes

    n:destroy()
  end

  local function truncate_eth(eth)
    for i=#eth,0,-1 do
      print("decode "..i.."/"..#eth)
      roundtrip(string.sub(eth, 1,i))
    end
  end

  local function truncate(pkt)
    assert(#ethip == (6+6+2))
    truncate_eth(ethip..pkt)
  end

  truncate(udp)
  truncate(tcp)
  -- empty_udp:
  truncate"000e08dadb3f000c29a2eb6908004500002e00000000ff11a5b40a00010a0a00010113c413c4001ac227"
  -- empty_ip:
  truncate"000e08dadb3f000c29a2eb6908004500001400000000ff11a5b40a00010a0a000101"
  -- empty_eth:
  truncate"000e08dadb3f000c29a2eb690800"
end

do
    print("test: udp reencoding")

    n = net.init()

    n:decode_ipv4(udp)

    print("in", h(udp))
    print(n:dump())

    b0 = n:block()
    udpt = n:get_udp()

    print("b0", h(b0))
    print("b0", udpt)

    -- identity
    n:udp(udpt)
    b1 = n:block()
    udpt = n:get_udp()

    print("b1", h(b1))
    print("b1", udpt)

    assert(b0 == b1)

    -- extend payload
    udpt.payload = udpt.payload.."yy"
    n:udp(udpt)

    b2 = n:block()
    udpt = n:get_udp()

    print("b2", h(b2))
    print("b2", udpt)

    assert(#b2 == #b1+2)
    assert(b2:sub(-2) == "yy")
    assert(udpt.payload == "xxxxxyy")


    -- truncate payload
    udpt.payload = nil

    n:udp(udpt)

    b3 = n:block()
    udpt = n:get_udp()

    print("b3", h(b3))
    print("b3", udpt)

    assert(#b3 == #b1-5)
    assert(udpt.payload == "")

    print"+ok"
end

do
    print("ipv4 reencoding")

    n = net.init()

    n:decode_ipv4(udp)

    print("in", h(udp))
    print(n:dump())

    b0 = n:block()
    argt = n:get_ipv4()

    print("b0", h(b0))
    print("b0", argt)

    -- identity
    n:ipv4(argt)
    b1 = n:block()
    argt = n:get_ipv4()

    print("b1", h(b1))
    print("b1", argt)

    assert(b0 == b1)

    -- extend options
    argt.options = argt.options.."yy"
    n:ipv4(argt)

    b2 = n:block()
    argt = n:get_ipv4()

    otag = n:tag_above(argt.ptag)
    obuf = n:block(otag)

    print("b2", h(b2))
    print("b2", argt)
    print("  otag", otag)
    print("  obuf", h(obuf))

    assert(#b2 == #b1+4)
    assert("yy\0\0" == obuf) -- yy gets padded to 4 bytes
    assert(argt.options == obuf)
    assert(b2:sub(21,24) == obuf)


    -- truncate options
    argt.options = nil

    n:ipv4(argt)

    b3 = n:block()
    argt = n:get_ipv4()

    print("b3", h(b3))
    print("b3", argt)

    assert(b3 == b1)
    assert(argt.options == "")
    assert(n:tag_above(argt.ptag) == n:get_udp().ptag) -- options block is now gone

    print"+ok"
end

do
    print"tcp encoding/decoding"
    n = net.init()
    tcpt = {src=0x88, dst=0x77, seq=9, ack=10, flags=0xff, win=11, urg=13, payload="xxxxx"}
    n:tcp(tcpt)

    t = n:get_tcp()
    print("tcp", tcpt)
    print("get", t)

    for k,v in pairs(tcpt) do
        print("set", k, v, "get", t[k])
        assert(t[k] == v)
    end

    print("n=") print(n:dump())
    n:ipv4{
        protocol=net.IP.TCP, -- TODO: protocol should default to that of the next pblock
        src="10.1.0.1",
        dst="10.1.0.2",
        options="oooooooo",
    }
    t = n:get_tcp()
    print("tcp", tcpt)
    print("get", t)

    for k,v in pairs(tcpt) do
        --print("set", k, v, "get", t[k])
        assert(t[k] == v)
    end

    print("n=") print(n:dump())
    print("tcp=", h(n:block()))
    print(n:get_tcp())
    print"+ok"
end

-- Below is to check libnet's checksum algorithm... libnet appears to work as I
-- expect, but it does not always get the results I find in pcaps, and that
-- wireshark claims is correct.
local function ip_pseudo(ip)
    if ip:sub(10, 10) ~= "\6" then -- it's not tcp!
        return nil, "not tcp"
    end
    return table.concat{
        ip:sub(20-8+1, 20), -- src,dst
        "\0", ip:sub(10, 10),
        net.htons(#ip - 20),
    }
end

local function tcp_pseudo(tcp)
    return table.concat{
        tcp:sub(1,16),
        "\0\0", -- checksum of zero
        tcp:sub(19)
    }
end

local function sum(ip0, nodst)
    local ip = ip_pseudo(ip0, nodst)
    if not ip then
        print "> Can't redo the checksum for non-tcp!"
        return
    end
    local tcp = tcp_pseudo(ip0:sub(21))
    print("_ip", h(ip))
    print("tcp", h(tcp), "pseudolen", ip:sub(-1):byte())
    print("sum", h(net.checksum(ip..tcp)))
end

local function roundtrip_ipv4_from_eth(name, pkt0, checksum_bug)
    print""
    print("test: roundtrip", name)
    local eth0 = b(pkt0)
    local n = net.init()
    assert(n:decode_eth(eth0))
    local eth1 = n:block()

    print("eth0", h(eth0))
    print("eth1", h(eth1))

    local etht = n:get_eth()
    assert(etht)
    assert(not etht.payload)

    local ip0 = string.sub(b(pkt0), 15) -- skip 14 bytes of eth header
    local n = net.init()

    assert(n:decode_ipv4(ip0))

    local ip1 = n:block()
    --print(n:dump())

    print("ip0", h(ip0))
    print("ip1", h(ip1))

    local ptag = n:tag_above()
    while ptag do
        print("pblock", n:pblock(ptag))
        ptag = n:tag_below(ptag)
    end

    sum(ip0)

    ipv4t = n:get_ipv4()

    print("ipv4", ipv4t)

    if ipv4t.protocol == 17 then
        print(n:get_udp())
    elseif ipv4t.protocol == 6 then
        print(n:get_tcp())
    end

    if not checksum_bug then
        assert(ip0 == ip1)
        print"+ok"
    else
        assert(ip0 ~= ip1)
        print"CHECKSUM BUG!!!!!!!!!!!!"
    end
end

roundtrip_ipv4_from_eth("tcp syn", 
    "00006cc0001200e0ed11c80008004500002c831f00003506a3a09780980997809802dac2c350f463c4f30000000060020800d9ae0000020405b4")

roundtrip_ipv4_from_eth("mysql ok response", 
    "000c29161a1e002481ff557a08004508003f0251400040066496c0a8293cc0a8293d0cea9f0f612b1703494bc81a8018005b226c00000101080a576814f102fed42f0700000200000002000000")

roundtrip_ipv4_from_eth("modbus tcp ack", 
    "00005413e6b000e0ed11c7f00800450000283bbb40004006b7c0c0a86301c0a86302c3a601f63a18c8dd00018515501016d004070000")

-- SR - I have absolutely no idea why libnet and I get different answers than
-- what is in the pcap, and which wireshark claims is correct.
roundtrip_ipv4_from_eth("tcp rst",
    "00006cc000be001577d8732a080045000028325200008006a906978098629780981404f9de8e00000000c7c2aca350140000f86a0000000000000000",
    true)

roundtrip_ipv4_from_eth("modbus response", 
    "00e0ed11c7f000005413e6b00800450000310000000040063373c0a86302c0a8630101f6c3a60001850c3a18c8dd50180108157c000001b800000003ff8f0300",
    true)

roundtrip_ipv4_from_eth("snmp_packet",
"08003715e6bc00123f4a33d2080045000052aa1a0000801111c3ac1f1336ac1f13493e2d00a1003e8d4d303402010004067075626c6963a027020127020100020100301c300c06082b060102010105000500300c06082b060102010106000500")

roundtrip_ipv4_from_eth("tcp syn",
  "ffaaaaaaaaff00ff42c772910800450000280000000040067ccb7f0000027f000003800e000100000000000000005e02ffff23ce0000")

-- reencoded as
-- ffaaaaaaaaff00ff42c772910800450000280000000040067ccb7f0000027f000003800e000100000000000000005002ffff31ce0000
-- with different tcp checksum...

