diff --git a/webutil/formscanner.go b/webutil/formscanner.go index a77c2d0..845767d 100644 --- a/webutil/formscanner.go +++ b/webutil/formscanner.go @@ -4,7 +4,6 @@ import ( "errors" "fmt" "net/url" - "strconv" ) var ErrUnsupportedType = errors.New("unsupported type") @@ -23,89 +22,13 @@ func (s *FormScanner) Scan(name string, val any) *FormScanner { return s } - setError := func(name string, err error) { - s.err = fmt.Errorf("Error in field %s: %w", name, err) - } - switch v := val.(type) { case *bool: *v = s.form.Has(name) - case *string: - *v = s.form.Get(name) - case *int: - if i, err := strconv.ParseInt(s.form.Get(name), 10, 64); err != nil { - setError(name, err) - } else { - *v = int(i) - } - case *int8: - if i, err := strconv.ParseInt(s.form.Get(name), 10, 8); err != nil { - setError(name, err) - } else { - *v = int8(i) - } - case *int16: - if i, err := strconv.ParseInt(s.form.Get(name), 10, 16); err != nil { - setError(name, err) - } else { - *v = int16(i) - } - case *int32: - if i, err := strconv.ParseInt(s.form.Get(name), 10, 32); err != nil { - setError(name, err) - } else { - *v = int32(i) - } - case *int64: - if i, err := strconv.ParseInt(s.form.Get(name), 10, 64); err != nil { - setError(name, err) - } else { - *v = int64(i) - } - case *uint: - if i, err := strconv.ParseUint(s.form.Get(name), 10, 64); err != nil { - setError(name, err) - } else { - *v = uint(i) - } - case *uint8: - if i, err := strconv.ParseUint(s.form.Get(name), 10, 8); err != nil { - setError(name, err) - } else { - *v = uint8(i) - } - case *uint16: - if i, err := strconv.ParseUint(s.form.Get(name), 10, 16); err != nil { - setError(name, err) - } else { - *v = uint16(i) - } - case *uint32: - if i, err := strconv.ParseUint(s.form.Get(name), 10, 32); err != nil { - setError(name, err) - } else { - *v = uint32(i) - } - case *uint64: - if i, err := strconv.ParseUint(s.form.Get(name), 10, 64); err != nil { - setError(name, err) - } else { - *v = uint64(i) - } - case *float32: - if f, err := strconv.ParseFloat(s.form.Get(name), 32); err != nil { - setError(name, err) - } else { - *v = float32(f) - } - case *float64: - if f, err := strconv.ParseFloat(s.form.Get(name), 64); err != nil { - setError(name, err) - } else { - *v = float64(f) - } default: - setError(name, ErrUnsupportedType) + if err := scan(s.form.Get(name), v); err != nil { + s.err = fmt.Errorf("Error in field %s: %w", name, err) + } } return s diff --git a/webutil/pathscaner.go b/webutil/pathscaner.go new file mode 100644 index 0000000..95d57c4 --- /dev/null +++ b/webutil/pathscaner.go @@ -0,0 +1,31 @@ +package webutil + +import ( + "fmt" + "net/http" +) + +type PathScanner struct { + r *http.Request + err error +} + +func NewPathScanner(r *http.Request) *PathScanner { + return &PathScanner{r: r} +} + +func (s *PathScanner) Scan(name string, val any) *PathScanner { + if s.err != nil { + return s + } + + if err := scan(s.r.PathValue(name), val); err != nil { + s.err = fmt.Errorf("Error in field %s: %w", name, err) + } + + return s +} + +func (s *PathScanner) Error() error { + return s.err +} diff --git a/webutil/scanner.go b/webutil/scanner.go new file mode 100644 index 0000000..286612e --- /dev/null +++ b/webutil/scanner.go @@ -0,0 +1,85 @@ +package webutil + +import "strconv" + +func scan(raw string, val any) error { + switch v := val.(type) { + case *string: + *v = raw + case *int: + if i, err := strconv.ParseInt(raw, 10, 64); err != nil { + return err + } else { + *v = int(i) + } + case *int8: + if i, err := strconv.ParseInt(raw, 10, 8); err != nil { + return err + } else { + *v = int8(i) + } + case *int16: + if i, err := strconv.ParseInt(raw, 10, 16); err != nil { + return err + } else { + *v = int16(i) + } + case *int32: + if i, err := strconv.ParseInt(raw, 10, 32); err != nil { + return err + } else { + *v = int32(i) + } + case *int64: + if i, err := strconv.ParseInt(raw, 10, 64); err != nil { + return err + } else { + *v = int64(i) + } + case *uint: + if i, err := strconv.ParseUint(raw, 10, 64); err != nil { + return err + } else { + *v = uint(i) + } + case *uint8: + if i, err := strconv.ParseUint(raw, 10, 8); err != nil { + return err + } else { + *v = uint8(i) + } + case *uint16: + if i, err := strconv.ParseUint(raw, 10, 16); err != nil { + return err + } else { + *v = uint16(i) + } + case *uint32: + if i, err := strconv.ParseUint(raw, 10, 32); err != nil { + return err + } else { + *v = uint32(i) + } + case *uint64: + if i, err := strconv.ParseUint(raw, 10, 64); err != nil { + return err + } else { + *v = uint64(i) + } + case *float32: + if f, err := strconv.ParseFloat(raw, 32); err != nil { + return err + } else { + *v = float32(f) + } + case *float64: + if f, err := strconv.ParseFloat(raw, 64); err != nil { + return err + } else { + *v = float64(f) + } + default: + return ErrUnsupportedType + } + return nil +}