-
Notifications
You must be signed in to change notification settings - Fork 9
/
Copy pathserver.py
219 lines (188 loc) · 7.46 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from typing import List, Dict, Optional
import numpy as np
import json
from datetime import datetime
import asyncio
from pir import (
SimplePIRParams, gen_params, gen_hint,
answer as pir_answer
)
from update import update_embeddings
from utils import strings_to_matrix
from contextlib import asynccontextmanager
import os
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Handle startup and shutdown events"""
# Startup
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"\n[{timestamp}] Running initial update on startup...")
update_embeddings()
state.load_data()
print(f"[{timestamp}] Initial update complete!")
except Exception as e:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] Error during initial update: {e}")
state._update_task = asyncio.create_task(state.update_loop())
yield # Server is running
# Shutdown
if state._update_task:
state._update_task.cancel()
try:
await state._update_task
except asyncio.CancelledError:
pass
app = FastAPI(title="Private Market Data Search", lifespan=lifespan)
# Enable CORS
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request/Response Models
class PIRQuery(BaseModel):
query: List[int]
class UpdateRequest(BaseModel):
type: str = "update"
class SetupResponse(BaseModel):
params: Dict
hint: List[List[float]]
centroids: Optional[List[List[float]]] = None
metadata: Optional[Dict] = None
embeddings: Optional[List[List[float]]] = None
num_articles: Optional[int] = None
class PIRResponse(BaseModel):
answer: List[int]
class UpdateResponse(BaseModel):
centroids: List[List[float]]
metadata: Dict
embeddings: List[List[float]]
# Server state
class ServerState:
def __init__(self):
self._update_task = None
# Create necessary directories
os.makedirs('embeddings', exist_ok=True)
os.makedirs('articles', exist_ok=True)
# Initialize with empty state
self.embeddings_db = np.zeros((0, 0))
self.centroids = np.zeros((0, 0))
self.metadata = {'articles': []}
self.embeddings_params = gen_params(m=1) # Minimal params for initial state
self.embeddings_hint = np.zeros((0, 0))
self.articles_db = np.zeros((0, 0))
self.articles_params = gen_params(m=1) # Minimal params for initial state
self.articles_hint = np.zeros((0, 0))
self.num_articles = 0
def load_data(self):
"""Load embeddings, metadata, and centroids from disk"""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"\n[{timestamp}] Loading embeddings database...")
try:
# Load embeddings data
self.embeddings_db = np.load('embeddings/embeddings.npy')
self.centroids = np.load('embeddings/centroids.npy')
with open('embeddings/metadata.json', 'r') as f:
self.metadata = json.load(f)
# Initialize PIR for embeddings
self.embeddings_params = gen_params(m=self.embeddings_db.shape[0])
self.embeddings_hint = gen_hint(self.embeddings_params, self.embeddings_db)
# Load articles
articles = []
for article_info in self.metadata['articles']:
with open(article_info['filepath'], 'r', encoding='utf-8') as f:
articles.append(f.read())
# Convert articles to matrix
self.articles_db, matrix_size = strings_to_matrix(articles)
self.articles_params = gen_params(m=matrix_size)
self.articles_hint = gen_hint(self.articles_params, self.articles_db)
self.num_articles = len(articles)
print(f"[{timestamp}] Embeddings shape: {self.embeddings_db.shape}")
print(f"[{timestamp}] Centroids shape: {self.centroids.shape}")
print(f"[{timestamp}] Articles loaded: {self.num_articles}")
except FileNotFoundError:
print(f"[{timestamp}] No existing embeddings found. Will create in first update.")
async def update_loop(self):
"""Periodically update data"""
while True:
await asyncio.sleep(60) # Wait for 1 minute
try:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"\n[{timestamp}] Starting scheduled update...")
update_embeddings()
self.load_data()
print(f"[{timestamp}] Update complete!")
except Exception as e:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}] Error during update: {e}")
# Initialize server state
state = ServerState()
@app.get("/")
async def root():
return {"message": "Private Market Data Search API"}
@app.get("/health")
async def health():
"""Health check endpoint"""
return {"status": "healthy"}
@app.get("/embedding/setup", response_model=SetupResponse)
async def embedding_setup():
"""Get initial setup data for embeddings"""
return {
'params': {
'n': int(state.embeddings_params.n),
'm': int(state.embeddings_params.m),
'q': int(state.embeddings_params.q),
'p': int(state.embeddings_params.p),
'std_dev': float(state.embeddings_params.std_dev),
'seed': int(state.embeddings_params.seed)
},
'hint': state.embeddings_hint.tolist(),
'centroids': state.centroids.tolist(),
'metadata': state.metadata,
'embeddings': state.embeddings_db.tolist()
}
@app.get("/article/setup", response_model=SetupResponse)
async def article_setup():
"""Get initial setup data for articles"""
return {
'params': {
'n': int(state.articles_params.n),
'm': int(state.articles_params.m),
'q': int(state.articles_params.q),
'p': int(state.articles_params.p),
'std_dev': float(state.articles_params.std_dev),
'seed': int(state.articles_params.seed)
},
'hint': state.articles_hint.tolist(),
'num_articles': state.num_articles
}
@app.post("/embedding/query", response_model=PIRResponse)
async def embedding_query(query: PIRQuery):
"""Handle PIR query for embeddings"""
query_array = np.array(query.query)
ans = pir_answer(query_array, state.embeddings_db, state.embeddings_params.q)
return {'answer': ans.tolist()}
@app.post("/article/query", response_model=PIRResponse)
async def article_query(query: PIRQuery):
"""Handle PIR query for articles"""
query_array = np.array(query.query)
ans = pir_answer(query_array, state.articles_db, state.articles_params.q)
return {'answer': ans.tolist()}
@app.post("/embedding/update", response_model=UpdateResponse)
async def embedding_update(request: UpdateRequest):
"""Get updated embedding data"""
return {
'centroids': state.centroids.tolist(),
'metadata': state.metadata,
'embeddings': state.embeddings_db.tolist()
}
if __name__ == "__main__":
import uvicorn
# Start the server
uvicorn.run(app, host="127.0.0.1", port=8000)