package pqhs

import (
	"crypto/cipher"
	"crypto/hkdf"
	"crypto/sha3"

	"golang.org/x/crypto/chacha20poly1305"
)

type SymmetricState struct {
	h, ck []byte
}

func (state *SymmetricState) K(ctx string) []byte {
	k, err := hkdf.Expand(NewSHAKE256, state.ck, ctx, chacha20poly1305.KeySize)
	if err != nil {
		panic(err)
	}
	return k
}

func (state *SymmetricState) H(data []byte) {
	state.h = sha3.SumSHAKE256(append(state.h, data...), 64)
}

func (state *SymmetricState) CK(key []byte) {
	var err error
	state.ck, err = hkdf.Extract(NewSHAKE256, key, state.ck)
	if err != nil {
		panic(err)
	}
	state.ck, err = hkdf.Expand(NewSHAKE256, state.ck, CtxCK, 64)
	if err != nil {
		panic(err)
	}
}

func (state *SymmetricState) Seal(ctx string, data []byte) []byte {
	aead, err := chacha20poly1305.New(state.K(ctx))
	if err != nil {
		panic(err)
	}
	ct := aead.Seal(nil, make([]byte, aead.NonceSize()), data, state.h)
	state.H(ct)
	return ct
}

func (state *SymmetricState) Open(ctx string, ct []byte) (pt []byte, err error) {
	var aead cipher.AEAD
	aead, err = chacha20poly1305.New(state.K(ctx))
	if err != nil {
		panic(err)
	}
	pt, err = aead.Open(nil, make([]byte, aead.NonceSize()), ct, state.h)
	if err == nil {
		state.H(ct)
	}
	return
}

func (state *SymmetricState) Keymat(l int) []byte {
	keymat, err := hkdf.Expand(NewSHAKE256, state.ck, CtxKeymat, l)
	if err != nil {
		panic(err)
	}
	return keymat
}
