diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..5c7247b --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,7 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [] +} \ No newline at end of file diff --git a/ap_info.go b/ap_info.go index 4362d66..8e3661c 100644 --- a/ap_info.go +++ b/ap_info.go @@ -8,8 +8,15 @@ import ( ) func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo, error) { + resp, err := api.getApInfo(siteID, macAddress) + if err != nil { + return nil, err + } + return &resp.Result, nil +} + +func (api *Api) getApInfo(siteID model.SiteID, macAddress string) (*model.Response[model.ApInfo], error) { api.refreshMutex.RLock() - defer api.refreshMutex.RUnlock() req := ezhttp.Request( ezhttp.Template(api.tmpl), @@ -23,14 +30,18 @@ func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo resp, err := ezhttp.Do(req) if err != nil { + api.refreshMutex.RUnlock() return nil, err } defer resp.Body.Close() + api.refreshMutex.RUnlock() response, err := ezhttp.ParseJsonResponse[model.Response[model.ApInfo]](resp) if err != nil { return nil, err } - return &response.Result, nil + return handleResponseErrors(api, response, func() (*model.Response[model.ApInfo], error) { + return api.getApInfo(siteID, macAddress) + }) } diff --git a/api.go b/api.go index a56d7fc..69abcee 100644 --- a/api.go +++ b/api.go @@ -1,7 +1,6 @@ package omadaapi import ( - "context" "fmt" "net/http" "sync" @@ -166,25 +165,6 @@ func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) { return response, nil } -func (api *Api) MustAutoRefresh(ctx context.Context, refreshBeforeExpiration time.Duration) { - if err := api.AutoRefresh(ctx, refreshBeforeExpiration); err != nil { - panic(err) - } -} - -func (api *Api) AutoRefresh(ctx context.Context, refreshBeforeExpiration time.Duration) error { - for { - select { - case <-ctx.Done(): - return nil - case <-time.After(time.Until(api.expiration.Add(-refreshBeforeExpiration))): - if err := api.Refresh(); err != nil { - return err - } - } - } -} - func (api *Api) Refresh() error { api.refreshMutex.Lock() defer api.refreshMutex.Unlock() diff --git a/client.go b/client.go index 5350cdc..fb550d0 100644 --- a/client.go +++ b/client.go @@ -4,27 +4,26 @@ import ( "fmt" "strconv" + "git.tordarus.net/tordarus/channel" "git.tordarus.net/tordarus/ezhttp" "git.tordarus.net/tordarus/omada-api/model" ) -func (api *Api) GetClients(siteID model.SiteID) (<-chan *model.Client, <-chan error) { - out := make(chan *model.Client, 1000) - errChan := make(chan error) +func (api *Api) GetClients(siteID model.SiteID) <-chan channel.Result[model.Client] { + out := make(chan channel.Result[model.Client], 1000) go func() { defer close(out) - defer close(errChan) for page := 1; ; page++ { resp, err := api.getClients(page, siteID) if err != nil { - errChan <- err + out <- channel.ResultOf[model.Client](nil, err) return } for _, v := range resp.Result.Data { - out <- &v + out <- channel.ResultOfValue(v, nil) } if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows { @@ -33,12 +32,11 @@ func (api *Api) GetClients(siteID model.SiteID) (<-chan *model.Client, <-chan er } }() - return out, errChan + return out } func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[model.Client], error) { api.refreshMutex.RLock() - defer api.refreshMutex.RUnlock() req := ezhttp.Request( ezhttp.Template(api.tmpl), @@ -52,14 +50,18 @@ func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[ resp, err := ezhttp.Do(req) if err != nil { + api.refreshMutex.RUnlock() return nil, err } defer resp.Body.Close() + api.refreshMutex.RUnlock() response, err := ezhttp.ParseJsonResponse[model.PagedResponse[model.Client]](resp) if err != nil { return nil, err } - return response, nil + return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Client], error) { + return api.getClients(page, siteID) + }) } diff --git a/device.go b/device.go index ab0331a..03b0bfd 100644 --- a/device.go +++ b/device.go @@ -4,27 +4,26 @@ import ( "fmt" "strconv" + "git.tordarus.net/tordarus/channel" "git.tordarus.net/tordarus/ezhttp" "git.tordarus.net/tordarus/omada-api/model" ) -func (api *Api) GetDevices(siteID model.SiteID) (<-chan *model.Device, <-chan error) { - out := make(chan *model.Device, 1000) - errChan := make(chan error) +func (api *Api) GetDevices(siteID model.SiteID) <-chan channel.Result[model.Device] { + out := make(chan channel.Result[model.Device], 1000) go func() { defer close(out) - defer close(errChan) for page := 1; ; page++ { resp, err := api.getDevices(page, siteID) if err != nil { - errChan <- err + out <- channel.ResultOf[model.Device](nil, err) return } for _, v := range resp.Result.Data { - out <- &v + out <- channel.ResultOfValue(v, nil) } if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows { @@ -33,12 +32,11 @@ func (api *Api) GetDevices(siteID model.SiteID) (<-chan *model.Device, <-chan er } }() - return out, errChan + return out } func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[model.Device], error) { api.refreshMutex.RLock() - defer api.refreshMutex.RUnlock() req := ezhttp.Request( ezhttp.Template(api.tmpl), @@ -52,14 +50,18 @@ func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[ resp, err := ezhttp.Do(req) if err != nil { + api.refreshMutex.RUnlock() return nil, err } defer resp.Body.Close() + api.refreshMutex.RUnlock() response, err := ezhttp.ParseJsonResponse[model.PagedResponse[model.Device]](resp) if err != nil { return nil, err } - return response, nil + return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Device], error) { + return api.getDevices(page, siteID) + }) } diff --git a/go.mod b/go.mod index 9fdb365..fa136bc 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module git.tordarus.net/tordarus/omada-api go 1.23.0 -require git.tordarus.net/tordarus/ezhttp v0.0.5 +require ( + git.tordarus.net/tordarus/channel v0.1.19 + git.tordarus.net/tordarus/ezhttp v0.0.5 +) diff --git a/go.sum b/go.sum index 891a526..0026ab5 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,4 @@ +git.tordarus.net/tordarus/channel v0.1.19 h1:d9xnSwFyvBh4B1/82mt0A7Gpm2nIZJTc+9ceJMIOu5Q= +git.tordarus.net/tordarus/channel v0.1.19/go.mod h1:8/dWFTdGO7g4AeSZ7cF6GerkGbe9c4dBVMVDBxOd9m4= git.tordarus.net/tordarus/ezhttp v0.0.5 h1:pxfEdfDeOHT/ATXYy5OQHmeBIho121SBuFvU4ISQ7w0= git.tordarus.net/tordarus/ezhttp v0.0.5/go.mod h1:Zq9o0Hibny61GqSCwJHa0PfGjVoUFv/zt2PjiQHXvmY= diff --git a/site.go b/site.go index 6330248..5ffc25b 100644 --- a/site.go +++ b/site.go @@ -4,27 +4,26 @@ import ( "fmt" "strconv" + "git.tordarus.net/tordarus/channel" "git.tordarus.net/tordarus/ezhttp" "git.tordarus.net/tordarus/omada-api/model" ) -func (api *Api) GetSites() (<-chan *model.Site, <-chan error) { - out := make(chan *model.Site, 1000) - errChan := make(chan error) +func (api *Api) GetSites() <-chan channel.Result[model.Site] { + out := make(chan channel.Result[model.Site], 1000) go func() { defer close(out) - defer close(errChan) for page := 1; ; page++ { resp, err := api.getSites(page) if err != nil { - errChan <- err + out <- channel.ResultOf[model.Site](nil, err) return } for _, v := range resp.Result.Data { - out <- &v + out <- channel.ResultOfValue(v, nil) } if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows { @@ -33,12 +32,11 @@ func (api *Api) GetSites() (<-chan *model.Site, <-chan error) { } }() - return out, errChan + return out } func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) { api.refreshMutex.RLock() - defer api.refreshMutex.RUnlock() req := ezhttp.Request( ezhttp.Template(api.tmpl), @@ -52,14 +50,18 @@ func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) { resp, err := ezhttp.Do(req) if err != nil { + api.refreshMutex.RUnlock() return nil, err } defer resp.Body.Close() + api.refreshMutex.RUnlock() response, err := ezhttp.ParseJsonResponse[model.PagedResponse[model.Site]](resp) if err != nil { return nil, err } - return response, nil + return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Site], error) { + return api.getSites(page) + }) } diff --git a/utils.go b/utils.go index 1af61a6..2cc7a88 100644 --- a/utils.go +++ b/utils.go @@ -1,10 +1,49 @@ package omadaapi -func PanicOnError[T any](valueChan <-chan T, errChan <-chan error) <-chan T { - go func() { - for err := range errChan { - panic(err) +import ( + "fmt" + + "git.tordarus.net/tordarus/omada-api/model" +) + +const ErrCodeAccessTokenExpired = -44112 + +func handleResponseErrors[T any](api *Api, response *model.Response[T], retry func() (*model.Response[T], error)) (*model.Response[T], error) { + switch response.ErrorCode { + case 0: + return response, nil + case ErrCodeAccessTokenExpired: + if err := api.Refresh(); err != nil { + return nil, fmt.Errorf("could not refresh access token: %w", err) } - }() - return valueChan + + newResp, err := retry() + if err != nil { + return nil, err + } + + return handleResponseErrors(api, newResp, retry) + default: + return nil, fmt.Errorf("invalid error code %d with message: %s", response.ErrorCode, response.Message) + } +} + +func handlePagedResponseErrors[T any](api *Api, response *model.PagedResponse[T], retry func() (*model.PagedResponse[T], error)) (*model.PagedResponse[T], error) { + switch response.ErrorCode { + case 0: + return response, nil + case ErrCodeAccessTokenExpired: + if err := api.Refresh(); err != nil { + return nil, fmt.Errorf("could not refresh access token: %w", err) + } + + newResp, err := retry() + if err != nil { + return nil, err + } + + return handlePagedResponseErrors(api, newResp, retry) + default: + return nil, fmt.Errorf("invalid error code %d with message: %s", response.ErrorCode, response.Message) + } }