177 lines
4.6 KiB
Go

package transport
import (
"context"
"io"
pb "deevirt.fr/compute/pkg/proto"
"github.com/hashicorp/raft"
)
// These are requests incoming over gRPC that we need to relay to the Raft engine.
type gRPCAPI struct {
manager *Manager
// "Unsafe" to ensure compilation fails if new methods are added but not implemented
pb.UnsafeRaftTransportServer
}
func (g gRPCAPI) handleRPC(command interface{}, data io.Reader) (interface{}, error) {
ch := make(chan raft.RPCResponse, 1)
rpc := raft.RPC{
Command: command,
RespChan: ch,
Reader: data,
}
if isHeartbeat(command) {
// We can take the fast path and use the heartbeat callback and skip the queue in g.manager.rpcChan.
g.manager.heartbeatFuncMtx.Lock()
fn := g.manager.heartbeatFunc
g.manager.heartbeatFuncMtx.Unlock()
if fn != nil {
fn(rpc)
goto wait
}
}
select {
case g.manager.rpcChan <- rpc:
case <-g.manager.shutdownCh:
return nil, raft.ErrTransportShutdown
}
wait:
select {
case resp := <-ch:
if resp.Error != nil {
return nil, resp.Error
}
return resp.Response, nil
case <-g.manager.shutdownCh:
return nil, raft.ErrTransportShutdown
}
}
func (g gRPCAPI) AppendEntries(ctx context.Context, req *pb.AppendEntriesRequest) (*pb.AppendEntriesResponse, error) {
resp, err := g.handleRPC(decodeAppendEntriesRequest(req), nil)
if err != nil {
return nil, err
}
return encodeAppendEntriesResponse(resp.(*raft.AppendEntriesResponse)), nil
}
func (g gRPCAPI) AppendEntriesChunked(stream pb.RaftTransport_AppendEntriesChunkedServer) error {
appendEntriesRequest, err := receiveAppendEntriesChunkedRequest(stream)
if err != nil {
return err
}
resp, err := g.handleRPC(decodeAppendEntriesRequest(appendEntriesRequest), nil)
if err != nil {
return err
}
return stream.SendAndClose(encodeAppendEntriesResponse(resp.(*raft.AppendEntriesResponse)))
}
func (g gRPCAPI) RequestVote(ctx context.Context, req *pb.RequestVoteRequest) (*pb.RequestVoteResponse, error) {
resp, err := g.handleRPC(decodeRequestVoteRequest(req), nil)
if err != nil {
return nil, err
}
return encodeRequestVoteResponse(resp.(*raft.RequestVoteResponse)), nil
}
func (g gRPCAPI) TimeoutNow(ctx context.Context, req *pb.TimeoutNowRequest) (*pb.TimeoutNowResponse, error) {
resp, err := g.handleRPC(decodeTimeoutNowRequest(req), nil)
if err != nil {
return nil, err
}
return encodeTimeoutNowResponse(resp.(*raft.TimeoutNowResponse)), nil
}
func (g gRPCAPI) RequestPreVote(ctx context.Context, req *pb.RequestPreVoteRequest) (*pb.RequestPreVoteResponse, error) {
resp, err := g.handleRPC(decodeRequestPreVoteRequest(req), nil)
if err != nil {
return nil, err
}
return encodeRequestPreVoteResponse(resp.(*raft.RequestPreVoteResponse)), nil
}
func (g gRPCAPI) InstallSnapshot(s pb.RaftTransport_InstallSnapshotServer) error {
isr, err := s.Recv()
if err != nil {
return err
}
resp, err := g.handleRPC(decodeInstallSnapshotRequest(isr), &snapshotStream{s, isr.GetData()})
if err != nil {
return err
}
return s.SendAndClose(encodeInstallSnapshotResponse(resp.(*raft.InstallSnapshotResponse)))
}
type snapshotStream struct {
s pb.RaftTransport_InstallSnapshotServer
buf []byte
}
func (s *snapshotStream) Read(b []byte) (int, error) {
if len(s.buf) > 0 {
n := copy(b, s.buf)
s.buf = s.buf[n:]
return n, nil
}
m, err := s.s.Recv()
if err != nil {
return 0, err
}
n := copy(b, m.GetData())
if n < len(m.GetData()) {
s.buf = m.GetData()[n:]
}
return n, nil
}
func (g gRPCAPI) AppendEntriesPipeline(s pb.RaftTransport_AppendEntriesPipelineServer) error {
for {
msg, err := s.Recv()
if err != nil {
return err
}
resp, err := g.handleRPC(decodeAppendEntriesRequest(msg), nil)
if err != nil {
// TODO(quis): One failure doesn't have to break the entire stream?
// Or does it all go wrong when it's out of order anyway?
return err
}
if err := s.Send(encodeAppendEntriesResponse(resp.(*raft.AppendEntriesResponse))); err != nil {
return err
}
}
}
func (g gRPCAPI) AppendEntriesChunkedPipeline(s pb.RaftTransport_AppendEntriesChunkedPipelineServer) error {
for {
msg, err := receiveAppendEntriesChunkedRequest(s)
if err != nil {
return err
}
resp, err := g.handleRPC(decodeAppendEntriesRequest(msg), nil)
if err != nil {
return err
}
if err := s.Send(encodeAppendEntriesResponse(resp.(*raft.AppendEntriesResponse))); err != nil {
return err
}
}
}
func isHeartbeat(command interface{}) bool {
req, ok := command.(*raft.AppendEntriesRequest)
if !ok {
return false
}
return req.Term != 0 && len(req.Addr) != 0 && req.PrevLogEntry == 0 && req.PrevLogTerm == 0 && len(req.Entries) == 0 && req.LeaderCommitIndex == 0
}