找回密码
 会员注册
查看: 29|回复: 0

机器学习Google开源大模型Gemma2:原理、微调训练及推理部署实战

[复制链接]

4

主题

0

回帖

13

积分

新手上路

积分
13
发表于 2024-9-12 15:12:59 | 显示全部楼层 |阅读模式
目录一、引言二、模型简介2.1 Gemma2概述2.2Gemma2 模型架构三、训练与推理3.1Gemma2 模型训练3.1.1下载基座模型3.1.2 导入依赖库3.1.3量化配置3.1.4分词器和模型实例化3.1.5引入PEFT进行LORA配置 3.1.6样本数据清洗与加载3.1.7模型训练与保存3.1.8完整训练代码 3.1.9启动训练以及收敛过程 3.1.10训练显存占用  3.2Gemma2基座与微调模型合并推理3.2.1导入库3.2.2导入基座模型3.2.3合并基座模型与微调模型3.2.4基于对话模版进行对话生成 3.2.5推理显存占用3.2.6推理效果3.2.7微调与推理完整代码四、总结一、引言Gemma是Google推出的轻量级、先进的开放模型系列,采用与Gemini模型相同的研究成果和技术构建而成。它们是仅使用解码器的文本到文本大型语言模型(提供英语版本),为预训练变体和指令调整变体具有开放权重。Gemma模型非常适合各种文本生成任务,包括问题解答、摘要和推理。由于它们相对较小,因此可以将其部署在资源有限的环境(如笔记本电脑、桌面设备或您自己的云基础架构)中,让更多人能够使用先进的AI模型,并帮助促进每个人的创新。二、模型简介2.1 Gemma2概述Gemma2与他的上一代Gemma以及Qwen2等均采用decoder-only网络结构,主要参数情况如下:与Gemma相同点: 上下文长度为8192个token使用旋转位置嵌入(RoPE)近似GeGLU非线性与Gemma不同点:局部滑动窗口和全局注意力。研究团队在每隔一层中交替使用局部滑动窗口注意力和全局注意力。局部注意力层的滑动窗口大小设置为4096个token,而全局注意力层的跨度设置为8192个token。Logit软封顶。根据Gemini1.5的方法,研究团队在每个注意力层和最终层限制logit,使得logit的值保持在−soft_cap和+soft_cap之间。对于9B和27B模型,研究团队将注意力对数封顶设置为50.0,最终对数封顶设置为30.0。截至本文发表时,注意力logit软封顶与常见的FlashAttention实现不兼容,因此他们已从使用FlashAttention的库中移除了此功能。研究团队对模型生成进行了有无注意力logit软封顶的消融实验,发现大多数预训练和后期评估中,生成质量几乎不受影响。本文中的所有评估均使用包含注意力logit软封顶的完整模型架构。然而,某些下游性能可能仍会受到此移除的轻微影响。使用RMSNorm进行post-norm和pre-norm。为了稳定训练,研究团队使用RMSNorm对每个变换子层、注意力层和前馈层的输入和输出进行归一化。分组查询注意力。27B和9B模型均使用GQA,num_groups=2,基于消融实验表明在保持下游性能的同时提高了推理速度。  分组查询注意力(GroupedQueryAttention)是一种在大型语言模型中的多查询注意力(MQA)和多头注意力(MHA)之间进行插值的方法,它的目标是在保持MQA速度的同时实现MHA的质量  效果对比:Gemma29B模型在多个维度超过近尺寸的Llama38B,27B尺寸模型在多个评价标准下超过314B的Grok-1:2.2Gemma2 模型架构通过AutoModelForCausalLM模型头查看模型结构:Gemma2ForCausalLM((model):Gemma2Model((embed_tokens):Embedding(256000,4608,padding_idx=0)(layers):ModuleList((0-45):46xGemma2DecoderLayer((self_attn):Gemma2SdpaAttention((q_proj)inear(in_features=4608,out_features=4096,bias=False)(k_proj)inear(in_features=4608,out_features=2048,bias=False)(v_proj)inear(in_features=4608,out_features=2048,bias=False)(o_proj)inear(in_features=4096,out_features=4608,bias=False)(rotary_emb):Gemma2RotaryEmbedding())(mlp):Gemma2MLP((gate_proj)inear(in_features=4608,out_features=36864,bias=False)(up_proj)inear(in_features=4608,out_features=36864,bias=False)(down_proj)inear(in_features=36864,out_features=4608,bias=False)(act_fn)ytorchGELUTanh())(input_layernorm):Gemma2RMSNorm()(post_attention_layernorm):Gemma2RMSNorm()(pre_feedforward_layernorm):Gemma2RMSNorm()(post_feedforward_layernorm):Gemma2RMSNorm()))(norm):Gemma2RMSNorm())(lm_head)inear(in_features=4608,out_features=256000,bias=False))46层Gemma2DecoderLayer,每层包含1个自注意力层Gemma2SdpaAttention、1个mlp层Gemma2MLP使用RMSNorm进行post-norm和pre-norm。为了稳定训练,研究团队使用RMSNorm对每个变换子层、注意力层和前馈层的输入和输出进行归一化三、训练与推理3.1Gemma2 模型训练在之前的文章中,我介绍过采用LlamaFactory的webui以及命令行进行模型训练,今天基于transformers库原生微调Gemma2。3.1.1下载基座模型我们仍然秉承一贯的作风,为网络不稳定的同学提供了modelscope下载方案:frommodelscopeimportsnapshot_downloadmodel_dir=snapshot_download('LLM-Research/gemma-2-27b-it')3.1.2 导入依赖库importtorchimporttransformersfromtransformersimportAutoTokenizer,AutoModelForCausalLM,BitsAndBytesConfig3.1.3量化配置quantization_config=BitsAndBytesConfig(load_in_4bit=True,#或者load_in_8bit=True,根据需要设置llm_int8_enable_fp32_cpu_offload=True,bnb_4bit_compute_dtype=torch.bfloat16,#虽然我们以4位加载和存储模型,但我们在需要时会部分反量化他,并以16位精度进行计算bnb_4bit_quant_type="nf4",#nf量化类型bnb_4bit_use_double_quant=True,#双重量化,量化一次后再量化,进一步解决显存)3.1.4分词器和模型实例化tokenizer=AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)model=AutoModelForCausalLM.from_pretrained(model_dir,trust_remote_code=True,device_map=device,torch_dtype=torch.bfloat16,quantization_config=quantization_config,attn_implementation='eager')model.gradient_checkpointing_enable3.1.5引入PEFT进行LORA配置frompeftimportLoraConfig,get_peft_model,prepare_model_for_kbit_trainingmodel=prepare_model_for_kbit_training(model)config=LoraConfig(r=32,lora_alpha=16,target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",)model=get_peft_model(model,config) 3.1.6样本数据清洗与加载fromdatasetsimportload_dataset,load_from_diskdata=load_dataset('json',data_files="./quotes.jsonl")data=data.map(lambdasamples:tokenizer(samples["quote"]),batched=True)print(data)3.1.7模型训练与保存trainer=transformers.Trainer(model=model,train_dataset=data["train"],args=transformers.TrainingArguments(per_device_train_batch_size=1,gradient_accumulation_steps=4,warmup_steps=10,max_steps=50,learning_rate=3e-4,fp16=True,logging_steps=1,output_dir="outputs/checkpoint-1"+time_str,optim="paged_adamw_8bit",save_strategy='steps',save_steps=10,),data_collator=transformers.DataCollatorForLanguageModeling(tokenizer,mlm=False),)model.config.use_cache=False#silencethewarnings.Pleasere-enableforinference!trainer.train()trainer.save_model(trainer.args.output_dir)注意:per_device_train_batch_size=1:开始设置为4会出现'grad_norm':nan,'learning_rate':0的情况。3.1.8完整训练代码 fromdatetimeimportdatetimenow=datetime.now()time_str=now.strftime('%Y-%m-%d%H:%M:%S')print(time_str)#0,downloadmodelfrommodelscopeimportsnapshot_downloadmodel_dir=snapshot_download('LLM-Research/gemma-2-27b-it')#model_dir=snapshot_download('qwen/Qwen2-7B-Instruct')importtorchimporttransformersfromtransformersimportAutoTokenizer,AutoModelForCausalLM,BitsAndBytesConfigdevice="auto"quantization_config=BitsAndBytesConfig(load_in_4bit=True,#或者load_in_8bit=True,根据需要设置llm_int8_enable_fp32_cpu_offload=True,bnb_4bit_compute_dtype=torch.bfloat16,#虽然我们以4位加载和存储模型,但我们在需要时会部分反量化他,并以16位精度进行计算bnb_4bit_quant_type="nf4",#nf量化类型bnb_4bit_use_double_quant=True,#双重量化,量化一次后再量化,进一步解决显存)tokenizer=AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)model=AutoModelForCausalLM.from_pretrained(model_dir,trust_remote_code=True,device_map=device,torch_dtype=torch.bfloat16,quantization_config=quantization_config,attn_implementation='eager')model.gradient_checkpointing_enablefrompeftimportLoraConfig,get_peft_model,prepare_model_for_kbit_trainingmodel=prepare_model_for_kbit_training(model)config=LoraConfig(r=32,lora_alpha=16,target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",)model=get_peft_model(model,config)fromdatasetsimportload_dataset,load_from_diskdata=load_dataset('json',data_files="./quotes.jsonl")data=data.map(lambdasamples:tokenizer(samples["quote"]),batched=True)print(data)trainer=transformers.Trainer(model=model,train_dataset=data["train"],args=transformers.TrainingArguments(per_device_train_batch_size=1,gradient_accumulation_steps=4,warmup_steps=10,max_steps=50,learning_rate=3e-4,fp16=True,logging_steps=1,output_dir="outputs/checkpoint-1"+time_str,optim="paged_adamw_8bit",save_strategy='steps',save_steps=10,),data_collator=transformers.DataCollatorForLanguageModeling(tokenizer,mlm=False),)model.config.use_cache=False#silencethewarnings.Pleasere-enableforinference!trainer.train()trainer.save_model(trainer.args.output_dir)3.1.9启动训练以及收敛过程 采用CUDA_VISIBLE_DEVICES=1,2,3 pythongemma2_train.py启动 3.1.10训练显存占用  3张显卡启动:针对27B尺寸模型进行int4位微调,占用显存约28.9G。如果bf16微调,大约需要54G。相比于LLama3、Qwen2等72B尺寸模型的优势就是仅消耗单卡A100即可bf16微调训练。3.2Gemma2基座与微调模型合并推理3.2.1导入库这里比较重要的是peft中的PeftModel和PeftConfig,PeftModel用于合并基座与微调模型,PeftConfig用于提取Peft微调模型的配置文件importtorchfrompeftimportPeftModel,PeftConfigfromtransformersimportAutoModelForCausalLM,AutoTokenizer3.2.2导入基座模型peft_model_dir=trainer.args.output_dirconfig=PeftConfig.from_pretrained(peft_model_dir)print(config)model=AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,return_dict=True,device_map=device,torch_dtype=torch.float16,quantization_config=quantization_config)tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)3.2.3合并基座模型与微调模型model=PeftModel.from_pretrained(model,peft_model_dir)3.2.4基于对话模版进行对话生成chat=[{"role":"user","content":"详细介绍一下大语言模型,评价下与深度学习的差异"},]prompt=tokenizer.apply_chat_template(chat,tokenize=True,add_generation_prompt=True,return_tensors="pt").to(model.device)outputs=model.generate(prompt,max_length=2500)outputs=[output_ids[len(input_ids):]forinput_ids,output_idsinzip(prompt,outputs)]print(tokenizer.batch_decode(outputs,skip_special_tokens=True)[0]) 3.2.5推理显存占用基座模型和微调模型合并后,大约需要40G??3.2.6推理效果3.2.7微调与推理完整代码fromdatetimeimportdatetimenow=datetime.now()time_str=now.strftime('%Y-%m-%d%H:%M:%S')print(time_str)#0,downloadmodelfrommodelscopeimportsnapshot_downloadmodel_dir=snapshot_download('LLM-Research/gemma-2-27b-it')#model_dir=snapshot_download('qwen/Qwen2-7B-Instruct')importtorchimporttransformersfromtransformersimportAutoTokenizer,AutoModelForCausalLM,BitsAndBytesConfigdevice="auto"quantization_config=BitsAndBytesConfig(load_in_4bit=True,#或者load_in_8bit=True,根据需要设置llm_int8_enable_fp32_cpu_offload=True,bnb_4bit_compute_dtype=torch.bfloat16,#虽然我们以4位加载和存储模型,但我们在需要时会部分反量化他,并以16位精度进行计算bnb_4bit_quant_type="nf4",#nf量化类型bnb_4bit_use_double_quant=True,#双重量化,量化一次后再量化,进一步解决显存)tokenizer=AutoTokenizer.from_pretrained(model_dir,trust_remote_code=True)model=AutoModelForCausalLM.from_pretrained(model_dir,trust_remote_code=True,device_map=device,torch_dtype=torch.bfloat16,quantization_config=quantization_config,attn_implementation='eager')model.gradient_checkpointing_enablefrompeftimportLoraConfig,get_peft_model,prepare_model_for_kbit_trainingmodel=prepare_model_for_kbit_training(model)config=LoraConfig(r=32,lora_alpha=16,target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",)model=get_peft_model(model,config)fromdatasetsimportload_dataset,load_from_diskdata=load_dataset('json',data_files="./quotes.jsonl")data=data.map(lambdasamples:tokenizer(samples["quote"]),batched=True)print(data)trainer=transformers.Trainer(model=model,train_dataset=data["train"],args=transformers.TrainingArguments(per_device_train_batch_size=1,gradient_accumulation_steps=4,warmup_steps=10,max_steps=50,learning_rate=3e-4,fp16=True,logging_steps=1,output_dir="outputs/checkpoint-1"+time_str,optim="paged_adamw_8bit",save_strategy='steps',save_steps=10,),data_collator=transformers.DataCollatorForLanguageModeling(tokenizer,mlm=False),)model.config.use_cache=False#silencethewarnings.Pleasere-enableforinference!#trainer.train()trainer.save_model(trainer.args.output_dir)#mergemodelandinferenceimporttorchfrompeftimportPeftModel,PeftConfigfromtransformersimportAutoModelForCausalLM,AutoTokenizer#peft_model_dir=trainer.args.output_dirpeft_model_dir="/aigc_dev/gemma2/outputs/checkpoint-12024-07-0421:57:45"config=PeftConfig.from_pretrained(peft_model_dir)print(config)model=AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path,return_dict=True,device_map=device,torch_dtype=torch.bfloat16,quantization_config=quantization_config)tokenizer=AutoTokenizer.from_pretrained(config.base_model_name_or_path)#LoadtheLoramodelmodel=PeftModel.from_pretrained(model,peft_model_dir)chat=[{"role":"user","content":"详细介绍一下大语言模型,评价下与深度学习的差异"},]prompt=tokenizer.apply_chat_template(chat,tokenize=True,add_generation_prompt=True,return_tensors="pt").to(model.device)outputs=model.generate(prompt,max_length=2500)outputs=[output_ids[len(input_ids):]forinput_ids,output_idsinzip(prompt,outputs)]print(tokenizer.batch_decode(outputs,skip_special_tokens=True)[0])四、总结在模型结构上,Gemma2与Qwen2非常相似,除了decoder-only、RoPE、分组查询注意力机制等技术相同,线性层(Lora的目标层)均为["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"]中文对话效果上经过多个样例测试个人感觉不如国产的Qwen2、GLM4、DeepSeek等。GOOGLE作为互联网技术老大哥,在大模型的角逐中,并没有那么强势。可叹啊!感谢您的阅读,如果喜欢的话,期待您的三连+投票。如果您还有时间,可以看看我的其他文章:《AI—工程篇》AI智能体研发之路-工程篇(一):Docker助力AI智能体开发提效AI智能体研发之路-工程篇(二):Dify智能体开发平台一键部署AI智能体研发之路-工程篇(三):大模型推理服务框架Ollama一键部署AI智能体研发之路-工程篇(四):大模型推理服务框架Xinference一键部署AI智能体研发之路-工程篇(五):大模型推理服务框架LocalAI一键部署《AI—模型篇》AI智能体研发之路-模型篇(一):大模型训练框架LLaMA-Factory在国内网络环境下的安装、部署及使用AI智能体研发之路-模型篇(二):DeepSeek-V2-Chat训练与推理实战AI智能体研发之路-模型篇(三):中文大模型开、闭源之争AI智能体研发之路-模型篇(四):一文入门pytorch开发AI智能体研发之路-模型篇(五):pytorchvstensorflow框架DNN网络结构源码级对比AI智能体研发之路-模型篇(六):【机器学习】基于tensorflow实现你的第一个DNN网络AI智能体研发之路-模型篇(七):【机器学习】基于YOLOv10实现你的第一个视觉AI大模型AI智能体研发之路-模型篇(八):【机器学习】Qwen1.5-14B-Chat大模型训练与推理实战AI智能体研发之路-模型篇(九):【机器学习】GLM4-9B-Chat大模型/GLM-4V-9B多模态大模型概述、原理及推理实战《AI—Transformers应用》【AI大模型】Transformers大模型库(一):Tokenizer【AI大模型】Transformers大模型库(二):AutoModelForCausalLM【AI大模型】Transformers大模型库(三):特殊标记(specialtokens)【AI大模型】Transformers大模型库(四):AutoTokenizer
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 会员注册

本版积分规则

QQ|手机版|心飞设计-版权所有:微度网络信息技术服务中心 ( 鲁ICP备17032091号-12 )|网站地图

GMT+8, 2024-12-27 15:51 , Processed in 0.730410 second(s), 25 queries .

Powered by Discuz! X3.5

© 2001-2024 Discuz! Team.

快速回复 返回顶部 返回列表