From 774913d8ff82160b0b65892ae4732a79f76f85e4 Mon Sep 17 00:00:00 2001 From: Tordarus Date: Wed, 12 Feb 2025 13:25:01 +0100 Subject: [PATCH] fixed many bugs --- ap_info.go | 6 +----- api.go | 53 ++++++++++++++++++++++++++++++++++++----------------- client.go | 6 +----- device.go | 6 +----- go.mod | 2 +- go.sum | 4 ++-- site.go | 6 +----- utils.go | 23 +++++++++++++++++++++++ 8 files changed, 66 insertions(+), 40 deletions(-) diff --git a/ap_info.go b/ap_info.go index 8e3661c..1bb77be 100644 --- a/ap_info.go +++ b/ap_info.go @@ -16,8 +16,6 @@ func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo } func (api *Api) getApInfo(siteID model.SiteID, macAddress string) (*model.Response[model.ApInfo], error) { - api.refreshMutex.RLock() - req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("GET"), @@ -28,13 +26,11 @@ func (api *Api) getApInfo(siteID model.SiteID, macAddress string) (*model.Respon macAddress)), ) - resp, err := ezhttp.Do(req) + resp, err := api.doRequest(req) if err != nil { - api.refreshMutex.RUnlock() return nil, err } defer resp.Body.Close() - api.refreshMutex.RUnlock() response, err := ezhttp.ParseJsonResponse[model.Response[model.ApInfo]](resp) if err != nil { diff --git a/api.go b/api.go index 69abcee..df729ef 100644 --- a/api.go +++ b/api.go @@ -39,31 +39,45 @@ func NewApi(config ApiConfig) (*Api, error) { config: config, } + if err := api.InitSession(); err != nil { + return nil, err + } + + return api, nil +} + +func (api *Api) InitSession() error { + api.refreshMutex.Lock() + defer api.refreshMutex.Unlock() + return api.initSessionNoMutexLock() +} + +func (api *Api) initSessionNoMutexLock() error { loginResponse, err := api.Login() if err != nil { - return nil, fmt.Errorf("login request failed: %w", err) + return fmt.Errorf("login request failed: %w", err) } if loginResponse.ErrorCode != 0 { - return nil, fmt.Errorf("login request failed: %s", loginResponse.Message) + return fmt.Errorf("login request failed: %s", loginResponse.Message) } authCodeResponse, err := api.AuthCode(loginResponse.Result.CsrfToken, loginResponse.Result.SessionID) if err != nil { - return nil, fmt.Errorf("auth code request failed: %w", err) + return fmt.Errorf("auth code request failed: %w", err) } if authCodeResponse.ErrorCode != 0 { - return nil, fmt.Errorf("auth code request failed: %s", authCodeResponse.Message) + return fmt.Errorf("auth code request failed: %s", authCodeResponse.Message) } authTokenResponse, err := api.AuthToken(*authCodeResponse.Result) if err != nil { - return nil, fmt.Errorf("auth token request failed: %w", err) + return fmt.Errorf("auth token request failed: %w", err) } if authTokenResponse.ErrorCode != 0 { - return nil, fmt.Errorf("auth token request failed: %s", authTokenResponse.Message) + return fmt.Errorf("auth token request failed: %s", authTokenResponse.Message) } api.expiration = time.Now().Add(time.Duration(authTokenResponse.Result.ExpiresIn) * time.Second) @@ -72,16 +86,13 @@ func NewApi(config ApiConfig) (*Api, error) { api.tmpl = ezhttp.Request( ezhttp.Template(api.tmpl), - ezhttp.Headers("Authorization", "AccessToken="+api.accessToken), + ezhttp.Auth(api.getAuthHeader), ) - return api, nil + return nil } func (api *Api) Login() (*model.LoginResponse, error) { - api.refreshMutex.RLock() - defer api.refreshMutex.RUnlock() - reqBody := model.LoginRequest{ Username: api.config.Username, Password: api.config.Password, @@ -110,9 +121,6 @@ func (api *Api) Login() (*model.LoginResponse, error) { } func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse, error) { - api.refreshMutex.RLock() - defer api.refreshMutex.RUnlock() - req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("POST"), @@ -136,9 +144,6 @@ func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse, } func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) { - api.refreshMutex.RLock() - defer api.refreshMutex.RUnlock() - req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("POST"), @@ -192,9 +197,23 @@ func (api *Api) Refresh() error { return err } + if response.ErrorCode == ErrCodeRefreshTokenExpired { + return api.initSessionNoMutexLock() + } + api.expiration = time.Now().Add(time.Duration(response.Result.ExpiresIn) * time.Second) api.accessToken = response.Result.AccessToken api.refreshToken = response.Result.RefreshToken return nil } + +func (api *Api) doRequest(r *http.Request) (*http.Response, error) { + api.refreshMutex.RLock() + defer api.refreshMutex.RUnlock() + return ezhttp.Do(r) +} + +func (api *Api) getAuthHeader() string { + return "AccessToken=" + api.accessToken +} diff --git a/client.go b/client.go index fb550d0..69cc770 100644 --- a/client.go +++ b/client.go @@ -36,8 +36,6 @@ func (api *Api) GetClients(siteID model.SiteID) <-chan channel.Result[model.Clie } func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[model.Client], error) { - api.refreshMutex.RLock() - req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("GET"), @@ -48,13 +46,11 @@ func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[ ), ) - resp, err := ezhttp.Do(req) + resp, err := api.doRequest(req) if err != nil { - api.refreshMutex.RUnlock() return nil, err } defer resp.Body.Close() - api.refreshMutex.RUnlock() response, err := ezhttp.ParseJsonResponse[model.PagedResponse[model.Client]](resp) if err != nil { diff --git a/device.go b/device.go index 03b0bfd..6adf56d 100644 --- a/device.go +++ b/device.go @@ -36,8 +36,6 @@ func (api *Api) GetDevices(siteID model.SiteID) <-chan channel.Result[model.Devi } func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[model.Device], error) { - api.refreshMutex.RLock() - req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("GET"), @@ -48,13 +46,11 @@ func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[ ), ) - resp, err := ezhttp.Do(req) + resp, err := api.doRequest(req) if err != nil { - api.refreshMutex.RUnlock() return nil, err } defer resp.Body.Close() - api.refreshMutex.RUnlock() response, err := ezhttp.ParseJsonResponse[model.PagedResponse[model.Device]](resp) if err != nil { diff --git a/go.mod b/go.mod index fa136bc..2dc1df8 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,5 @@ go 1.23.0 require ( git.tordarus.net/tordarus/channel v0.1.19 - git.tordarus.net/tordarus/ezhttp v0.0.5 + git.tordarus.net/tordarus/ezhttp v0.0.8 ) diff --git a/go.sum b/go.sum index 0026ab5..56d6463 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,4 @@ git.tordarus.net/tordarus/channel v0.1.19 h1:d9xnSwFyvBh4B1/82mt0A7Gpm2nIZJTc+9ceJMIOu5Q= git.tordarus.net/tordarus/channel v0.1.19/go.mod h1:8/dWFTdGO7g4AeSZ7cF6GerkGbe9c4dBVMVDBxOd9m4= -git.tordarus.net/tordarus/ezhttp v0.0.5 h1:pxfEdfDeOHT/ATXYy5OQHmeBIho121SBuFvU4ISQ7w0= -git.tordarus.net/tordarus/ezhttp v0.0.5/go.mod h1:Zq9o0Hibny61GqSCwJHa0PfGjVoUFv/zt2PjiQHXvmY= +git.tordarus.net/tordarus/ezhttp v0.0.8 h1:S+LxXTnLaoSlkeA+vYB52clHVtszXSr/TFm8nl5KHEg= +git.tordarus.net/tordarus/ezhttp v0.0.8/go.mod h1:Zq9o0Hibny61GqSCwJHa0PfGjVoUFv/zt2PjiQHXvmY= diff --git a/site.go b/site.go index 5ffc25b..1512755 100644 --- a/site.go +++ b/site.go @@ -36,8 +36,6 @@ func (api *Api) GetSites() <-chan channel.Result[model.Site] { } func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) { - api.refreshMutex.RLock() - req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("GET"), @@ -48,13 +46,11 @@ func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) { ), ) - resp, err := ezhttp.Do(req) + resp, err := api.doRequest(req) if err != nil { - api.refreshMutex.RUnlock() return nil, err } defer resp.Body.Close() - api.refreshMutex.RUnlock() response, err := ezhttp.ParseJsonResponse[model.PagedResponse[model.Site]](resp) if err != nil { diff --git a/utils.go b/utils.go index 2cc7a88..2565bda 100644 --- a/utils.go +++ b/utils.go @@ -7,6 +7,7 @@ import ( ) const ErrCodeAccessTokenExpired = -44112 +const ErrCodeRefreshTokenExpired = -44114 func handleResponseErrors[T any](api *Api, response *model.Response[T], retry func() (*model.Response[T], error)) (*model.Response[T], error) { switch response.ErrorCode { @@ -22,6 +23,17 @@ func handleResponseErrors[T any](api *Api, response *model.Response[T], retry fu return nil, err } + return handleResponseErrors(api, newResp, retry) + case ErrCodeRefreshTokenExpired: + if err := api.InitSession(); err != nil { + return nil, fmt.Errorf("could not initialize new session: %w", err) + } + + newResp, err := retry() + if err != nil { + return nil, err + } + return handleResponseErrors(api, newResp, retry) default: return nil, fmt.Errorf("invalid error code %d with message: %s", response.ErrorCode, response.Message) @@ -42,6 +54,17 @@ func handlePagedResponseErrors[T any](api *Api, response *model.PagedResponse[T] return nil, err } + return handlePagedResponseErrors(api, newResp, retry) + case ErrCodeRefreshTokenExpired: + if err := api.InitSession(); err != nil { + return nil, fmt.Errorf("could not initialize new session: %w", err) + } + + newResp, err := retry() + if err != nil { + return nil, err + } + return handlePagedResponseErrors(api, newResp, retry) default: return nil, fmt.Errorf("invalid error code %d with message: %s", response.ErrorCode, response.Message)