// Copyright 2020 - MinIO, Inc. All rights reserved.
// Use of this source code is governed by the AGPLv3
// license that can be found in the LICENSE file.

// Package gemalto implements a key store that fetches/stores
// cryptographic keys on a Gemalto KeySecure instance.
package gemalto

import (
	"bytes"
	"context"
	"crypto/tls"
	"crypto/x509"
	"encoding/json"
	"errors"
	"fmt"
	"io"
	"net"
	"net/http"
	"os"
	"path/filepath"
	"strings"
	"time"

	"aead.dev/mem"
	"github.com/minio/kes"
	xhttp "github.com/minio/kes/internal/http"
	"github.com/minio/kes/internal/keystore"
	kesdk "github.com/minio/kms-go/kes"
)

// Credentials represents a Gemalto KeySecure
// refresh token that can be used to obtain a
// short-lived authentication token.
//
// A token is valid within either the default root
// domain (empty) or a specific domain - e.g. my-domain.
type Credentials struct {
	Token  string        // The KeySecure refresh token
	Domain string        // The KeySecure domain - similar to a Vault Namespace
	Retry  time.Duration // The time to wait before trying to re-authenticate
}

// Config is a structure containing configuration
// options for connecting to a KeySecure server.
type Config struct {
	// Endpoint is the KeySecure instance endpoint.
	Endpoint string

	// CAPath is a path to the root CA certificate(s)
	// used to verify the TLS certificate of the KeySecure
	// instance. If empty, the host's root CA set is used.
	CAPath string

	// Login credentials are used to authenticate to the
	// KeySecure instance and obtain a short-lived authentication
	// token.
	Login Credentials
}

// Store is a Gemalto KeySecure secret store.
type Store struct {
	config Config
	client *client
	stop   context.CancelFunc
}

// Connect returns a Store to a Gemalto KeySecure
// server using the given config.
func Connect(ctx context.Context, config *Config) (c *Store, err error) {
	var rootCAs *x509.CertPool
	if config.CAPath != "" {
		rootCAs, err = loadCustomCAs(config.CAPath)
		if err != nil {
			return nil, err
		}
	}

	client := &client{
		Retry: xhttp.Retry{
			Client: http.Client{
				Transport: &http.Transport{
					TLSClientConfig: &tls.Config{
						RootCAs: rootCAs,
					},
					Proxy: http.ProxyFromEnvironment,
					DialContext: (&net.Dialer{
						Timeout:   10 * time.Second,
						KeepAlive: 10 * time.Second,
						DualStack: true,
					}).DialContext,
					ForceAttemptHTTP2:     true,
					MaxIdleConns:          100,
					IdleConnTimeout:       30 * time.Second,
					TLSHandshakeTimeout:   10 * time.Second,
					ExpectContinueTimeout: 1 * time.Second,
				},
			},
		},
	}
	if err = client.Authenticate(ctx, config.Endpoint, config.Login); err != nil {
		return nil, err
	}

	ctx, cancel := context.WithCancel(ctx)
	go client.RenewAuthToken(ctx, config.Endpoint, config.Login)
	return &Store{
		config: *config,
		client: client,
		stop:   cancel,
	}, nil
}

func (s *Store) String() string { return "Gemalto KeySecure: " + s.config.Endpoint }

// Status returns the current state of the Gemalto KeySecure instance.
// In particular, whether it is reachable and the network latency.
func (s *Store) Status(ctx context.Context) (kes.KeyStoreState, error) {
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, s.config.Endpoint, nil)
	if err != nil {
		return kes.KeyStoreState{}, err
	}

	start := time.Now()
	resp, err := s.client.Do(req)
	if err != nil {
		return kes.KeyStoreState{}, &keystore.ErrUnreachable{Err: err}
	}
	defer xhttp.DrainBody(resp.Body)

	return kes.KeyStoreState{
		Latency: time.Since(start),
	}, nil
}

// Create creates the given key-value pair at Gemalto if and only
// if the given key does not exist. If such an entry already exists
// it returns kes.ErrKeyExists.
func (s *Store) Create(ctx context.Context, name string, value []byte) error {
	type Request struct {
		Type  string `json:"dataType"`
		Value string `json:"material"`
		Name  string `json:"name"`
	}

	body, err := json.Marshal(Request{
		Type:  "seed", // KeySecure supports blob, password and seed
		Name:  name,
		Value: string(value),
	})
	if err != nil {
		return fmt.Errorf("gemalto: failed to create key '%s': %v", name, err)
	}

	url := fmt.Sprintf("%s/api/v1/vault/secrets", s.config.Endpoint)
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, xhttp.RetryReader(bytes.NewReader(body)))
	if err != nil {
		return fmt.Errorf("gemalto: failed to create key '%s': %v", name, err)
	}
	req.Header.Set("Content-Type", "application/json")
	req.Header.Set("Authorization", s.client.AuthToken())

	resp, err := s.client.Do(req)
	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
		return err
	}
	if err != nil {
		return fmt.Errorf("gemalto: failed to create key '%s': %v", name, err)
	}
	defer xhttp.DrainBody(resp.Body)

	if resp.StatusCode == http.StatusConflict {
		return kesdk.ErrKeyExists
	}
	if resp.StatusCode != http.StatusCreated {
		response, err := parseServerError(resp)
		if err != nil {
			return fmt.Errorf("gemalto: '%s': failed to parse server response: %v", resp.Status, err)
		}
		return fmt.Errorf("gemalto: failed to create key '%s': %q (%d)", name, response.Message, response.Code)
	}
	return nil
}

// Set creates the given key-value pair at Gemalto if and only
// if the given key does not exist. If such an entry already exists
// it returns kes.ErrKeyExists.
func (s *Store) Set(ctx context.Context, name string, value []byte) error {
	return s.Create(ctx, name, value)
}

// Get returns the value associated with the given key.
// If no entry for the key exists it returns kes.ErrKeyNotFound.
func (s *Store) Get(ctx context.Context, name string) ([]byte, error) {
	type Response struct {
		Value string `json:"material"`
	}

	url := fmt.Sprintf("%s/api/v1/vault/secrets/%s/export?type=name", s.config.Endpoint, name)
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, nil)
	if err != nil {
		return nil, fmt.Errorf("gemalto: failed to access key '%s': %v", name, err)
	}
	req.Header.Set("Authorization", s.client.AuthToken())

	resp, err := s.client.Do(req)
	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
		return nil, err
	}
	if err != nil {
		return nil, fmt.Errorf("gemalto: failed to access key '%s': %v", name, err)
	}
	defer xhttp.DrainBody(resp.Body)

	if resp.StatusCode == http.StatusNotFound {
		return nil, kesdk.ErrKeyNotFound
	}
	if resp.StatusCode != http.StatusOK {
		response, err := parseServerError(resp)
		if err != nil {
			return nil, fmt.Errorf("gemalto: '%s': failed to parse server response: %v", resp.Status, err)
		}
		return nil, fmt.Errorf("gemalto: failed to access key %q: %q (%d)", name, response.Message, response.Code)
	}

	var response Response
	if err = json.NewDecoder(mem.LimitReader(resp.Body, 2*mem.MiB)).Decode(&response); err != nil {
		if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
			return nil, err
		}
		return nil, fmt.Errorf("gemalto: failed to parse server response: %v", err)
	}
	return []byte(response.Value), nil
}

// Delete removes a the value associated with the given key
// from Gemalto, if it exists.
func (s *Store) Delete(ctx context.Context, name string) error {
	url := fmt.Sprintf("%s/api/v1/vault/secrets/%s?type=name", s.config.Endpoint, name)
	req, err := http.NewRequestWithContext(ctx, http.MethodDelete, url, nil)
	if err != nil {
		return fmt.Errorf("gemalto: failed to delete key  '%s': %v", name, err)
	}
	req.Header.Set("Authorization", s.client.AuthToken())

	resp, err := s.client.Do(req)
	if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
		return err
	}
	if err != nil {
		return fmt.Errorf("gemalto: failed to delete key '%s': %v", name, err)
	}
	defer xhttp.DrainBody(resp.Body)

	if resp.StatusCode != http.StatusNoContent && resp.StatusCode != http.StatusNotFound {
		// BUG(aead): The KeySecure server returns 404 NotFound if the
		// secret does not exist but also when we are not allowed to access/delete
		// the secret due to insufficient policy permissions.
		// The reason for this is probably that a client should not be able
		// to determine whether a particular secret exists (if the client has
		// no access to it).
		// Unfortunately, we cannot guarantee anymore that we actually deleted the
		// secret. It could also be the case that we lost access (e.g. due to a
		// policy change). So, in this case we don't return an error such that the
		// client thinks it has deleted the secret successfully.
		response, err := parseServerError(resp)
		if err != nil {
			return fmt.Errorf("gemalto: %s: failed to parse server response: %v", resp.Status, err)
		}
		return fmt.Errorf("gemalto: failed to delete key %q: %s (%d)", name, response.Message, response.Code)
	}
	return nil
}

// List returns a new Iterator over the names of
// all stored keys.
// List returns the first n key names, that start with the given
// prefix, and the next prefix from which the listing should
// continue.
//
// It returns all keys with the prefix if n < 0 and less than n
// names if n is greater than the number of keys with the prefix.
//
// An empty prefix matches any key name. At the end of the listing
// or when there are no (more) keys starting with the prefix, the
// returned prefix is empty
func (s *Store) List(ctx context.Context, prefix string, n int) ([]string, string, error) {
	// Response is the JSON response returned by KeySecure.
	// It only contains the fields that we need to implement
	// paginated listing. The raw response contains much more
	// information - like created-at date etc.
	type Response struct {
		Skip      uint64 `json:"skip"`  // The number of items skipped (in total)
		Total     uint64 `json:"total"` // The total number of items
		Resources []struct {
			Name string `json:"name"` // The name of the key
		} `json:"resources"`
	}

	const limit = 200 // We limit a listing page to 200. This an arbitrary but reasonable value.
	var (
		skip     uint64 // Keep track of the items processed so far and skip them.
		response Response
		names    []string
	)
	for {
		// We have to tell KeySecure how many items we want to process per page and how many
		// items we want to skip - resp. how many items we have processed already.
		url := fmt.Sprintf("%s/api/v1/vault/secrets?limit=%d&skip=%d", s.config.Endpoint, limit, skip)
		req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
		if err != nil {
			return nil, "", fmt.Errorf("gemalto: failed to list keys: %v", err)
		}
		req.Header.Set("Authorization", s.client.AuthToken())

		resp, err := s.client.Do(req)
		if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
			return nil, "", err
		}
		if err != nil {
			return nil, "", err
		}
		defer xhttp.DrainBody(resp.Body)

		if resp.StatusCode != http.StatusOK {
			response, err := parseServerError(resp)
			if err != nil {
				return nil, "", fmt.Errorf("gemalto: %s: failed to parse server response: %v", resp.Status, err)
			}
			return nil, "", fmt.Errorf("gemalto: failed to list keys: '%s' (%d)", response.Message, response.Code)

		}

		const MaxBody = 32 * mem.MiB // A page should not be larger than 32 MiB.
		if err := json.NewDecoder(mem.LimitReader(resp.Body, MaxBody)).Decode(&response); err != nil {
			if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
				return nil, "", err
			}
			return nil, "", fmt.Errorf("gemalto: failed to list keys: listing page too large: %v", err)
		}

		// We check that the invariant that the KeySecure instance has skipped as many items
		// as we requested is true. If both numbers are off then the KeySecure would either
		// return items that we've already served to the client or skip items that we haven't
		// served, yet.
		if response.Skip != skip {
			return nil, "", fmt.Errorf("gemalto: failed to list keys: pagination is out-of-sync: tried to skip %d but skipped %d", skip, response.Skip)
		}
		for _, v := range response.Resources {
			names = append(names, v.Name)
		}

		skip += uint64(len(response.Resources))
		if response.Skip >= response.Total { // Stop once we've reached the end of the listing.
			break
		}
	}
	return keystore.List(names, prefix, n)
}

// Close closes the Store. It stops any authentication renewal in the background.
func (s *Store) Close() error {
	s.stop()
	return nil
}

// errResponse represents a KeySecure API error
// response.
type errResponse struct {
	Code    int    `json:"code"`
	Message string `json:"codeDesc"`
}

func parseServerError(resp *http.Response) (errResponse, error) {
	const MaxSize = 1 * mem.MiB
	size := mem.Size(resp.ContentLength)
	if size < 0 || size > MaxSize {
		size = MaxSize
	}
	defer resp.Body.Close()

	// The KeySecure server does not always return a JSON error
	// response bodies. It only returns a JSON body in case
	// of a well-defined API error - e.g. when trying to create
	// a secret with a name that already exists.
	// It does not return a JSON body in case of a missing
	// authorization header.
	// Therefore, we try to unmarshal the body only when the
	// Content-Type is application/json. Otherwise, we just assume
	// the body is a raw text string and use the HTTP response code
	// as error code.

	contentType := strings.TrimSpace(resp.Header.Get("Content-Type"))
	if strings.HasPrefix(contentType, "application/json") {
		var response errResponse
		err := json.NewDecoder(mem.LimitReader(resp.Body, size)).Decode(&response)
		return response, err
	}

	var s strings.Builder
	if _, err := io.Copy(&s, mem.LimitReader(resp.Body, size)); err != nil {
		return errResponse{}, err
	}
	message := strings.TrimSpace(s.String())
	message = strings.TrimSuffix(message, "\n") // Some error message end with '\n' causing messy logs

	return errResponse{
		Code:    resp.StatusCode,
		Message: message,
	}, nil
}

// loadCustomCAs returns a new RootCA certificate pool
// that contains one or multiple certificates found at
// the given path.
//
// If path is a file then loadCustomCAs tries to parse
// the file as a PEM-encoded certificate.
//
// If path is a directory then loadCustomCAs tries to
// parse any file inside path as PEM-encoded certificate.
// It returns a non-nil error if one file is not a valid
// PEM-encoded X.509 certificate.
func loadCustomCAs(path string) (*x509.CertPool, error) {
	rootCAs := x509.NewCertPool()

	f, err := os.Open(path)
	if err != nil {
		return rootCAs, err
	}
	defer f.Close()

	stat, err := f.Stat()
	if err != nil {
		return rootCAs, err
	}
	if !stat.IsDir() {
		bytes, err := io.ReadAll(f)
		if err != nil {
			return rootCAs, err
		}
		if !rootCAs.AppendCertsFromPEM(bytes) {
			return rootCAs, fmt.Errorf("'%s' does not contain a valid X.509 PEM-encoded certificate", path)
		}
		return rootCAs, nil
	}

	files, err := f.Readdir(0)
	if err != nil {
		return rootCAs, err
	}
	for _, file := range files {
		if file.IsDir() {
			continue
		}

		name := filepath.Join(path, file.Name())
		bytes, err := os.ReadFile(name)
		if err != nil {
			return rootCAs, err
		}
		if !rootCAs.AppendCertsFromPEM(bytes) {
			return rootCAs, fmt.Errorf("'%s' does not contain a valid X.509 PEM-encoded certificate", name)
		}
	}
	return rootCAs, nil
}
