From 939802292de0bfae3e8b0ff5af1da07b95dc9ebe Mon Sep 17 00:00:00 2001 From: Lucas BEE Date: Thu, 13 Oct 2016 13:19:55 +0000 Subject: [PATCH] Close properly SSH clients created by NativeClient Signed-off-by: Lucas BEE --- libmachine/ssh/client.go | 48 +++++++++++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 11 deletions(-) diff --git a/libmachine/ssh/client.go b/libmachine/ssh/client.go index d093f646..78d84b02 100644 --- a/libmachine/ssh/client.go +++ b/libmachine/ssh/client.go @@ -46,6 +46,7 @@ type NativeClient struct { Hostname string Port int openSession *ssh.Session + openClient *ssh.Client } type Auth struct { @@ -156,43 +157,49 @@ func NewNativeConfig(user string, auth *Auth) (ssh.ClientConfig, error) { } func (client *NativeClient) dialSuccess() bool { - if _, err := ssh.Dial("tcp", net.JoinHostPort(client.Hostname, strconv.Itoa(client.Port)), &client.Config); err != nil { + conn, err := ssh.Dial("tcp", net.JoinHostPort(client.Hostname, strconv.Itoa(client.Port)), &client.Config) + if err != nil { log.Debugf("Error dialing TCP: %s", err) return false } + closeConn(conn) return true } -func (client *NativeClient) session(command string) (*ssh.Session, error) { +func (client *NativeClient) session(command string) (*ssh.Client, *ssh.Session, error) { if err := mcnutils.WaitFor(client.dialSuccess); err != nil { - return nil, fmt.Errorf("Error attempting SSH client dial: %s", err) + return nil, nil, fmt.Errorf("Error attempting SSH client dial: %s", err) } conn, err := ssh.Dial("tcp", net.JoinHostPort(client.Hostname, strconv.Itoa(client.Port)), &client.Config) if err != nil { - return nil, fmt.Errorf("Mysterious error dialing TCP for SSH (we already succeeded at least once) : %s", err) + return nil, nil, fmt.Errorf("Mysterious error dialing TCP for SSH (we already succeeded at least once) : %s", err) } + session, err := conn.NewSession() - return conn.NewSession() + return conn, session, err } func (client *NativeClient) Output(command string) (string, error) { - session, err := client.session(command) + conn, session, err := client.session(command) if err != nil { return "", nil } + defer closeConn(conn) + defer session.Close() output, err := session.CombinedOutput(command) - defer session.Close() return string(output), err } func (client *NativeClient) OutputWithPty(command string) (string, error) { - session, err := client.session(command) + conn, session, err := client.session(command) if err != nil { return "", nil } + defer closeConn(conn) + defer session.Close() fd := int(os.Stdin.Fd()) @@ -214,13 +221,12 @@ func (client *NativeClient) OutputWithPty(command string) (string, error) { } output, err := session.CombinedOutput(command) - defer session.Close() return string(output), err } func (client *NativeClient) Start(command string) (io.ReadCloser, io.ReadCloser, error) { - session, err := client.session(command) + conn, session, err := client.session(command) if err != nil { return nil, nil, err } @@ -237,15 +243,27 @@ func (client *NativeClient) Start(command string) (io.ReadCloser, io.ReadCloser, return nil, nil, err } + client.openClient = conn client.openSession = session return ioutil.NopCloser(stdout), ioutil.NopCloser(stderr), nil } func (client *NativeClient) Wait() error { err := client.openSession.Wait() + if err != nil { + return err + } + _ = client.openSession.Close() + + err = client.openClient.Close() + if err != nil { + return err + } + client.openSession = nil - return err + client.openClient = nil + return nil } func (client *NativeClient) Shell(args ...string) error { @@ -256,6 +274,7 @@ func (client *NativeClient) Shell(args ...string) error { if err != nil { return err } + defer closeConn(conn) session, err := conn.NewSession() if err != nil { @@ -414,3 +433,10 @@ func (client *ExternalClient) Wait() error { client.cmd = nil return err } + +func closeConn(c io.Closer) { + err := c.Close() + if err != nil { + log.Debugf("Error closing SSH Client: %s", err) + } +}