from pathlib import Path
from tqdm.auto import tqdm
import pandas as pd
from bs4 import BeautifulSoup
import re
import string
from transformers import AutoTokenizer, AutoModel
from transformers.adapters import AutoAdapterModel
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
from rake_nltk import Rake
import spacy
from random import shuffle
import uuid
import shutil
import numpy as np

nlp = spacy.load("en_core_web_sm")
r = Rake(min_length=2, max_length=4, include_repeated_phrases=False)

tokenizer = AutoTokenizer.from_pretrained('allenai/specter2_base')

#load base model
model = AutoAdapterModel.from_pretrained('allenai/specter2_base')

#load the adapter(s) as per the required task, provide an identifier for the adapter in load_as argument and activate it
model.load_adapter("allenai/specter2", source="hf", load_as="specter2", set_active=True)


topics = []
sentences = {}

input_data = {}

#Path to files processed using GROBID
input_files = Path("/home/jovyan/projects/publication_similarity/grobid_outputs").glob("**/*.tei")
input_files = list(input_files)

test_set = []
citations = []
input_files_filtered = []
titles = []
file_names = []

#change to maximum number of publications
max_files = 15000

output_path = Path("output")

if output_path.exists():
    shutil.rmtree(output_path)

output_path.mkdir()
'''
for input_file in tqdm(input_files):
    file_path = input_file
    
    with open(file_path, encoding="UTF-8") as f:
        soup = BeautifulSoup(f, features="xml")
    
    title = soup.find("fileDesc").find("title").text.lower()
    
    if not title:
        continue
    if title in titles:
        continue
    else:
        titles.append(title)
    abstract = soup.find("abstract").text.lower()

    child_text = None
    body = soup.find("body")
    for div in body.find_all("div"):
        get_text = True
        for child in div.children:
            if get_text:
                child_text = child.text
                child_text = re.sub(r"\[[^)]*\]", " ", child.text)
                child_text = child_text.translate(str.maketrans("", "", string.punctuation))
                child_text = re.sub(r" +", " ", child_text)
                child_text = child_text.replace('"', "'")

    if not abstract or not child_text:
        continue
    else:
        input_files_filtered.append(input_file)

    if len(input_files_filtered) == max_files:
        break

for input_file in tqdm(input_files_filtered):
    file_path = input_file
    
    with open(file_path, encoding="UTF-8") as f:
        soup = BeautifulSoup(f, features="xml")
    
    title = soup.find("fileDesc").find("title").text.lower()
    
    abstract = soup.find("abstract").text.lower()
    doc = nlp(abstract)
    keywords = set()
    for token in doc:
        if token.pos_ == "NOUN" or token.pos_ == "PROPN":
            lemma = token.lemma_.lower()
            if len(lemma) > 1:
                keywords.add(lemma)
    keyphrases = r.extract_keywords_from_text(abstract)
    keyphrases = r.get_ranked_phrases()
    chapters = []
    body = soup.find("body")
    child_text = None
    for div in body.find_all("div"):
        get_text = True
        for child in div.children:
            if child.name == "head":
                chapter = child.text.lower()
                chapter_words = chapter.split()
                if len(chapter_words[0]) < 3 or "." in chapter_words[0]:
                    chapter = " ".join(chapter_words[1:])
                chapters.append(chapter)
            if get_text:
                child_text = child.text
                child_text = re.sub(r"\[[^)]*\]", " ", child.text)
                child_text = child_text.translate(str.maketrans("", "", string.punctuation))
                child_text = re.sub(r" +", " ", child_text)
                child_text = child_text.replace('"', "'")
    references = []
    authors = []
    ref_sections = soup.find_all("biblStruct")
    for ref in ref_sections:
        analytic = ref.find("analytic")
        if analytic:
            ref_title = analytic.find("title")
            if ref_title:
                ref_title = ref_title.text.lower()
                references.append(ref_title)
    
    references = list(set(references))
    references = [x for x in references if len(x)>0]
    if child_text:
        doc = nlp(child_text)
        keywords_fulltext = set()
        for token in doc:
            if token.pos_ == "NOUN" or token.pos_ == "PROPN":
                lemma = token.lemma_.lower()
                if len(lemma) > 1:
                    keywords_fulltext.add(lemma)
        keyphrases_fulltext = r.extract_keywords_from_text(child_text)
        keyphrases_fulltext = r.get_ranked_phrases()
    else:
        keywords_fulltext = list()
        keyphrases_fulltext = list()
        
    
    for author_element in soup.find("fileDesc").find_all("author"):
            try:
                authors.append(author_element.find("persName").find("forename").text + " " + author_element.find("persName").find("surname").text)
            except:
                pass
    
    abstract_cleaned = abstract.strip()
    abstract_cleaned = re.sub(r"\[[^)]*\]", " ", abstract_cleaned)
    abstract_cleaned = abstract_cleaned.translate(str.maketrans("", "", string.punctuation))
    abstract_cleaned = re.sub(r" +", " ", abstract_cleaned)
    abstract_cleaned = abstract_cleaned.replace('"', "'")

    text_batch = title + tokenizer.sep_token + abstract
    inputs = tokenizer(text_batch, padding=True, truncation=True,
                                   return_tensors="pt", return_token_type_ids=False, max_length=512)
    output = model(**inputs)
    embedding = output.last_hidden_state[:, 0, :].detach().numpy()

    external_ids = {}
    external_id_elements = soup.find("fileDesc").find_all("idno")

    for external_id_element in external_id_elements:
        external_ids[external_id_element.get("type")] = external_id_element.text
    
    identifier = "_".join(title.lower().split())
    if not keyphrases:
        keyphrases = []
    if not keyphrases_fulltext:
        keyphrases_fulltext = []
    test_entry = {"file_path" : file_path,
                  "abstract" : abstract.strip(),
                  "identifier" : str(identifier),
                  "full_text" : child_text,
                  "keyphrases_fulltext" : keyphrases_fulltext,
                  "keywords_fulltext" : keywords_fulltext,
                  "title" : title,
                  "embedding" : embedding,
                  "chapters" : chapters,
                  "references" : references,
                  "authors" : authors,
                  "external_ids" : external_ids,
                  "keywords" : keywords,
                  "keyphrases" : keyphrases
                  }
    
    test_set.append(test_entry)

df = pd.DataFrame(test_set)

with open("publication_data.tsv", "w", encoding="UTF-8") as f:
    df.to_csv(f,index=False, sep="\t")
matches = {}
matches_list = []

'''

with open("publication_data.tsv", encoding="UTF-8") as f:
    df = pd.read_csv(f, sep="\t", index_col=None)

test_set = df.to_dict('records')[:100]


def str_to_np(input_str):
    formatted_str = input_str.replace("\n" , "")
    formatted_str = re.sub(r"[\[\]]", "", formatted_str)
    formatted_str = formatted_str.split()
    array_values = [x for x in formatted_str if x]
    np_array = np.array([array_values])
    return np_array

def match(test_tuple):
    test_entry = test_tuple[0]
    test_set = test_tuple[1]
    matches_list = []
    matches = {}
    parse_embedding = False
    test_entry_embedding = test_entry["embedding"]
    if isinstance(test_entry_embedding, str):
        test_entry_embedding = str_to_np(test_entry_embedding)
        parse_embedding = True
    
    id = test_entry["identifier"]
    for counter_entry in test_set:
        if id == counter_entry["identifier"]:
            continue
        comparison_id = str(id) + "_" + str(counter_entry["identifier"])
        if str(counter_entry["identifier"]) + "_" + str(id) in matches.keys():
            continue
        matches_entry = {}
        matches_entry["document_1"] = test_entry["identifier"]
        matches_entry["document_2"] = counter_entry["identifier"]
        matches_entry["document_1_title"] = test_entry["title"]
        matches_entry["document_1_file_path"] = test_entry["file_path"]
        matches_entry["document_1_abstract"] = test_entry["abstract"]
        matches_entry["document_2_title"] = counter_entry["title"]
        matches_entry["document_2_file_path"] = counter_entry["file_path"]
        matches_entry["document_2_abstract"] = counter_entry["abstract"]
        matches_entry["chapters_matches"] = [x for x in test_entry["chapters"] if x in counter_entry["chapters"]]
        matches_entry["references_matches"] = [x for x in test_entry["references"] if x in counter_entry["references"] and len(x)>0]
        matches_entry["num_references_matches"] = len(matches_entry["references_matches"])
        matches_entry["authors_matches"] = [x for x in test_entry["authors"] if x in counter_entry["authors"]]
        matches_entry["num_authors_matches"] = len(matches_entry["authors_matches"])
        if matches_entry["document_1_title"] == matches_entry["document_2_title"]:
            continue

        counter_entry_embedding = counter_entry["embedding"]
        if parse_embedding:
            counter_entry_embedding = str_to_np(counter_entry_embedding)
            
        embedding_similarity = cosine_similarity(test_entry_embedding, )
        matches_entry["cosine_similarity"] = embedding_similarity[0][0]

        keyword_matches = [k for k in test_entry["keywords"] if k in counter_entry["keywords"]]
        num_keyword_matches = len(keyword_matches)
        matches_entry["keyword_matches"] = keyword_matches
        matches_entry["num_keyword_matches"] = num_keyword_matches
        keyphrase_matches = [k for k in test_entry["keyphrases"] if k in counter_entry["keyphrases"]]
        num_keyphrase_matches = len(keyphrase_matches)
        matches_entry["keyphrase_matches"] = keyphrase_matches
        matches_entry["num_keyphrase_matches"] = num_keyphrase_matches

        keyword_matches_fulltext = [k for k in test_entry["keywords_fulltext"] if k in counter_entry["keywords_fulltext"]]
        num_keyword_matches_fulltext = len(keyword_matches_fulltext)
        matches_entry["keyword_matches_fulltext"] = keyword_matches_fulltext
        matches_entry["num_keyword_matches_fulltext"] = num_keyword_matches_fulltext
        keyphrase_matches_fulltext = [k for k in test_entry["keyphrases_fulltext"] if k in counter_entry["keyphrases_fulltext"]]
        num_keyphrase_matches_fulltext = len(keyphrase_matches_fulltext)
        matches_entry["keyphrase_matches_fulltext"] = keyphrase_matches_fulltext
        matches_entry["num_keyphrase_matches_fulltext"] = num_keyphrase_matches_fulltext
        
        matches[comparison_id] = matches_entry
        matches_list.append(matches_entry)
    df = pd.DataFrame.from_records(matches_list)
    while True:
        out_name = str(uuid.uuid4())
        out = Path(f"{str(output_path)}/publication_matches_{out_name}.tsv")
        if not out.exists():
            with open(out, "w", encoding="UTF-8") as f:
                df.to_csv(f,index=False, sep="\t")
            break

from tqdm.contrib.concurrent import process_map

subsets = []

for x in range(len(test_set)):
    if x == len(test_set) - 1:
        break
    subsets.append((test_set[x], test_set[x+1:]))

process_map(match, subsets, max_workers=35)
