import { useApolloClient } from '@apollo/client';

import { gql } from '../../__generated__';
import type {
  GetLeadsQuery,
  GetLeadsQueryVariables,
  GetPipelineStagesQuery,
} from '../../__generated__/graphql';

import { buildQueryVars } from './LeadsKanbanView';
import { GET_LEADS, GET_LEADS_COUNT } from './leadsQueries';

type UpdateLeadsStageParams = {
  stageId: string;
  leadIds: string[];
  where: GetLeadsQueryVariables['where'];
  stagesData?: GetPipelineStagesQuery;
  isKanbanView?: boolean;
};

export const useUpdateLeadsStageInCache = () => {
  const apolloClient = useApolloClient();

  return ({
    stageId: newStageId,
    leadIds,
    where,
    stagesData,
    isKanbanView = false,
    destinationIndex,
  }: UpdateLeadsStageParams & { destinationIndex?: number }) => {
    if (!isKanbanView || newStageId == null) {
      return;
    }

    const updateLeadInCache = async (leadId: string) => {
      const leadRef = apolloClient.cache.identify({
        __typename: 'leads',
        id: leadId,
      });

      // Get the current lead data to find its current stage
      const currentLead = apolloClient.cache.readFragment<{
        id: string;
        stage: { id: string };
      }>({
        id: leadRef,
        fragment: gql(/* GraphQL */ `
          fragment CurrentLead on leads {
            id
            stage {
              id
            }
          }
        `),
      });

      const oldStageId = currentLead?.stage?.id;
      const stageInfo = stagesData?.pipelines_by_pk?.stages?.find(
        s => s.id === newStageId,
      );

      const getQueryVars = (stageId: string) => {
        const queryVars = buildQueryVars(where, stageId);
        return {
          query: GET_LEADS,
          variables: queryVars,
        };
      };

      const getCountQueryVars = (stageId: string) => ({
        query: GET_LEADS_COUNT,
        variables: {
          where: buildQueryVars(where, stageId).where,
        },
      });

      // Batch cache updates in a single operation
      apolloClient.cache.batch({
        update: cache => {
          // 1. Remove from old stage if it exists
          if (oldStageId) {
            const oldData = cache.readQuery<GetLeadsQuery>(
              getQueryVars(oldStageId),
            );

            if (oldData?.leads) {
              cache.writeQuery<GetLeadsQuery>({
                ...getQueryVars(oldStageId),
                data: {
                  ...oldData,
                  leads: oldData.leads.filter(l => l.id !== leadId),
                },
              });

              // Update count for old stage
              const oldCountData = cache.readQuery(
                getCountQueryVars(oldStageId),
              );
              if (oldCountData) {
                cache.writeQuery({
                  ...getCountQueryVars(oldStageId),
                  data: {
                    leads_aggregate: {
                      aggregate: {
                        count: Math.max(
                          0,
                          (oldCountData.leads_aggregate?.aggregate?.count ??
                            0) - 1,
                        ),
                      },
                    },
                  },
                });
              }
            }
          }

          // 2. Add to new stage
          const newData = cache.readQuery<GetLeadsQuery>(
            getQueryVars(newStageId),
          );

          const updatedLead = {
            ...currentLead,
            stage_id: newStageId,
            stage: {
              __typename: 'lead_stages',
              id: newStageId,
              label: stageInfo?.label,
              status: stageInfo?.status,
              pipeline: {
                __typename: 'pipelines',
                id: stagesData?.pipelines_by_pk?.id ?? '',
                lead_type: 'sales',
              },
            },
          } as unknown as GetLeadsQuery['leads'][number];

          const updatedLeads = newData?.leads ? [...newData.leads] : [];
          if (typeof destinationIndex === 'number') {
            updatedLeads.splice(destinationIndex, 0, updatedLead);
          } else {
            updatedLeads.push(updatedLead);
          }

          cache.writeQuery<GetLeadsQuery>({
            ...getQueryVars(newStageId),
            data: {
              ...newData,
              leads: updatedLeads,
            },
          });

          // Update count for new stage
          const newCountData = cache.readQuery(getCountQueryVars(newStageId));
          if (newCountData) {
            cache.writeQuery({
              ...getCountQueryVars(newStageId),
              data: {
                leads_aggregate: {
                  aggregate: {
                    count:
                      (newCountData.leads_aggregate?.aggregate?.count ?? 0) + 1,
                  },
                },
              },
            });
          }

          // 3. Update the lead's own cache entry
          cache.modify({
            id: leadRef,
            fields: {
              stage_id: () => newStageId,
              stage: () => ({
                __typename: 'lead_stages',
                id: newStageId,
                label: stageInfo?.label,
                status: stageInfo?.status,
                pipeline: {
                  __typename: 'pipelines',
                  id: stagesData?.pipelines_by_pk?.id ?? '',
                  lead_type: 'sales',
                },
              }),
            },
          });
        },
      });
    };

    // Process leads sequentially
    for (const leadId of leadIds) {
      updateLeadInCache(leadId);
    }
  };
};
