Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, render_template, request, jsonify
|
2 |
+
from qdrant_client import QdrantClient
|
3 |
+
from qdrant_client import models
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch
|
6 |
+
from torch import Tensor
|
7 |
+
from transformers import AutoTokenizer, AutoModel
|
8 |
+
from qdrant_client.models import Batch, PointStruct
|
9 |
+
from pickle import load, dump
|
10 |
+
import numpy as np
|
11 |
+
import os, time, sys
|
12 |
+
from datetime import datetime as dt
|
13 |
+
from datetime import timedelta
|
14 |
+
from datetime import timezone
|
15 |
+
from faster_whisper import WhisperModel
|
16 |
+
import io
|
17 |
+
|
18 |
+
app = Flask(__name__)
|
19 |
+
|
20 |
+
# Faster Whisper setup
|
21 |
+
# model_size = 'small'
|
22 |
+
beamsize = 2
|
23 |
+
wmodel = WhisperModel("guillaumekln/faster-whisper-small", device="cpu", compute_type="int8")
|
24 |
+
|
25 |
+
# Initialize Qdrant Client and other required settings
|
26 |
+
qdrant_api_key = os.environ.get("qdrant_api_key")
|
27 |
+
qdrant_url = os.environ.get("qdrant_url")
|
28 |
+
|
29 |
+
client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False)
|
30 |
+
|
31 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
32 |
+
|
33 |
+
def average_pool(last_hidden_states: Tensor,
|
34 |
+
attention_mask: Tensor) -> Tensor:
|
35 |
+
last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
|
36 |
+
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
|
37 |
+
|
38 |
+
tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2')
|
39 |
+
model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device)
|
40 |
+
|
41 |
+
def e5embed(query):
|
42 |
+
batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt')
|
43 |
+
batch_dict = {k: v.to(device) for k, v in batch_dict.items()}
|
44 |
+
outputs = model(**batch_dict)
|
45 |
+
embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
|
46 |
+
embeddings = F.normalize(embeddings, p=2, dim=1)
|
47 |
+
embeddings = embeddings.cpu().detach().numpy().flatten().tolist()
|
48 |
+
return embeddings
|
49 |
+
|
50 |
+
def get_id(collection):
|
51 |
+
resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,)
|
52 |
+
max_id = max([r.id for r in resp[0]])+1
|
53 |
+
return int(max_id)
|
54 |
+
|
55 |
+
@app.route("/")
|
56 |
+
def index():
|
57 |
+
return render_template("index.html")
|
58 |
+
|
59 |
+
@app.route("/search", methods=["POST"])
|
60 |
+
def search():
|
61 |
+
query = request.form["query"]
|
62 |
+
collection_name = request.form["collection"]
|
63 |
+
topN = 200 # Define your topN value
|
64 |
+
|
65 |
+
|
66 |
+
print('QUERY: ',query)
|
67 |
+
if query.strip().startswith('tilc:'):
|
68 |
+
collection_name = 'tils'
|
69 |
+
qvector = "context"
|
70 |
+
query = query.replace('tilc:', '')
|
71 |
+
elif query.strip().startswith('til:'):
|
72 |
+
collection_name = 'tils'
|
73 |
+
qvector = "title"
|
74 |
+
query = query.replace('til:', '')
|
75 |
+
else: collection_name = 'jks'
|
76 |
+
|
77 |
+
timh = time.time()
|
78 |
+
sq = e5embed(query)
|
79 |
+
print('EMBEDDING TIME: ', time.time() - timh)
|
80 |
+
|
81 |
+
timh = time.time()
|
82 |
+
if collection_name == "jks": results = client.search(collection_name=collection_name, query_vector=sq, with_payload=True, limit=topN)
|
83 |
+
else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100)
|
84 |
+
print('SEARCH TIME: ', time.time() - timh)
|
85 |
+
|
86 |
+
#print(results[0])
|
87 |
+
# try:
|
88 |
+
new_results = []
|
89 |
+
if collection_name == 'jks':
|
90 |
+
for r in results:
|
91 |
+
if 'date' not in r.payload: r.payload['date'] = '20200101'
|
92 |
+
new_results.append({"text": r.payload['text'], "date": str(int(r.payload['date'])), "id": r.id}) # Implement your Qdrant search here
|
93 |
+
else:
|
94 |
+
for r in results:
|
95 |
+
if 'context' in r.payload and r.payload['context'] != '':
|
96 |
+
if 'date' not in r.payload: r.payload['date'] = '20200101'
|
97 |
+
new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
|
98 |
+
else:
|
99 |
+
if 'date' not in r.payload: r.payload['date'] = '20200101'
|
100 |
+
new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id})
|
101 |
+
return jsonify(new_results)
|
102 |
+
# except:
|
103 |
+
# return jsonify([])
|
104 |
+
|
105 |
+
@app.route("/add_item", methods=["POST"])
|
106 |
+
def add_item():
|
107 |
+
title = request.form["title"]
|
108 |
+
url = request.form["url"]
|
109 |
+
if url.strip() == '':
|
110 |
+
collection_name = 'jks'
|
111 |
+
cid = get_id(collection_name)
|
112 |
+
print('cid', cid, time.strftime("%Y%m%d"))
|
113 |
+
resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),)
|
114 |
+
else:
|
115 |
+
collection_name = 'tils'
|
116 |
+
cid = get_id('tils')
|
117 |
+
print('cid', cid, time.strftime("%Y%m%d"), collection_name)
|
118 |
+
til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")}
|
119 |
+
resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)])
|
120 |
+
print('Upsert response:', resp)
|
121 |
+
return jsonify({"success": True, "index": collection_name})
|
122 |
+
|
123 |
+
|
124 |
+
@app.route("/delete_joke", methods=["POST"])
|
125 |
+
def delete_joke():
|
126 |
+
joke_id = request.form["id"]
|
127 |
+
collection_name = request.form["collection"]
|
128 |
+
print('Deleting no.', joke_id, 'from collection', collection_name)
|
129 |
+
client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),)
|
130 |
+
return jsonify({"deleted": True})
|
131 |
+
|
132 |
+
@app.route("/whisper_transcribe", methods=["POST"])
|
133 |
+
def whisper_transcribe():
|
134 |
+
if 'audio' not in request.files: return jsonify({'error': 'No file provided'}), 400
|
135 |
+
|
136 |
+
audio_file = request.files['audio']
|
137 |
+
allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
|
138 |
+
if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions): return jsonify({'error': 'Invalid file format'}), 400
|
139 |
+
|
140 |
+
print('Transcribing audio')
|
141 |
+
audio_bytes = audio_file.read()
|
142 |
+
audio_file = io.BytesIO(audio_bytes)
|
143 |
+
|
144 |
+
segments, info = wmodel.transcribe(audio_file, beam_size=beamsize) # beamsize is 2.
|
145 |
+
text = ''
|
146 |
+
starttime = time.time()
|
147 |
+
for segment in segments:
|
148 |
+
text += segment.text
|
149 |
+
print('Time to transcribe:', time.time() - starttime, 'seconds')
|
150 |
+
|
151 |
+
return jsonify({'transcription': text})
|
152 |
+
|
153 |
+
|
154 |
+
if __name__ == "__main__":
|
155 |
+
app.run(host="0.0.0.0", debug=True, port=7860)
|