216 lines
5.3 KiB
Go

package raft
import (
"context"
"flag"
"fmt"
"log"
"os"
"path/filepath"
"time"
"deevirt.fr/compute/pkg/scheduler"
transport "github.com/Jille/raft-grpc-transport"
"github.com/hashicorp/raft"
raftboltdb "github.com/hashicorp/raft-boltdb"
"google.golang.org/grpc"
)
var (
raftDir = flag.String("raft_data_dir", "data/", "Raft data dir")
raftBootstrap = flag.Bool("raft_bootstrap", false, "Whether to bootstrap the Raft cluster")
)
type RaftNode struct {
Raft *raft.Raft
NodeID string
StateCh chan raft.Observation
scheduler *scheduler.Scheduler
}
type Worker struct {
ctx context.Context
cancel context.CancelFunc
cancelled bool // Variable pour suivre si cancel a été appelé
}
// Vérifie si l'état Raft existe déjà
func checkIfStateExists(logStore *raftboltdb.BoltStore, snapshotStore raft.SnapshotStore) (bool, error) {
// Vérifier les snapshots
snapshots, err := snapshotStore.List()
if err != nil {
return false, err
}
if len(snapshots) > 0 {
return true, nil // Il y a déjà un snapshot, donc un état Raft
}
// Vérifier les logs Raft
firstIndex, err := logStore.FirstIndex()
if err != nil {
return false, err
}
lastIndex, err := logStore.LastIndex()
if err != nil {
return false, err
}
return lastIndex > firstIndex, nil
}
func NewRaft(ctx context.Context, myID, myAddress string) (*raft.Raft, *transport.Manager, error) {
c := raft.DefaultConfig()
c.LocalID = raft.ServerID(myID)
/*addr, err := net.ResolveTCPAddr("tcp", myAddress)
if err != nil {
return nil, nil, err
}*/
baseDir := filepath.Join(*raftDir, myID)
ldb, err := raftboltdb.NewBoltStore(filepath.Join(baseDir, "logs.dat"))
if err != nil {
return nil, nil, fmt.Errorf(`boltdb.NewBoltStore(%q): %v`, filepath.Join(baseDir, "logs.dat"), err)
}
sdb, err := raftboltdb.NewBoltStore(filepath.Join(baseDir, "stable.dat"))
if err != nil {
return nil, nil, fmt.Errorf(`boltdb.NewBoltStore(%q): %v`, filepath.Join(baseDir, "stable.dat"), err)
}
fss, err := raft.NewFileSnapshotStore(baseDir, 3, os.Stderr)
if err != nil {
return nil, nil, fmt.Errorf(`raft.NewFileSnapshotStore(%q, ...): %v`, baseDir, err)
}
/*transport, err := raft.NewTCPTransport(myAddress, addr, 3, 10*time.Second, os.Stderr)
if err != nil {
return nil, nil, fmt.Errorf("transport: %v", err)
}*/
tm := transport.New(raft.ServerAddress(myAddress), []grpc.DialOption{grpc.WithInsecure()})
r, err := raft.NewRaft(c, nil, ldb, sdb, fss, tm.Transport())
if err != nil {
return nil, nil, fmt.Errorf("raft.NewRaft: %v", err)
}
s, err := scheduler.New()
if err != nil {
return nil, nil, fmt.Errorf("scheduler: %v", err)
}
// Observer pour surveiller les changements d'état
stateCh := make(chan raft.Observation, 1) // Canal de type raft.Observation
r.RegisterObserver(raft.NewObserver(stateCh, true, nil))
node := &RaftNode{
Raft: r,
NodeID: myID,
StateCh: stateCh,
scheduler: s,
}
go node.watchStateChanges()
// 🔍 Vérification si des logs ou snapshots existent
hasState, _ := checkIfStateExists(ldb, fss)
println(myAddress)
if *raftBootstrap && !hasState {
cfg := raft.Configuration{
Servers: []raft.Server{
{
ID: raft.ServerID(myID),
Address: raft.ServerAddress(myAddress),
},
/*{
ID: raft.ServerID("nodeB"),
Address: raft.ServerAddress("localhost:50053"),
},
{
ID: raft.ServerID("nodeC"),
Address: raft.ServerAddress("localhost:50054"),
},*/
},
}
f := r.BootstrapCluster(cfg)
if err := f.Error(); err != nil {
return nil, nil, fmt.Errorf("raft.Raft.BootstrapCluster: %v", err)
}
}
return r, tm, nil
}
func (w *Worker) Start() {
go func() {
for {
select {
case <-w.ctx.Done():
fmt.Println("🛑 Worker arrêté !")
return
default:
fmt.Println("🔄 Worker en cours...")
time.Sleep(1 * time.Second)
}
}
}()
}
func (w *Worker) Stop() {
if !w.cancelled {
w.cancel() // Annuler le contexte
w.cancelled = true // Marquer comme annulé
} else {
fmt.Println("❗ Cancel déjà appelé, Worker déjà arrêté.")
}
}
func (w *Worker) IsCancelled() bool {
return w.cancelled
}
// Fonction pour surveiller et afficher les changements d'état
func (n *RaftNode) watchStateChanges() {
for obs := range n.StateCh {
switch evt := obs.Data.(type) {
case raft.RaftState:
if evt == raft.Leader {
go n.scheduler.Start()
log.Println("[ÉVÉNEMENT] Changement d'état Raft :", evt)
} else {
n.scheduler.Stop()
}
log.Println("[ÉVÉNEMENT] Changement d'état Raft :", evt)
case raft.LeaderObservation:
log.Println("[ÉVÉNEMENT] Le leader est", evt.LeaderID)
case raft.PeerObservation:
if n.Raft.State() == raft.Leader {
peerID := evt.Peer.ID
peerAddr := evt.Peer.Address
log.Println("[NOUVEAU NŒUD] Détection de", peerID, "à", peerAddr)
log.Println("[ACTION] Ajout automatique en tant que voter...")
future := n.Raft.AddVoter(peerID, peerAddr, 0, 0)
if err := future.Error(); err != nil {
log.Println("[ERREUR] Impossible d'ajouter", peerID, ":", err)
} else {
log.Println("[SUCCÈS] Voter ajouté :", peerID)
}
}
case raft.FailedHeartbeatObservation:
log.Println("[ÉVÉNEMENT] Perte de connexion avec un nœud :", evt.PeerID)
default:
log.Println("[ÉVÉNEMENT] Autre événement :", evt)
}
}
}