import { type AppContext } from "~/lib/context";
import { MaxOrderedList } from "~/lib/dataStructures";
import { BusinessLogicError } from "~/lib/errors/businessLogicError";
import { logInfo } from "~/lib/logger";
import { createEmbeddings } from "~/lib/openai";
import { assertNotNil } from "~/lib/utils";
import { CURRENT_EMBEDDING_MODEL } from "~/services/job/ai/model";
import { sanitizeJobNameForEmbedding } from "~/services/job/ai/sanitizeJobNameForEmbedding";

export type SimilarityResult = {
  similarity: number;
};

type Vector = number[];

export type ClassificationResult = SimilarityResult & {
  jobAlias: {
    id: number;
    name: string;
    isRootAlias: boolean;
  };
  job: {
    id: number;
    name: string;
  };
};

export const classifyJobTitlesWithEmbeddings = async (
  ctx: AppContext,
  params: { jobTitles: string[]; kNearest: number }
) => {
  logInfo(ctx, `[classification-embedding] Trying to classify '${params.jobTitles.length}' job titles`);

  const sanitizedJobTitles = params.jobTitles.map((jobTitle) => sanitizeJobNameForEmbedding(jobTitle));

  const embeddingResults = await createEmbeddings(ctx, {
    texts: sanitizedJobTitles.map((jobTitle) => `Job title: ${jobTitle}.`),
  });

  if (embeddingResults.length === 0) {
    return [] as ClassificationResult[][];
  }

  if (embeddingResults.length !== params.jobTitles.length) {
    throw new BusinessLogicError("Issue with embeddings, not all job titles were embedded");
  }

  const jobEmbeddings = await ctx.prisma.jobAliasEmbedding.findMany({
    select: {
      embedding: true,
      alias: {
        select: {
          id: true,
          name: true,
          isRootAlias: true,
          job: {
            select: {
              id: true,
              name: true,
            },
          },
        },
      },
    },
    where: {
      embeddingModel: CURRENT_EMBEDDING_MODEL,
    },
  });

  return embeddingResults.map((embedding) => {
    const results = new MaxOrderedList<ClassificationResult>(
      params.kNearest,
      (a: SimilarityResult, b: SimilarityResult) => b.similarity - a.similarity
    );

    for (const jobEmbedding of jobEmbeddings) {
      const similarity = computeCosineSimilarity(embedding, jobEmbedding.embedding);

      const previousResults = results.values();
      const previousResultIndexForJob = previousResults.findIndex(
        (result) => result.job.id === jobEmbedding.alias.job.id
      );

      if (previousResultIndexForJob !== -1) {
        const previousResult = assertNotNil(previousResults[previousResultIndexForJob]);
        if (similarity > previousResult.similarity) {
          results.removeAt(previousResultIndexForJob);
        } else {
          continue;
        }
      }

      results.push({
        similarity,
        jobAlias: {
          id: jobEmbedding.alias.id,
          name: jobEmbedding.alias.name,
          isRootAlias: jobEmbedding.alias.isRootAlias,
        },
        job: {
          id: jobEmbedding.alias.job.id,
          name: jobEmbedding.alias.job.name,
        },
      });
    }

    return results.values();
  });
};

const computeCosineSimilarity = (embedding1: Vector, embedding2: Vector) => {
  if (embedding1.length !== embedding2.length) {
    throw new BusinessLogicError(
      `Embeddings must have the same size. Got ${embedding1.length} and ${embedding2.length}`
    );
  }

  const dotProduct = embedding1
    .map((val, i) => val * (embedding2[i] as number))
    .reduce((accum, curr) => accum + curr, 0);
  const embedding1Size = calcVectorSize(embedding1);
  const embedding2Size = calcVectorSize(embedding2);

  return dotProduct / (embedding1Size * embedding2Size);
};

const calcVectorSize = (embedding: Vector) => {
  return Math.sqrt(embedding.reduce((accum, curr) => accum + Math.pow(curr, 2), 0));
};
