import { type DatasetEmployee } from "@prisma/client";
import { groupBy, sumBy } from "~/lib/lodash";

export const LOST_WEIGHT_PER_MONTH = 0.15; // Only starting after one year

export const computeDatasetEmployeeWeight = (datasetEmployee: Pick<DatasetEmployee, "lastValidAt">) => {
  const now = new Date();

  if (!datasetEmployee.lastValidAt) {
    return 1;
  }

  const diffMonths =
    (now.getFullYear() - datasetEmployee.lastValidAt.getFullYear()) * 12 +
    now.getMonth() -
    datasetEmployee.lastValidAt.getMonth();
  const effectiveMonths = diffMonths - 12; // Subtract 12 months for the first year
  const weight = 1 - effectiveMonths * LOST_WEIGHT_PER_MONTH;

  // Return 1 if lastValidAt is less than a year ago
  if (diffMonths < 12) {
    return 1;
  }

  if (weight < 0) {
    return 0;
  }

  return weight;
};

type CompanyId = number | string | null;

type UnweightedCompensationData = { id: number; value: number; weight: number; companyId: CompanyId };

export const computeCompensationDataWeighted = (unweightedCompensationData: UnweightedCompensationData[]) => {
  const companyGroups = groupBy(unweightedCompensationData, "companyId");
  const totalWeight = sumBy(unweightedCompensationData, "weight");

  if (isOneCompanyAboveThreshold(companyGroups, totalWeight)) {
    const adjustedWeightsInPercentagePerCompany = computeAdjustedWeightsPerCompany(companyGroups, totalWeight);
    return computeAdjustedWeightsPerEmployee(adjustedWeightsInPercentagePerCompany, companyGroups);
  }

  return unweightedCompensationData.map((item) => ({
    value: item.value,
    weight: item.weight,
  }));
};

//exported for testing
export const isOneCompanyAboveThreshold = (
  companyGroups: Record<string, UnweightedCompensationData[]>,
  totalWeight: number
) => {
  if (Object.keys(companyGroups).length === 0) {
    return false;
  }

  const MAX_ALLOWED_COMPANY_WEIGHT = Object.keys(companyGroups).length === 3 ? 0.34 : 0.25; //25% capped above 3

  return Object.values(companyGroups).some(
    (group) => sumBy(group, "weight") / totalWeight > MAX_ALLOWED_COMPANY_WEIGHT
  );
};

export const computeAdjustedWeightsPerCompany = (
  companyGroups: Record<string, UnweightedCompensationData[]>,
  totalWeight: number
) => {
  const companyWithPercentage = Object.entries(companyGroups).map(([companyId, group]) => ({
    companyId,
    percentage: sumBy(group, "weight") / totalWeight,
  }));

  const transformedWeights = companyWithPercentage.map((weight) => ({
    companyId: weight.companyId,
    weight: Math.sqrt(weight.percentage),
  }));

  const totalTransformedWeight = sumBy(transformedWeights, "weight");

  return transformedWeights.map((item) => ({
    companyId: item.companyId,
    adjustedWeightPercentage: item.weight / totalTransformedWeight,
  }));
};

export const computeAdjustedWeightsPerEmployee = (
  adjustedWeightsInPercentagePerCompany: { companyId: string; adjustedWeightPercentage: number }[],
  companyGroups: Record<string, UnweightedCompensationData[]>
) => {
  const adjustedValuesPerEmployee: { value: number; weight: number }[] = [];

  adjustedWeightsInPercentagePerCompany.forEach(({ companyId, adjustedWeightPercentage }) => {
    const group = companyGroups[companyId];
    if (!group) return;

    const companyEmployeeWeights = group.map(({ value, weight }) => {
      return {
        value: value,
        weight: weight / adjustedWeightPercentage,
      };
    });

    adjustedValuesPerEmployee.push(...companyEmployeeWeights);
  });

  return adjustedValuesPerEmployee;
};
