Spaces:
Sleeping
Sleeping
import torch | |
import os | |
import json | |
import httpx | |
import time, random | |
import argparse | |
from itertools import chain | |
all_samples = {} | |
groq_results = {} | |
def call_groq_api(prompt_messages, model_name="llama3-70b-8192"): | |
api_key= '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' | |
try: | |
response = httpx.post( | |
"https://api.groq.com/openai/v1/chat/completions", | |
headers={ | |
"Authorization": f"Bearer {api_key}", | |
"Content-Type": "application/json" | |
}, | |
json={ | |
"model": model_name, | |
"messages": prompt_messages, | |
"temperature": 0.0 | |
}, | |
timeout=120.0 | |
) | |
# Check for non-200 responses | |
if response.status_code != 200: | |
print(f"❌ API Error: Status Code {response.status_code}") | |
print(f"Response: {response.text}") | |
return None | |
response_json = response.json() | |
# Check if 'choices' exists in the response | |
if "choices" not in response_json or not response_json["choices"]: | |
print("❌ Missing 'choices' in response.") | |
print(f"Full response: {response_json}") | |
return None | |
return response_json["choices"][0]["message"]["content"] | |
except httpx.RequestError as e: | |
print(f"❌ Request failed: {e}") | |
return None | |
except Exception as e: | |
print(f"❌ Unexpected error: {e}") | |
return None | |
def build_noise_prompt(entry): | |
positive_flat = list(chain.from_iterable(entry['positive'][:2])) | |
negative_flat = entry['negative'][:2] # already flat | |
# Merge and join for context | |
context = "\n".join(positive_flat + negative_flat) | |
return [ | |
{"role": "system", "content": "You are a question answering assistant. Use only the context provided."}, | |
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {entry['query']}\nAnswer briefly."} | |
] | |
def run_groq_api(key): | |
results = [] | |
for idx, entry in enumerate(all_samples[key]): | |
prompt = build_noise_prompt(entry) | |
#result = call_groq_api(prompt, 'llama3-70b-8192') | |
result = call_groq_api(prompt, 'moonshotai/kimi-k2-instruct') | |
if result: | |
print(f"✅ Success: {result}...") # Preview | |
else: | |
print(f"⚠️ Failed to get result for sample #{idx + 1}") | |
results.append(result) | |
# Add a short delay between calls to avoid overload | |
time.sleep(random.uniform(2, 4)) # Adjustable based on rate | |
groq_results[key] = results | |
print("\n🎉 All samples processed.") | |
return results | |
def loadDataSets(filepath): | |
samples = [] | |
with open(filepath, 'r', encoding='utf-8') as f: | |
for line in f: | |
try: | |
data = json.loads(line) | |
samples.append(data) | |
except json.JSONDecodeError as e: | |
print(f"Error decoding JSON from line in {filepath}: {e}") | |
return samples | |
def loadSamples(key): | |
print('==== Loading samples', key) | |
all_samples[key] = loadDataSets(f'data/{key}.json') | |
# all_samples['en_fact'] = loadDataSets('data/en_fact.json') | |
# all_samples['en_refine'] = loadDataSets('data/en_refine.json') | |
# all_samples['en_int'] = loadDataSets('data/en_int.json') | |
print('==== Requested sample is loaded') | |
print('==== Hitting Grog request') | |
run_groq_api(key) | |
print('Saving under the Predictions ===== ') | |
savePredictions(key) | |
print('==== Predictions are saved!!!') | |
def savePredictions(key): | |
r_fact = groq_results[key] | |
dt = all_samples[key] | |
r1 = [] | |
for idx, entry in enumerate(dt): | |
r1.append( | |
{ | |
"id": dt[idx]['id'], | |
"query": dt[idx]['query'], | |
"prediction" : dt[idx]['answer'], | |
"answer": r_fact[idx], | |
"positive": dt[idx]['positive'], | |
"negative": dt[idx]['negative'], | |
#"positive_wrong": ["fake‑supporting wrong context"], | |
} | |
) | |
with open("prediction/prediction_en_fact_kimi_temp0.2_noise0.0_passage5_correct.0.json", "w") as f: | |
json.dump(r1, f) | |
def main(params): | |
print('test meee') | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
'--dataset', type=str, default='en', | |
help='evaluetion dataset', | |
choices=['en','zh','en_int','zh_int','en_fact','zh_fact'] | |
) | |
args = parser.parse_args(params) | |
print(args.dataset) | |
loadSamples(args.dataset) |