package httpconn

import (
	"bufio"
	"context"
	"crypto/tls"
	"errors"
	"io"
	"net"
	"net/http"
	"net/url"
	"time"
)

var (
	ErrUnknownScheme = errors.New("uknown scheme")
)

type Dialer struct {
	timeout time.Duration
}

func NewDialer() *Dialer {
	return &Dialer{timeout: 10 * time.Second}
}

func (d *Dialer) SetTimeout(timeout time.Duration) {
	d.timeout = timeout
}

func (d *Dialer) Dial(rawURL string) (net.Conn, error) {
	u, err := url.Parse(rawURL)
	if err != nil {
		return nil, err
	}

	switch u.Scheme {
	case "https":
		return d.DialHTTPS(u.Host+":443", u.Path)
	case "http":
		return d.DialHTTP(u.Host, u.Path)
	default:
		return nil, ErrUnknownScheme
	}
}

func (d *Dialer) DialHTTPS(host, path string) (net.Conn, error) {
	ctx, cancel := context.WithTimeout(context.Background(), d.timeout)
	dd := tls.Dialer{}
	conn, err := dd.DialContext(ctx, "tcp", host)
	cancel()
	if err != nil {
		return nil, err
	}
	return d.finishDialing(conn, host, path)

}

func (d *Dialer) DialHTTP(host, path string) (net.Conn, error) {
	conn, err := net.DialTimeout("tcp", host, d.timeout)
	if err != nil {
		return nil, err
	}
	return d.finishDialing(conn, host, path)
}

func (d *Dialer) finishDialing(conn net.Conn, host, path string) (net.Conn, error) {
	conn.SetDeadline(time.Now().Add(d.timeout))

	if _, err := io.WriteString(conn, "CONNECT "+path+" HTTP/1.0\n"); err != nil {
		return nil, err
	}
	if _, err := io.WriteString(conn, "Host: "+host+"\n\n"); err != nil {
		return nil, err
	}

	// Require successful HTTP response before using the conn.
	resp, err := http.ReadResponse(bufio.NewReader(conn), &http.Request{Method: "CONNECT"})
	if err != nil {
		conn.Close()
		return nil, err
	}

	if resp.Status != "200 OK" {
		conn.Close()
		return nil, err
	}

	conn.SetDeadline(time.Time{})

	return conn, nil
}

func Dial(rawURL string) (net.Conn, error) {
	return NewDialer().Dial(rawURL)
}

func DialHTTPS(host, path string) (net.Conn, error) {
	return NewDialer().DialHTTPS(host, path)
}

func DialHTTP(host, path string) (net.Conn, error) {
	return NewDialer().DialHTTP(host, path)
}