Update app.py
Browse files
app.py
CHANGED
@@ -1,155 +1,45 @@
|
|
1 |
-
from flask import Flask,
|
2 |
-
from qdrant_client import QdrantClient
|
3 |
-
from qdrant_client import models
|
4 |
-
import torch.nn.functional as F
|
5 |
-
import torch
|
6 |
-
from torch import Tensor
|
7 |
-
from transformers import AutoTokenizer, AutoModel
|
8 |
-
from qdrant_client.models import Batch, PointStruct
|
9 |
-
from pickle import load, dump
|
10 |
-
import numpy as np
|
11 |
-
import os, time, sys
|
12 |
-
from datetime import datetime as dt
|
13 |
-
from datetime import timedelta
|
14 |
-
from datetime import timezone
|
15 |
from faster_whisper import WhisperModel
|
|
|
16 |
import io
|
|
|
17 |
|
18 |
app = Flask(__name__)
|
19 |
|
|
|
|
|
|
|
|
|
|
|
20 |
# Faster Whisper setup
|
21 |
-
# model_size = 'small'
|
22 |
beamsize = 2
|
23 |
-
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=
|
24 |
-
|
25 |
-
# Initialize Qdrant Client and other required settings
|
26 |
-
qdrant_api_key = os.environ.get("qdrant_api_key")
|
27 |
-
qdrant_url = os.environ.get("qdrant_url")
|
28 |
-
|
29 |
-
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
|
30 |
-
|
31 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
-
|
33 |
-
def average_pool(last_hidden_states: Tensor,
|
34 |
-
attention_mask: Tensor) -> Tensor:
|
35 |
-
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
36 |
-
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
37 |
-
|
38 |
-
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
|
39 |
-
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)
|
40 |
-
|
41 |
-
def e5embed(query):
|
42 |
-
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
|
43 |
-
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
|
44 |
-
outputs = model(**batch_dict)
|
45 |
-
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
46 |
-
embeddings = F.normalize(embeddings, p=2, dim=1)
|
47 |
-
embeddings = embeddings.cpu().detach().numpy().flatten().tolist()
|
48 |
-
return embeddings
|
49 |
-
|
50 |
-
def get_id(collection):
|
51 |
-
resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,)
|
52 |
-
max_id = max([r.id for r in resp[0]])+1
|
53 |
-
return int(max_id)
|
54 |
-
|
55 |
-
@app.route("/")
|
56 |
-
def index():
|
57 |
-
return render_template("index.html")
|
58 |
-
|
59 |
-
@app.route("/search", methods=["POST"])
|
60 |
-
def search():
|
61 |
-
query = request.form["query"]
|
62 |
-
collection_name = request.form["collection"]
|
63 |
-
topN = 200 # Define your topN value
|
64 |
-
|
65 |
-
|
66 |
-
print('QUERY: ',query)
|
67 |
-
if query.strip().startswith('tilc:'):
|
68 |
-
collection_name = 'tils'
|
69 |
-
qvector = "context"
|
70 |
-
query = query.replace('tilc:', '')
|
71 |
-
elif query.strip().startswith('til:'):
|
72 |
-
collection_name = 'tils'
|
73 |
-
qvector = "title"
|
74 |
-
query = query.replace('til:', '')
|
75 |
-
else: collection_name = 'jks'
|
76 |
-
|
77 |
-
timh = time.time()
|
78 |
-
sq = e5embed(query)
|
79 |
-
print('EMBEDDING TIME: ', time.time() - timh)
|
80 |
-
|
81 |
-
timh = time.time()
|
82 |
-
if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN)
|
83 |
-
else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100)
|
84 |
-
print('SEARCH TIME: ', time.time() - timh)
|
85 |
-
|
86 |
-
#print(results[0])
|
87 |
-
# try:
|
88 |
-
new_results = []
|
89 |
-
if collection_name == 'jks':
|
90 |
-
for r in results:
|
91 |
-
if 'date' not in r.payload: r.payload['date'] = '20200101'
|
92 |
-
new_results.append({"text": r.payload['text'], "date": str(int(r.payload['date'])), "id": r.id}) # Implement your Qdrant search here
|
93 |
-
else:
|
94 |
-
for r in results:
|
95 |
-
if 'context' in r.payload and r.payload['context'] != '':
|
96 |
-
if 'date' not in r.payload: r.payload['date'] = '20200101'
|
97 |
-
new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
|
98 |
-
else:
|
99 |
-
if 'date' not in r.payload: r.payload['date'] = '20200101'
|
100 |
-
new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
|
101 |
-
return jsonify(new_results)
|
102 |
-
# except:
|
103 |
-
# return jsonify([])
|
104 |
-
|
105 |
-
@app.route("/add_item", methods=["POST"])
|
106 |
-
def add_item():
|
107 |
-
title = request.form["title"]
|
108 |
-
url = request.form["url"]
|
109 |
-
if url.strip() == '':
|
110 |
-
collection_name = 'jks'
|
111 |
-
cid = get_id(collection_name)
|
112 |
-
print('cid', cid, time.strftime("%Y%m%d"))
|
113 |
-
resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),)
|
114 |
-
else:
|
115 |
-
collection_name = 'tils'
|
116 |
-
cid = get_id('tils')
|
117 |
-
print('cid', cid, time.strftime("%Y%m%d"), collection_name)
|
118 |
-
til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")}
|
119 |
-
resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)])
|
120 |
-
print('Upsert response:', resp)
|
121 |
-
return jsonify({"success": True, "index": collection_name})
|
122 |
-
|
123 |
-
|
124 |
-
@app.route("/delete_joke", methods=["POST"])
|
125 |
-
def delete_joke():
|
126 |
-
joke_id = request.form["id"]
|
127 |
-
collection_name = request.form["collection"]
|
128 |
-
print('Deleting no.', joke_id, 'from collection', collection_name)
|
129 |
-
client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
|
130 |
-
return jsonify({"deleted": True})
|
131 |
|
132 |
@app.route("/whisper_transcribe", methods=["POST"])
|
133 |
def whisper_transcribe():
|
134 |
-
if 'audio' not in request.files:
|
|
|
135 |
|
136 |
audio_file = request.files['audio']
|
137 |
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
|
138 |
-
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions):
|
|
|
139 |
|
140 |
-
print(
|
141 |
audio_bytes = audio_file.read()
|
142 |
audio_file = io.BytesIO(audio_bytes)
|
143 |
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
|
|
|
|
153 |
|
154 |
if __name__ == "__main__":
|
155 |
app.run(host="0.0.0.0", debug=True, port=7860)
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from faster_whisper import WhisperModel
|
3 |
+
import torch
|
4 |
import io
|
5 |
+
import time
|
6 |
|
7 |
app = Flask(__name__)
|
8 |
|
9 |
+
# Device check for faster-whisper
|
10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
compute_type = "float16" if device == "cuda" else "int8"
|
12 |
+
print(f"Using device: {device} with compute_type: {compute_type}")
|
13 |
+
|
14 |
# Faster Whisper setup
|
|
|
15 |
beamsize = 2
|
16 |
+
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=device, compute_type=compute_type)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
@app.route("/whisper_transcribe", methods=["POST"])
|
19 |
def whisper_transcribe():
|
20 |
+
if 'audio' not in request.files:
|
21 |
+
return jsonify({'error': 'No file provided'}), 400
|
22 |
|
23 |
audio_file = request.files['audio']
|
24 |
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
|
25 |
+
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions):
|
26 |
+
return jsonify({'error': 'Invalid file format'}), 400
|
27 |
|
28 |
+
print(f"Transcribing audio on {device}")
|
29 |
audio_bytes = audio_file.read()
|
30 |
audio_file = io.BytesIO(audio_bytes)
|
31 |
|
32 |
+
try:
|
33 |
+
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize)
|
34 |
+
text = ''
|
35 |
+
starttime = time.time()
|
36 |
+
for segment in segments:
|
37 |
+
text += segment.text
|
38 |
+
print(f"Time to transcribe: {time.time() - starttime} seconds")
|
39 |
+
return jsonify({'transcription': text})
|
40 |
+
except Exception as e:
|
41 |
+
print(f"Transcription error: {str(e)}")
|
42 |
+
return jsonify({'error': 'Transcription failed'}), 500
|
43 |
|
44 |
if __name__ == "__main__":
|
45 |
app.run(host="0.0.0.0", debug=True, port=7860)
|