Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Update device setting for inference
Browse files- inference.py +11 -1
inference.py
CHANGED
|
@@ -16,6 +16,15 @@ class SentimentInference:
|
|
| 16 |
model_yaml_cfg = config_data.get('model', {})
|
| 17 |
inference_yaml_cfg = config_data.get('inference', {})
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
|
| 20 |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
|
| 21 |
local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
|
|
@@ -127,12 +136,13 @@ class SentimentInference:
|
|
| 127 |
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
|
| 128 |
raise e_fallback # Re-raise if fallback also fails catastrophically
|
| 129 |
|
|
|
|
| 130 |
self.model.eval()
|
| 131 |
|
| 132 |
def predict(self, text: str) -> Dict[str, Any]:
|
| 133 |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
|
| 134 |
with torch.no_grad():
|
| 135 |
-
outputs = self.model(input_ids=inputs['input_ids'], attention_mask=inputs['attention_mask'])
|
| 136 |
logits = outputs.get("logits") # Use .get for safety
|
| 137 |
if logits is None:
|
| 138 |
raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
|
|
|
|
| 16 |
model_yaml_cfg = config_data.get('model', {})
|
| 17 |
inference_yaml_cfg = config_data.get('inference', {})
|
| 18 |
|
| 19 |
+
# Determine device early
|
| 20 |
+
if torch.cuda.is_available():
|
| 21 |
+
self.device = torch.device("cuda")
|
| 22 |
+
elif torch.backends.mps.is_available(): # Check for MPS (Apple Silicon GPU)
|
| 23 |
+
self.device = torch.device("mps")
|
| 24 |
+
else:
|
| 25 |
+
self.device = torch.device("cpu")
|
| 26 |
+
print(f"[INFERENCE_LOG] Using device: {self.device}")
|
| 27 |
+
|
| 28 |
model_hf_repo_id = model_yaml_cfg.get('name_or_path')
|
| 29 |
tokenizer_hf_repo_id = model_yaml_cfg.get('tokenizer_name_or_path', model_hf_repo_id)
|
| 30 |
local_model_weights_path = inference_yaml_cfg.get('model_path') # Path for local .pt file
|
|
|
|
| 136 |
print(f"[INFERENCE_LOG] HUB_LOAD_FALLBACK: Critical error during fallback load: {e_fallback}")
|
| 137 |
raise e_fallback # Re-raise if fallback also fails catastrophically
|
| 138 |
|
| 139 |
+
self.model.to(self.device) # Move model to the determined device
|
| 140 |
self.model.eval()
|
| 141 |
|
| 142 |
def predict(self, text: str) -> Dict[str, Any]:
|
| 143 |
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True)
|
| 144 |
with torch.no_grad():
|
| 145 |
+
outputs = self.model(input_ids=inputs['input_ids'].to(self.device), attention_mask=inputs['attention_mask'].to(self.device))
|
| 146 |
logits = outputs.get("logits") # Use .get for safety
|
| 147 |
if logits is None:
|
| 148 |
raise ValueError("Model output did not contain 'logits'. Check model's forward pass.")
|