hamza2923 commited on
Commit
56a884d
·
verified ·
1 Parent(s): 0117db0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -135
app.py CHANGED
@@ -1,155 +1,45 @@
1
- from flask import Flask, render_template, request, jsonify
2
- from qdrant_client import QdrantClient
3
- from qdrant_client import models
4
- import torch.nn.functional as F
5
- import torch
6
- from torch import Tensor
7
- from transformers import AutoTokenizer, AutoModel
8
- from qdrant_client.models import Batch, PointStruct
9
- from pickle import load, dump
10
- import numpy as np
11
- import os, time, sys
12
- from datetime import datetime as dt
13
- from datetime import timedelta
14
- from datetime import timezone
15
  from faster_whisper import WhisperModel
 
16
  import io
 
17
 
18
  app = Flask(__name__)
19
 
 
 
 
 
 
20
  # Faster Whisper setup
21
- # model_size = 'small'
22
  beamsize = 2
23
- wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
24
-
25
- # Initialize Qdrant Client and other required settings
26
- qdrant_api_key = os.environ.get("qdrant_api_key")
27
- qdrant_url = os.environ.get("qdrant_url")
28
-
29
- client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
30
-
31
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
-
33
- def average_pool(last_hidden_states: Tensor,
34
- attention_mask: Tensor) -> Tensor:
35
- last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
36
- return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
37
-
38
- tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
39
- model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)
40
-
41
- def e5embed(query):
42
- batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
43
- batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
44
- outputs = model(**batch_dict)
45
- embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
46
- embeddings = F.normalize(embeddings, p=2, dim=1)
47
- embeddings = embeddings.cpu().detach().numpy().flatten().tolist()
48
- return embeddings
49
-
50
- def get_id(collection):
51
- resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,)
52
- max_id = max([r.id for r in resp[0]])+1
53
- return int(max_id)
54
-
55
- @app.route("/")
56
- def index():
57
- return render_template("index.html")
58
-
59
- @app.route("/search", methods=["POST"])
60
- def search():
61
- query = request.form["query"]
62
- collection_name = request.form["collection"]
63
- topN = 200 # Define your topN value
64
-
65
-
66
- print('QUERY: ',query)
67
- if query.strip().startswith('tilc:'):
68
- collection_name = 'tils'
69
- qvector = "context"
70
- query = query.replace('tilc:', '')
71
- elif query.strip().startswith('til:'):
72
- collection_name = 'tils'
73
- qvector = "title"
74
- query = query.replace('til:', '')
75
- else: collection_name = 'jks'
76
-
77
- timh = time.time()
78
- sq = e5embed(query)
79
- print('EMBEDDING TIME: ', time.time() - timh)
80
-
81
- timh = time.time()
82
- if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN)
83
- else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100)
84
- print('SEARCH TIME: ', time.time() - timh)
85
-
86
- #print(results[0])
87
- # try:
88
- new_results = []
89
- if collection_name == 'jks':
90
- for r in results:
91
- if 'date' not in r.payload: r.payload['date'] = '20200101'
92
- new_results.append({"text": r.payload['text'], "date": str(int(r.payload['date'])), "id": r.id}) # Implement your Qdrant search here
93
- else:
94
- for r in results:
95
- if 'context' in r.payload and r.payload['context'] != '':
96
- if 'date' not in r.payload: r.payload['date'] = '20200101'
97
- new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
98
- else:
99
- if 'date' not in r.payload: r.payload['date'] = '20200101'
100
- new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
101
- return jsonify(new_results)
102
- # except:
103
- # return jsonify([])
104
-
105
- @app.route("/add_item", methods=["POST"])
106
- def add_item():
107
- title = request.form["title"]
108
- url = request.form["url"]
109
- if url.strip() == '':
110
- collection_name = 'jks'
111
- cid = get_id(collection_name)
112
- print('cid', cid, time.strftime("%Y%m%d"))
113
- resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),)
114
- else:
115
- collection_name = 'tils'
116
- cid = get_id('tils')
117
- print('cid', cid, time.strftime("%Y%m%d"), collection_name)
118
- til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")}
119
- resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)])
120
- print('Upsert response:', resp)
121
- return jsonify({"success": True, "index": collection_name})
122
-
123
-
124
- @app.route("/delete_joke", methods=["POST"])
125
- def delete_joke():
126
- joke_id = request.form["id"]
127
- collection_name = request.form["collection"]
128
- print('Deleting no.', joke_id, 'from collection', collection_name)
129
- client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
130
- return jsonify({"deleted": True})
131
 
132
  @app.route("/whisper_transcribe", methods=["POST"])
133
  def whisper_transcribe():
134
- if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400
 
135
 
136
  audio_file = request.files['audio']
137
  allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
138
- if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400
 
139
 
140
- print('Transcribing audio')
141
  audio_bytes = audio_file.read()
142
  audio_file = io.BytesIO(audio_bytes)
143
 
144
- segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) # beamsize is 2.
145
- text = ''
146
- starttime = time.time()
147
- for segment in segments:
148
- text += segment.text
149
- print('Time to transcribe:', time.time() - starttime, 'seconds')
150
-
151
- return jsonify({'transcription': text})
152
-
 
 
153
 
154
  if __name__ == "__main__":
155
  app.run(host="0.0.0.0", debug=True, port=7860)
 
1
+ from flask import Flask, request, jsonify
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from faster_whisper import WhisperModel
3
+ import torch
4
  import io
5
+ import time
6
 
7
  app = Flask(__name__)
8
 
9
+ # Device check for faster-whisper
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ compute_type = "float16" if device == "cuda" else "int8"
12
+ print(f"Using device: {device} with compute_type: {compute_type}")
13
+
14
  # Faster Whisper setup
 
15
  beamsize = 2
16
+ wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=device, compute_type=compute_type)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  @app.route("/whisper_transcribe", methods=["POST"])
19
  def whisper_transcribe():
20
+ if 'audio' not in request.files:
21
+ return jsonify({'error': 'No file provided'}), 400
22
 
23
  audio_file = request.files['audio']
24
  allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
25
+ if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions):
26
+ return jsonify({'error': 'Invalid file format'}), 400
27
 
28
+ print(f"Transcribing audio on {device}")
29
  audio_bytes = audio_file.read()
30
  audio_file = io.BytesIO(audio_bytes)
31
 
32
+ try:
33
+ segments, info = wmodel.transcribe(audio_file, beam_size=beamsize)
34
+ text = ''
35
+ starttime = time.time()
36
+ for segment in segments:
37
+ text += segment.text
38
+ print(f"Time to transcribe: {time.time() - starttime} seconds")
39
+ return jsonify({'transcription': text})
40
+ except Exception as e:
41
+ print(f"Transcription error: {str(e)}")
42
+ return jsonify({'error': 'Transcription failed'}), 500
43
 
44
  if __name__ == "__main__":
45
  app.run(host="0.0.0.0", debug=True, port=7860)