// VoRS -- Vo(IP) Really Simple
// Copyright (C) 2024-2025 Sergey Matveev <stargrave@stargrave.org>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, version 3 of the License.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package main

import (
	"crypto/rand"
	"crypto/sha3"
	"crypto/subtle"
	"flag"
	"fmt"
	"io"
	"log"
	"log/slog"
	"net"
	"net/netip"
	"os"
	"strconv"
	"strings"
	"time"

	"github.com/dchest/siphash"
	"github.com/jroimartin/gocui"
	vors "go.stargrave.org/vors/v6/internal"
	"go.stargrave.org/vors/v6/pqhs"
	"golang.org/x/crypto/chacha20poly1305"
)

var (
	PrvMcEliece []byte
	PrvX25519   []byte
	PubHash     []byte
	Cookies     = map[vors.Cookie]chan *net.UDPAddr{}
)

func newPeer(conn *net.TCPConn) {
	logger := slog.With("remote", conn.RemoteAddr().String())
	logger.Info("connected")
	defer conn.Close()
	err := conn.SetNoDelay(true)
	if err != nil {
		log.Fatalln("nodelay:", err)
	}
	nsConn := vors.NewNSConn(conn, 1<<16)
	buf := <-nsConn.Rx
	if buf == nil {
		logger.Error("read magic", "err", nsConn.Err)
		return
	}
	if string(buf) != vors.Magic {
		logger.Error("handshake: wrong magic")
		return
	}
	buf = <-nsConn.Rx
	if buf == nil {
		logger.Error("handshake: read hello", "err", nsConn.Err)
		return
	}
	hs, buf, err := pqhs.NewServer(PrvMcEliece, PrvX25519, PubHash, buf)
	if err != nil {
		logger.Error("handshake: process hello", "err", err)
		return
	}
	nsConn.Tx(buf)
	buf = <-nsConn.Rx
	if buf == nil {
		logger.Error("handshake: read finish", "err", nsConn.Err)
		return
	}
	buf, err = hs.Read(buf)
	if err != nil {
		logger.Error("handshake: process finish", "err", err)
		return
	}
	peer := &Peer{
		logger: logger,
		conn:   nsConn,
		stats:  &Stats{},
		rx:     make(chan []byte),
		tx:     make(chan []byte, 10),
		alive:  make(chan struct{}),
	}
	{
		var rxKey, txKey []byte
		keys := hs.Keymat(3*chacha20poly1305.KeySize + vors.SipHash24KeySize)
		rxKey, keys = keys[:chacha20poly1305.KeySize], keys[chacha20poly1305.KeySize:]
		txKey, peer.key = keys[:chacha20poly1305.KeySize], keys[chacha20poly1305.KeySize:]
		peer.mac = siphash.New(peer.key[vors.ChaCha20KeySize:])
		peer.rxAEAD, err = chacha20poly1305.New(rxKey)
		if err != nil {
			log.Fatal(err)
		}
		peer.txAEAD, err = chacha20poly1305.New(txKey)
		if err != nil {
			log.Fatal(err)
		}
	}
	peer.rxNonce = make([]byte, chacha20poly1305.NonceSize)
	peer.txNonce = make([]byte, chacha20poly1305.NonceSize)
	var room *Room
	{
		var args [][]byte
		args, err = vors.ArgsDecode(buf)
		if err != nil {
			logger.Error("handshake: decode args", "err", err)
			return
		}
		peer.name = string(args[0])
		roomName := string(args[1])
		key := string(args[2])
		logger = logger.With("name", peer.name, "room", roomName)
		RoomsM.Lock()
		room = Rooms[roomName]
		if room == nil {
			room = &Room{
				name:  roomName,
				key:   key,
				peers: make(map[byte]*Peer),
				alive: make(chan struct{}),
			}
			Rooms[roomName] = room
			RoomsM.Unlock()
			go func() {
				if *NoGUI {
					return
				}
				tick := time.Tick(vors.ScreenRefresh)
				var now time.Time
				var v *gocui.View
				for {
					select {
					case <-room.alive:
						GUI.DeleteView(room.name)
						return
					case now = <-tick:
						v, err = GUI.View(room.name)
						if err == nil {
							v.Clear()
							v.Write([]byte(strings.Join(room.Stats(now), "\n")))
						}
					}
				}
			}()
		} else {
			RoomsM.Unlock()
		}
		if room.key != key {
			logger.Error("wrong password")
			nsConn.Tx(peer.txAEAD.Seal(nil, peer.txNonce, vors.ArgsEncode(
				[]byte(vors.CmdErr), []byte("wrong password"),
			), nil))
			return
		}
	}
	peer.room = room

	room.peersM.RLock()
	for _, p := range room.peers {
		if p.name != peer.name {
			continue
		}
		logger.Error("name already taken")
		nsConn.Tx(peer.txAEAD.Seal(nil, peer.txNonce, vors.ArgsEncode(
			[]byte(vors.CmdErr), []byte("name already taken"),
		), nil))
		room.peersM.RUnlock()
		return
	}
	room.peersM.RUnlock()

	{
		var i byte
		var ok bool
		var found bool
		PeersM.Lock()
		for i = 0; i <= (1<<8)-1; i++ {
			if _, ok = Peers[i]; !ok {
				peer.sid = i
				found = true
				break
			}
		}
		if found {
			Peers[peer.sid] = peer
			go peer.Tx()
		}
		PeersM.Unlock()
		if !found {
			nsConn.Tx(peer.txAEAD.Seal(nil, peer.txNonce, vors.ArgsEncode(
				[]byte(vors.CmdErr), []byte("too many users"),
			), nil))
			return
		}
	}
	logger = logger.With("sid", peer.sid)
	room.peersM.Lock()
	room.peers[peer.sid] = peer
	room.peersM.Unlock()
	logger.Info("logged in")

	defer func() {
		logger.Info("removing")
		PeersM.Lock()
		delete(Peers, peer.sid)
		room.peersM.Lock()
		delete(room.peers, peer.sid)
		room.peersM.Unlock()
		PeersM.Unlock()
		s := vors.ArgsEncode([]byte(vors.CmdDel), []byte{peer.sid})
		room.peersM.RLock()
		for _, p := range room.peers {
			p.tx <- s
		}
		room.peersM.RUnlock()
	}()

	{
		var cookie vors.Cookie
		if _, err = io.ReadFull(rand.Reader, cookie[:]); err != nil {
			log.Fatalln("cookie:", err)
		}
		gotCookie := make(chan *net.UDPAddr)
		Cookies[cookie] = gotCookie

		err = nsConn.Tx(peer.txAEAD.Seal(nil, peer.txNonce, vors.ArgsEncode(
			[]byte(vors.CmdCookie), cookie[:],
		), nil))
		if err != nil {
			logger.Error("handshake write", "err", err)
			delete(Cookies, cookie)
			return
		}

		timeout := time.NewTimer(vors.PingTime)
		select {
		case peer.addr = <-gotCookie:
		case <-timeout.C:
			logger.Error("cookie timeout")
			delete(Cookies, cookie)
			return
		}
		delete(Cookies, cookie)
		if !timeout.Stop() {
			<-timeout.C
		}
	}
	go peer.Rx()
	peer.tx <- vors.ArgsEncode([]byte(vors.CmdSID), []byte{peer.sid})

	room.peersM.RLock()
	for _, p := range room.peers {
		if p.sid == peer.sid {
			continue
		}
		peer.tx <- vors.ArgsEncode(
			[]byte(vors.CmdAdd), []byte{p.sid}, []byte(p.name), p.key)
	}
	room.peersM.RUnlock()

	{
		s := vors.ArgsEncode(
			[]byte(vors.CmdAdd), []byte{peer.sid}, []byte(peer.name), peer.key)
		room.peersM.RLock()
		for _, p := range room.peers {
			if p.sid != peer.sid {
				p.tx <- s
			}
		}
		room.peersM.RUnlock()
	}

	seen := time.Now()
	go func(seen *time.Time) {
		ticker := time.Tick(vors.PingTime)
		var now time.Time
		for {
			select {
			case now = <-ticker:
				if seen.Add(2 * vors.PingTime).Before(now) {
					logger.Error("timeout", "seen", seen)
					peer.Close()
					return
				}
			case <-peer.alive:
				return
			}
		}
	}(&seen)

	for buf := range peer.rx {
		args, err := vors.ArgsDecode(buf)
		if err != nil {
			logger.Error("decode args", "err", err)
			break
		}
		if len(args) == 0 {
			logger.Error("empty args")
			break
		}
		seen = time.Now()
		switch cmd := string(args[0]); cmd {
		case vors.CmdPing:
			peer.tx <- vors.ArgsEncode([]byte(vors.CmdPong))
		case vors.CmdMuted:
			peer.muted = true
			s := vors.ArgsEncode([]byte(vors.CmdMuted), []byte{peer.sid})
			room.peersM.RLock()
			for _, p := range room.peers {
				if p.sid != peer.sid {
					p.tx <- s
				}
			}
			room.peersM.RUnlock()
		case vors.CmdUnmuted:
			peer.muted = false
			s := vors.ArgsEncode([]byte(vors.CmdUnmuted), []byte{peer.sid})
			room.peersM.RLock()
			for _, p := range room.peers {
				if p.sid != peer.sid {
					p.tx <- s
				}
			}
			room.peersM.RUnlock()
		case vors.CmdChat:
			if len(args) != 2 {
				logger.Error("wrong len(args)")
				continue
			}
			s := vors.ArgsEncode([]byte(vors.CmdChat), []byte{peer.sid}, args[1])
			room.peersM.RLock()
			for _, p := range room.peers {
				if p.sid != peer.sid {
					p.tx <- s
				}
			}
			room.peersM.RUnlock()
		default:
			logger.Error("unknown", "cmd", cmd)
		}
	}
}

func main() {
	bind := flag.String("bind", "[::1]:"+strconv.Itoa(vors.DefaultPort),
		"host:TCP/UDP port to listen on")
	pubFile := flag.String("pub", "pub", "path to file with public key")
	prvFile := flag.String("prv", "prv", "path to file with private key")
	prefer4 := flag.Bool("4", false,
		"Prefer obsolete legacy IPv4 address during name resolution")
	version := flag.Bool("version", false, "print version")
	warranty := flag.Bool("warranty", false, "print warranty information")
	flag.Usage = func() {
		fmt.Fprintln(os.Stderr, "Usage: vors-server [opts] -bind HOST:PORT -prv PRV -pub PUB -srv HOST:PORT")
		flag.PrintDefaults()
		fmt.Fprintln(os.Stderr, `
List of known rooms is shown by default. If room requires password
authentication, then "protected" is written nearby. Each room's member
username and IP address is shown, together with various statistics:
number of received, transmitted packets, number of bad packets (failed
authentication), amount of traffic.
Green "T" means that recently an audio packet was received.
Red "M" means that peer is in muted mode.
Press F10 to quit.`)
	}
	flag.Parse()
	log.SetFlags(log.Lmicroseconds | log.Lshortfile)

	if *warranty {
		fmt.Println(vors.Warranty)
		return
	}
	if *version {
		fmt.Println(vors.GetVersion())
		return
	}

	{
		prv, err := os.ReadFile(*prvFile)
		if err != nil {
			log.Fatal(err)
		}
		PrvMcEliece, PrvX25519 = pqhs.KeyDecombine(prv)
		pub, err := os.ReadFile(*pubFile)
		if err != nil {
			log.Fatal(err)
		}
		PubHash = sha3.SumSHAKE256(pub, 64)
	}

	vors.PreferIPv4 = *prefer4
	lnTCP, err := net.ListenTCP("tcp",
		net.TCPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
	if err != nil {
		log.Fatal(err)
	}
	lnUDP, err := net.ListenUDP("udp",
		net.UDPAddrFromAddrPort(netip.MustParseAddrPort(*bind)))
	if err != nil {
		log.Fatal(err)
	}

	LoggerReady := make(chan struct{})
	if *NoGUI {
		close(GUIReadyC)
		slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
		close(LoggerReady)
	} else {
		GUI, err = gocui.NewGui(gocui.OutputNormal)
		if err != nil {
			log.Fatal(err)
		}
		defer GUI.Close()
		GUI.SetManagerFunc(guiLayout)
		if err = GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone,
			func(g *gocui.Gui, v *gocui.View) error {
				go func() {
					time.Sleep(100 * time.Millisecond)
					os.Exit(0)
				}()
				return gocui.ErrQuit
			}); err != nil {
			log.Fatal(err)
		}

		go func() {
			<-GUIReadyC
			var v *gocui.View
			v, err = GUI.View("logs")
			if err != nil {
				log.Fatal(err)
			}
			slog.SetDefault(slog.New(slog.NewTextHandler(v, nil)))
			close(LoggerReady)
			for {
				time.Sleep(vors.ScreenRefresh)
				GUI.Update(func(gui *gocui.Gui) error {
					return nil
				})
			}
		}()
	}

	go func() {
		<-LoggerReady
		buf := make([]byte, 2*vors.FrameLen)
		var n int
		var from *net.UDPAddr
		var err error
		var sid byte
		var peer *Peer
		tag := make([]byte, siphash.Size)
		for {
			n, from, err = lnUDP.ReadFromUDP(buf)
			if err != nil {
				log.Fatalln("recvfrom:", err)
			}

			if n == vors.CookieLen {
				var cookie vors.Cookie
				copy(cookie[:], buf)
				if c, ok := Cookies[cookie]; ok {
					c <- from
					close(c)
					continue
				}
			}

			sid = buf[0]
			peer = Peers[sid]
			if peer == nil {
				slog.Info("unknown", "sid", sid, "from", from)
				continue
			}

			if peer.addr == nil ||
				from.Port != peer.addr.Port ||
				!from.IP.Equal(peer.addr.IP) {
				slog.Info("wrong addr",
					"peer", peer.name,
					"our", peer.addr,
					"got", from)
				continue
			}

			peer.stats.pktsRx++
			peer.stats.bytesRx += vors.IPHdrLen(from.IP) + 8 + uint64(n)
			if n == 1 {
				continue
			}
			if n <= 4+siphash.Size {
				peer.stats.bads++
				continue
			}

			peer.mac.Reset()
			if _, err = peer.mac.Write(buf[:n-siphash.Size]); err != nil {
				log.Fatal(err)
			}
			peer.mac.Sum(tag[:0])
			if subtle.ConstantTimeCompare(
				tag[:siphash.Size],
				buf[n-siphash.Size:n],
			) != 1 {
				peer.stats.bads++
				continue
			}

			peer.stats.last = time.Now()
			peer.room.peersM.RLock()
			for _, p := range peer.room.peers {
				if p.sid == sid || p.addr == nil {
					continue
				}
				p.stats.pktsTx++
				p.stats.bytesTx += vors.IPHdrLen(p.addr.IP) + 8 + uint64(n)
				if _, err = lnUDP.WriteToUDP(buf[:n], p.addr); err != nil {
					slog.Warn("sendto", "peer", peer.name, "err", err)
				}
			}
			peer.room.peersM.RUnlock()
		}
	}()

	go func() {
		<-LoggerReady
		slog.Info("listening", "bind", *bind)
		for {
			conn, errConn := lnTCP.AcceptTCP()
			if err != nil {
				log.Fatalln("accept:", errConn)
			}
			go newPeer(conn)
		}
	}()

	if *NoGUI {
		<-make(chan struct{})
	}
	err = GUI.MainLoop()
	if err != nil && err != gocui.ErrQuit {
		log.Fatal(err)
	}
}
