package raft

import (
	"crypto/tls"
	"crypto/x509"
	"encoding/json"
	"fmt"
	"log"
	"os"
	"path/filepath"
	"regexp"
	"strings"
	"sync"
	"time"

	transport "deevirt.fr/compute/pkg/api/raft/transport"
	"github.com/hashicorp/raft"
	raftboltdb "github.com/hashicorp/raft-boltdb/v2"
	"google.golang.org/grpc"
	"google.golang.org/grpc/credentials"

	"deevirt.fr/compute/pkg/config"
	etcd_client "deevirt.fr/compute/pkg/etcd"
)

const (
	retainSnapshotCount = 2
	raftTimeout         = 10 * time.Second
)

type Domain struct {
	State  int
	Config byte
	CPUMAP byte
}

type Node struct {
	Domains []Domain
}

type command struct {
	Op    string `json:"op,omitempty"`
	Key   string `json:"key,omitempty"`
	Value []byte `json:"value,omitempty"`
}

type Store struct {
	mu   sync.Mutex
	conf *config.Config // Configuration générale

	m map[string][]byte // The key-value store for the system.

	Raft *raft.Raft // The consensus mechanism

	logger *log.Logger
}

type Peers struct {
	Id      string
	Address string
}

func getTLSCredentials(conf *config.Config) credentials.TransportCredentials {
	cert, err := tls.LoadX509KeyPair(conf.Manager.TlsCert, conf.Manager.TlsKey)
	if err != nil {
		log.Fatalf("Erreur chargement du certificat: %v", err)
	}

	// Charger la CA (facultatif, pour la vérification des clients)
	caCert, err := os.ReadFile(conf.Manager.TlsCert)
	if err != nil {
		log.Fatalf("Erreur chargement CA: %v", err)
	}
	certPool := x509.NewCertPool()
	certPool.AppendCertsFromPEM(caCert)

	// Créer les credentials TLS
	creds := credentials.NewTLS(&tls.Config{
		Certificates:       []tls.Certificate{cert},
		ClientCAs:          certPool,
		InsecureSkipVerify: true,
	})

	return creds
}

func New(conf *config.Config) *Store {
	return &Store{
		conf:   conf,
		m:      make(map[string][]byte),
		logger: log.New(os.Stderr, "[store] ", log.LstdFlags),
	}
}

func (s *Store) Open() (*transport.Manager, error) {
	// Création du répertoire
	baseDir := filepath.Join("/var/lib/deevirt/mgr/", s.conf.NodeID)
	err := os.MkdirAll(baseDir, 0740)
	if err != nil {
		return nil, err
	}

	c := raft.DefaultConfig()
	c.SnapshotInterval = 60 * time.Second
	c.SnapshotThreshold = 1000
	c.HeartbeatTimeout = 2 * time.Second
	c.ElectionTimeout = 3 * time.Second

	c.LocalID = raft.ServerID(s.conf.NodeID)

	ldb, err := raftboltdb.NewBoltStore(filepath.Join(baseDir, "logs.dat"))
	if err != nil {
		return nil, fmt.Errorf(`boltdb.NewBoltStore(%q): %v`, filepath.Join(baseDir, "logs.dat"), err)
	}

	fss, err := raft.NewFileSnapshotStore(baseDir, 3, os.Stderr)
	if err != nil {
		return nil, fmt.Errorf(`raft.NewFileSnapshotStore(%q, ...): %v`, baseDir, err)
	}

	dialOption := []grpc.DialOption{}

	if s.conf.Manager.TlsKey != "" {
		dialOption = append(dialOption, grpc.WithTransportCredentials(getTLSCredentials(s.conf)))
	}

	tm := transport.New(raft.ServerAddress(s.conf.AddressPrivate), dialOption)

	r, err := raft.NewRaft(c, (*Fsm)(s), ldb, ldb, fss, tm.Transport())
	if err != nil {
		return nil, fmt.Errorf("raft.NewRaft: %v", err)
	}
	s.Raft = r

	// Observer pour surveiller les changements d'état
	stateCh := make(chan raft.Observation, 1) // Canal de type raft.Observation
	r.RegisterObserver(raft.NewObserver(stateCh, true, nil))

	node := &RaftNode{
		Bootstrap: false,
		Raft:      r,
		Store:     s,
		NodeID:    s.conf.NodeID,
		StateCh:   stateCh,
	}

	go node.WatchStateChanges()

	hasState, _ := checkIfStateExists(ldb)

	if strings.Split(s.conf.AddressPrivate, ":")[0] == s.conf.AddressPrivate && !hasState {
		println("Démarrage du bootstrap ! ")
		node.Bootstrap = true

		// Récupération des Noeuds ID
		etcd, err := etcd_client.New(s.conf.EtcdURI)
		if err != nil {
			return nil, err
		}
		defer etcd.Close()

		peers := []raft.Server{}

		for key, value := range etcd_client.GetNodes(etcd, s.conf.ClusterID) {
			for _, peer := range s.conf.Manager.Peers {
				addressPort := strings.Split(peer, ":")
				if addressPort[0] == value.IpManagement {
					peers = append(peers, raft.Server{
						ID:      raft.ServerID(key),
						Address: raft.ServerAddress(peer),
					})
				}
			}
		}

		cfg := raft.Configuration{
			Servers: peers,
		}
		f := r.BootstrapCluster(cfg)
		if err := f.Error(); err != nil {
			return nil, fmt.Errorf("raft.Raft.BootstrapCluster: %v", err)
		}
	}

	return tm, nil
}

type LsOptions struct {
	Recursive bool
	Data      bool
}

// Retourne le contenu de la clé
func (s *Store) Ls(key string, options LsOptions) (map[string][]byte, error) {
	s.mu.Lock()
	defer s.mu.Unlock()

	dir := map[string][]byte{}

	for k, v := range s.m {
		if options.Recursive {
			re := regexp.MustCompile(fmt.Sprintf("^%s/([^/]+)/([^/]+)", key))
			matches := re.FindStringSubmatch(k)
			if matches != nil {
				if options.Data {
					dir[strings.Join(matches[1:], "/")] = v
				} else {
					dir[strings.Join(matches[1:], "/")] = nil
				}
			}
		} else {
			re := regexp.MustCompile(fmt.Sprintf("^%s/([^/]+)$", key))
			matches := re.FindStringSubmatch(k)
			if matches != nil {
				if options.Data {
					dir[matches[1]] = v
				} else {
					dir[matches[1]] = nil
				}
			}
		}

	}

	return dir, nil
}

// Get returns the value for the given key.
func (s *Store) Get(key string) ([]byte, error) {
	s.mu.Lock()
	defer s.mu.Unlock()
	return s.m[key], nil
}

// Set sets the value for the given key.
func (s *Store) Set(key string, value []byte) error {
	if s.Raft.State() != raft.Leader {
		return fmt.Errorf("not leader")
	}

	c := &command{
		Op:    "set",
		Key:   key,
		Value: value,
	}
	b, err := json.Marshal(c)
	if err != nil {
		return err
	}

	f := s.Raft.Apply(b, raftTimeout)
	return f.Error()
}

// Delete deletes the given key.
func (s *Store) Delete(key string) error {
	if s.Raft.State() != raft.Leader {
		return fmt.Errorf("not leader")
	}

	c := &command{
		Op:  "delete",
		Key: key,
	}
	b, err := json.Marshal(c)
	if err != nil {
		return err
	}

	f := s.Raft.Apply(b, raftTimeout)
	return f.Error()
}

// Vérifie si l'état Raft existe déjà
func checkIfStateExists(logStore *raftboltdb.BoltStore) (bool, error) {
	// Vérifier les logs Raft
	firstIndex, err := logStore.FirstIndex()
	if err != nil {
		return false, err
	}

	if firstIndex > 0 {
		return true, nil
	}

	return false, nil
}