// VoRS -- Vo(IP) Really Simple
// Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, version 3 of the License.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package main

import (
	"crypto/rand"
	"crypto/subtle"
	"crypto/tls"
	"encoding/base64"
	"encoding/hex"
	"flag"
	"fmt"
	"io"
	"log"
	"log/slog"
	"net"
	"net/netip"
	"os"
	"strconv"
	"strings"
	"time"

	"github.com/dchest/siphash"
	"github.com/flynn/noise"
	"github.com/jroimartin/gocui"
	vors "go.stargrave.org/vors/internal"
	"golang.org/x/crypto/blake2s"
)

var (
	TLSCfg = &tls.Config{
		MinVersion:       tls.VersionTLS13,
		CurvePreferences: []tls.CurveID{tls.X25519},
	}
	Prv, Pub []byte
	Cookies  = map[vors.Cookie]chan *net.UDPAddr{}
)

func newPeer(conn *net.TCPConn) {
	logger := slog.With("remote", conn.RemoteAddr().String())
	logger.Info("connected")
	defer conn.Close()
	err := conn.SetNoDelay(true)
	if err != nil {
		log.Fatalln("nodelay:", err)
	}
	buf := make([]byte, len(vors.NoisePrologue))

	if _, err = io.ReadFull(conn, buf); err != nil {
		logger.Error("handshake: read prologue", "err", err)
		return
	}
	if string(buf) != vors.NoisePrologue {
		logger.Error("handshake: wrong prologue", "err", err)
		return
	}

	hs, err := noise.NewHandshakeState(noise.Config{
		CipherSuite:   vors.NoiseCipherSuite,
		Pattern:       noise.HandshakeNK,
		Initiator:     false,
		StaticKeypair: noise.DHKey{Private: Prv, Public: Pub},
		Prologue:      []byte(vors.NoisePrologue),
	})
	if err != nil {
		log.Fatalln("noise.NewHandshakeState:", err)
	}
	buf, err = vors.PktRead(conn)
	if err != nil {
		logger.Error("read handshake", "err", err)
		return
	}
	peer := &Peer{
		logger: logger,
		conn:   conn,
		stats:  &Stats{},
		rx:     make(chan []byte),
		tx:     make(chan []byte, 10),
		alive:  make(chan struct{}),
	}
	var room *Room
	{
		nameAndRoom, _, _, err := hs.ReadMessage(nil, buf)
		if err != nil {
			logger.Error("handshake: decrypt", "err", err)
			return
		}
		cols := strings.SplitN(string(nameAndRoom), " ", 3)
		roomName := "/"
		if len(cols) > 1 {
			roomName = cols[1]
		}
		var key string
		if len(cols) > 2 {
			key = cols[2]
		}
		peer.name = string(cols[0])
		logger = logger.With("name", peer.name, "room", roomName)
		RoomsM.Lock()
		room = Rooms[roomName]
		if room == nil {
			room = &Room{
				name:  roomName,
				key:   key,
				peers: make(map[byte]*Peer),
				alive: make(chan struct{}),
			}
			Rooms[roomName] = room
			go func() {
				if *NoGUI {
					return
				}
				tick := time.Tick(vors.ScreenRefresh)
				var now time.Time
				var v *gocui.View
				for {
					select {
					case <-room.alive:
						GUI.DeleteView(room.name)
						return
					case now = <-tick:
						v, err = GUI.View(room.name)
						if err == nil {
							v.Clear()
							v.Write([]byte(strings.Join(room.Stats(now), "\n")))
						}
					}
				}
			}()
		}
		RoomsM.Unlock()
		if room.key != key {
			logger.Error("wrong password")
			buf, _, _, err = hs.WriteMessage(nil, []byte("wrong password"))
			if err != nil {
				log.Fatal(err)
			}
			vors.PktWrite(conn, buf)
			return
		}
	}
	peer.room = room

	for _, p := range room.peers {
		if p.name != peer.name {
			continue
		}
		logger.Error("name already taken")
		buf, _, _, err = hs.WriteMessage(nil, []byte("name already taken"))
		if err != nil {
			log.Fatal(err)
		}
		vors.PktWrite(conn, buf)
		return
	}

	{
		var i byte
		var ok bool
		var found bool
		PeersM.Lock()
		for i = 0; i <= (1<<8)-1; i++ {
			if _, ok = Peers[i]; !ok {
				peer.sid = i
				found = true
				break
			}
		}
		if found {
			Peers[peer.sid] = peer
			go peer.Tx()
		}
		PeersM.Unlock()
		if !found {
			buf, _, _, err = hs.WriteMessage(nil, []byte("too many users"))
			if err != nil {
				log.Fatal(err)
			}
			vors.PktWrite(conn, buf)
			return
		}
	}
	logger = logger.With("sid", peer.sid)
	room.peers[peer.sid] = peer
	logger.Info("logged in")

	defer func() {
		logger.Info("removing")
		PeersM.Lock()
		delete(Peers, peer.sid)
		delete(room.peers, peer.sid)
		PeersM.Unlock()
		s := []byte(fmt.Sprintf("%s %d", vors.CmdDel, peer.sid))
		for _, p := range room.peers {
			go func(tx chan []byte) { tx <- s }(p.tx)
		}
	}()

	{
		var cookie vors.Cookie
		if _, err = io.ReadFull(rand.Reader, cookie[:]); err != nil {
			log.Fatalln("cookie:", err)
		}
		gotCookie := make(chan *net.UDPAddr)
		Cookies[cookie] = gotCookie

		var txCS, rxCS *noise.CipherState
		buf, txCS, rxCS, err := hs.WriteMessage(nil,
			[]byte(fmt.Sprintf("OK %s", hex.EncodeToString(cookie[:]))))
		if err != nil {
			log.Fatalln("hs.WriteMessage:", err)
		}
		if err = vors.PktWrite(conn, buf); err != nil {
			logger.Error("handshake write", "err", err)
			delete(Cookies, cookie)
			return
		}
		peer.rxCS, peer.txCS = txCS, rxCS

		timeout := time.NewTimer(vors.PingTime)
		select {
		case peer.addr = <-gotCookie:
		case <-timeout.C:
			logger.Error("cookie timeout")
			delete(Cookies, cookie)
			return
		}
		delete(Cookies, cookie)
		logger.Info("got cookie", "addr", peer.addr)
		if !timeout.Stop() {
			<-timeout.C
		}
	}
	go peer.Rx()
	peer.tx <- []byte(fmt.Sprintf("SID %d", peer.sid))

	for _, p := range room.peers {
		if p.sid == peer.sid {
			continue
		}
		peer.tx <- []byte(fmt.Sprintf("%s %d %s %s",
			vors.CmdAdd, p.sid, p.name, hex.EncodeToString(p.key)))
	}

	{
		xof, err := blake2s.NewXOF(32+16, nil)
		if err != nil {
			log.Fatalln(err)
		}
		xof.Write([]byte(vors.NoisePrologue))
		xof.Write(hs.ChannelBinding())
		peer.key = make([]byte, 32+16)
		if _, err = io.ReadFull(xof, peer.key); err != nil {
			log.Fatalln(err)
		}
		peer.mac = siphash.New(peer.key[32:])
	}

	{
		s := []byte(fmt.Sprintf("%s %d %s %s",
			vors.CmdAdd, peer.sid, peer.name, hex.EncodeToString(peer.key)))
		for _, p := range room.peers {
			if p.sid != peer.sid {
				p.tx <- s
			}
		}
	}

	seen := time.Now()
	go func(seen *time.Time) {
		ticker := time.Tick(vors.PingTime)
		var now time.Time
		for {
			select {
			case now = <-ticker:
				if seen.Add(2 * vors.PingTime).Before(now) {
					logger.Error("timeout:", "seen", seen)
					peer.Close()
					return
				}
			case <-peer.alive:
				return
			}
		}
	}(&seen)

	for buf := range peer.rx {
		if string(buf) == vors.CmdPing {
			seen = time.Now()
			peer.tx <- []byte(vors.CmdPong)
		}
	}
}

func main() {
	bind := flag.String("bind", "[::1]:"+strconv.Itoa(vors.DefaultPort),
		"host:TCP/UDP port to listen on")
	kpFile := flag.String("key", "key", "path to keypair file")
	prefer4 := flag.Bool("4", false,
		"Prefer obsolete legacy IPv4 address during name resolution")
	version := flag.Bool("version", false, "print version")
	warranty := flag.Bool("warranty", false, "print warranty information")
	flag.Parse()
	log.SetFlags(log.Lmicroseconds | log.Lshortfile)

	if *warranty {
		fmt.Println(vors.Warranty)
		return
	}
	if *version {
		fmt.Println(vors.GetVersion())
		return
	}

	{
		data, err := os.ReadFile(*kpFile)
		if err != nil {
			log.Fatal(err)
		}
		Prv, Pub = data[:len(data)/2], data[len(data)/2:]
	}

	vors.PreferIPv4 = *prefer4
	lnTCP, err := net.ListenTCP("tcp",
		net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
	if err != nil {
		log.Fatal(err)
	}
	lnUDP, err := net.ListenUDP("udp",
		net.UDPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
	if err != nil {
		log.Fatal(err)
	}

	LoggerReady := make(chan struct{})
	if *NoGUI {
		close(GUIReadyC)
		slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
		close(LoggerReady)
	} else {
		GUI, err = gocui.NewGui(gocui.OutputNormal)
		if err != nil {
			log.Fatal(err)
		}
		defer GUI.Close()
		GUI.SetManagerFunc(guiLayout)
		if err := GUI.SetKeybinding("", 'q', gocui.ModNone, guiQuit); err != nil {
			log.Fatal(err)
		}

		go func() {
			<-GUIReadyC
			v, err := GUI.View("logs")
			if err != nil {
				log.Fatal(err)
			}
			slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
			close(LoggerReady)
			for {
				time.Sleep(vors.ScreenRefresh)
				GUI.Update(func(gui *gocui.Gui) error {
					return nil
				})
			}
		}()
	}

	go func() {
		<-LoggerReady
		buf := make([]byte, 2*vors.FrameLen)
		var n int
		var from *net.UDPAddr
		var err error
		var sid byte
		var peer *Peer
		tag := make([]byte, siphash.Size)
		for {
			n, from, err = lnUDP.ReadFromUDP(buf)
			if err != nil {
				log.Fatalln("recvfrom:", err)
			}

			if n == vors.CookieLen {
				var cookie vors.Cookie
				copy(cookie[:], buf)
				if c, ok := Cookies[cookie]; ok {
					c <- from
					close(c)
					continue
				}
			}

			sid = buf[0]
			peer = Peers[sid]
			if peer == nil {
				slog.Info("unknown", "sid", sid, "from", from)
				continue
			}

			if peer.addr == nil ||
				from.Port != peer.addr.Port ||
				!from.IP.Equal(peer.addr.IP) {
				slog.Info("wrong addr",
					"peer", peer.name,
					"our", peer.addr,
					"got", from)
				continue
			}

			peer.stats.pktsRx++
			peer.stats.bytesRx += vors.IPHdrLen(from.IP) + 8 + uint64(n)
			if n == 1 {
				continue
			}
			if n <= 4+siphash.Size {
				peer.stats.bads++
				continue
			}

			peer.mac.Reset()
			if _, err = peer.mac.Write(buf[:n-siphash.Size]); err != nil {
				log.Fatal(err)
			}
			peer.mac.Sum(tag[:0])
			if subtle.ConstantTimeCompare(
				tag[:siphash.Size],
				buf[n-siphash.Size:n],
			) != 1 {
				peer.stats.bads++
				continue
			}

			peer.stats.last = time.Now()
			for _, p := range peer.room.peers {
				if p.sid == sid || p.addr == nil {
					continue
				}
				p.stats.pktsTx++
				p.stats.bytesTx += vors.IPHdrLen(p.addr.IP) + 8 + uint64(n)
				if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
					slog.Warn("sendto", "peer", peer.name, "err", err)
				}
			}
		}
	}()

	go func() {
		<-LoggerReady
		slog.Info("listening",
			"bind", *bind,
			"pub", base64.RawURLEncoding.EncodeToString(Pub))
		for {
			conn, err := lnTCP.AcceptTCP()
			if err != nil {
				log.Fatalln("accept:", err)
			}
			go newPeer(conn)
		}
	}()

	if *NoGUI {
		<-make(chan struct{})
	}
	err = GUI.MainLoop()
	if err != nil && err != gocui.ErrQuit {
		log.Fatal(err)
	}
}
