// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.

package libolm

// #cgo LDFLAGS: -lolm -lstdc++
// #include <olm/olm.h>
// #include <stdlib.h>
// #include <stdio.h>
// void olm_session_describe(OlmSession * session, char *buf, size_t buflen) __attribute__((weak));
// void meowlm_session_describe(OlmSession * session, char *buf, size_t buflen) {
//   if (olm_session_describe) {
//     olm_session_describe(session, buf, buflen);
//   } else {
//     sprintf(buf, "olm_session_describe not supported");
//   }
// }
import "C"

import (
	"crypto/rand"
	"encoding/base64"
	"runtime"
	"unsafe"

	"maunium.net/go/mautrix/crypto/olm"
	"maunium.net/go/mautrix/id"
)

// Session stores an end to end encrypted messaging session.
type Session struct {
	int *C.OlmSession
	mem []byte
}

// Ensure that [Session] implements [olm.Session].
var _ olm.Session = (*Session)(nil)

// sessionSize is the size of a session object in bytes.
func sessionSize() uint {
	return uint(C.olm_session_size())
}

// SessionFromPickled loads a Session from a pickled base64 string.  Decrypts
// the Session using the supplied key.  Returns error on failure.  If the key
// doesn't match the one used to encrypt the Session then the error will be
// "BAD_SESSION_KEY".  If the base64 couldn't be decoded then the error will be
// "INVALID_BASE64".
func SessionFromPickled(pickled, key []byte) (*Session, error) {
	if len(pickled) == 0 {
		return nil, olm.EmptyInput
	}
	s := NewBlankSession()
	return s, s.Unpickle(pickled, key)
}

func NewBlankSession() *Session {
	memory := make([]byte, sessionSize())
	return &Session{
		int: C.olm_session(unsafe.Pointer(unsafe.SliceData(memory))),
		mem: memory,
	}
}

// lastError returns an error describing the most recent error to happen to a
// session.
func (s *Session) lastError() error {
	return convertError(C.GoString(C.olm_session_last_error((*C.OlmSession)(s.int))))
}

// Clear clears the memory used to back this Session.
func (s *Session) Clear() error {
	r := C.olm_clear_session((*C.OlmSession)(s.int))
	if r == errorVal() {
		return s.lastError()
	}
	return nil
}

// pickleLen returns the number of bytes needed to store a session.
func (s *Session) pickleLen() uint {
	return uint(C.olm_pickle_session_length((*C.OlmSession)(s.int)))
}

// createOutboundRandomLen returns the number of random bytes needed to create
// an outbound session.
func (s *Session) createOutboundRandomLen() uint {
	return uint(C.olm_create_outbound_session_random_length((*C.OlmSession)(s.int)))
}

// idLen returns the length of the buffer needed to return the id for this
// session.
func (s *Session) idLen() uint {
	return uint(C.olm_session_id_length((*C.OlmSession)(s.int)))
}

// encryptRandomLen returns the number of random bytes needed to encrypt the
// next message.
func (s *Session) encryptRandomLen() uint {
	return uint(C.olm_encrypt_random_length((*C.OlmSession)(s.int)))
}

// encryptMsgLen returns the size of the next message in bytes for the given
// number of plain-text bytes.
func (s *Session) encryptMsgLen(plainTextLen int) uint {
	return uint(C.olm_encrypt_message_length((*C.OlmSession)(s.int), C.size_t(plainTextLen)))
}

// decryptMaxPlaintextLen returns the maximum number of bytes of plain-text a
// given message could decode to.  The actual size could be different due to
// padding.  Returns error on failure.  If the message base64 couldn't be
// decoded then the error will be "INVALID_BASE64".  If the message is for an
// unsupported version of the protocol then the error will be
// "BAD_MESSAGE_VERSION".  If the message couldn't be decoded then the error
// will be "BAD_MESSAGE_FORMAT".
func (s *Session) decryptMaxPlaintextLen(message string, msgType id.OlmMsgType) (uint, error) {
	if len(message) == 0 {
		return 0, olm.EmptyInput
	}
	messageCopy := []byte(message)
	r := C.olm_decrypt_max_plaintext_length(
		(*C.OlmSession)(s.int),
		C.size_t(msgType),
		unsafe.Pointer(unsafe.SliceData((messageCopy))),
		C.size_t(len(messageCopy)),
	)
	runtime.KeepAlive(messageCopy)
	if r == errorVal() {
		return 0, s.lastError()
	}
	return uint(r), nil
}

// Pickle returns a Session as a base64 string.  Encrypts the Session using the
// supplied key.
func (s *Session) Pickle(key []byte) ([]byte, error) {
	if len(key) == 0 {
		return nil, olm.NoKeyProvided
	}
	pickled := make([]byte, s.pickleLen())
	r := C.olm_pickle_session(
		(*C.OlmSession)(s.int),
		unsafe.Pointer(unsafe.SliceData(key)),
		C.size_t(len(key)),
		unsafe.Pointer(unsafe.SliceData(pickled)),
		C.size_t(len(pickled)))
	runtime.KeepAlive(key)
	if r == errorVal() {
		panic(s.lastError())
	}
	return pickled[:r], nil
}

// Unpickle unpickles the base64-encoded Olm session decrypting it with the
// provided key. This function mutates the input pickled data slice.
func (s *Session) Unpickle(pickled, key []byte) error {
	if len(key) == 0 {
		return olm.NoKeyProvided
	}
	r := C.olm_unpickle_session(
		(*C.OlmSession)(s.int),
		unsafe.Pointer(unsafe.SliceData(key)),
		C.size_t(len(key)),
		unsafe.Pointer(unsafe.SliceData(pickled)),
		C.size_t(len(pickled)))
	runtime.KeepAlive(pickled)
	runtime.KeepAlive(key)
	if r == errorVal() {
		return s.lastError()
	}
	return nil
}

// Deprecated
func (s *Session) GobEncode() ([]byte, error) {
	pickled, err := s.Pickle(pickleKey)
	if err != nil {
		return nil, err
	}
	length := base64.RawStdEncoding.DecodedLen(len(pickled))
	rawPickled := make([]byte, length)
	_, err = base64.RawStdEncoding.Decode(rawPickled, pickled)
	return rawPickled, err
}

// Deprecated
func (s *Session) GobDecode(rawPickled []byte) error {
	if s == nil || s.int == nil {
		*s = *NewBlankSession()
	}
	length := base64.RawStdEncoding.EncodedLen(len(rawPickled))
	pickled := make([]byte, length)
	base64.RawStdEncoding.Encode(pickled, rawPickled)
	return s.Unpickle(pickled, pickleKey)
}

// Deprecated
func (s *Session) MarshalJSON() ([]byte, error) {
	pickled, err := s.Pickle(pickleKey)
	if err != nil {
		return nil, err
	}
	quotes := make([]byte, len(pickled)+2)
	quotes[0] = '"'
	quotes[len(quotes)-1] = '"'
	copy(quotes[1:len(quotes)-1], pickled)
	return quotes, nil
}

// Deprecated
func (s *Session) UnmarshalJSON(data []byte) error {
	if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
		return olm.InputNotJSONString
	}
	if s == nil || s.int == nil {
		*s = *NewBlankSession()
	}
	return s.Unpickle(data[1:len(data)-1], pickleKey)
}

// Id returns an identifier for this Session.  Will be the same for both ends
// of the conversation.
func (s *Session) ID() id.SessionID {
	sessionID := make([]byte, s.idLen())
	r := C.olm_session_id(
		(*C.OlmSession)(s.int),
		unsafe.Pointer(unsafe.SliceData(sessionID)),
		C.size_t(len(sessionID)),
	)
	if r == errorVal() {
		panic(s.lastError())
	}
	return id.SessionID(sessionID)
}

// HasReceivedMessage returns true if this session has received any message.
func (s *Session) HasReceivedMessage() bool {
	switch C.olm_session_has_received_message((*C.OlmSession)(s.int)) {
	case 0:
		return false
	default:
		return true
	}
}

// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
// Session.  This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply.  Returns true if the session
// matches.  Returns false if the session does not match.  Returns error on
// failure.  If the base64 couldn't be decoded then the error will be
// "INVALID_BASE64".  If the message was for an unsupported protocol version
// then the error will be "BAD_MESSAGE_VERSION".  If the message couldn't be
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
	if len(oneTimeKeyMsg) == 0 {
		return false, olm.EmptyInput
	}
	oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
	r := C.olm_matches_inbound_session(
		(*C.OlmSession)(s.int),
		unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
		C.size_t(len(oneTimeKeyMsgCopy)),
	)
	runtime.KeepAlive(oneTimeKeyMsgCopy)
	if r == 1 {
		return true, nil
	} else if r == 0 {
		return false, nil
	} else { // if r == errorVal()
		return false, s.lastError()
	}
}

// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound
// Session.  This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply.  Returns true if the session
// matches.  Returns false if the session does not match.  Returns error on
// failure.  If the base64 couldn't be decoded then the error will be
// "INVALID_BASE64".  If the message was for an unsupported protocol version
// then the error will be "BAD_MESSAGE_VERSION".  If the message couldn't be
// decoded then then the error will be "BAD_MESSAGE_FORMAT".
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
	if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
		return false, olm.EmptyInput
	}
	theirIdentityKeyCopy := []byte(theirIdentityKey)
	oneTimeKeyMsgCopy := []byte(oneTimeKeyMsg)
	r := C.olm_matches_inbound_session_from(
		(*C.OlmSession)(s.int),
		unsafe.Pointer(unsafe.SliceData(theirIdentityKeyCopy)),
		C.size_t(len(theirIdentityKeyCopy)),
		unsafe.Pointer(unsafe.SliceData(oneTimeKeyMsgCopy)),
		C.size_t(len(oneTimeKeyMsgCopy)),
	)
	runtime.KeepAlive(theirIdentityKeyCopy)
	runtime.KeepAlive(oneTimeKeyMsgCopy)
	if r == 1 {
		return true, nil
	} else if r == 0 {
		return false, nil
	} else { // if r == errorVal()
		return false, s.lastError()
	}
}

// EncryptMsgType returns the type of the next message that Encrypt will
// return.  Returns MsgTypePreKey if the message will be a PRE_KEY message.
// Returns MsgTypeMsg if the message will be a normal message.  Returns error
// on failure.
func (s *Session) EncryptMsgType() id.OlmMsgType {
	switch C.olm_encrypt_message_type((*C.OlmSession)(s.int)) {
	case C.size_t(id.OlmMsgTypePreKey):
		return id.OlmMsgTypePreKey
	case C.size_t(id.OlmMsgTypeMsg):
		return id.OlmMsgTypeMsg
	default:
		panic("olm_encrypt_message_type returned invalid result")
	}
}

// Encrypt encrypts a message using the Session.  Returns the encrypted message
// as base64.
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte, error) {
	if len(plaintext) == 0 {
		return 0, nil, olm.EmptyInput
	}
	// Make the slice be at least length 1
	random := make([]byte, s.encryptRandomLen()+1)
	_, err := rand.Read(random)
	if err != nil {
		// TODO can we just return err here?
		return 0, nil, olm.NotEnoughGoRandom
	}
	messageType := s.EncryptMsgType()
	message := make([]byte, s.encryptMsgLen(len(plaintext)))
	r := C.olm_encrypt(
		(*C.OlmSession)(s.int),
		unsafe.Pointer(unsafe.SliceData(plaintext)),
		C.size_t(len(plaintext)),
		unsafe.Pointer(unsafe.SliceData(random)),
		C.size_t(len(random)),
		unsafe.Pointer(unsafe.SliceData(message)),
		C.size_t(len(message)),
	)
	runtime.KeepAlive(plaintext)
	runtime.KeepAlive(random)
	if r == errorVal() {
		return 0, nil, s.lastError()
	}
	return messageType, message[:r], nil
}

// Decrypt decrypts a message using the Session.  Returns the the plain-text on
// success.  Returns error on failure.  If the base64 couldn't be decoded then
// the error will be "INVALID_BASE64".  If the message is for an unsupported
// version of the protocol then the error will be "BAD_MESSAGE_VERSION".  If
// the message couldn't be decoded then the error will be BAD_MESSAGE_FORMAT".
// If the MAC on the message was invalid then the error will be
// "BAD_MESSAGE_MAC".
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
	if len(message) == 0 {
		return nil, olm.EmptyInput
	}
	decryptMaxPlaintextLen, err := s.decryptMaxPlaintextLen(message, msgType)
	if err != nil {
		return nil, err
	}
	messageCopy := []byte(message)
	plaintext := make([]byte, decryptMaxPlaintextLen)
	r := C.olm_decrypt(
		(*C.OlmSession)(s.int),
		C.size_t(msgType),
		unsafe.Pointer(unsafe.SliceData(messageCopy)),
		C.size_t(len(messageCopy)),
		unsafe.Pointer(unsafe.SliceData(plaintext)),
		C.size_t(len(plaintext)),
	)
	runtime.KeepAlive(messageCopy)
	if r == errorVal() {
		return nil, s.lastError()
	}
	return plaintext[:r], nil
}

// https://gitlab.matrix.org/matrix-org/olm/-/blob/3.2.8/include/olm/olm.h#L392-393
const maxDescribeSize = 600

// Describe generates a string describing the internal state of an olm session for debugging and logging purposes.
func (s *Session) Describe() string {
	desc := (*C.char)(C.malloc(C.size_t(maxDescribeSize)))
	defer C.free(unsafe.Pointer(desc))
	C.meowlm_session_describe(
		(*C.OlmSession)(s.int),
		desc,
		C.size_t(maxDescribeSize),
	)
	return C.GoString(desc)
}
