// VoRS -- Vo(IP) Really Simple
// Copyright (C) 2024-2025 Sergey Matveev <stargrave@stargrave.org>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, version 3 of the License.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package main

import (
	"bytes"
	"crypto/cipher"
	"crypto/sha3"
	"crypto/subtle"
	"encoding/binary"
	"flag"
	"fmt"
	"io"
	"log"
	"log/slog"
	"net"
	"os"
	"os/exec"
	"strconv"
	"strings"
	"sync"
	"time"

	"github.com/aead/chacha20"
	"github.com/dchest/siphash"
	"github.com/jroimartin/gocui"
	"go.stargrave.org/opus/v2"
	vors "go.stargrave.org/vors/v6/internal"
	"go.stargrave.org/vors/v6/pqhs"
	"golang.org/x/crypto/chacha20poly1305"
)

type Stream struct {
	in       chan []byte
	stats    *Stats
	name     string
	ctr      uint32
	actr     uint32
	muted    bool
	silenced bool
}

var (
	Streams  = map[byte]*Stream{}
	StreamsM sync.RWMutex
	Finish   = make(chan struct{})
	OurStats = &Stats{dead: make(chan struct{})}
	Name     = flag.String("name", "test", "username")
	Room     = flag.String("room", "/", "room name")
	Muted    bool
	Ctrl     = make(chan []byte)
)

func muteToggle() (muted bool) {
	Muted = !Muted
	if Ctrl != nil {
		var cmd string
		if Muted {
			cmd = vors.CmdMuted
		} else {
			cmd = vors.CmdUnmuted
		}
		Ctrl <- vors.ArgsEncode([]byte(cmd))
	}
	return Muted
}

func main() {
	srvAddr := flag.String("srv", "vors.home.arpa:"+strconv.Itoa(vors.DefaultPort),
		"host:TCP/UDP port to connect to")
	srvPubPth := flag.String("pub", "", "Path to server's public key")
	recCmd := flag.String("rec", "rec "+vors.SoxParams, "rec command")
	playCmd := flag.String("play", "play "+vors.SoxParams, "play command")
	vadRaw := flag.Uint("vad", 0, "VAD threshold")
	passwd := flag.String("passwd", "", "protected room's password")
	muteTogglePth := flag.String("mute-toggle", "",
		"path to FIFO to toggle mute")
	prefer4 := flag.Bool("4", false,
		"Prefer obsolete legacy IPv4 address during name resolution")
	version := flag.Bool("version", false, "print version")
	warranty := flag.Bool("warranty", false, "print warranty information")
	flag.Usage = func() {
		fmt.Fprintln(os.Stderr, "Usage: vors-client [opts] -name NAME -pub PUB -srv HOST:PORT")
		flag.PrintDefaults()
		fmt.Fprintln(os.Stderr, `
Press Tab to cycle through peers and chat windows. Pressing Enter in a
peer window toggles silencing (no audio will be played from it). Chat
windows allows you to enter the text and send it to everyone in the room
by pressing Enter.

Press F1 to toggle mute -- no sending of microphone audio to server).
Press F10 to quit.

Each peer contains various statistics: number of packets received from
it (or sent, if it is you), traffic amount, number of silence seconds,
number of bad packets (malformed or altered, number of lost packets,
number of reordered packets.
Gree "T" means that recently an audio packet was received.
Red "M" means that peer is in muted mode.
Magenta "S" means that peer is locally muted.`)
	}
	flag.Parse()
	log.SetFlags(log.Lmicroseconds)

	if *warranty {
		fmt.Println(vors.Warranty)
		return
	}
	if *version {
		fmt.Println(vors.GetVersion())
		return
	}

	var passwdHsh []byte
	if *passwd != "" {
		hsh := sha3.SumSHAKE128([]byte(*passwd), 32)
		passwdHsh = hsh[:]
	}

	srvPub, err := os.ReadFile(*srvPubPth)
	if err != nil {
		log.Fatal(err)
	}
	*Name = strings.ReplaceAll(*Name, " ", "-")

	go func() {
		if *muteTogglePth == "" {
			return
		}
		for {
			fd, err := os.OpenFile(*muteTogglePth, os.O_WRONLY, os.FileMode(0o666))
			if err != nil {
				log.Fatalln(err)
			}
			var reply string
			if muteToggle() {
				reply = "muted"
			} else {
				reply = "unmuted"
			}
			fd.WriteString(reply + "\n")
			fd.Close()
			time.Sleep(time.Second)
		}
	}()

	vad := uint64(*vadRaw)
	opusEnc := newOpusEnc()
	var mic io.ReadCloser
	if *recCmd != "" {
		cmd := vors.MakeCmd(*recCmd)
		mic, err = cmd.StdoutPipe()
		if err != nil {
			log.Fatal(err)
		}
		err = cmd.Start()
		if err != nil {
			log.Fatal(err)
		}
	}

	vors.PreferIPv4 = *prefer4
	hs, buf, err := pqhs.NewClient(pqhs.KeyDecombine(srvPub))
	if err != nil {
		log.Fatalln("pqhs.NewClient:", err)
	}

	ctrlConn, err := net.DialTCP("tcp", nil, vors.MustResolveTCP(*srvAddr))
	if err != nil {
		log.Fatalln("dial server:", err)
	}
	defer ctrlConn.Close()
	if err = ctrlConn.SetNoDelay(true); err != nil {
		log.Fatalln("nodelay:", err)
	}

	ctrl := vors.NewNSConn(ctrlConn, 1<<16)
	if err = ctrl.Tx([]byte(vors.Magic)); err != nil {
		log.Fatalln("write handshake magic:", err)
	}
	if err = ctrl.Tx(buf); err != nil {
		log.Fatalln("write handshake hello:", err)
	}
	buf = <-ctrl.Rx
	if buf == nil {
		log.Fatalln("read handshake hello:", ctrl.Err)
	}
	buf, err = hs.Read(buf, vors.ArgsEncode(
		[]byte(*Name), []byte(*Room), passwdHsh,
	))
	if err != nil {
		log.Fatalln("process handshake:", err)
	}
	if err = ctrl.Tx(buf); err != nil {
		log.Fatalln("write handshake finish:", err)
	}

	var txKey, rxKey, keyCiphOur, keyMACOur []byte
	var txAEAD, rxAEAD cipher.AEAD
	keys := hs.Keymat(3*chacha20poly1305.KeySize + vors.SipHash24KeySize)
	txKey, keys = keys[:chacha20poly1305.KeySize], keys[chacha20poly1305.KeySize:]
	rxKey, keys = keys[:chacha20poly1305.KeySize], keys[chacha20poly1305.KeySize:]
	keyCiphOur, keyMACOur = keys[:vors.ChaCha20KeySize], keys[vors.ChaCha20KeySize:]
	txAEAD, err = chacha20poly1305.New(txKey)
	if err != nil {
		log.Fatal(err)
	}
	rxAEAD, err = chacha20poly1305.New(rxKey)
	if err != nil {
		log.Fatal(err)
	}
	txNonce := make([]byte, chacha20poly1305.NonceSize)
	rxNonce := make([]byte, chacha20poly1305.NonceSize)

	buf = <-ctrl.Rx
	if buf == nil {
		log.Fatalln("read handshake finish:", ctrl.Err)
	}
	buf, err = rxAEAD.Open(buf[:0], rxNonce, buf, nil)
	if err != nil {
		log.Fatalln("handshake decrypt:", err)
	}

	rx := make(chan []byte)
	go func() {
		for buf := range ctrl.Rx {
			buf, err = rxAEAD.Open(buf[:0], rxNonce, buf, nil)
			if err != nil {
				log.Println("rx decrypt", err)
				break
			}
			rx <- buf
			vors.Incr(rxNonce)
		}
		Finish <- struct{}{}
	}()

	srvAddrUDP := vors.MustResolveUDP(*srvAddr)
	conn, err := net.DialUDP("udp", nil, srvAddrUDP)
	if err != nil {
		log.Fatalln("connect:", err)
	}
	var sidConnected byte
	{
		var args [][]byte
		args, err = vors.ArgsDecode(buf)
		if err != nil {
			log.Fatalln("args decode:", err)
		}
		if len(args) < 2 {
			log.Fatalln("empty args")
		}
		var cookie vors.Cookie
		switch cmd := string(args[0]); cmd {
		case vors.CmdErr:
			log.Fatalln("handshake failed:", string(args[1]))
		case vors.CmdCookie:
			copy(cookie[:], args[1])
		default:
			log.Fatalln("unexpected post-handshake cmd:", cmd)
		}
		timeout := time.NewTimer(vors.PingTime)
		defer func() {
			if !timeout.Stop() {
				<-timeout.C
			}
		}()
		ticker := time.NewTicker(time.Second)
		if _, err = conn.Write(cookie[:]); err != nil {
			log.Fatalln("write:", err)
		}
	WaitForCookieAcceptance:
		for {
			select {
			case <-timeout.C:
				log.Fatalln("cookie acceptance timeout")
			case <-ticker.C:
				if _, err = conn.Write(cookie[:]); err != nil {
					log.Fatalln("write:", err)
				}
			case buf = <-rx:
				var args [][]byte
				args, err = vors.ArgsDecode(buf)
				if err != nil {
					log.Fatalln("args decode:", err)
				}
				if len(args) < 2 {
					log.Fatalln("empty args")
				}
				switch cmd := string(args[0]); cmd {
				case vors.CmdErr:
					log.Fatalln("cookie acceptance failed:", string(args[1]))
				case vors.CmdSID:
					sidConnected = args[1][0]
					StreamsM.Lock()
					Streams[sidConnected] = &Stream{name: *Name, stats: OurStats}
					StreamsM.Unlock()
				default:
					log.Fatalln("unexpected post-cookie cmd:", cmd)
				}
				break WaitForCookieAcceptance
			}
		}
		if !timeout.Stop() {
			<-timeout.C
		}
	}

	seen := time.Now()

	LoggerReady := make(chan struct{})
	if *NoGUI {
		close(GUIReadyC)
		slog.SetDefault(slog.New(slog.NewTextHandler(os.Stderr, nil)))
		log.Println("connected", "sid:", sidConnected,
			"addr:", conn.LocalAddr().String())
		close(LoggerReady)
	} else {
		GUI, err = gocui.NewGui(gocui.OutputNormal)
		if err != nil {
			log.Fatal(err)
		}
		defer GUI.Close()
		GUI.SelFgColor = gocui.ColorCyan
		GUI.Highlight = true
		GUI.SetManagerFunc(guiLayout)
		if err = GUI.SetKeybinding("", gocui.KeyTab, gocui.ModNone, tabHandle); err != nil {
			log.Fatal(err)
		}
		if err = GUI.SetKeybinding("", gocui.KeyF1, gocui.ModNone,
			func(gui *gocui.Gui, v *gocui.View) error {
				muteToggle()
				return nil
			},
		); err != nil {
			log.Fatal(err)
		}
		if err = GUI.SetKeybinding("", gocui.KeyF10, gocui.ModNone,
			func(gui *gocui.Gui, v *gocui.View) error {
				Finish <- struct{}{}
				return gocui.ErrQuit
			},
		); err != nil {
			log.Fatal(err)
		}
		go func() {
			<-GUIReadyC
			v, errView := GUI.View("logs")
			if errView != nil {
				log.Fatal(errView)
			}
			log.SetOutput(v)
			log.Println("connected", "sid:", sidConnected,
				"addr:", conn.LocalAddr().String())
			close(LoggerReady)
			for {
				time.Sleep(vors.ScreenRefresh)
				GUI.Update(func(gui *gocui.Gui) error {
					return nil
				})
			}
		}()
	}

	go func() {
		<-Finish
		if !*NoGUI {
			go GUI.Close()
			time.Sleep(100 * time.Millisecond)
		}
		os.Exit(0)
	}()

	go func() {
		for buf := range Ctrl {
			buf = txAEAD.Seal(nil, txNonce, buf, nil)
			if err = ctrl.Tx(buf); err != nil {
				log.Fatalln("tx:", err)
			}
			vors.Incr(txNonce)
		}
	}()

	go func() {
		for {
			time.Sleep(vors.PingTime)
			Ctrl <- vors.ArgsEncode([]byte(vors.CmdPing))
		}
	}()

	go func(seen *time.Time) {
		var now time.Time
		for buf := range rx {
			args, err := vors.ArgsDecode(buf)
			if err != nil {
				log.Fatalln("args decode:", err)
			}
			if len(args) == 0 {
				log.Fatalln("empty args")
			}
			switch cmd := string(args[0]); cmd {
			case vors.CmdPong:
				now = time.Now()
				*seen = now
			case vors.CmdAdd:
				sidRaw, name, key := args[1], args[2], args[3]
				sid := sidRaw[0]
				printBell()
				log.Println("add", string(name), "sid:", sid)
				keyCiph, keyMAC := key[:vors.ChaCha20KeySize], key[vors.ChaCha20KeySize:]
				stream := &Stream{
					name:  string(name),
					in:    make(chan []byte, 1<<10),
					stats: &Stats{dead: make(chan struct{})},
				}
				go func() {
					dec, err := opus.NewDecoder(vors.Rate, 1)
					if err != nil {
						log.Fatal(err)
					}
					if err = dec.SetComplexity(10); err != nil {
						log.Fatal(err)
					}

					var player io.WriteCloser
					playerTx := make(chan []byte, 5)
					var cmd *exec.Cmd
					if *playCmd != "" {
						cmd = vors.MakeCmd(*playCmd)
						player, err = cmd.StdinPipe()
						if err != nil {
							log.Fatal(err)
						}
						err = cmd.Start()
						if err != nil {
							log.Fatal(err)
						}
						go func() {
							var pcmbuf []byte
							var ok bool
							var err error
							for {
								for len(playerTx) > vors.MaxLost {
									<-playerTx
									stream.stats.reorder++
								}
								pcmbuf, ok = <-playerTx
								if !ok {
									break
								}
								if stream.silenced {
									continue
								}
								if _, err = io.Copy(player,
									bytes.NewReader(pcmbuf)); err != nil {
									log.Println("play:", err)
								}
							}
							cmd.Process.Kill()
						}()
					}

					mac := siphash.New(keyMAC)
					tag := make([]byte, siphash.Size)
					var ctr uint32
					pcm := make([]int16, vors.FrameLen)
					nonce := make([]byte, 12)
					var pkt []byte
					lost := -1
					var lastDur int
					for buf := range stream.in {
						copy(nonce[len(nonce)-4:], buf)
						mac.Reset()
						if _, err = mac.Write(
							buf[:len(buf)-siphash.Size],
						); err != nil {
							log.Fatal(err)
						}
						mac.Sum(tag[:0])
						if subtle.ConstantTimeCompare(
							tag[:siphash.Size],
							buf[len(buf)-siphash.Size:],
						) != 1 {
							stream.stats.bads++
							continue
						}
						pkt = buf[4+3 : len(buf)-siphash.Size]
						chacha20.XORKeyStream(pkt, pkt, nonce, keyCiph)

						ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
						if lost == -1 {
							// ignore the very first packet in the stream
							lost = 0
						} else {
							lost = int(ctr - (stream.ctr + 1))
						}
						stream.ctr = ctr
						stream.actr = uint32(buf[4+0])<<16 |
							uint32(buf[4+1])<<8 | uint32(buf[4+2])
						stream.stats.lost += int64(lost)
						if lost > vors.MaxLost {
							lost = 0
						}
						for ; lost > 0; lost-- {
							lastDur, err = dec.LastPacketDuration()
							if err != nil {
								log.Println("PLC:", err)
								continue
							}
							err = dec.DecodePLC(pcm[:lastDur])
							if err != nil {
								log.Println("PLC:", err)
								continue
							}
							stream.stats.AddRMS(pcm)
							if cmd == nil {
								continue
							}
							pcmbuf := make([]byte, 2*lastDur)
							pcmConv(pcmbuf, pcm[:lastDur])
							playerTx <- pcmbuf
						}
						_, err = dec.Decode(pkt, pcm)
						if err != nil {
							log.Println("decode:", err)
							continue
						}
						stream.stats.AddRMS(pcm)
						stream.stats.last = time.Now()
						if cmd == nil {
							continue
						}
						pcmbuf := make([]byte, 2*len(pcm))
						pcmConv(pcmbuf, pcm)
						playerTx <- pcmbuf
					}
					if cmd != nil {
						close(playerTx)
					}
				}()
				if !*NoGUI {
					go statsDrawer(stream)
				}
				StreamsM.Lock()
				Streams[sid] = stream
				StreamsM.Unlock()
			case vors.CmdDel:
				sid := args[1][0]
				s := Streams[sid]
				if s == nil {
					log.Println("unknown sid:", sid)
					continue
				}
				printBell()
				log.Println("del", s.name, "sid:", sid)
				StreamsM.Lock()
				delete(Streams, sid)
				StreamsM.Unlock()
				close(s.in)
				close(s.stats.dead)
			case vors.CmdMuted:
				sid := args[1][0]
				s := Streams[sid]
				if s == nil {
					log.Println("unknown sid:", sid)
					continue
				}
				s.muted = true
			case vors.CmdUnmuted:
				sid := args[1][0]
				s := Streams[sid]
				if s == nil {
					log.Println("unknown sid:", sid)
					continue
				}
				s.muted = false
			case vors.CmdChat:
				sid := args[1][0]
				s := Streams[sid]
				if s == nil {
					log.Println("unknown sid:", sid)
					continue
				}
				printBell()
				log.Println(s.name, ":", string(args[2]))
			default:
				log.Fatal("unexpected cmd:", cmd)
			}
		}
	}(&seen)

	go func(seen *time.Time) {
		for now := range time.Tick(vors.PingTime) {
			if seen.Add(2 * vors.PingTime).Before(now) {
				log.Println("timeout:", seen)
				Finish <- struct{}{}
				break
			}
		}
	}(&seen)

	go func() {
		<-LoggerReady
		var n int
		var from *net.UDPAddr
		var err error
		var stream *Stream
		var ctr uint32
		for {
			buf := make([]byte, 2*vors.FrameLen)
			n, from, err = conn.ReadFromUDP(buf)
			if err != nil {
				log.Println("recvfrom:", err)
				Finish <- struct{}{}
				break
			}
			if from.Port != srvAddrUDP.Port || !from.IP.Equal(srvAddrUDP.IP) {
				log.Println("wrong addr:", from)
				continue
			}
			if n <= 4+siphash.Size {
				log.Println("too small:", n)
				continue
			}
			stream = Streams[buf[0]]
			if stream == nil {
				log.Println("unknown stream:", buf[0])
				continue
			}
			stream.stats.pkts++
			stream.stats.bytes += vors.IPHdrLen(from.IP) + 8 + uint64(n)
			ctr = binary.BigEndian.Uint32(buf)
			if ctr <= stream.ctr {
				stream.stats.reorder++
				continue
			}
			stream.in <- buf[:n]
		}
	}()

	if !*NoGUI {
		go statsDrawer(&Stream{name: *Name, stats: OurStats})
	}
	go func() {
		<-LoggerReady
		for now := range time.NewTicker(time.Second).C {
			if !OurStats.last.Add(time.Second).Before(now) {
				continue
			}
			OurStats.pkts++
			OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + 1
			if _, err = conn.Write([]byte{sidConnected}); err != nil {
				log.Println("send:", err)
			}
		}
	}()
	go func() {
		if *recCmd == "" {
			return
		}
		<-LoggerReady
		mac := siphash.New(keyMACOur)
		tag := make([]byte, siphash.Size)
		buf := make([]byte, 2*vors.FrameLen)
		pcm := make([]int16, vors.FrameLen)
		actr := make([]byte, 3)
		nonce := make([]byte, 12)
		nonce[len(nonce)-4] = sidConnected
		var pkt []byte
		var n, i int
		for {
			_, err = io.ReadFull(mic, buf)
			if err != nil {
				log.Println("mic:", err)
				break
			}
			vors.Incr(actr[:])
			if Muted {
				continue
			}
			for i = range vors.FrameLen {
				pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
			}
			if vad != 0 && vors.RMS(pcm) < vad {
				continue
			}
			n, err = opusEnc.Encode(pcm, buf[4+len(actr):])
			if err != nil {
				log.Fatal(err)
			}
			if n <= 2 {
				// DTX
				continue
			}

			vors.Incr(nonce[len(nonce)-3:])
			copy(buf, nonce[len(nonce)-4:])
			copy(buf[4:], actr)
			chacha20.XORKeyStream(
				buf[4+len(actr):4+len(actr)+n],
				buf[4+len(actr):4+len(actr)+n],
				nonce, keyCiphOur,
			)
			mac.Reset()
			if _, err = mac.Write(buf[:4+len(actr)+n]); err != nil {
				log.Fatal(err)
			}
			mac.Sum(tag[:0])
			copy(buf[4+len(actr)+n:], tag)
			pkt = buf[:4+len(actr)+n+siphash.Size]

			OurStats.pkts++
			OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + uint64(len(pkt))
			OurStats.last = time.Now()
			OurStats.AddRMS(pcm)
			if _, err = conn.Write(pkt); err != nil {
				log.Println("send:", err)
			}
		}
	}()

	if !*NoGUI {
		err = GUI.MainLoop()
		if err != nil && err != gocui.ErrQuit {
			log.Fatal(err)
		}
	}
}
