package scheduler

import (
	"context"
	"fmt"
	"time"

	"github.com/prometheus/client_golang/api"
	v1 "github.com/prometheus/client_golang/api/prometheus/v1"
	"github.com/prometheus/common/model"
	"go.uber.org/zap"

	"deevirt.fr/compute/pkg/config"
	"deevirt.fr/compute/pkg/prom"
)

type Scheduler struct {
	Config *config.Config
	Log    *zap.Logger
	Api    v1.API
}

type TopNode struct {
	NodeID string
	Score  float64
}

type TopDomain struct {
	NodeID   string
	DomainID string
	Score    float64
}

func New() (*Scheduler, error) {
	config, _ := config.New()

	logger, _ := zap.NewProduction()

	client, err := api.NewClient(api.Config{
		Address: "http://172.16.9.161:9090",
	})
	if err != nil {
		logger.Error("Prometheus HS")
		return nil, nil
	}

	s := &Scheduler{
		Config: config,
		Log:    logger,
		Api:    v1.NewAPI(client),
	}

	return s, nil
}

func (s *Scheduler) GetTopNode(number int) ([]TopNode, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	// On calcul un score global, pondéré à 30% sur le processeur, 70% pour la mémoire.
	query := fmt.Sprintf(`
	topk(%d, sum(
  		(
    		(1 - sum(rate(libvirt_node_cpu_time_seconds_total{cluster_id="%s"}[5m]) / libvirt_node_cpu_threads) by (node_id)) * 0.3
    		+
    		(1 - sum(libvirt_node_memory_usage_bytes{cluster_id="%s"} / libvirt_node_memory_total_bytes) by (node_id)) * 0.7
  			) * 100
		) by (node_id) > 30 
		and on(node_id) libvirt_up == 1)
	`, number, s.Config.ClusterID, s.Config.ClusterID)

	api, _ := prom.New()
	res, _, err := api.Query(ctx, query, time.Now())
	if err != nil {
		return nil, fmt.Errorf("erreur lors de la récupération des alertes filtrées: %v", err)
	}

	data := []TopNode{}
	for _, res := range res.(model.Vector) {
		data = append(data, TopNode{
			NodeID: string(res.Metric["node_id"]),
			Score:  float64(res.Value),
		})
	}

	return data, nil
}

func (s *Scheduler) GetTopNodeCPU(number int) ([]TopNode, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	// On calcul un score global, pondéré à 30% sur le processeur, 70% pour la mémoire.
	/*query := fmt.Sprintf(`
		topk(1, ( (1 - sum(rate(libvirt_domain_cpu_time_seconds_total{cluster_id="%s"}[5m]) / libvirt_domain_virtual_cpus)
	      by (node_id, domain_id)) * 100) > 30
	      and on(node_id, domain_id) libvirt_domain_state == 1)
		`, s.Config.ClusterID)*/

	query := fmt.Sprintf(`
	topk(%d, ( (1 - sum(rate(libvirt_node_cpu_time_seconds_total{cluster_id="%s"}[5m]) / libvirt_node_cpu_threads) 
      by (node_id)) * 100) > 30
      and on(node_id) libvirt_up == 1)
	`, number, s.Config.ClusterID)

	api, _ := prom.New()
	res, _, err := api.Query(ctx, query, time.Now())
	if err != nil {
		return nil, fmt.Errorf("erreur lors de la récupération des alertes filtrées: %v", err)
	}

	data := []TopNode{}
	for _, res := range res.(model.Vector) {
		data = append(data, TopNode{
			NodeID: string(res.Metric["node_id"]),
			Score:  float64(res.Value),
		})
	}

	return data, nil
}

func (s *Scheduler) GetTopNodeMemory(number int) ([]TopNode, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	// On calcul un score global, pondéré à 30% sur le processeur, 70% pour la mémoire.
	query := fmt.Sprintf(`
	topk(%d, ( (1 - sum(libvirt_node_memory_usage_bytes{cluster_id="%s"} / libvirt_node_memory_total_bytes) 
      by (node_id)) * 100) > 30
      and on(node_id) libvirt_up == 1)
	`, number, s.Config.ClusterID)

	api, _ := prom.New()
	res, _, err := api.Query(ctx, query, time.Now())
	if err != nil {
		return nil, fmt.Errorf("erreur lors de la récupération des alertes filtrées: %v", err)
	}

	data := []TopNode{}
	for _, res := range res.(model.Vector) {
		data = append(data, TopNode{
			NodeID: string(res.Metric["node_id"]),
			Score:  float64(res.Value),
		})
	}

	return data, nil
}

// Domains
func (s *Scheduler) GetTopDomainCPUUse(nodeID string, number int) ([]TopDomain, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	// On calcul un score global, pondéré à 30% sur le processeur, 70% pour la mémoire.
	query := fmt.Sprintf(`
	topk(%d, sum(
        rate(libvirt_domain_cpu_time_seconds_total{node_id="%s"}[5m]) / libvirt_domain_virtual_cpus)
        by (node_id, domain_id) * 100
		and on(domain_id) libvirt_domain_state == 1 and on(node_id) libvirt_up == 1)
	`, number, nodeID)

	api, _ := prom.New()
	res, _, err := api.Query(ctx, query, time.Now())
	if err != nil {
		return nil, fmt.Errorf("erreur lors de la récupération des alertes filtrées: %v", err)
	}

	data := []TopDomain{}
	for _, res := range res.(model.Vector) {
		data = append(data, TopDomain{
			NodeID:   string(res.Metric["node_id"]),
			DomainID: string(res.Metric["domain_id"]),
			Score:    float64(res.Value),
		})
	}

	return data, nil
}

func (s *Scheduler) GetTopDomainMemoryUse(nodeID string, number int) ([]TopDomain, error) {
	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
	defer cancel()

	// On calcul un score global, pondéré à 30% sur le processeur, 70% pour la mémoire.
	query := fmt.Sprintf(`
	topk(%d, sum(libvirt_domain_balloon_current_bytes{node_id="%s"})
        by (node_id, domain_id)
		and on(domain_id) libvirt_domain_state == 1 and on(node_id) libvirt_up == 1)
	`, number, nodeID)

	api, _ := prom.New()
	res, _, err := api.Query(ctx, query, time.Now())
	if err != nil {
		return nil, fmt.Errorf("erreur lors de la récupération des alertes filtrées: %v", err)
	}

	data := []TopDomain{}
	for _, res := range res.(model.Vector) {
		data = append(data, TopDomain{
			NodeID:   string(res.Metric["node_id"]),
			DomainID: string(res.Metric["domain_id"]),
			Score:    float64(res.Value),
		})
	}

	return data, nil
}