import torch
from transformers import AutoTokenizer, AutoModel

# 加载 CoSent 模型和分词器
model_name = "webis/CoSent-GPT2"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name)


def get_sentence_embedding(sentence):
    inputs = tokenizer(sentence, return_tensors="pt")
    with torch.no_grad():
        outputs = model(**inputs)
    return outputs.last_hidden_state[:, 0, :]


def calculate_cosine_similarity(embedding1, embedding2):
    return torch.nn.functional.cosine_similarity(embedding1, embedding2).item()


# 例句
sentence1 = "苹果是一种水果。"
sentence2 = "苹果是一家科技公司。"
sentence3 = "橙子是一种水果。"

# 计算句子嵌入
embedding1 = get_sentence_embedding(sentence1)
embedding2 = get_sentence_embedding(sentence2)
embedding3 = get_sentence_embedding(sentence3)

# 计算余弦相似度
similarity1_2 = calculate_cosine_similarity(embedding1, embedding2)
similarity1_3 = calculate_cosine_similarity(embedding1, embedding3)
similarity2_3 = calculate_cosine_similarity(embedding2, embedding3)

print(f"句子1与句子2的相似度: {similarity1_2:.4f}")
print(f"句子1与句子3的相似度: {similarity1_3:.4f}")
print(f"句子2与句子3的相似度: {similarity2_3:.4f}")