hamza2923 commited on
Commit
0117db0
·
verified ·
1 Parent(s): 55aec8e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -0
app.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from qdrant_client import QdrantClient
3
+ from qdrant_client import models
4
+ import torch.nn.functional as F
5
+ import torch
6
+ from torch import Tensor
7
+ from transformers import AutoTokenizer, AutoModel
8
+ from qdrant_client.models import Batch, PointStruct
9
+ from pickle import load, dump
10
+ import numpy as np
11
+ import os, time, sys
12
+ from datetime import datetime as dt
13
+ from datetime import timedelta
14
+ from datetime import timezone
15
+ from faster_whisper import WhisperModel
16
+ import io
17
+
18
+ app = Flask(__name__)
19
+
20
+ # Faster Whisper setup
21
+ # model_size = 'small'
22
+ beamsize = 2
23
+ wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
24
+
25
+ # Initialize Qdrant Client and other required settings
26
+ qdrant_api_key = os.environ.get("qdrant_api_key")
27
+ qdrant_url = os.environ.get("qdrant_url")
28
+
29
+ client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
30
+
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+
33
+ def average_pool(last_hidden_states: Tensor,
34
+ attention_mask: Tensor) -> Tensor:
35
+ last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
36
+ return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
37
+
38
+ tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
39
+ model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)
40
+
41
+ def e5embed(query):
42
+ batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
43
+ batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
44
+ outputs = model(**batch_dict)
45
+ embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
46
+ embeddings = F.normalize(embeddings, p=2, dim=1)
47
+ embeddings = embeddings.cpu().detach().numpy().flatten().tolist()
48
+ return embeddings
49
+
50
+ def get_id(collection):
51
+ resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,)
52
+ max_id = max([r.id for r in resp[0]])+1
53
+ return int(max_id)
54
+
55
+ @app.route("/")
56
+ def index():
57
+ return render_template("index.html")
58
+
59
+ @app.route("/search", methods=["POST"])
60
+ def search():
61
+ query = request.form["query"]
62
+ collection_name = request.form["collection"]
63
+ topN = 200 # Define your topN value
64
+
65
+
66
+ print('QUERY: ',query)
67
+ if query.strip().startswith('tilc:'):
68
+ collection_name = 'tils'
69
+ qvector = "context"
70
+ query = query.replace('tilc:', '')
71
+ elif query.strip().startswith('til:'):
72
+ collection_name = 'tils'
73
+ qvector = "title"
74
+ query = query.replace('til:', '')
75
+ else: collection_name = 'jks'
76
+
77
+ timh = time.time()
78
+ sq = e5embed(query)
79
+ print('EMBEDDING TIME: ', time.time() - timh)
80
+
81
+ timh = time.time()
82
+ if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN)
83
+ else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100)
84
+ print('SEARCH TIME: ', time.time() - timh)
85
+
86
+ #print(results[0])
87
+ # try:
88
+ new_results = []
89
+ if collection_name == 'jks':
90
+ for r in results:
91
+ if 'date' not in r.payload: r.payload['date'] = '20200101'
92
+ new_results.append({"text": r.payload['text'], "date": str(int(r.payload['date'])), "id": r.id}) # Implement your Qdrant search here
93
+ else:
94
+ for r in results:
95
+ if 'context' in r.payload and r.payload['context'] != '':
96
+ if 'date' not in r.payload: r.payload['date'] = '20200101'
97
+ new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
98
+ else:
99
+ if 'date' not in r.payload: r.payload['date'] = '20200101'
100
+ new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
101
+ return jsonify(new_results)
102
+ # except:
103
+ # return jsonify([])
104
+
105
+ @app.route("/add_item", methods=["POST"])
106
+ def add_item():
107
+ title = request.form["title"]
108
+ url = request.form["url"]
109
+ if url.strip() == '':
110
+ collection_name = 'jks'
111
+ cid = get_id(collection_name)
112
+ print('cid', cid, time.strftime("%Y%m%d"))
113
+ resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),)
114
+ else:
115
+ collection_name = 'tils'
116
+ cid = get_id('tils')
117
+ print('cid', cid, time.strftime("%Y%m%d"), collection_name)
118
+ til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")}
119
+ resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)])
120
+ print('Upsert response:', resp)
121
+ return jsonify({"success": True, "index": collection_name})
122
+
123
+
124
+ @app.route("/delete_joke", methods=["POST"])
125
+ def delete_joke():
126
+ joke_id = request.form["id"]
127
+ collection_name = request.form["collection"]
128
+ print('Deleting no.', joke_id, 'from collection', collection_name)
129
+ client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
130
+ return jsonify({"deleted": True})
131
+
132
+ @app.route("/whisper_transcribe", methods=["POST"])
133
+ def whisper_transcribe():
134
+ if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400
135
+
136
+ audio_file = request.files['audio']
137
+ allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
138
+ if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400
139
+
140
+ print('Transcribing audio')
141
+ audio_bytes = audio_file.read()
142
+ audio_file = io.BytesIO(audio_bytes)
143
+
144
+ segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) # beamsize is 2.
145
+ text = ''
146
+ starttime = time.time()
147
+ for segment in segments:
148
+ text += segment.text
149
+ print('Time to transcribe:', time.time() - starttime, 'seconds')
150
+
151
+ return jsonify({'transcription': text})
152
+
153
+
154
+ if __name__ == "__main__":
155
+ app.run(host="0.0.0.0", debug=True, port=7860)