Compare commits

...

3 Commits
v0.0.7 ... main

9 changed files with 161 additions and 86 deletions

7
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,7 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": []
}

View File

@ -8,9 +8,14 @@ import (
) )
func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo, error) { func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo, error) {
api.refreshMutex.RLock() resp, err := api.getApInfo(siteID, macAddress)
defer api.refreshMutex.RUnlock() if err != nil {
return nil, err
}
return &resp.Result, nil
}
func (api *Api) getApInfo(siteID model.SiteID, macAddress string) (*model.Response[model.ApInfo], error) {
req := ezhttp.Request( req := ezhttp.Request(
ezhttp.Template(api.tmpl), ezhttp.Template(api.tmpl),
ezhttp.Method("GET"), ezhttp.Method("GET"),
@ -21,7 +26,7 @@ func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo
macAddress)), macAddress)),
) )
resp, err := ezhttp.Do(req) resp, err := api.doRequest(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -32,5 +37,7 @@ func (api *Api) GetApInfo(siteID model.SiteID, macAddress string) (*model.ApInfo
return nil, err return nil, err
} }
return &response.Result, nil return handleResponseErrors(api, response, func() (*model.Response[model.ApInfo], error) {
return api.getApInfo(siteID, macAddress)
})
} }

74
api.go
View File

@ -1,7 +1,6 @@
package omadaapi package omadaapi
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"sync" "sync"
@ -40,31 +39,45 @@ func NewApi(config ApiConfig) (*Api, error) {
config: config, config: config,
} }
if err := api.InitSession(); err != nil {
return nil, err
}
return api, nil
}
func (api *Api) InitSession() error {
api.refreshMutex.Lock()
defer api.refreshMutex.Unlock()
return api.initSessionNoMutexLock()
}
func (api *Api) initSessionNoMutexLock() error {
loginResponse, err := api.Login() loginResponse, err := api.Login()
if err != nil { if err != nil {
return nil, fmt.Errorf("login request failed: %w", err) return fmt.Errorf("login request failed: %w", err)
} }
if loginResponse.ErrorCode != 0 { if loginResponse.ErrorCode != 0 {
return nil, fmt.Errorf("login request failed: %s", loginResponse.Message) return fmt.Errorf("login request failed: %s", loginResponse.Message)
} }
authCodeResponse, err := api.AuthCode(loginResponse.Result.CsrfToken, loginResponse.Result.SessionID) authCodeResponse, err := api.AuthCode(loginResponse.Result.CsrfToken, loginResponse.Result.SessionID)
if err != nil { if err != nil {
return nil, fmt.Errorf("auth code request failed: %w", err) return fmt.Errorf("auth code request failed: %w", err)
} }
if authCodeResponse.ErrorCode != 0 { if authCodeResponse.ErrorCode != 0 {
return nil, fmt.Errorf("auth code request failed: %s", authCodeResponse.Message) return fmt.Errorf("auth code request failed: %s", authCodeResponse.Message)
} }
authTokenResponse, err := api.AuthToken(*authCodeResponse.Result) authTokenResponse, err := api.AuthToken(*authCodeResponse.Result)
if err != nil { if err != nil {
return nil, fmt.Errorf("auth token request failed: %w", err) return fmt.Errorf("auth token request failed: %w", err)
} }
if authTokenResponse.ErrorCode != 0 { if authTokenResponse.ErrorCode != 0 {
return nil, fmt.Errorf("auth token request failed: %s", authTokenResponse.Message) return fmt.Errorf("auth token request failed: %s", authTokenResponse.Message)
} }
api.expiration = time.Now().Add(time.Duration(authTokenResponse.Result.ExpiresIn) * time.Second) api.expiration = time.Now().Add(time.Duration(authTokenResponse.Result.ExpiresIn) * time.Second)
@ -73,16 +86,14 @@ func NewApi(config ApiConfig) (*Api, error) {
api.tmpl = ezhttp.Request( api.tmpl = ezhttp.Request(
ezhttp.Template(api.tmpl), ezhttp.Template(api.tmpl),
ezhttp.Headers("Authorization", "AccessToken="+api.accessToken), ezhttp.RemoveAllHeaders(),
ezhttp.Auth(api.getAuthHeader),
) )
return api, nil return nil
} }
func (api *Api) Login() (*model.LoginResponse, error) { func (api *Api) Login() (*model.LoginResponse, error) {
api.refreshMutex.RLock()
defer api.refreshMutex.RUnlock()
reqBody := model.LoginRequest{ reqBody := model.LoginRequest{
Username: api.config.Username, Username: api.config.Username,
Password: api.config.Password, Password: api.config.Password,
@ -111,9 +122,6 @@ func (api *Api) Login() (*model.LoginResponse, error) {
} }
func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse, error) { func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse, error) {
api.refreshMutex.RLock()
defer api.refreshMutex.RUnlock()
req := ezhttp.Request( req := ezhttp.Request(
ezhttp.Template(api.tmpl), ezhttp.Template(api.tmpl),
ezhttp.Method("POST"), ezhttp.Method("POST"),
@ -137,9 +145,6 @@ func (api *Api) AuthCode(csrfToken, sessionID string) (*model.AuthCodeResponse,
} }
func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) { func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) {
api.refreshMutex.RLock()
defer api.refreshMutex.RUnlock()
req := ezhttp.Request( req := ezhttp.Request(
ezhttp.Template(api.tmpl), ezhttp.Template(api.tmpl),
ezhttp.Method("POST"), ezhttp.Method("POST"),
@ -166,25 +171,6 @@ func (api *Api) AuthToken(authCode string) (*model.AuthTokenResponse, error) {
return response, nil return response, nil
} }
func (api *Api) MustAutoRefresh(ctx context.Context, refreshBeforeExpiration time.Duration) {
if err := api.AutoRefresh(ctx, refreshBeforeExpiration); err != nil {
panic(err)
}
}
func (api *Api) AutoRefresh(ctx context.Context, refreshBeforeExpiration time.Duration) error {
for {
select {
case <-ctx.Done():
return nil
case <-time.After(time.Until(api.expiration.Add(-refreshBeforeExpiration))):
if err := api.Refresh(); err != nil {
return err
}
}
}
}
func (api *Api) Refresh() error { func (api *Api) Refresh() error {
api.refreshMutex.Lock() api.refreshMutex.Lock()
defer api.refreshMutex.Unlock() defer api.refreshMutex.Unlock()
@ -212,9 +198,23 @@ func (api *Api) Refresh() error {
return err return err
} }
if response.ErrorCode == ErrCodeRefreshTokenExpired {
return api.initSessionNoMutexLock()
}
api.expiration = time.Now().Add(time.Duration(response.Result.ExpiresIn) * time.Second) api.expiration = time.Now().Add(time.Duration(response.Result.ExpiresIn) * time.Second)
api.accessToken = response.Result.AccessToken api.accessToken = response.Result.AccessToken
api.refreshToken = response.Result.RefreshToken api.refreshToken = response.Result.RefreshToken
return nil return nil
} }
func (api *Api) doRequest(r *http.Request) (*http.Response, error) {
api.refreshMutex.RLock()
defer api.refreshMutex.RUnlock()
return ezhttp.Do(r)
}
func (api *Api) getAuthHeader() string {
return "AccessToken=" + api.accessToken
}

View File

@ -4,27 +4,26 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"git.tordarus.net/tordarus/channel"
"git.tordarus.net/tordarus/ezhttp" "git.tordarus.net/tordarus/ezhttp"
"git.tordarus.net/tordarus/omada-api/model" "git.tordarus.net/tordarus/omada-api/model"
) )
func (api *Api) GetClients(siteID model.SiteID) (<-chan *model.Client, <-chan error) { func (api *Api) GetClients(siteID model.SiteID) <-chan channel.Result[model.Client] {
out := make(chan *model.Client, 1000) out := make(chan channel.Result[model.Client], 1000)
errChan := make(chan error)
go func() { go func() {
defer close(out) defer close(out)
defer close(errChan)
for page := 1; ; page++ { for page := 1; ; page++ {
resp, err := api.getClients(page, siteID) resp, err := api.getClients(page, siteID)
if err != nil { if err != nil {
errChan <- err out <- channel.ResultOf[model.Client](nil, err)
return return
} }
for _, v := range resp.Result.Data { for _, v := range resp.Result.Data {
out <- &v out <- channel.ResultOfValue(v, nil)
} }
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows { if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
@ -33,13 +32,10 @@ func (api *Api) GetClients(siteID model.SiteID) (<-chan *model.Client, <-chan er
} }
}() }()
return out, errChan return out
} }
func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[model.Client], error) { func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[model.Client], error) {
api.refreshMutex.RLock()
defer api.refreshMutex.RUnlock()
req := ezhttp.Request( req := ezhttp.Request(
ezhttp.Template(api.tmpl), ezhttp.Template(api.tmpl),
ezhttp.Method("GET"), ezhttp.Method("GET"),
@ -50,7 +46,7 @@ func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[
), ),
) )
resp, err := ezhttp.Do(req) resp, err := api.doRequest(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,5 +57,7 @@ func (api *Api) getClients(page int, siteID model.SiteID) (*model.PagedResponse[
return nil, err return nil, err
} }
return response, nil return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Client], error) {
return api.getClients(page, siteID)
})
} }

View File

@ -4,27 +4,26 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"git.tordarus.net/tordarus/channel"
"git.tordarus.net/tordarus/ezhttp" "git.tordarus.net/tordarus/ezhttp"
"git.tordarus.net/tordarus/omada-api/model" "git.tordarus.net/tordarus/omada-api/model"
) )
func (api *Api) GetDevices(siteID model.SiteID) (<-chan *model.Device, <-chan error) { func (api *Api) GetDevices(siteID model.SiteID) <-chan channel.Result[model.Device] {
out := make(chan *model.Device, 1000) out := make(chan channel.Result[model.Device], 1000)
errChan := make(chan error)
go func() { go func() {
defer close(out) defer close(out)
defer close(errChan)
for page := 1; ; page++ { for page := 1; ; page++ {
resp, err := api.getDevices(page, siteID) resp, err := api.getDevices(page, siteID)
if err != nil { if err != nil {
errChan <- err out <- channel.ResultOf[model.Device](nil, err)
return return
} }
for _, v := range resp.Result.Data { for _, v := range resp.Result.Data {
out <- &v out <- channel.ResultOfValue(v, nil)
} }
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows { if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
@ -33,13 +32,10 @@ func (api *Api) GetDevices(siteID model.SiteID) (<-chan *model.Device, <-chan er
} }
}() }()
return out, errChan return out
} }
func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[model.Device], error) { func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[model.Device], error) {
api.refreshMutex.RLock()
defer api.refreshMutex.RUnlock()
req := ezhttp.Request( req := ezhttp.Request(
ezhttp.Template(api.tmpl), ezhttp.Template(api.tmpl),
ezhttp.Method("GET"), ezhttp.Method("GET"),
@ -50,7 +46,7 @@ func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[
), ),
) )
resp, err := ezhttp.Do(req) resp, err := api.doRequest(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,5 +57,7 @@ func (api *Api) getDevices(page int, siteID model.SiteID) (*model.PagedResponse[
return nil, err return nil, err
} }
return response, nil return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Device], error) {
return api.getDevices(page, siteID)
})
} }

5
go.mod
View File

@ -2,4 +2,7 @@ module git.tordarus.net/tordarus/omada-api
go 1.23.0 go 1.23.0
require git.tordarus.net/tordarus/ezhttp v0.0.5 require (
git.tordarus.net/tordarus/channel v0.1.19
git.tordarus.net/tordarus/ezhttp v0.0.9
)

6
go.sum
View File

@ -1,2 +1,4 @@
git.tordarus.net/tordarus/ezhttp v0.0.5 h1:pxfEdfDeOHT/ATXYy5OQHmeBIho121SBuFvU4ISQ7w0= git.tordarus.net/tordarus/channel v0.1.19 h1:d9xnSwFyvBh4B1/82mt0A7Gpm2nIZJTc+9ceJMIOu5Q=
git.tordarus.net/tordarus/ezhttp v0.0.5/go.mod h1:Zq9o0Hibny61GqSCwJHa0PfGjVoUFv/zt2PjiQHXvmY= git.tordarus.net/tordarus/channel v0.1.19/go.mod h1:8/dWFTdGO7g4AeSZ7cF6GerkGbe9c4dBVMVDBxOd9m4=
git.tordarus.net/tordarus/ezhttp v0.0.9 h1:YwdQ4YcJwvpMw5CX5NcCEM23XQL+WCz5nWuc2dzX/84=
git.tordarus.net/tordarus/ezhttp v0.0.9/go.mod h1:Zq9o0Hibny61GqSCwJHa0PfGjVoUFv/zt2PjiQHXvmY=

22
site.go
View File

@ -4,27 +4,26 @@ import (
"fmt" "fmt"
"strconv" "strconv"
"git.tordarus.net/tordarus/channel"
"git.tordarus.net/tordarus/ezhttp" "git.tordarus.net/tordarus/ezhttp"
"git.tordarus.net/tordarus/omada-api/model" "git.tordarus.net/tordarus/omada-api/model"
) )
func (api *Api) GetSites() (<-chan *model.Site, <-chan error) { func (api *Api) GetSites() <-chan channel.Result[model.Site] {
out := make(chan *model.Site, 1000) out := make(chan channel.Result[model.Site], 1000)
errChan := make(chan error)
go func() { go func() {
defer close(out) defer close(out)
defer close(errChan)
for page := 1; ; page++ { for page := 1; ; page++ {
resp, err := api.getSites(page) resp, err := api.getSites(page)
if err != nil { if err != nil {
errChan <- err out <- channel.ResultOf[model.Site](nil, err)
return return
} }
for _, v := range resp.Result.Data { for _, v := range resp.Result.Data {
out <- &v out <- channel.ResultOfValue(v, nil)
} }
if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows { if resp.Result.CurrentPage*resp.Result.CurrentSize >= resp.Result.TotalRows {
@ -33,13 +32,10 @@ func (api *Api) GetSites() (<-chan *model.Site, <-chan error) {
} }
}() }()
return out, errChan return out
} }
func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) { func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) {
api.refreshMutex.RLock()
defer api.refreshMutex.RUnlock()
req := ezhttp.Request( req := ezhttp.Request(
ezhttp.Template(api.tmpl), ezhttp.Template(api.tmpl),
ezhttp.Method("GET"), ezhttp.Method("GET"),
@ -50,7 +46,7 @@ func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) {
), ),
) )
resp, err := ezhttp.Do(req) resp, err := api.doRequest(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -61,5 +57,7 @@ func (api *Api) getSites(page int) (*model.PagedResponse[model.Site], error) {
return nil, err return nil, err
} }
return response, nil return handlePagedResponseErrors(api, response, func() (*model.PagedResponse[model.Site], error) {
return api.getSites(page)
})
} }

View File

@ -1,10 +1,72 @@
package omadaapi package omadaapi
func PanicOnError[T any](valueChan <-chan T, errChan <-chan error) <-chan T { import (
go func() { "fmt"
for err := range errChan {
panic(err) "git.tordarus.net/tordarus/omada-api/model"
)
const ErrCodeAccessTokenExpired = -44112
const ErrCodeRefreshTokenExpired = -44114
func handleResponseErrors[T any](api *Api, response *model.Response[T], retry func() (*model.Response[T], error)) (*model.Response[T], error) {
switch response.ErrorCode {
case 0:
return response, nil
case ErrCodeAccessTokenExpired:
if err := api.Refresh(); err != nil {
return nil, fmt.Errorf("could not refresh access token: %w", err)
} }
}()
return valueChan newResp, err := retry()
if err != nil {
return nil, err
}
return handleResponseErrors(api, newResp, retry)
case ErrCodeRefreshTokenExpired:
if err := api.InitSession(); err != nil {
return nil, fmt.Errorf("could not initialize new session: %w", err)
}
newResp, err := retry()
if err != nil {
return nil, err
}
return handleResponseErrors(api, newResp, retry)
default:
return nil, fmt.Errorf("invalid error code %d with message: %s", response.ErrorCode, response.Message)
}
}
func handlePagedResponseErrors[T any](api *Api, response *model.PagedResponse[T], retry func() (*model.PagedResponse[T], error)) (*model.PagedResponse[T], error) {
switch response.ErrorCode {
case 0:
return response, nil
case ErrCodeAccessTokenExpired:
if err := api.Refresh(); err != nil {
return nil, fmt.Errorf("could not refresh access token: %w", err)
}
newResp, err := retry()
if err != nil {
return nil, err
}
return handlePagedResponseErrors(api, newResp, retry)
case ErrCodeRefreshTokenExpired:
if err := api.InitSession(); err != nil {
return nil, fmt.Errorf("could not initialize new session: %w", err)
}
newResp, err := retry()
if err != nil {
return nil, err
}
return handlePagedResponseErrors(api, newResp, retry)
default:
return nil, fmt.Errorf("invalid error code %d with message: %s", response.ErrorCode, response.Message)
}
} }