// VoRS -- Vo(IP) Really Simple
// Copyright (C) 2024 Sergey Matveev <stargrave@stargrave.org>
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, version 3 of the License.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program.  If not, see <http://www.gnu.org/licenses/>.

package main

import (
	"bytes"
	"crypto/subtle"
	"encoding/base64"
	"encoding/binary"
	"encoding/hex"
	"flag"
	"fmt"
	"io"
	"log"
	"net"
	"os"
	"os/exec"
	"strconv"
	"strings"
	"time"

	"github.com/dchest/siphash"
	"github.com/flynn/noise"
	"github.com/jroimartin/gocui"
	"go.stargrave.org/opus/v2"
	vors "go.stargrave.org/vors/internal"
	"golang.org/x/crypto/blake2s"
	"golang.org/x/crypto/chacha20"
)

type Stream struct {
	name  string
	ctr   uint32
	in    chan []byte
	stats *Stats
}

var (
	Streams  = map[byte]*Stream{}
	Finish   = make(chan struct{})
	OurStats = &Stats{dead: make(chan struct{})}
	Name     = flag.String("name", "test", "username")
	Room     = flag.String("room", "/", "room name")
	Muted    bool
)

func parseSID(s string) byte {
	n, err := strconv.Atoi(s)
	if err != nil {
		log.Fatal(err)
	}
	if n > 255 {
		log.Fatal("too big stream num")
	}
	return byte(n)
}

func incr(data []byte) {
	for i := len(data) - 1; i >= 0; i-- {
		data[i]++
		if data[i] != 0 {
			return
		}
	}
	panic("overflow")
}

func main() {
	srvAddr := flag.String("srv", "vors.home.arpa:"+strconv.Itoa(vors.DefaultPort),
		"host:TCP/UDP port to connect to")
	srvPubB64 := flag.String("pub", "", "server's public key, Base64")
	recCmd := flag.String("rec", "rec "+vors.SoxParams, "rec command")
	playCmd := flag.String("play", "play "+vors.SoxParams, "play command")
	vadRaw := flag.Uint("vad", 0, "VAD threshold")
	passwd := flag.String("passwd", "", "protected room's password")
	muteToggle := flag.String("mute-toggle", "", "path to FIFO to toggle mute")
	prefer4 := flag.Bool("4", false,
		"Prefer obsolete legacy IPv4 address during name resolution")
	version := flag.Bool("version", false, "print version")
	warranty := flag.Bool("warranty", false, "print warranty information")
	flag.Parse()
	log.SetFlags(log.Lmicroseconds | log.Lshortfile)

	if *warranty {
		fmt.Println(vors.Warranty)
		return
	}
	if *version {
		fmt.Println(vors.GetVersion())
		return
	}

	srvPub, err := base64.RawURLEncoding.DecodeString(*srvPubB64)
	if err != nil {
		log.Fatal(err)
	}
	*Name = strings.ReplaceAll(*Name, " ", "-")

	go func() {
		if *muteToggle == "" {
			return
		}
		for {
			fd, err := os.OpenFile(*muteToggle, os.O_WRONLY, os.FileMode(0666))
			if err != nil {
				log.Fatalln(err)
			}
			Muted = !Muted
			var reply string
			if Muted {
				reply = "muted"
			} else {
				reply = "unmuted"
			}
			fd.WriteString(reply + "\n")
			fd.Close()
			time.Sleep(time.Second)
		}
	}()

	vad := uint64(*vadRaw)
	opusEnc := newOpusEnc()
	var mic io.ReadCloser
	if *recCmd != "" {
		cmd := vors.MakeCmd(*recCmd)
		mic, err = cmd.StdoutPipe()
		if err != nil {
			log.Fatal(err)
		}
		err = cmd.Start()
		if err != nil {
			log.Fatal(err)
		}
	}

	vors.PreferIPv4 = *prefer4
	ctrl, err := net.DialTCP("tcp", nil, vors.MustResolveTCP(*srvAddr))
	if err != nil {
		log.Fatalln("dial server:", err)
	}
	defer ctrl.Close()
	if err = ctrl.SetNoDelay(true); err != nil {
		log.Fatalln("nodelay:", err)
	}

	hs, err := noise.NewHandshakeState(noise.Config{
		CipherSuite: vors.NoiseCipherSuite,
		Pattern:     noise.HandshakeNK,
		Initiator:   true,
		PeerStatic:  srvPub,
		Prologue:    []byte(vors.NoisePrologue),
	})
	if err != nil {
		log.Fatalln("noise.NewHandshakeState:", err)
	}
	buf, _, _, err := hs.WriteMessage(nil, []byte(*Name+" "+*Room+" "+*passwd))
	if err != nil {
		log.Fatalln("handshake encrypt:", err)
	}
	buf = append(
		append(
			[]byte(vors.NoisePrologue),
			byte((len(buf)&0xFF00)>>8),
			byte((len(buf)&0x00FF)>>0),
		),
		buf...,
	)
	_, err = io.Copy(ctrl, bytes.NewReader(buf))
	if err != nil {
		log.Fatalln("write handshake:", err)
		return
	}
	buf, err = vors.PktRead(ctrl)
	if err != nil {
		log.Fatalln("read handshake:", err)
	}
	buf, txCS, rxCS, err := hs.ReadMessage(nil, buf)
	if err != nil {
		log.Fatalln("handshake decrypt:", err)
	}

	rx := make(chan []byte)
	go func() {
		for {
			buf, err := vors.PktRead(ctrl)
			if err != nil {
				log.Println("rx", err)
				break
			}
			buf, err = rxCS.Decrypt(buf[:0], nil, buf)
			if err != nil {
				log.Println("rx decrypt", err)
				break
			}
			rx <- buf
		}
		Finish <- struct{}{}
	}()

	srvAddrUDP := vors.MustResolveUDP(*srvAddr)
	conn, err := net.DialUDP("udp", nil, srvAddrUDP)
	if err != nil {
		log.Fatalln("connect:", err)
	}
	var sid byte
	{
		cols := strings.Fields(string(buf))
		if cols[0] != "OK" || len(cols) != 2 {
			log.Fatalln("handshake failed:", cols)
		}
		var cookie vors.Cookie
		cookieRaw, err := hex.DecodeString(cols[1])
		if err != nil {
			log.Fatal(err)
		}
		copy(cookie[:], cookieRaw)
		timeout := time.NewTimer(vors.PingTime)
		defer func() {
			if !timeout.Stop() {
				<-timeout.C
			}
		}()
		ticker := time.NewTicker(time.Second)
		if _, err = conn.Write(cookie[:]); err != nil {
			log.Fatalln("write:", err)
		}
	WaitForCookieAcceptance:
		for {
			select {
			case <-timeout.C:
				log.Fatalln("cookie acceptance timeout")
			case <-ticker.C:
				if _, err = conn.Write(cookie[:]); err != nil {
					log.Fatalln("write:", err)
				}
			case buf = <-rx:
				cols = strings.Fields(string(buf))
				if cols[0] != "SID" || len(cols) != 2 {
					log.Fatalln("cookie acceptance failed:", string(buf))
				}
				sid = parseSID(cols[1])
				Streams[sid] = &Stream{name: *Name, stats: OurStats}
				break WaitForCookieAcceptance
			}
		}
		if !timeout.Stop() {
			<-timeout.C
		}
	}

	var keyCiphOur []byte
	var keyMACOur []byte
	{
		xof, err := blake2s.NewXOF(32+16, nil)
		if err != nil {
			log.Fatalln(err)
		}
		xof.Write([]byte(vors.NoisePrologue))
		xof.Write(hs.ChannelBinding())
		buf := make([]byte, 32+16)
		if _, err = io.ReadFull(xof, buf); err != nil {
			log.Fatalln(err)
		}
		keyCiphOur, keyMACOur = buf[:32], buf[32:]
	}

	seen := time.Now()

	LoggerReady := make(chan struct{})
	GUI, err = gocui.NewGui(gocui.OutputNormal)
	if err != nil {
		log.Fatal(err)
	}
	defer GUI.Close()
	GUI.SelFgColor = gocui.ColorCyan
	GUI.Highlight = true
	GUI.SetManagerFunc(guiLayout)
	if err := GUI.SetKeybinding("", 'q', gocui.ModNone, guiQuit); err != nil {
		log.Fatal(err)
	}
	if err := GUI.SetKeybinding("", gocui.KeyEnter, gocui.ModNone, mute); err != nil {
		log.Fatal(err)
	}

	go func() {
		<-GUIReadyC
		v, err := GUI.View("logs")
		if err != nil {
			log.Fatal(err)
		}
		log.SetOutput(v)
		log.Println("connected", "sid:", sid,
			"addr:", conn.LocalAddr().String())
		close(LoggerReady)
		for {
			time.Sleep(vors.ScreenRefresh)
			GUI.Update(func(gui *gocui.Gui) error {
				return nil
			})
		}
	}()

	go func() {
		<-Finish
		go GUI.Close()
		time.Sleep(100 * time.Millisecond)
		os.Exit(0)
	}()

	go func() {
		for {
			time.Sleep(vors.PingTime)
			buf, err := txCS.Encrypt(nil, nil, []byte(vors.CmdPing))
			if err != nil {
				log.Fatalln("tx encrypt:", err)
			}
			if err = vors.PktWrite(ctrl, buf); err != nil {
				log.Fatalln("tx:", err)
			}
		}
	}()

	go func(seen *time.Time) {
		var now time.Time
		for buf := range rx {
			if string(buf) == vors.CmdPong {
				now = time.Now()
				*seen = now
				continue
			}
			cols := strings.Fields(string(buf))
			switch cols[0] {
			case vors.CmdAdd:
				sidRaw, name, keyHex := cols[1], cols[2], cols[3]
				log.Println("add", name, "sid:", sidRaw)
				sid := parseSID(sidRaw)
				key, err := hex.DecodeString(keyHex)
				if err != nil {
					log.Fatal(err)
				}
				keyCiph, keyMAC := key[:32], key[32:]
				stream := &Stream{
					name:  name,
					in:    make(chan []byte, 1<<10),
					stats: &Stats{dead: make(chan struct{})},
				}
				go func() {
					dec, err := opus.NewDecoder(vors.Rate, 1)
					if err != nil {
						log.Fatal(err)
					}
					if err = dec.SetComplexity(10); err != nil {
						log.Fatal(err)
					}

					var player io.WriteCloser
					playerTx := make(chan []byte, 5)
					var cmd *exec.Cmd
					if *playCmd != "" {
						cmd = vors.MakeCmd(*playCmd)
						player, err = cmd.StdinPipe()
						if err != nil {
							log.Fatal(err)
						}
						err = cmd.Start()
						if err != nil {
							log.Fatal(err)
						}
						go func() {
							var pcmbuf []byte
							var ok bool
							var err error
							for {
								for len(playerTx) > vors.MaxLost {
									<-playerTx
									stream.stats.reorder++
								}
								pcmbuf, ok = <-playerTx
								if !ok {
									break
								}
								if _, err = io.Copy(player,
									bytes.NewReader(pcmbuf)); err != nil {
									log.Println("play:", err)
								}
							}
							cmd.Process.Kill()
						}()
					}

					var ciph *chacha20.Cipher
					mac := siphash.New(keyMAC)
					tag := make([]byte, siphash.Size)
					var ctr uint32
					pcm := make([]int16, vors.FrameLen)
					nonce := make([]byte, 12)
					var pkt []byte
					lost := -1
					var lastDur int
					for buf := range stream.in {
						copy(nonce[len(nonce)-4:], buf)
						mac.Reset()
						if _, err = mac.Write(buf[:len(buf)-siphash.Size]); err != nil {
							log.Fatal(err)
						}
						mac.Sum(tag[:0])
						if subtle.ConstantTimeCompare(
							tag[:siphash.Size],
							buf[len(buf)-siphash.Size:],
						) != 1 {
							stream.stats.bads++
							continue
						}
						ciph, err = chacha20.NewUnauthenticatedCipher(keyCiph, nonce)
						if err != nil {
							log.Fatal(err)
						}
						pkt = buf[4 : len(buf)-siphash.Size]
						ciph.XORKeyStream(pkt, pkt)

						ctr = binary.BigEndian.Uint32(nonce[len(nonce)-4:])
						if lost == -1 {
							// ignore the very first packet in the stream
							lost = 0
						} else {
							lost = int(ctr - (stream.ctr + 1))
						}
						stream.ctr = ctr
						stream.stats.lost += int64(lost)
						if lost > vors.MaxLost {
							lost = 0
						}
						for ; lost > 0; lost-- {
							lastDur, err = dec.LastPacketDuration()
							if err != nil {
								log.Println("PLC:", err)
								continue
							}
							err = dec.DecodePLC(pcm[:lastDur])
							if err != nil {
								log.Println("PLC:", err)
								continue
							}
							stream.stats.AddRMS(pcm)
							if cmd == nil {
								continue
							}
							pcmbuf := make([]byte, 2*lastDur)
							pcmConv(pcmbuf, pcm[:lastDur])
							playerTx <- pcmbuf
						}
						_, err = dec.Decode(pkt, pcm)
						if err != nil {
							log.Println("decode:", err)
							continue
						}
						stream.stats.AddRMS(pcm)
						stream.stats.last = time.Now()
						if cmd == nil {
							continue
						}
						pcmbuf := make([]byte, 2*len(pcm))
						pcmConv(pcmbuf, pcm)
						playerTx <- pcmbuf
					}
					if cmd != nil {
						close(playerTx)
					}
				}()
				go statsDrawer(stream.stats, stream.name)
				Streams[sid] = stream
			case vors.CmdDel:
				sid := parseSID(cols[1])
				s := Streams[sid]
				if s == nil {
					log.Println("unknown sid:", sid)
					continue
				}
				log.Println("del", s.name, "sid:", cols[1])
				delete(Streams, sid)
				close(s.in)
				close(s.stats.dead)
			default:
				log.Fatal("unknown cmd:", cols[0])
			}
		}
	}(&seen)

	go func(seen *time.Time) {
		for now := range time.Tick(vors.PingTime) {
			if seen.Add(2 * vors.PingTime).Before(now) {
				log.Println("timeout:", seen)
				Finish <- struct{}{}
				break
			}
		}
	}(&seen)

	go func() {
		<-LoggerReady
		var n int
		var from *net.UDPAddr
		var err error
		var stream *Stream
		var ctr uint32
		for {
			buf := make([]byte, 2*vors.FrameLen)
			n, from, err = conn.ReadFromUDP(buf)
			if err != nil {
				log.Println("recvfrom:", err)
				Finish <- struct{}{}
				break
			}
			if from.Port != srvAddrUDP.Port || !from.IP.Equal(srvAddrUDP.IP) {
				log.Println("wrong addr:", from)
				continue
			}
			if n <= 4+siphash.Size {
				log.Println("too small:", n)
				continue
			}
			stream = Streams[buf[0]]
			if stream == nil {
				// log.Println("unknown stream:", buf[0])
				continue
			}
			stream.stats.pkts++
			stream.stats.bytes += vors.IPHdrLen(from.IP) + 8 + uint64(n)
			ctr = binary.BigEndian.Uint32(buf)
			if ctr <= stream.ctr {
				stream.stats.reorder++
				continue
			}
			stream.in <- buf[:n]
		}
	}()

	go statsDrawer(OurStats, *Name)
	go func() {
		<-LoggerReady
		for now := range time.NewTicker(time.Second).C {
			if !OurStats.last.Add(time.Second).Before(now) {
				continue
			}
			OurStats.pkts++
			OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + 1
			if _, err = conn.Write([]byte{sid}); err != nil {
				log.Println("send:", err)
			}
		}
	}()
	go func() {
		if *recCmd == "" {
			return
		}
		<-LoggerReady
		var ciph *chacha20.Cipher
		mac := siphash.New(keyMACOur)
		tag := make([]byte, siphash.Size)
		buf := make([]byte, 2*vors.FrameLen)
		pcm := make([]int16, vors.FrameLen)
		nonce := make([]byte, 12)
		nonce[len(nonce)-4] = sid
		var pkt []byte
		var n, i int
		for {
			_, err = io.ReadFull(mic, buf)
			if err != nil {
				log.Println("mic:", err)
				break
			}
			if Muted {
				continue
			}
			for i = 0; i < vors.FrameLen; i++ {
				pcm[i] = int16(uint16(buf[i*2+0]) | (uint16(buf[i*2+1]) << 8))
			}
			if vad != 0 && vors.RMS(pcm) < vad {
				continue
			}
			n, err = opusEnc.Encode(pcm, buf[4:])
			if err != nil {
				log.Fatal(err)
			}
			if n <= 2 {
				// DTX
				continue
			}

			incr(nonce[len(nonce)-3:])
			copy(buf, nonce[len(nonce)-4:])
			ciph, err = chacha20.NewUnauthenticatedCipher(keyCiphOur, nonce)
			if err != nil {
				log.Fatal(err)
			}
			ciph.XORKeyStream(buf[4:4+n], buf[4:4+n])
			mac.Reset()
			if _, err = mac.Write(buf[:4+n]); err != nil {
				log.Fatal(err)
			}
			mac.Sum(tag[:0])
			copy(buf[4+n:], tag)
			pkt = buf[:4+n+siphash.Size]

			OurStats.pkts++
			OurStats.bytes += vors.IPHdrLen(srvAddrUDP.IP) + 8 + uint64(len(pkt))
			OurStats.last = time.Now()
			OurStats.AddRMS(pcm)
			if _, err = conn.Write(pkt); err != nil {
				log.Println("send:", err)
			}
		}
	}()

	err = GUI.MainLoop()
	if err != nil && err != gocui.ErrQuit {
		log.Fatal(err)
	}
}
