nuojohnchen commited on
Commit
a4ed224
·
verified ·
1 Parent(s): 9e9a188

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -94,19 +94,30 @@ def load_model(model_path, progress=gr.Progress()):
94
  progress(0.3, desc="Loading tokenizer...")
95
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
96
  if 'MoE' in model_path:
97
- config_moe = config
98
- config_moe.auto_map["AutoConfig"] = "./configuration_upcycling_qwen2_moe.UpcyclingQwen2MoeConfig"
99
- config_moe.auto_map["AutoModelForCausalLM"] = "./modeling_upcycling_qwen2_moe.UpcyclingQwen2MoeForCausalLM"
 
100
  current_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False,trust_remote_code=True)
101
 
102
  progress(0.5, desc="Loading model...")
103
- current_model = AutoModelForCausalLM.from_pretrained(
104
- model_path,
105
- device_map="auto",
106
- torch_dtype=torch.float16,
107
- config=config_moe if 'MoE' in model_path else config,
108
- trust_remote_code=True
109
- )
 
 
 
 
 
 
 
 
 
 
110
 
111
  current_model_path = model_path
112
  progress(1.0, desc="Model loading complete!")
 
94
  progress(0.3, desc="Loading tokenizer...")
95
  config = AutoConfig.from_pretrained(model_path, trust_remote_code=True)
96
  if 'MoE' in model_path:
97
+ from configuration_upcycling_qwen2_moe import UpcyclingQwen2MoeConfig
98
+ config = UpcyclingQwen2MoeConfig.from_pretrained(model_path, trust_remote_code=True)
99
+ # config_moe.auto_map["AutoConfig"] = "./configuration_upcycling_qwen2_moe.UpcyclingQwen2MoeConfig"
100
+ # config_moe.auto_map["AutoModelForCausalLM"] = "./modeling_upcycling_qwen2_moe.UpcyclingQwen2MoeForCausalLM"
101
  current_tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False,trust_remote_code=True)
102
 
103
  progress(0.5, desc="Loading model...")
104
+ if 'MoE' in model_path:
105
+ from modeling_upcycling_qwen2_moe import UpcyclingQwen2MoeForCausalLM
106
+ current_model = UpcyclingQwen2MoeForCausalLM.from_pretrained(
107
+ model_path,
108
+ device_map="auto",
109
+ torch_dtype=torch.float16,
110
+ config=config,
111
+ trust_remote_code=True
112
+ )
113
+ else:
114
+ current_model = AutoModelForCausalLM.from_pretrained(
115
+ model_path,
116
+ device_map="auto",
117
+ torch_dtype=torch.float16,
118
+ config=config,
119
+ trust_remote_code=True
120
+ )
121
 
122
  current_model_path = model_path
123
  progress(1.0, desc="Model loading complete!")