File size: 4,432 Bytes
8678b21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import os
import json
import httpx
import time, random
import argparse
from itertools import chain


all_samples = {}
groq_results = {}

def call_groq_api(prompt_messages, model_name="llama3-70b-8192"):
  api_key= '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'

  try:
    response = httpx.post(
        "https://api.groq.com/openai/v1/chat/completions",
        headers={
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        },
        json={
            "model": model_name,
            "messages": prompt_messages,
            "temperature": 0.0
        },
        timeout=120.0
    )
    # Check for non-200 responses
    if response.status_code != 200:
        print(f"❌ API Error: Status Code {response.status_code}")
        print(f"Response: {response.text}")
        return None

    response_json = response.json()

    # Check if 'choices' exists in the response
    if "choices" not in response_json or not response_json["choices"]:
        print("❌ Missing 'choices' in response.")
        print(f"Full response: {response_json}")
        return None

    return response_json["choices"][0]["message"]["content"]

  except httpx.RequestError as e:
      print(f"❌ Request failed: {e}")
      return None

  except Exception as e:
      print(f"❌ Unexpected error: {e}")
      return None
  

def build_noise_prompt(entry):
    positive_flat = list(chain.from_iterable(entry['positive'][:2]))
    negative_flat = entry['negative'][:2]  # already flat

    # Merge and join for context
    context = "\n".join(positive_flat + negative_flat)
    return [
        {"role": "system", "content": "You are a question answering assistant. Use only the context provided."},
        {"role": "user", "content": f"Context:\n{context}\n\nQuestion: {entry['query']}\nAnswer briefly."}
    ]


    
def run_groq_api(key):
    results = []
    for idx, entry in enumerate(all_samples[key]):
      prompt = build_noise_prompt(entry)

      #result = call_groq_api(prompt, 'llama3-70b-8192')
      result = call_groq_api(prompt, 'moonshotai/kimi-k2-instruct')

      if result:
          print(f"✅ Success: {result}...")  # Preview
      else:
          print(f"⚠️ Failed to get result for sample #{idx + 1}")

      results.append(result)

      # Add a short delay between calls to avoid overload
      time.sleep(random.uniform(2, 4))  # Adjustable based on rate

    groq_results[key] = results
    print("\n🎉 All samples processed.")
    return results
    
def loadDataSets(filepath):
    samples = []
    with open(filepath, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line)
                samples.append(data)
            except json.JSONDecodeError as e:
                print(f"Error decoding JSON from line in {filepath}: {e}")
    return samples

def loadSamples(key):
    print('==== Loading samples', key)
    
    all_samples[key] = loadDataSets(f'data/{key}.json')
    # all_samples['en_fact'] = loadDataSets('data/en_fact.json')
    # all_samples['en_refine'] = loadDataSets('data/en_refine.json')
    # all_samples['en_int'] = loadDataSets('data/en_int.json')
    print('==== Requested sample is loaded')
    print('==== Hitting Grog request')
    run_groq_api(key)
    print('Saving under the Predictions ===== ')
    savePredictions(key)
    print('==== Predictions are saved!!!')


def savePredictions(key):
    r_fact = groq_results[key]
    dt = all_samples[key]
    r1 = []
    for idx, entry in enumerate(dt):
        r1.append(
            {
            "id": dt[idx]['id'],
            "query": dt[idx]['query'],
            "prediction" : dt[idx]['answer'],
            "answer": r_fact[idx],
            "positive": dt[idx]['positive'],
            "negative": dt[idx]['negative'],
            #"positive_wrong": ["fake‑supporting wrong context"],
            }
        )

    with open("prediction/prediction_en_fact_kimi_temp0.2_noise0.0_passage5_correct.0.json", "w") as f:
        json.dump(r1, f)


def main(params):
    print('test meee')
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--dataset', type=str, default='en',
        help='evaluetion dataset',
        choices=['en','zh','en_int','zh_int','en_fact','zh_fact']
    )
    args = parser.parse_args(params)
    print(args.dataset)
    loadSamples(args.dataset)