Spaces:
Sleeping
Sleeping
File size: 4,432 Bytes
8678b21 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import torch
import os
import json
import httpx
import time, random
import argparse
from itertools import chain
all_samples = {}
groq_results = {}
def call_groq_api(prompt_messages, model_name="llama3-70b-8192"):
api_key= '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'
try:
response = httpx.post(
"https://api.groq.com/openai/v1/chat/completions",
headers={
"Authorization": f"Bearer {api_key}",
"Content-Type": "application/json"
},
json={
"model": model_name,
"messages": prompt_messages,
"temperature": 0.0
},
timeout=120.0
)
# Check for non-200 responses
if response.status_code != 200:
print(f"❌ API Error: Status Code {response.status_code}")
print(f"Response: {response.text}")
return None
response_json = response.json()
# Check if 'choices' exists in the response
if "choices" not in response_json or not response_json["choices"]:
print("❌ Missing 'choices' in response.")
print(f"Full response: {response_json}")
return None
return response_json["choices"][0]["message"]["content"]
except httpx.RequestError as e:
print(f"❌ Request failed: {e}")
return None
except Exception as e:
print(f"❌ Unexpected error: {e}")
return None
def build_noise_prompt(entry):
positive_flat = list(chain.from_iterable(entry['positive'][:2]))
negative_flat = entry['negative'][:2] # already flat
# Merge and join for context
context = "\n".join(positive_flat + negative_flat)
return [
{"role": "system", "content": "You are a question answering assistant. Use only the context provided."},
{"role": "user", "content": f"Context:\n{context}\n\nQuestion: {entry['query']}\nAnswer briefly."}
]
def run_groq_api(key):
results = []
for idx, entry in enumerate(all_samples[key]):
prompt = build_noise_prompt(entry)
#result = call_groq_api(prompt, 'llama3-70b-8192')
result = call_groq_api(prompt, 'moonshotai/kimi-k2-instruct')
if result:
print(f"✅ Success: {result}...") # Preview
else:
print(f"⚠️ Failed to get result for sample #{idx + 1}")
results.append(result)
# Add a short delay between calls to avoid overload
time.sleep(random.uniform(2, 4)) # Adjustable based on rate
groq_results[key] = results
print("\n🎉 All samples processed.")
return results
def loadDataSets(filepath):
samples = []
with open(filepath, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line)
samples.append(data)
except json.JSONDecodeError as e:
print(f"Error decoding JSON from line in {filepath}: {e}")
return samples
def loadSamples(key):
print('==== Loading samples', key)
all_samples[key] = loadDataSets(f'data/{key}.json')
# all_samples['en_fact'] = loadDataSets('data/en_fact.json')
# all_samples['en_refine'] = loadDataSets('data/en_refine.json')
# all_samples['en_int'] = loadDataSets('data/en_int.json')
print('==== Requested sample is loaded')
print('==== Hitting Grog request')
run_groq_api(key)
print('Saving under the Predictions ===== ')
savePredictions(key)
print('==== Predictions are saved!!!')
def savePredictions(key):
r_fact = groq_results[key]
dt = all_samples[key]
r1 = []
for idx, entry in enumerate(dt):
r1.append(
{
"id": dt[idx]['id'],
"query": dt[idx]['query'],
"prediction" : dt[idx]['answer'],
"answer": r_fact[idx],
"positive": dt[idx]['positive'],
"negative": dt[idx]['negative'],
#"positive_wrong": ["fake‑supporting wrong context"],
}
)
with open("prediction/prediction_en_fact_kimi_temp0.2_noise0.0_passage5_correct.0.json", "w") as f:
json.dump(r1, f)
def main(params):
print('test meee')
parser = argparse.ArgumentParser()
parser.add_argument(
'--dataset', type=str, default='en',
help='evaluetion dataset',
choices=['en','zh','en_int','zh_int','en_fact','zh_fact']
)
args = parser.parse_args(params)
print(args.dataset)
loadSamples(args.dataset) |