import csv
import json
import os

from tensorflow.keras.preprocessing.sequence import pad_sequences

from BiLSTMCRF import BiLSTMCRF

model_path = 'model/model.h5'
vocab_path = 'model/vocab.txt'
class_dict = {
    "O": 0,
    "B-NUMBER": 1,
    "I-NUMBER": 2,
    "B-SIZE": 3,
    "I-SIZE": 4,
    "B-ENE": 5,
    "I-ENE": 6,
    "B-ANATOMY": 7,
    "I-ANATOMY": 8,
    "B-SQUAMOUS": 9,
    "I-SQUAMOUS": 10,
    "B-INVASION": 11,
    "I-INVASION": 12,
    "B-PN": 13,
    "I-PN": 14,
    "B-LEVEL": 15,
    "I-LEVEL": 16,
    "B-OTHER": 17,
    "I-OTHER": 18,
    "B-DOI": 19,
    "I-DOI": 20
}
maxLen = 500
classSum = 21


def build_input(text):
    x = []
    for char in text:
        if char not in word_dict:
            char = 'UNK'
        x.append(word_dict.get(char))
    x = pad_sequences([x], padding = 'post', maxlen = maxLen)
    return x


def load_worddict():
    vocabs = [line.strip()
              for line in open(vocab_path, encoding = 'utf-8')]
    word_dict = {wd: index for index, wd in enumerate(vocabs)}
    return word_dict


def predict(text):
    y_pre = []
    str = build_input(text)
    raw = model.predict(str)[0]
    chars = [i for i in text]
    tags = [label_dict[i] for i in raw][:len(text)]
    res = list(zip(chars, tags))
    for i, tag in enumerate(tags):
        y_pre.append(tag)
    return res, y_pre


def output(txt, cnt):
    output = []
    flag = 0
    start = []
    end = []
    tags = []
    for i, tag in enumerate(cnt):
        if tag == 'O':
            if flag == 1:
                end = i-1
                output.append([tags, txt[start:end+1], start, end])
            flag = 0
            continue
        if tag.split("-")[0] == 'B':
            if flag == 1:
                end = i
                output.append([tags, txt[start:end], start, end-1])
            flag = 1
            start = i
            tags = tag.split("-")[1]
            continue
    return output


word_dict = load_worddict()
vocabSize = len(word_dict) + 1
label_dict = {j: i for i, j in class_dict.items()}

model = BiLSTMCRF(vocabSize = vocabSize, maxLen = maxLen,
                  tagIndexDict = class_dict, tagSum = classSum)
model.load_weights(model_path)

if __name__ == '__main__':
    s = """
“右舌”鳞状细胞癌(复发),高-中分化,灶性多核巨细胞浸润,肿瘤侵犯神经。送检淋巴结:“左颌下”1只、“颏下”1只均阴性(-)
"""
    a = predict(s)
    for i in a[0]:
        print(i)
    b = output(s, a[1])
    print(b)