Skip to content

Commit

Permalink
perf: check token status
Browse files Browse the repository at this point in the history
  • Loading branch information
LeeEirc committed Oct 9, 2024
1 parent b53c461 commit ad3e277
Show file tree
Hide file tree
Showing 13 changed files with 1,799 additions and 445 deletions.
88 changes: 76 additions & 12 deletions cmd/common/beat_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@ import (

func NewBeatService(apiClient *service.JMService) *BeatService {
return &BeatService{
sessMap: make(map[string]struct{}),
sessMap: make(map[string]*SessionToken),
apiClient: apiClient,
taskChan: make(chan *model.TerminalTask, 5),
}
}

type SessionToken struct {
model.Session
TokenId string
invalid bool
}

type BeatService struct {
sessMap map[string]struct{}
sessMap map[string]*SessionToken

apiClient *service.JMService

Expand Down Expand Up @@ -88,18 +94,14 @@ func (b *BeatService) receiveWsTask(ws *websocket.Conn, done chan struct{}) {
}
if len(tasks) != 0 {
for i := range tasks {
select {
case b.taskChan <- &tasks[i]:
default:
logger.Infof("Discard task %v", tasks[i])
}
b.sendTask(&tasks[i])
}
}
}
}

func (b *BeatService) GetStatusData() interface{} {
sessions := b.getSessions()
sessions := b.getSessionIds()
payload := model.HeartbeatData{
SessionOnlineIds: sessions,
CpuUsed: common.CpuLoad1Usage(),
Expand All @@ -113,7 +115,7 @@ func (b *BeatService) GetStatusData() interface{} {
}
}

func (b *BeatService) getSessions() []string {
func (b *BeatService) getSessionIds() []string {
b.Lock()
defer b.Unlock()
sids := make([]string, 0, len(b.sessMap))
Expand All @@ -123,12 +125,20 @@ func (b *BeatService) getSessions() []string {
return sids
}

var empty = struct{}{}
func (b *BeatService) StoreSessionId(sess *SessionToken) {
b.Lock()
defer b.Unlock()
b.sessMap[sess.ID] = sess
}

func (b *BeatService) StoreSessionId(sid string) {
func (b *BeatService) GetSessions() []*SessionToken {
b.Lock()
defer b.Unlock()
b.sessMap[sid] = empty
sids := make([]*SessionToken, 0, len(b.sessMap))
for sid := range b.sessMap {
sids = append(sids, b.sessMap[sid])
}
return sids
}

func (b *BeatService) RemoveSessionId(sid string) {
Expand All @@ -144,3 +154,57 @@ func (b *BeatService) GetTerminalTaskChan() <-chan *model.TerminalTask {
func (b *BeatService) FinishTask(taskId string) error {
return b.apiClient.FinishTask(taskId)
}

func (b *BeatService) KeepCheckTokens() {
for {
time.Sleep(5 * time.Minute)
sessions := b.GetSessions()
tokens := make(map[string]model.TokenCheckStatus, len(sessions))
for _, s := range sessions {
ret, ok := tokens[s.TokenId]
if ok {
b.handleTokenCheck(s, &ret)
continue
}
ret, err := b.apiClient.CheckTokenStatus(s.TokenId)
if err != nil && ret.Code == "" {
logger.Errorf("Check token status failed: %s", err)
continue
}
tokens[s.TokenId] = ret
b.handleTokenCheck(s, &ret)
}
}
}

func (b *BeatService) sendTask(task *model.TerminalTask) {
select {
case b.taskChan <- task:
default:
logger.Errorf("Discard task %v", task)
}
}

func (b *BeatService) handleTokenCheck(session *SessionToken, tokenStatus *model.TokenCheckStatus) {
var action string
switch tokenStatus.Code {
case model.CodePermOk:
action = model.TaskPermValid
if !session.invalid {
return
}
session.invalid = false
default:
if session.invalid {
return
}
session.invalid = true
action = model.TaskPermExpired
}
task := model.TerminalTask{
Name: action,
Args: session.ID,
TokenStatus: *tokenStatus,
}
b.sendTask(&task)
}
1 change: 1 addition & 0 deletions cmd/impl/convert_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func ConvertToSession(sees *pb.Session) model.Session {
AssetID: sees.AssetId,
AccountID: sees.AccountId,
Type: model.NORMALType,
TokenId: sees.TokenId,
}
}

Expand Down
19 changes: 17 additions & 2 deletions cmd/impl/jms.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,12 @@ func (j *JMServer) CreateSession(ctx context.Context, req *pb.SessionCreateReque
return &pb.SessionCreateResponse{Status: &status}, nil
}
status.Ok = true
j.beat.StoreSessionId(apiResp.ID)
logger.Debugf("Creat session %s", apiResp.ID)
sessionToken := common.SessionToken{
Session: apiResp,
TokenId: req.Data.TokenId,
}
j.beat.StoreSessionId(&sessionToken)
logger.Debugf("Creat session %s", apiSess.ID)
return &pb.SessionCreateResponse{Status: &status,
Data: ConvertToProtobufSession(apiResp)}, nil
}
Expand Down Expand Up @@ -199,6 +203,17 @@ func (j *JMServer) sendStreamTask(ctx context.Context, stream pb.Service_Dispatc
case model.TaskUnlockSession:
pbTask.Action = pb.TaskAction_UnlockSession
pbTask.CreatedBy = task.Kwargs.CreatedByUser
case model.TaskPermExpired:
pbTask.Action = pb.TaskAction_TokenPermExpired
pbTask.TokenStatus = &pb.TokenStatus{
Code: "",
Detail: "",
IsExpired: false,
}

case model.TaskPermValid:
pbTask.Action = pb.TaskAction_TokenPermValid

default:
logger.Errorf("Unknown task name %s", task.Name)
continue
Expand Down
1 change: 1 addition & 0 deletions cmd/root.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ var rootCmd = &cobra.Command{
beat := common.NewBeatService(apiClient)
{
go beat.KeepHeartBeat()
go beat.KeepCheckTokens()
}
ctx := common.GetSignalCtx()
grpcImplSrv := impl.NewJMServer(apiClient, uploader, beat)
Expand Down
4 changes: 3 additions & 1 deletion pkg/common/random.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@ import (

const letters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"

var localRand = rand.New(rand.NewSource(time.Now().UnixNano()))

func RandomStr(length int) string {
rand.Seed(time.Now().UnixNano())
localRand.Seed(time.Now().UnixNano())
b := make([]byte, length)
for i := range b {
b[i] = letters[rand.Intn(len(letters))]
Expand Down
1 change: 1 addition & 0 deletions pkg/jms-sdk-go/model/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ type Session struct {
AssetID string `json:"asset_id"`
AccountID string `json:"account_id"`
Type LabelFiled `json:"type"`
TokenId string `json:"token_id"`
}

type ReplayVersion string
Expand Down
16 changes: 11 additions & 5 deletions pkg/jms-sdk-go/model/terminal.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,23 @@ type Terminal struct {
}

type TerminalTask struct {
ID string `json:"id"`
Name string `json:"name"`
Args string `json:"args"`
Kwargs TaskKwargs `json:"kwargs"`
IsFinished bool
ID string `json:"id"`
Name string `json:"name"`
Args string `json:"args"`
Kwargs TaskKwargs `json:"kwargs"`

TokenStatus TokenCheckStatus `json:"-"`
}

const (
TaskKillSession = "kill_session"
TaskLockSession = "lock_session"
TaskUnlockSession = "unlock_session"

// TaskPermExpired TaskPermValid 非 api 数据,仅用于内部处理

TaskPermExpired = "perm_expired"
TaskPermValid = "perm_valid"
)

type TaskKwargs struct {
Expand Down
14 changes: 14 additions & 0 deletions pkg/jms-sdk-go/model/token.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,3 +37,17 @@ type ConnectTokenInfo struct {
AccountName string `json:"account_name"`
Protocol string `json:"protocol"`
}

// token 授权和过期状态

type TokenCheckStatus struct {
Detail string `json:"detail"`
Code string `json:"code"`
Expired bool `json:"expired"`
}

const (
CodePermOk = "perm_ok"
CodePermAccountInvalid = "perm_account_invalid"
CodePermExpired = "perm_expired"
)
6 changes: 6 additions & 0 deletions pkg/jms-sdk-go/service/jms_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,9 @@ type TokenRenewalResponse struct {
Ok bool `json:"ok"`
Msg string `json:"msg"`
}

func (s *JMService) CheckTokenStatus(tokenId string) (res model.TokenCheckStatus, err error) {
reqURL := fmt.Sprintf(SuperConnectTokenCheckURL, tokenId)
_, err = s.authClient.Get(reqURL, &res)
return
}
1 change: 1 addition & 0 deletions pkg/jms-sdk-go/service/url.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ const (
SuperConnectTokenSecretURL = "/api/v1/authentication/super-connection-token/secret/"
SuperConnectTokenInfoURL = "/api/v1/authentication/super-connection-token/"
SuperTokenRenewalURL = "/api/v1/authentication/super-connection-token/renewal/"
SuperConnectTokenCheckURL = "/api/v1/authentication/super-connection-token/%s/check/"

UserPermsAssetsURL = "/api/v1/perms/users/%s/assets/"

Expand Down
Loading

0 comments on commit ad3e277

Please sign in to comment.