package am import ( "fmt" "log" "net/http" "path/filepath" "strconv" ) func formGetInt(r *http.Request, name string) int64 { s := r.Form.Get(name) i, _ := strconv.ParseInt(s, 10, 64) return i } func execTmpl(w http.ResponseWriter, name string, data interface{}) { if err := tmpl.ExecuteTemplate(w, name, data); err != nil { log.Printf("Failed to execute template %s: %v", name, err) } } func respondNotAuthorized(w http.ResponseWriter) { w.Header().Set("WWW-Authenticate", `Basic realm="am"`) w.WriteHeader(401) w.Write([]byte("Unauthorised.\n")) } func respondInvalidCSRF(w http.ResponseWriter) { w.WriteHeader(403) w.Write([]byte("Forbidden.\n")) } func respondRedirect(w http.ResponseWriter, r *http.Request, url string, args ...interface{}) { http.Redirect(w, r, fmt.Sprintf(url, args...), http.StatusSeeOther) } func getReqUser(w http.ResponseWriter, r *http.Request) (User, bool) { username, pass, ok := r.BasicAuth() if !ok { return User{}, false } user, err := db.UserGetWithPwd(username, pass) if err != nil { return user, false } return user, true } func getCSRF(w http.ResponseWriter, r *http.Request) string { cookie, err := r.Cookie("am_csrf") if err != nil || len(cookie.Value) != 32 { token := newCSRF() http.SetCookie(w, &http.Cookie{ Name: "am_csrf", Value: token, Path: "/", }) return token } return cookie.Value } func checkCSRF(w http.ResponseWriter, r *http.Request) bool { if r.Method != http.MethodPost { return true } cookieVal := getCSRF(w, r) r.ParseForm() formVal := r.Form.Get("CSRF") if formVal == "" { return false } return formVal == cookieVal } func handle_user(path string, h func(w http.ResponseWriter, r *http.Request)) { http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { if _, ok := getReqUser(w, r); !ok { respondNotAuthorized(w) return } if !checkCSRF(w, r) { respondInvalidCSRF(w) return } h(w, r) }) } func handle_admin(path string, h func(w http.ResponseWriter, r *http.Request)) { http.HandleFunc(path, func(w http.ResponseWriter, r *http.Request) { if u, ok := getReqUser(w, r); !ok || !u.Admin { respondNotAuthorized(w) return } if !checkCSRF(w, r) { respondInvalidCSRF(w) return } h(w, r) }) } // ---------------------------------------------------------------------------- func handleRoot(w http.ResponseWriter, r *http.Request) { respondRedirect(w, r, "/log/list") } // ---------------------------------------------------------------------------- func handleUserInsert(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { execTmpl(w, "UserInsert", struct{ CSRF string }{getCSRF(w, r)}) return } user := User{ Username: r.Form.Get("Username"), Admin: r.Form.Get("Admin") != "", } if err := db.UserInsert(user, r.Form.Get("Password")); err != nil { execTmpl(w, "Error", err) return } respondRedirect(w, r, "/user/view/%s", user.Username) } func handleUserList(w http.ResponseWriter, r *http.Request) { l, err := db.UserList() if err != nil { execTmpl(w, "Error", err) return } execTmpl(w, "UserList", l) } func handleUserView(w http.ResponseWriter, r *http.Request) { name := filepath.Base(r.URL.Path) u, err := db.UserGet(name) if err != nil { execTmpl(w, "Error", err) return } execTmpl(w, "UserView", u) } func handleUserUpdate(w http.ResponseWriter, r *http.Request) { name := filepath.Base(r.URL.Path) u, err := db.UserGet(name) if err != nil { execTmpl(w, "Error", err) return } if r.Method == http.MethodGet { execTmpl(w, "UserUpdate", struct { User User CSRF string }{u, getCSRF(w, r)}) return } admin := r.Form.Get("Admin") != "" if u.Admin != admin { if err := db.UserUpdateAdmin(u.Username, admin); err != nil { execTmpl(w, "Error", err) return } } pwd := r.Form.Get("NewPassword") if pwd != "" { if err := db.UserUpdatePwd(u.Username, pwd); err != nil { execTmpl(w, "Error", err) return } } respondRedirect(w, r, "/user/view/%s", u.Username) } func handleUserDelete(w http.ResponseWriter, r *http.Request) { name := filepath.Base(r.URL.Path) u, err := db.UserGet(name) if err != nil { execTmpl(w, "Error", err) return } if r.Method == http.MethodGet { execTmpl(w, "UserDelete", struct { User User CSRF string }{u, getCSRF(w, r)}) return } if err := db.UserDelete(name); err != nil { execTmpl(w, "Error", err) return } respondRedirect(w, r, "/user/list") } // ---------------------------------------------------------------------------- func handleSourceInsert(w http.ResponseWriter, r *http.Request) { if r.Method == http.MethodGet { execTmpl(w, "SourceInsert", struct{ CSRF string }{getCSRF(w, r)}) return } s := Source{ Name: r.Form.Get("Name"), Description: r.Form.Get("Description"), AlertTimeout: formGetInt(r, "AlertTimeout"), } if err := db.SourceInsert(&s); err != nil { execTmpl(w, "Error", err) return } respondRedirect(w, r, "/source/view/%s", s.SourceID) } func handleSourceList(w http.ResponseWriter, r *http.Request) { l, err := db.SourceList() if err != nil { execTmpl(w, "Error", err) return } execTmpl(w, "SourceList", l) } func handleSourceView(w http.ResponseWriter, r *http.Request) { id := filepath.Base(r.URL.Path) s, err := db.SourceGet(id) if err != nil { execTmpl(w, "Error", err) return } execTmpl(w, "SourceView", s) } func handleSourceUpdate(w http.ResponseWriter, r *http.Request) { id := filepath.Base(r.URL.Path) s, err := db.SourceGet(id) if err != nil { execTmpl(w, "Error", err) return } if r.Method == http.MethodGet { execTmpl(w, "SourceUpdate", struct { Source Source CSRF string }{s, getCSRF(w, r)}) return } s.Description = r.Form.Get("Description") s.AlertTimeout = formGetInt(r, "AlertTimeout") if err := db.SourceUpdate(s); err != nil { execTmpl(w, "Error", err) return } respondRedirect(w, r, "/source/view/%s", s.SourceID) } func handleSourceDelete(w http.ResponseWriter, r *http.Request) { id := filepath.Base(r.URL.Path) u, err := db.SourceGet(id) if err != nil { execTmpl(w, "Error", err) return } if r.Method == http.MethodGet { execTmpl(w, "SourceDelete", struct { Source Source CSRF string }{u, getCSRF(w, r)}) return } if err := db.SourceDelete(id); err != nil { execTmpl(w, "Error", err) return } respondRedirect(w, r, "/source/list") } // ---------------------------------------------------------------------------- type logListArgs struct { BeforeID int64 SourceID string Type string // One of "all", "alert", "log". } func handleLogList(w http.ResponseWriter, r *http.Request) { r.ParseForm() args := logListArgs{ BeforeID: formGetInt(r, "BeforeID"), SourceID: r.Form.Get("SourceID"), Type: r.Form.Get("Type"), } limit := int64(200) listArgs := LogListArgs{ BeforeID: args.BeforeID, Limit: limit + 1, SourceID: args.SourceID, } switch args.Type { case "alert": b := true listArgs.Alert = &b case "log": b := false listArgs.Alert = &b } l, err := db.LogList(listArgs) if err != nil { execTmpl(w, "Error", err) return } nextURL := "" if len(l) > int(limit) { l = l[:len(l)-1] nextURL = fmt.Sprintf("?Limit=%d&BeforeID=%d&SourceID=%s&Type=%s", limit, l[len(l)-1].LogID, args.SourceID, args.Type) } sources, err := db.SourceList() if err != nil { execTmpl(w, "Error", err) return } execTmpl(w, "LogList", struct { Entries []EntryListRow Sources []Source Args logListArgs NextURL string }{l, sources, args, nextURL}) }