From 921138f45066ad6b85c430b4c3cb2011ea458c71 Mon Sep 17 00:00:00 2001 From: Tordarus Date: Sun, 2 Feb 2025 10:54:39 +0100 Subject: [PATCH] Refresh and AutoRefresh implemented --- ap_info.go | 3 +++ api.go | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ client.go | 3 +++ device.go | 3 +++ site.go | 3 +++ 5 files changed, 77 insertions(+) diff --git a/ap_info.go b/ap_info.go index 082eb40..7a2597d 100644 --- a/ap_info.go +++ b/ap_info.go @@ -8,6 +8,9 @@ import ( ) func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo, error) { + api.refreshMutex.RLock() + defer api.refreshMutex.RUnlock() + req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("GET"), diff --git a/api.go b/api.go index 8d02944..4baf609 100644 --- a/api.go +++ b/api.go @@ -1,8 +1,11 @@ package omadaapi import ( + "context" "fmt" "net/http" + "sync" + "time" "git.tordarus.net/tordarus/ezhttp" "git.tordarus.net/tordarus/omada-api/model" @@ -15,6 +18,9 @@ type Api struct { accessToken string refreshToken string + + expiration time.Time + refreshMutex sync.RWMutex } type ApiConfig struct { @@ -61,6 +67,7 @@ func NewApi(config ApiConfig) (*Api, error) { return nil, fmt.Errorf("auth token request failed: %s", authTokenResponse.Message) } + api.expiration = time.Now().Add(time.Duration(authTokenResponse.Result.ExpiresIn-1) * time.Second) api.accessToken = authTokenResponse.Result.AccessToken api.refreshToken = authTokenResponse.Result.RefreshToken @@ -73,6 +80,9 @@ func NewApi(config ApiConfig) (*Api, error) { } func (api *Api) Login() (*model.LoginResponse, error) { + api.refreshMutex.RLock() + defer api.refreshMutex.RUnlock() + reqBody := model.LoginRequest{ Username: api.config.Username, Password: api.config.Password, @@ -101,6 +111,9 @@ func (api *Api) Login() (*model.LoginResponse, error) { } func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse, error) { + api.refreshMutex.RLock() + defer api.refreshMutex.RUnlock() + req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("POST"), @@ -124,6 +137,9 @@ func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse, } func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) { + api.refreshMutex.RLock() + defer api.refreshMutex.RUnlock() + req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("POST"), @@ -149,3 +165,52 @@ func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) { return response, nil } + +func (api *Api) AutoRefresh(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return nil + case <-time.After(time.Until(api.expiration)): + if err := api.Refresh(); err != nil { + return err + } + } + } +} + +func (api *Api) Refresh() error { + fmt.Println("refresh") + api.refreshMutex.Lock() + defer api.refreshMutex.Unlock() + + req := ezhttp.Request( + ezhttp.Template(api.tmpl), + ezhttp.Method("POST"), + ezhttp.AppendPath("/openapi/authorize/token"), + ezhttp.Query( + "client_id", api.config.ClientID, + "client_secret", api.config.ClientSecret, + "refresh_token", api.refreshToken, + "grant_type", "refresh_token", + ), + ) + + resp, err := ezhttp.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + response, err := ezhttp.ParseJsonResponse[model.AuthTokenResponse](resp.Body) + if err != nil { + return err + } + + api.expiration = time.Now().Add(time.Duration(response.Result.ExpiresIn-1) * time.Second) + api.accessToken = response.Result.AccessToken + api.refreshToken = response.Result.RefreshToken + + fmt.Println("refresh successful") + return nil +} diff --git a/client.go b/client.go index 3b8fdcc..d94d9b9 100644 --- a/client.go +++ b/client.go @@ -34,6 +34,9 @@ func (api *Api) GetClients(siteID model.SiteID) <-chan *model.Client { } func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[model.Client], error) { + api.refreshMutex.RLock() + defer api.refreshMutex.RUnlock() + req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("GET"), diff --git a/device.go b/device.go index 22bcf0e..e53ad75 100644 --- a/device.go +++ b/device.go @@ -34,6 +34,9 @@ func (api *Api) GetDevices(siteID model.SiteID) <-chan *model.Device { } func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[model.Device], error) { + api.refreshMutex.RLock() + defer api.refreshMutex.RUnlock() + req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("GET"), diff --git a/site.go b/site.go index d4dfe18..285ceb1 100644 --- a/site.go +++ b/site.go @@ -34,6 +34,9 @@ func (api *Api) GetSites() <-chan *model.Site { } func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) { + api.refreshMutex.RLock() + defer api.refreshMutex.RUnlock() + req := ezhttp.Request( ezhttp.Template(api.tmpl), ezhttp.Method("GET"),