package transport

import (
	"context"
	"io"
	"sync"
	"time"

	pb "deevirt.fr/compute/pkg/api/proto"
	"github.com/hashicorp/raft"
	"google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/status"
)

// These are calls from the Raft engine that we need to send out over gRPC.

type raftAPI struct {
	manager *Manager
}

var _ raft.Transport = raftAPI{}
var _ raft.WithClose = raftAPI{}
var _ raft.WithPeers = raftAPI{}
var _ raft.WithPreVote = raftAPI{}

type conn struct {
	clientConn *grpc.ClientConn
	client     pb.RaftTransportClient
	mtx        sync.Mutex
}

// Consumer returns a channel that can be used to consume and respond to RPC requests.
func (r raftAPI) Consumer() <-chan raft.RPC {
	return r.manager.rpcChan
}

// LocalAddr is used to return our local address to distinguish from our peers.
func (r raftAPI) LocalAddr() raft.ServerAddress {
	return r.manager.localAddress
}

func (r raftAPI) getPeer(target raft.ServerAddress) (pb.RaftTransportClient, error) {
	r.manager.connectionsMtx.Lock()
	c, ok := r.manager.connections[target]
	if !ok {
		c = &conn{}
		c.mtx.Lock()
		r.manager.connections[target] = c
	}
	r.manager.connectionsMtx.Unlock()
	if ok {
		c.mtx.Lock()
	}
	defer c.mtx.Unlock()
	if c.clientConn == nil {
		conn, err := grpc.NewClient(string(target), r.manager.dialOptions...)
		if err != nil {
			return nil, err
		}
		c.clientConn = conn
		c.client = pb.NewRaftTransportClient(conn)
	}
	return c.client, nil
}

// AppendEntries sends the appropriate RPC to the target node.
func (r raftAPI) AppendEntries(id raft.ServerID, target raft.ServerAddress, args *raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) error {
	c, err := r.getPeer(target)
	if err != nil {
		return err
	}
	ctx := context.TODO()
	if r.manager.heartbeatTimeout > 0 && isHeartbeat(args) {
		var cancel context.CancelFunc
		ctx, cancel = context.WithTimeout(ctx, r.manager.heartbeatTimeout)
		defer cancel()
	}
	appendEntriesRequest := encodeAppendEntriesRequest(args)
	ret, err := c.AppendEntries(ctx, appendEntriesRequest)
	if statusErr, ok := status.FromError(err); ok && statusErr.Code() == codes.ResourceExhausted {
		chunkedRet, chunkedErr := r.appendEntriesChunked(ctx, r.manager.appendEntriesChunkSize, c, appendEntriesRequest)
		if statusErr, ok := status.FromError(chunkedErr); ok && statusErr.Code() != codes.Unimplemented {
			ret, err = chunkedRet, chunkedErr
		}
	}
	if err != nil {
		return err
	}
	*resp = *decodeAppendEntriesResponse(ret)
	return nil
}

// AppendEntries sends the appropriate RPC to the target node.
func (r raftAPI) appendEntriesChunked(ctx context.Context, chunkSize int, c pb.RaftTransportClient, appendEntriesRequest *pb.AppendEntriesRequest) (*pb.AppendEntriesResponse, error) {
	stream, err := c.AppendEntriesChunked(ctx)
	if err != nil {
		return &pb.AppendEntriesResponse{}, err
	}
	defer stream.CloseSend()

	if err := sendAppendEntriesChunkedRequest(chunkSize, stream, appendEntriesRequest); err != nil {
		return &pb.AppendEntriesResponse{}, err
	}

	return stream.CloseAndRecv()
}

// RequestVote sends the appropriate RPC to the target node.
func (r raftAPI) RequestVote(id raft.ServerID, target raft.ServerAddress, args *raft.RequestVoteRequest, resp *raft.RequestVoteResponse) error {
	c, err := r.getPeer(target)
	if err != nil {
		return err
	}
	ret, err := c.RequestVote(context.TODO(), encodeRequestVoteRequest(args))
	if err != nil {
		return err
	}
	*resp = *decodeRequestVoteResponse(ret)
	return nil
}

// TimeoutNow is used to start a leadership transfer to the target node.
func (r raftAPI) TimeoutNow(id raft.ServerID, target raft.ServerAddress, args *raft.TimeoutNowRequest, resp *raft.TimeoutNowResponse) error {
	c, err := r.getPeer(target)
	if err != nil {
		return err
	}
	ret, err := c.TimeoutNow(context.TODO(), encodeTimeoutNowRequest(args))
	if err != nil {
		return err
	}
	*resp = *decodeTimeoutNowResponse(ret)
	return nil
}

// RequestPreVote is the command used by a candidate to ask a Raft peer for a vote in an election.
func (r raftAPI) RequestPreVote(id raft.ServerID, target raft.ServerAddress, args *raft.RequestPreVoteRequest, resp *raft.RequestPreVoteResponse) error {
	c, err := r.getPeer(target)
	if err != nil {
		return err
	}
	ret, err := c.RequestPreVote(context.TODO(), encodeRequestPreVoteRequest(args))
	if err != nil {
		return err
	}
	*resp = *decodeRequestPreVoteResponse(ret)
	return nil
}

// InstallSnapshot is used to push a snapshot down to a follower. The data is read from
// the ReadCloser and streamed to the client.
func (r raftAPI) InstallSnapshot(id raft.ServerID, target raft.ServerAddress, req *raft.InstallSnapshotRequest, resp *raft.InstallSnapshotResponse, data io.Reader) error {
	c, err := r.getPeer(target)
	if err != nil {
		return err
	}
	stream, err := c.InstallSnapshot(context.TODO())
	if err != nil {
		return err
	}
	if err := stream.Send(encodeInstallSnapshotRequest(req)); err != nil {
		return err
	}
	var buf [16384]byte
	for {
		n, err := data.Read(buf[:])
		if err == io.EOF || (err == nil && n == 0) {
			break
		}
		if err != nil {
			return err
		}
		if err := stream.Send(&pb.InstallSnapshotRequest{
			Data: buf[:n],
		}); err != nil {
			return err
		}
	}
	ret, err := stream.CloseAndRecv()
	if err != nil {
		return err
	}
	*resp = *decodeInstallSnapshotResponse(ret)
	return nil
}

type AppendEntriesPipelineInterface interface {
	grpc.ClientStream
	Recv() (*pb.AppendEntriesResponse, error)
}

// AppendEntriesPipeline returns an interface that can be used to pipeline
// AppendEntries requests.
func (r raftAPI) AppendEntriesPipeline(id raft.ServerID, target raft.ServerAddress) (raft.AppendPipeline, error) {
	c, err := r.getPeer(target)
	if err != nil {
		return nil, err
	}
	ctx := context.TODO()
	ctx, cancel := context.WithCancel(ctx)
	var stream AppendEntriesPipelineInterface
	stream, err = c.AppendEntriesChunkedPipeline(ctx)
	if statusErr, ok := status.FromError(err); ok && statusErr.Code() == codes.Unimplemented {
		stream, err = c.AppendEntriesPipeline(ctx)
	}
	if err != nil {
		cancel()
		return nil, err
	}
	rpa := &raftPipelineAPI{
		stream:                 stream,
		appendEntriesChunkSize: r.manager.appendEntriesChunkSize,
		cancel:                 cancel,
		inflightCh:             make(chan *appendFuture, 20),
		doneCh:                 make(chan raft.AppendFuture, 20),
	}
	go rpa.receiver()
	return rpa, nil
}

type raftPipelineAPI struct {
	stream                 AppendEntriesPipelineInterface
	appendEntriesChunkSize int
	cancel                 func()
	inflightChMtx          sync.Mutex
	inflightCh             chan *appendFuture
	doneCh                 chan raft.AppendFuture
}

// AppendEntries is used to add another request to the pipeline.
// The send may block which is an effective form of back-pressure.
func (r *raftPipelineAPI) AppendEntries(req *raft.AppendEntriesRequest, resp *raft.AppendEntriesResponse) (raft.AppendFuture, error) {
	af := &appendFuture{
		start:   time.Now(),
		request: req,
		done:    make(chan struct{}),
	}
	var err error
	appendEntriesRequest := encodeAppendEntriesRequest(req)
	switch stream := r.stream.(type) {
	case pb.RaftTransport_AppendEntriesPipelineClient:
		err = stream.Send(appendEntriesRequest)
	case pb.RaftTransport_AppendEntriesChunkedPipelineClient:
		err = sendAppendEntriesChunkedRequest(r.appendEntriesChunkSize, stream, appendEntriesRequest)
	}
	if err != nil {
		return nil, err
	}
	r.inflightChMtx.Lock()
	select {
	case <-r.stream.Context().Done():
	default:
		r.inflightCh <- af
	}
	r.inflightChMtx.Unlock()
	return af, nil
}

// Consumer returns a channel that can be used to consume
// response futures when they are ready.
func (r *raftPipelineAPI) Consumer() <-chan raft.AppendFuture {
	return r.doneCh
}

// Close closes the pipeline and cancels all inflight RPCs
func (r *raftPipelineAPI) Close() error {
	r.cancel()
	r.inflightChMtx.Lock()
	close(r.inflightCh)
	r.inflightChMtx.Unlock()
	return nil
}

func (r *raftPipelineAPI) receiver() {
	for af := range r.inflightCh {
		msg, err := r.stream.Recv()
		if err != nil {
			af.err = err
		} else {
			af.response = *decodeAppendEntriesResponse(msg)
		}
		close(af.done)
		r.doneCh <- af
	}
}

type appendFuture struct {
	raft.AppendFuture

	start    time.Time
	request  *raft.AppendEntriesRequest
	response raft.AppendEntriesResponse
	err      error
	done     chan struct{}
}

// Error blocks until the future arrives and then
// returns the error status of the future.
// This may be called any number of times - all
// calls will return the same value.
// Note that it is not OK to call this method
// twice concurrently on the same Future instance.
func (f *appendFuture) Error() error {
	<-f.done
	return f.err
}

// Start returns the time that the append request was started.
// It is always OK to call this method.
func (f *appendFuture) Start() time.Time {
	return f.start
}

// Request holds the parameters of the AppendEntries call.
// It is always OK to call this method.
func (f *appendFuture) Request() *raft.AppendEntriesRequest {
	return f.request
}

// Response holds the results of the AppendEntries call.
// This method must only be called after the Error
// method returns, and will only be valid on success.
func (f *appendFuture) Response() *raft.AppendEntriesResponse {
	return &f.response
}

// EncodePeer is used to serialize a peer's address.
func (r raftAPI) EncodePeer(id raft.ServerID, addr raft.ServerAddress) []byte {
	return []byte(addr)
}

// DecodePeer is used to deserialize a peer's address.
func (r raftAPI) DecodePeer(p []byte) raft.ServerAddress {
	return raft.ServerAddress(p)
}

// SetHeartbeatHandler is used to setup a heartbeat handler
// as a fast-pass. This is to avoid head-of-line blocking from
// disk IO. If a Transport does not support this, it can simply
// ignore the call, and push the heartbeat onto the Consumer channel.
func (r raftAPI) SetHeartbeatHandler(cb func(rpc raft.RPC)) {
	r.manager.heartbeatFuncMtx.Lock()
	r.manager.heartbeatFunc = cb
	r.manager.heartbeatFuncMtx.Unlock()
}

func (r raftAPI) Close() error {
	return r.manager.Close()
}

func (r raftAPI) Connect(target raft.ServerAddress, t raft.Transport) {
	_, _ = r.getPeer(target)
}

func (r raftAPI) Disconnect(target raft.ServerAddress) {
	r.manager.connectionsMtx.Lock()
	c, ok := r.manager.connections[target]
	if !ok {
		delete(r.manager.connections, target)
	}
	r.manager.connectionsMtx.Unlock()
	if ok {
		c.mtx.Lock()
		_ = c.clientConn.Close()
		c.mtx.Unlock()
	}
}

func (r raftAPI) DisconnectAll() {
	_ = r.manager.disconnectAll()
}