ChatGLM-6B 微调之后模型 加载 并且问问题 代码
import os import platform import signal from transformers import AutoTokenizer, AutoModel # model_dir="" print(load tokenizer) model_dir=/xxx/home/work/chatglm-6b import torch tokenizer = AutoTokenizer.from_pretrained(model_dir, trust_remote_code=True) print(load tokenizer end ) from transformers import ( AutoConfig, AutoModel, AutoTokenizer, AutoTokenizer, DataCollatorForSeq2Seq, HfArgumentParser, Seq2SeqTrainingArguments, set_seed, ) PRE_SEQ_LEN=128 class model_args: # ptuning_checkpoint=None ptuning_checkpoint=/xxx/trainOut/job_and_school-chatglm-6b-pt-128-2e-2/checkpoint-3000 model_name_or_path=model_dir pre_seq_len=PRE_SEQ_LEN prefix_projection=False # Load pretrained model and tokenizer config = AutoConfig.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) config.pre_seq_len = model_args.pre_seq_len config.prefix_projection = model_args.prefix_projection # Traceback (most recent call last): # File "/xxx/home/work/chat-glm-6-b-2/cli_demo.py", line 34, in <module> # model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) # File "/xxx/condaEnvs/mossChat/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1614, in __getattr__ # raise AttributeError("{} object has no attribute {}".format( # AttributeError: ChatGLMModel object has no attribute prefix_encoder if model_args.ptuning_checkpoint is not None: # Evaluation # Loading extra state dict of prefix encoder # config=config, # model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_args.model_name_or_path,config=config, trust_remote_code=True).half().cuda() prefix_state_dict = torch.load(os.path.join(model_args.ptuning_checkpoint, "pytorch_model.bin")) new_prefix_state_dict = { } for k, v in prefix_state_dict.items(): new_prefix_state_dict[k[len("transformer.prefix_encoder."):]] = v model.transformer.prefix_encoder.load_state_dict(new_prefix_state_dict) else: # config=config, model = AutoModel.from_pretrained(model_args.model_name_or_path, trust_remote_code=True).half().cuda() # model = AutoModel.from_pretrained(model_dir, trust_remote_code=True).half().cuda() print(load model end ) print(model eval ) model = model.eval() print(model eval end ) os_name = platform.system() clear_command = cls if os_name == Windows else clear stop_stream = False def build_prompt(history): prompt = "欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序" for query, response in history: prompt += f" 用户:{ query}" prompt += f" ChatGLM-6B:{ response}" return prompt def signal_handler(signal, frame): global stop_stream stop_stream = True def main(): history = [] global stop_stream print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") while True: query = input(" 用户:") if query.strip() == "stop": break if query.strip() == "clear": history = [] os.system(clear_command) print("欢迎使用 ChatGLM-6B 模型,输入内容即可进行对话,clear 清空对话历史,stop 终止程序") continue count = 0 for response, history in model.stream_chat(tokenizer, query, history=history): if stop_stream: stop_stream = False break else: count += 1 if count % 8 == 0: os.system(clear_command) print(build_prompt(history), flush=True) signal.signal(signal.SIGINT, signal_handler) os.system(clear_command) print(build_prompt(history), flush=True) # deepspeed if __name__ == "__main__": main()
上一篇:
JS实现多线程数据分片下载
下一篇:
Miniconda的下载安装和配置详解