giridhar_rgb / dataloader.py
giridhar99's picture
Create dataloader.py
8678b21 verified
import torch
import os
import json
import httpx
import time, random
import argparse
from itertools import chain
all_samples = {}
groq_results = {}
def call_groq_api(prompt_messages, model_name="llama3-70b-8192"):
api_key= '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'
try:
response = httpx.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json={
"model": model_name,
"messages": prompt_messages,
"temperature": 0.0
},
timeout=120.0
)
# Check for non-200 responses
if response.status_code != 200:
print(f"❌ API Error: Status Code {response.status_code}")
print(f"Response: {response.text}")
return None
response_json = response.json()
# Check if 'choices' exists in the response
if "choices" not in response_json or not response_json["choices"]:
print("❌ Missing 'choices' in response.")
print(f"Full response: {response_json}")
return None
return response_json["choices"][0]["message"]["content"]
except httpx.RequestError as e:
print(f"❌ Request failed: {e}")
return None
except Exception as e:
print(f"❌ Unexpected error: {e}")
return None
def build_noise_prompt(entry):
positive_flat = list(chain.from_iterable(entry['positive'][:2]))
negative_flat = entry['negative'][:2] # already flat
# Merge and join for context
context = "\n".join(positive_flat + negative_flat)
return [
{"role": "system", "content": "You are a question answering assistant. Use only the context provided."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {entry['query']}\nAnswer briefly."}
]
def run_groq_api(key):
results = []
for idx, entry in enumerate(all_samples[key]):
prompt = build_noise_prompt(entry)
#result = call_groq_api(prompt, 'llama3-70b-8192')
result = call_groq_api(prompt, 'moonshotai/kimi-k2-instruct')
if result:
print(f"✅ Success: {result}...") # Preview
else:
print(f"⚠️ Failed to get result for sample #{idx + 1}")
results.append(result)
# Add a short delay between calls to avoid overload
time.sleep(random.uniform(2, 4)) # Adjustable based on rate
groq_results[key] = results
print("\n🎉 All samples processed.")
return results
def loadDataSets(filepath):
samples = []
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line)
samples.append(data)
except json.JSONDecodeError as e:
print(f"Error decoding JSON from line in {filepath}: {e}")
return samples
def loadSamples(key):
print('==== Loading samples', key)
all_samples[key] = loadDataSets(f'data/{key}.json')
# all_samples['en_fact'] = loadDataSets('data/en_fact.json')
# all_samples['en_refine'] = loadDataSets('data/en_refine.json')
# all_samples['en_int'] = loadDataSets('data/en_int.json')
print('==== Requested sample is loaded')
print('==== Hitting Grog request')
run_groq_api(key)
print('Saving under the Predictions ===== ')
savePredictions(key)
print('==== Predictions are saved!!!')
def savePredictions(key):
r_fact = groq_results[key]
dt = all_samples[key]
r1 = []
for idx, entry in enumerate(dt):
r1.append(
{
"id": dt[idx]['id'],
"query": dt[idx]['query'],
"prediction" : dt[idx]['answer'],
"answer": r_fact[idx],
"positive": dt[idx]['positive'],
"negative": dt[idx]['negative'],
#"positive_wrong": ["fake‑supporting wrong context"],
}
)
with open("prediction/prediction_en_fact_kimi_temp0.2_noise0.0_passage5_correct.0.json", "w") as f:
json.dump(r1, f)
def main(params):
print('test meee')
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset', type=str, default='en',
help='evaluetion dataset',
choices=['en','zh','en_int','zh_int','en_fact','zh_fact']
)
args = parser.parse_args(params)
print(args.dataset)
loadSamples(args.dataset)