// Package transport provides a Transport for github.com/hashicorp/raft over gRPC.
package transport

import (
	"sync"
	"time"

	pb "deevirt.fr/compute/pkg/api/proto"
	"github.com/hashicorp/go-multierror"
	"github.com/hashicorp/raft"
	"github.com/pkg/errors"
	"google.golang.org/grpc"
)

var (
	errCloseErr = errors.New("error closing connections")
)

type Manager struct {
	localAddress raft.ServerAddress
	dialOptions  []grpc.DialOption

	rpcChan          chan raft.RPC
	heartbeatFunc    func(raft.RPC)
	heartbeatFuncMtx sync.Mutex
	heartbeatTimeout time.Duration

	connectionsMtx         sync.Mutex
	connections            map[raft.ServerAddress]*conn
	appendEntriesChunkSize int

	shutdown     bool
	shutdownCh   chan struct{}
	shutdownLock sync.Mutex
}

// New creates both components of raft-grpc-transport: a gRPC service and a Raft Transport.
func New(localAddress raft.ServerAddress, dialOptions []grpc.DialOption, options ...Option) *Manager {
	m := &Manager{
		localAddress: localAddress,
		dialOptions:  dialOptions,

		rpcChan:                make(chan raft.RPC),
		connections:            map[raft.ServerAddress]*conn{},
		appendEntriesChunkSize: 4*1024*1024 - 10, // same as gRPC default value (minus some overhead)

		shutdownCh: make(chan struct{}),
	}
	for _, opt := range options {
		opt(m)
	}
	return m
}

// Register the RaftTransport gRPC service on a gRPC server.
func (m *Manager) Register(s grpc.ServiceRegistrar) {
	pb.RegisterRaftTransportServer(s, gRPCAPI{manager: m})
}

// Transport returns a raft.Transport that communicates over gRPC.
func (m *Manager) Transport() raft.Transport {
	return raftAPI{m}
}

func (m *Manager) Close() error {
	m.shutdownLock.Lock()
	defer m.shutdownLock.Unlock()

	if m.shutdown {
		return nil
	}

	close(m.shutdownCh)
	m.shutdown = true
	return m.disconnectAll()
}

func (m *Manager) disconnectAll() error {
	m.connectionsMtx.Lock()
	defer m.connectionsMtx.Unlock()

	err := errCloseErr
	for k, conn := range m.connections {
		// Lock conn.mtx to ensure Dial() is complete
		conn.mtx.Lock()
		closeErr := conn.clientConn.Close()
		conn.mtx.Unlock()
		if closeErr != nil {
			err = multierror.Append(err, closeErr)
		}
		delete(m.connections, k)
	}

	if err != errCloseErr {
		return err
	}

	return nil
}