package transport

import (
	"context"
	"io"

	pb "deevirt.fr/compute/pkg/api/proto"
	"github.com/hashicorp/raft"
)

// These are requests incoming over gRPC that we need to relay to the Raft engine.

type gRPCAPI struct {
	manager *Manager

	// "Unsafe" to ensure compilation fails if new methods are added but not implemented
	pb.UnsafeRaftTransportServer
}

func (g gRPCAPI) handleRPC(command interface{}, data io.Reader) (interface{}, error) {
	ch := make(chan raft.RPCResponse, 1)
	rpc := raft.RPC{
		Command:  command,
		RespChan: ch,
		Reader:   data,
	}
	if isHeartbeat(command) {
		// We can take the fast path and use the heartbeat callback and skip the queue in g.manager.rpcChan.
		g.manager.heartbeatFuncMtx.Lock()
		fn := g.manager.heartbeatFunc
		g.manager.heartbeatFuncMtx.Unlock()
		if fn != nil {
			fn(rpc)
			goto wait
		}
	}
	select {
	case g.manager.rpcChan <- rpc:
	case <-g.manager.shutdownCh:
		return nil, raft.ErrTransportShutdown
	}

wait:
	select {
	case resp := <-ch:
		if resp.Error != nil {
			return nil, resp.Error
		}
		return resp.Response, nil
	case <-g.manager.shutdownCh:
		return nil, raft.ErrTransportShutdown
	}
}

func (g gRPCAPI) AppendEntries(ctx context.Context, req *pb.AppendEntriesRequest) (*pb.AppendEntriesResponse, error) {
	resp, err := g.handleRPC(decodeAppendEntriesRequest(req), nil)
	if err != nil {
		return nil, err
	}
	return encodeAppendEntriesResponse(resp.(*raft.AppendEntriesResponse)), nil
}

func (g gRPCAPI) AppendEntriesChunked(stream pb.RaftTransport_AppendEntriesChunkedServer) error {
	appendEntriesRequest, err := receiveAppendEntriesChunkedRequest(stream)
	if err != nil {
		return err
	}

	resp, err := g.handleRPC(decodeAppendEntriesRequest(appendEntriesRequest), nil)
	if err != nil {
		return err
	}

	return stream.SendAndClose(encodeAppendEntriesResponse(resp.(*raft.AppendEntriesResponse)))
}

func (g gRPCAPI) RequestVote(ctx context.Context, req *pb.RequestVoteRequest) (*pb.RequestVoteResponse, error) {
	resp, err := g.handleRPC(decodeRequestVoteRequest(req), nil)
	if err != nil {
		return nil, err
	}
	return encodeRequestVoteResponse(resp.(*raft.RequestVoteResponse)), nil
}

func (g gRPCAPI) TimeoutNow(ctx context.Context, req *pb.TimeoutNowRequest) (*pb.TimeoutNowResponse, error) {
	resp, err := g.handleRPC(decodeTimeoutNowRequest(req), nil)
	if err != nil {
		return nil, err
	}
	return encodeTimeoutNowResponse(resp.(*raft.TimeoutNowResponse)), nil
}

func (g gRPCAPI) RequestPreVote(ctx context.Context, req *pb.RequestPreVoteRequest) (*pb.RequestPreVoteResponse, error) {
	resp, err := g.handleRPC(decodeRequestPreVoteRequest(req), nil)
	if err != nil {
		return nil, err
	}
	return encodeRequestPreVoteResponse(resp.(*raft.RequestPreVoteResponse)), nil
}

func (g gRPCAPI) InstallSnapshot(s pb.RaftTransport_InstallSnapshotServer) error {
	isr, err := s.Recv()
	if err != nil {
		return err
	}
	resp, err := g.handleRPC(decodeInstallSnapshotRequest(isr), &snapshotStream{s, isr.GetData()})
	if err != nil {
		return err
	}
	return s.SendAndClose(encodeInstallSnapshotResponse(resp.(*raft.InstallSnapshotResponse)))
}

type snapshotStream struct {
	s pb.RaftTransport_InstallSnapshotServer

	buf []byte
}

func (s *snapshotStream) Read(b []byte) (int, error) {
	if len(s.buf) > 0 {
		n := copy(b, s.buf)
		s.buf = s.buf[n:]
		return n, nil
	}
	m, err := s.s.Recv()
	if err != nil {
		return 0, err
	}
	n := copy(b, m.GetData())
	if n < len(m.GetData()) {
		s.buf = m.GetData()[n:]
	}
	return n, nil
}

func (g gRPCAPI) AppendEntriesPipeline(s pb.RaftTransport_AppendEntriesPipelineServer) error {
	for {
		msg, err := s.Recv()
		if err != nil {
			return err
		}
		resp, err := g.handleRPC(decodeAppendEntriesRequest(msg), nil)
		if err != nil {
			// TODO(quis): One failure doesn't have to break the entire stream?
			// Or does it all go wrong when it's out of order anyway?
			return err
		}
		if err := s.Send(encodeAppendEntriesResponse(resp.(*raft.AppendEntriesResponse))); err != nil {
			return err
		}
	}
}

func (g gRPCAPI) AppendEntriesChunkedPipeline(s pb.RaftTransport_AppendEntriesChunkedPipelineServer) error {
	for {
		msg, err := receiveAppendEntriesChunkedRequest(s)
		if err != nil {
			return err
		}
		resp, err := g.handleRPC(decodeAppendEntriesRequest(msg), nil)
		if err != nil {
			return err
		}
		if err := s.Send(encodeAppendEntriesResponse(resp.(*raft.AppendEntriesResponse))); err != nil {
			return err
		}
	}
}

func isHeartbeat(command interface{}) bool {
	req, ok := command.(*raft.AppendEntriesRequest)
	if !ok {
		return false
	}
	return req.Term != 0 && len(req.Addr) != 0 && req.PrevLogEntry == 0 && req.PrevLogTerm == 0 && len(req.Entries) == 0 && req.LeaderCommitIndex == 0
}