package pqhs

import (
	"crypto/ecdh"
	"crypto/rand"

	vors "go.stargrave.org/vors/v6/internal"
	"go.stargrave.org/vors/v6/pqhs/mceliece6960119"
	sntrup761kem "go.stargrave.org/vors/v6/pqhs/sntrup761/kem"
	sntrup761 "go.stargrave.org/vors/v6/pqhs/sntrup761/kem/ntruprime/sntrup761"
	"golang.org/x/crypto/chacha20poly1305"
)

type Server struct {
	ephPrvSNTRUP sntrup761kem.PrivateKey
	SymmetricState
}

func NewServer(
	serverStaticPrvMcElieceRaw, serverStaticPrvX25519Raw,
	serverStaticPubHash, clientPayload []byte,
) (s *Server, payload []byte, err error) {
	var serverStaticPrvMcEliece *mceliece6960119.PrivateKey
	serverStaticPrvMcEliece, err = mceliece6960119.UnmarshalBinaryPrivateKey(
		serverStaticPrvMcElieceRaw)
	if err != nil {
		return
	}
	x25519 := ecdh.X25519()
	var serverStaticPrvX25519 *ecdh.PrivateKey
	serverStaticPrvX25519, err = x25519.NewPrivateKey(serverStaticPrvX25519Raw)
	if err != nil {
		return
	}
	ctMcEliece := clientPayload[:mceliece6960119.CiphertextSize]
	ctX25519 := clientPayload[mceliece6960119.CiphertextSize:]
	var k []byte
	k, err = mceliece6960119.Decapsulate(serverStaticPrvMcEliece, ctMcEliece)
	if err != nil {
		return
	}
	s = &Server{}
	s.CK(k)
	s.H([]byte(vors.Magic))
	s.H(serverStaticPubHash)
	s.H(ctMcEliece)
	var clientEphPubX25519Raw []byte
	clientEphPubX25519Raw, err = s.Open(CtxClientX25519, ctX25519)
	if err != nil {
		return
	}
	var clientEphPubX25519 *ecdh.PublicKey
	clientEphPubX25519, k, err = DH(serverStaticPrvX25519, clientEphPubX25519Raw)
	if err != nil {
		return
	}
	s.CK(k)
	var serverEphPrvX25519 *ecdh.PrivateKey
	serverEphPrvX25519, err = x25519.GenerateKey(rand.Reader)
	if err != nil {
		return
	}
	ctX25519 = s.Seal(CtxServerX25519, serverEphPrvX25519.PublicKey().Bytes())
	k, err = serverEphPrvX25519.ECDH(clientEphPubX25519)
	if err != nil {
		return
	}
	s.CK(k)
	var serverEphPubSNTRUP sntrup761kem.PublicKey
	serverEphPubSNTRUP, s.ephPrvSNTRUP, err = sntrup761.Scheme().GenerateKeyPair()
	if err != nil {
		return
	}
	var serverEphPubSNTRUPRaw []byte
	serverEphPubSNTRUPRaw, err = serverEphPubSNTRUP.MarshalBinary()
	if err != nil {
		return
	}
	payload = append(ctX25519, s.Seal(CtxServerSNTRUP761, serverEphPubSNTRUPRaw)...)
	return
}

func (s *Server) Read(reply []byte) (prefinish []byte, err error) {
	sntrup761s := sntrup761.Scheme()
	clientEphCTSNTRUPRaw := reply[:sntrup761s.CiphertextSize()+chacha20poly1305.Overhead]
	prefinish = reply[len(clientEphCTSNTRUPRaw):]
	{
		clientEphCTSNTRUPRaw, err = s.Open(CtxClientSNTRUP761, clientEphCTSNTRUPRaw)
		if err != nil {
			return
		}
		var k []byte
		k, err = sntrup761s.Decapsulate(s.ephPrvSNTRUP, clientEphCTSNTRUPRaw)
		if err != nil {
			return
		}
		s.CK(k)
	}
	prefinish, err = s.Open(CtxClientPrefinish, prefinish)
	return
}
