Compare commits
5 Commits
Author | SHA1 | Date | |
---|---|---|---|
c004398887 | |||
774913d8ff | |||
97f1f09373 | |||
12c700a5c5 | |||
db413ce145 |
7
.vscode/launch.json
vendored
Normal file
7
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,7 @@
|
|||||||
|
{
|
||||||
|
// Use IntelliSense to learn about possible attributes.
|
||||||
|
// Hover to view descriptions of existing attributes.
|
||||||
|
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||||
|
"version": "0.2.0",
|
||||||
|
"configurations": []
|
||||||
|
}
|
15
ap_info.go
15
ap_info.go
@ -8,9 +8,14 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo, error) {
|
func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo, error) {
|
||||||
api.refreshMutex.RLock()
|
resp, err := api.getApInfo(siteID, macAddress)
|
||||||
defer api.refreshMutex.RUnlock()
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &resp.Result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (api *Api) getApInfo(siteID model.SiteID, macAddress string) (*model.Response[model.ApInfo], error) {
|
||||||
req := ezhttp.Request(
|
req := ezhttp.Request(
|
||||||
ezhttp.Template(api.tmpl),
|
ezhttp.Template(api.tmpl),
|
||||||
ezhttp.Method("GET"),
|
ezhttp.Method("GET"),
|
||||||
@ -21,7 +26,7 @@ func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo
|
|||||||
macAddress)),
|
macAddress)),
|
||||||
)
|
)
|
||||||
|
|
||||||
resp, err := ezhttp.Do(req)
|
resp, err := api.doRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -32,5 +37,7 @@ func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return &response.Result, nil
|
return handleResponseErrors(api, response, func() (*model.Response[model.ApInfo], error) {
|
||||||
|
return api.getApInfo(siteID, macAddress)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
74
api.go
74
api.go
@ -1,7 +1,6 @@
|
|||||||
package omadaapi
|
package omadaapi
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"sync"
|
"sync"
|
||||||
@ -40,31 +39,45 @@ func NewApi(config ApiConfig) (*Api, error) {
|
|||||||
config: config,
|
config: config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := api.InitSession(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return api, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (api *Api) InitSession() error {
|
||||||
|
api.refreshMutex.Lock()
|
||||||
|
defer api.refreshMutex.Unlock()
|
||||||
|
return api.initSessionNoMutexLock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (api *Api) initSessionNoMutexLock() error {
|
||||||
loginResponse, err := api.Login()
|
loginResponse, err := api.Login()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("login request failed: %w", err)
|
return fmt.Errorf("login request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if loginResponse.ErrorCode != 0 {
|
if loginResponse.ErrorCode != 0 {
|
||||||
return nil, fmt.Errorf("login request failed: %s", loginResponse.Message)
|
return fmt.Errorf("login request failed: %s", loginResponse.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
authCodeResponse, err := api.AuthCode(loginResponse.Result.CsrfToken, loginResponse.Result.SessionID)
|
authCodeResponse, err := api.AuthCode(loginResponse.Result.CsrfToken, loginResponse.Result.SessionID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("auth code request failed: %w", err)
|
return fmt.Errorf("auth code request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if authCodeResponse.ErrorCode != 0 {
|
if authCodeResponse.ErrorCode != 0 {
|
||||||
return nil, fmt.Errorf("auth code request failed: %s", authCodeResponse.Message)
|
return fmt.Errorf("auth code request failed: %s", authCodeResponse.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
authTokenResponse, err := api.AuthToken(*authCodeResponse.Result)
|
authTokenResponse, err := api.AuthToken(*authCodeResponse.Result)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("auth token request failed: %w", err)
|
return fmt.Errorf("auth token request failed: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if authTokenResponse.ErrorCode != 0 {
|
if authTokenResponse.ErrorCode != 0 {
|
||||||
return nil, fmt.Errorf("auth token request failed: %s", authTokenResponse.Message)
|
return fmt.Errorf("auth token request failed: %s", authTokenResponse.Message)
|
||||||
}
|
}
|
||||||
|
|
||||||
api.expiration = time.Now().Add(time.Duration(authTokenResponse.Result.ExpiresIn) * time.Second)
|
api.expiration = time.Now().Add(time.Duration(authTokenResponse.Result.ExpiresIn) * time.Second)
|
||||||
@ -73,16 +86,14 @@ func NewApi(config ApiConfig) (*Api, error) {
|
|||||||
|
|
||||||
api.tmpl = ezhttp.Request(
|
api.tmpl = ezhttp.Request(
|
||||||
ezhttp.Template(api.tmpl),
|
ezhttp.Template(api.tmpl),
|
||||||
ezhttp.Headers("Authorization", "AccessToken="+api.accessToken),
|
ezhttp.RemoveAllHeaders(),
|
||||||
|
ezhttp.Auth(api.getAuthHeader),
|
||||||
)
|
)
|
||||||
|
|
||||||
return api, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *Api) Login() (*model.LoginResponse, error) {
|
func (api *Api) Login() (*model.LoginResponse, error) {
|
||||||
api.refreshMutex.RLock()
|
|
||||||
defer api.refreshMutex.RUnlock()
|
|
||||||
|
|
||||||
reqBody := model.LoginRequest{
|
reqBody := model.LoginRequest{
|
||||||
Username: api.config.Username,
|
Username: api.config.Username,
|
||||||
Password: api.config.Password,
|
Password: api.config.Password,
|
||||||
@ -111,9 +122,6 @@ func (api *Api) Login() (*model.LoginResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse, error) {
|
func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse, error) {
|
||||||
api.refreshMutex.RLock()
|
|
||||||
defer api.refreshMutex.RUnlock()
|
|
||||||
|
|
||||||
req := ezhttp.Request(
|
req := ezhttp.Request(
|
||||||
ezhttp.Template(api.tmpl),
|
ezhttp.Template(api.tmpl),
|
||||||
ezhttp.Method("POST"),
|
ezhttp.Method("POST"),
|
||||||
@ -137,9 +145,6 @@ func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) {
|
func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) {
|
||||||
api.refreshMutex.RLock()
|
|
||||||
defer api.refreshMutex.RUnlock()
|
|
||||||
|
|
||||||
req := ezhttp.Request(
|
req := ezhttp.Request(
|
||||||
ezhttp.Template(api.tmpl),
|
ezhttp.Template(api.tmpl),
|
||||||
ezhttp.Method("POST"),
|
ezhttp.Method("POST"),
|
||||||
@ -166,25 +171,6 @@ func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) {
|
|||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *Api) MustAutoRefresh(ctx context.Context, refreshBeforeExpiration time.Duration) {
|
|
||||||
if err := api.AutoRefresh(ctx, refreshBeforeExpiration); err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (api *Api) AutoRefresh(ctx context.Context, refreshBeforeExpiration time.Duration) error {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
case <-time.After(time.Until(api.expiration.Add(-refreshBeforeExpiration))):
|
|
||||||
if err := api.Refresh(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (api *Api) Refresh() error {
|
func (api *Api) Refresh() error {
|
||||||
api.refreshMutex.Lock()
|
api.refreshMutex.Lock()
|
||||||
defer api.refreshMutex.Unlock()
|
defer api.refreshMutex.Unlock()
|
||||||
@ -212,9 +198,23 @@ func (api *Api) Refresh() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if response.ErrorCode == ErrCodeRefreshTokenExpired {
|
||||||
|
return api.initSessionNoMutexLock()
|
||||||
|
}
|
||||||
|
|
||||||
api.expiration = time.Now().Add(time.Duration(response.Result.ExpiresIn) * time.Second)
|
api.expiration = time.Now().Add(time.Duration(response.Result.ExpiresIn) * time.Second)
|
||||||
api.accessToken = response.Result.AccessToken
|
api.accessToken = response.Result.AccessToken
|
||||||
api.refreshToken = response.Result.RefreshToken
|
api.refreshToken = response.Result.RefreshToken
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (api *Api) doRequest(r *http.Request) (*http.Response, error) {
|
||||||
|
api.refreshMutex.RLock()
|
||||||
|
defer api.refreshMutex.RUnlock()
|
||||||
|
return ezhttp.Do(r)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (api *Api) getAuthHeader() string {
|
||||||
|
return "AccessToken=" + api.accessToken
|
||||||
|
}
|
||||||
|
17
client.go
17
client.go
@ -4,12 +4,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"git.tordarus.net/tordarus/channel"
|
||||||
"git.tordarus.net/tordarus/ezhttp"
|
"git.tordarus.net/tordarus/ezhttp"
|
||||||
"git.tordarus.net/tordarus/omada-api/model"
|
"git.tordarus.net/tordarus/omada-api/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (api *Api) GetClients(siteID model.SiteID) <-chan *model.Client {
|
func (api *Api) GetClients(siteID model.SiteID) <-chan channel.Result[model.Client] {
|
||||||
out := make(chan *model.Client, 1000)
|
out := make(chan channel.Result[model.Client], 1000)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
@ -17,11 +18,12 @@ func (api *Api) GetClients(siteID model.SiteID) <-chan *model.Client {
|
|||||||
for page := 1; ; page++ {
|
for page := 1; ; page++ {
|
||||||
resp, err := api.getClients(page, siteID)
|
resp, err := api.getClients(page, siteID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
out <- channel.ResultOf[model.Client](nil, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range resp.Result.Data {
|
for _, v := range resp.Result.Data {
|
||||||
out <- &v
|
out <- channel.ResultOfValue(v, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
|
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
|
||||||
@ -34,9 +36,6 @@ func (api *Api) GetClients(siteID model.SiteID) <-chan *model.Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[model.Client], error) {
|
func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[model.Client], error) {
|
||||||
api.refreshMutex.RLock()
|
|
||||||
defer api.refreshMutex.RUnlock()
|
|
||||||
|
|
||||||
req := ezhttp.Request(
|
req := ezhttp.Request(
|
||||||
ezhttp.Template(api.tmpl),
|
ezhttp.Template(api.tmpl),
|
||||||
ezhttp.Method("GET"),
|
ezhttp.Method("GET"),
|
||||||
@ -47,7 +46,7 @@ func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
resp, err := ezhttp.Do(req)
|
resp, err := api.doRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -58,5 +57,7 @@ func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, nil
|
return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Client], error) {
|
||||||
|
return api.getClients(page, siteID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
17
device.go
17
device.go
@ -4,12 +4,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"git.tordarus.net/tordarus/channel"
|
||||||
"git.tordarus.net/tordarus/ezhttp"
|
"git.tordarus.net/tordarus/ezhttp"
|
||||||
"git.tordarus.net/tordarus/omada-api/model"
|
"git.tordarus.net/tordarus/omada-api/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (api *Api) GetDevices(siteID model.SiteID) <-chan *model.Device {
|
func (api *Api) GetDevices(siteID model.SiteID) <-chan channel.Result[model.Device] {
|
||||||
out := make(chan *model.Device, 1000)
|
out := make(chan channel.Result[model.Device], 1000)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
@ -17,11 +18,12 @@ func (api *Api) GetDevices(siteID model.SiteID) <-chan *model.Device {
|
|||||||
for page := 1; ; page++ {
|
for page := 1; ; page++ {
|
||||||
resp, err := api.getDevices(page, siteID)
|
resp, err := api.getDevices(page, siteID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
out <- channel.ResultOf[model.Device](nil, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range resp.Result.Data {
|
for _, v := range resp.Result.Data {
|
||||||
out <- &v
|
out <- channel.ResultOfValue(v, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
|
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
|
||||||
@ -34,9 +36,6 @@ func (api *Api) GetDevices(siteID model.SiteID) <-chan *model.Device {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[model.Device], error) {
|
func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[model.Device], error) {
|
||||||
api.refreshMutex.RLock()
|
|
||||||
defer api.refreshMutex.RUnlock()
|
|
||||||
|
|
||||||
req := ezhttp.Request(
|
req := ezhttp.Request(
|
||||||
ezhttp.Template(api.tmpl),
|
ezhttp.Template(api.tmpl),
|
||||||
ezhttp.Method("GET"),
|
ezhttp.Method("GET"),
|
||||||
@ -47,7 +46,7 @@ func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
resp, err := ezhttp.Do(req)
|
resp, err := api.doRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -58,5 +57,7 @@ func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, nil
|
return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Device], error) {
|
||||||
|
return api.getDevices(page, siteID)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
5
go.mod
5
go.mod
@ -2,4 +2,7 @@ module git.tordarus.net/tordarus/omada-api
|
|||||||
|
|
||||||
go 1.23.0
|
go 1.23.0
|
||||||
|
|
||||||
require git.tordarus.net/tordarus/ezhttp v0.0.5
|
require (
|
||||||
|
git.tordarus.net/tordarus/channel v0.1.19
|
||||||
|
git.tordarus.net/tordarus/ezhttp v0.0.9
|
||||||
|
)
|
||||||
|
6
go.sum
6
go.sum
@ -1,2 +1,4 @@
|
|||||||
git.tordarus.net/tordarus/ezhttp v0.0.5 h1:pxfEdfDeOHT/ATXYy5OQHmeBIho121SBuFvU4ISQ7w0=
|
git.tordarus.net/tordarus/channel v0.1.19 h1:d9xnSwFyvBh4B1/82mt0A7Gpm2nIZJTc+9ceJMIOu5Q=
|
||||||
git.tordarus.net/tordarus/ezhttp v0.0.5/go.mod h1:Zq9o0Hibny61GqSCwJHa0PfGjVoUFv/zt2PjiQHXvmY=
|
git.tordarus.net/tordarus/channel v0.1.19/go.mod h1:8/dWFTdGO7g4AeSZ7cF6GerkGbe9c4dBVMVDBxOd9m4=
|
||||||
|
git.tordarus.net/tordarus/ezhttp v0.0.9 h1:YwdQ4YcJwvpMw5CX5NcCEM23XQL+WCz5nWuc2dzX/84=
|
||||||
|
git.tordarus.net/tordarus/ezhttp v0.0.9/go.mod h1:Zq9o0Hibny61GqSCwJHa0PfGjVoUFv/zt2PjiQHXvmY=
|
||||||
|
17
site.go
17
site.go
@ -4,12 +4,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
|
||||||
|
"git.tordarus.net/tordarus/channel"
|
||||||
"git.tordarus.net/tordarus/ezhttp"
|
"git.tordarus.net/tordarus/ezhttp"
|
||||||
"git.tordarus.net/tordarus/omada-api/model"
|
"git.tordarus.net/tordarus/omada-api/model"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (api *Api) GetSites() <-chan *model.Site {
|
func (api *Api) GetSites() <-chan channel.Result[model.Site] {
|
||||||
out := make(chan *model.Site, 1000)
|
out := make(chan channel.Result[model.Site], 1000)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer close(out)
|
defer close(out)
|
||||||
@ -17,11 +18,12 @@ func (api *Api) GetSites() <-chan *model.Site {
|
|||||||
for page := 1; ; page++ {
|
for page := 1; ; page++ {
|
||||||
resp, err := api.getSites(page)
|
resp, err := api.getSites(page)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
out <- channel.ResultOf[model.Site](nil, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, v := range resp.Result.Data {
|
for _, v := range resp.Result.Data {
|
||||||
out <- &v
|
out <- channel.ResultOfValue(v, nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
|
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
|
||||||
@ -34,9 +36,6 @@ func (api *Api) GetSites() <-chan *model.Site {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) {
|
func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) {
|
||||||
api.refreshMutex.RLock()
|
|
||||||
defer api.refreshMutex.RUnlock()
|
|
||||||
|
|
||||||
req := ezhttp.Request(
|
req := ezhttp.Request(
|
||||||
ezhttp.Template(api.tmpl),
|
ezhttp.Template(api.tmpl),
|
||||||
ezhttp.Method("GET"),
|
ezhttp.Method("GET"),
|
||||||
@ -47,7 +46,7 @@ func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) {
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
resp, err := ezhttp.Do(req)
|
resp, err := api.doRequest(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@ -58,5 +57,7 @@ func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
return response, nil
|
return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Site], error) {
|
||||||
|
return api.getSites(page)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
72
utils.go
Normal file
72
utils.go
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package omadaapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.tordarus.net/tordarus/omada-api/model"
|
||||||
|
)
|
||||||
|
|
||||||
|
const ErrCodeAccessTokenExpired = -44112
|
||||||
|
const ErrCodeRefreshTokenExpired = -44114
|
||||||
|
|
||||||
|
func handleResponseErrors[T any](api *Api, response *model.Response[T], retry func() (*model.Response[T], error)) (*model.Response[T], error) {
|
||||||
|
switch response.ErrorCode {
|
||||||
|
case 0:
|
||||||
|
return response, nil
|
||||||
|
case ErrCodeAccessTokenExpired:
|
||||||
|
if err := api.Refresh(); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not refresh access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newResp, err := retry()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return handleResponseErrors(api, newResp, retry)
|
||||||
|
case ErrCodeRefreshTokenExpired:
|
||||||
|
if err := api.InitSession(); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not initialize new session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newResp, err := retry()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return handleResponseErrors(api, newResp, retry)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid error code %d with message: %s", response.ErrorCode, response.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handlePagedResponseErrors[T any](api *Api, response *model.PagedResponse[T], retry func() (*model.PagedResponse[T], error)) (*model.PagedResponse[T], error) {
|
||||||
|
switch response.ErrorCode {
|
||||||
|
case 0:
|
||||||
|
return response, nil
|
||||||
|
case ErrCodeAccessTokenExpired:
|
||||||
|
if err := api.Refresh(); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not refresh access token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newResp, err := retry()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return handlePagedResponseErrors(api, newResp, retry)
|
||||||
|
case ErrCodeRefreshTokenExpired:
|
||||||
|
if err := api.InitSession(); err != nil {
|
||||||
|
return nil, fmt.Errorf("could not initialize new session: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newResp, err := retry()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return handlePagedResponseErrors(api, newResp, retry)
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid error code %d with message: %s", response.ErrorCode, response.Message)
|
||||||
|
}
|
||||||
|
}
|
Loading…
x
Reference in New Issue
Block a user