前言
此方法适用于Pod
不需要大量连接的情况:
- 有多个pod在执行任务, 偶尔需要连接其中一个pod查看进度/日志;
- 对pod执行一个脚本/命令;
不适用于大量连接建立的情况:
- pod启的数据库服务;
- pod启的Api服务;
- pod启的前端服务;
- pod启的Oss服务;
Portforward简介
Portforward就是端口转发, 可以将本地机器的端口转发到 Kubernetes 集群中的Pod
中, 主要是调试和临时访问场景,尤其是当你想要在不暴露服务的情况下访问 Pod 中的应用时; 比如:
- 数据库服务本地连接
- Api服务请求调试
主要命令格式:
kubectl port-forward <resource>/<pod-name> <local-port>:<remote-port>
支持Pod
和Service
多端口转发, 比如:
kubectl port-forward pod/my-pod 9090:8080
kubectl port-forward pod/my-pod 9090:8080 7070:7777
kubectl port-forward svc/my-svc 9090:8080
kubectl port-forward svc/my-svc 9090:8080 7070:7777
需求背景
我们后台管理了多个集群, 每个集群都有海量的Pod
任务, 需要提供SSH服务供用户连接到Pod
;
有两种实现方式:
- 使用Exec(不支持虚拟机)
- Podforward
本篇主要讲Podforward
;
源码解析
Podforward
的实现方式主要是通过对HTTP请求进行连接升级, 支持多路流; 然后在本地打开监听端口, 接收TCP请求并创建新的流进行交互; 下面贴一下主要的流程代码:
ForwardPorts
Podforward
的入口函数, 打开对Pod
的流式连接, 准备进行端口转发;
func (pf *PortForwarder) ForwardPorts() error {
defer pf.Close()
var err error
var protocol string
pf.streamConn, protocol, err = pf.dialer.Dial(PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("error upgrading connection: %s", err)
}
defer pf.streamConn.Close()
if protocol != PortForwardProtocolV1Name {
return fmt.Errorf("unable to negotiate protocol: client supports %q, server returned %q", PortForwardProtocolV1Name, protocol)
}
return pf.forward()
}
forward
forward
获取端口映射参数, 开始监听指定的本地端口;
func (pf *PortForwarder) forward() error {
var err error
listenSuccess := false
for i := range pf.ports {
port := &pf.ports[i]
err = pf.listenOnPort(port)
switch {
case err == nil:
listenSuccess = true
default:
if pf.errOut != nil {
fmt.Fprintf(pf.errOut, "Unable to listen on port %d: %v\n", port.Local, err)
}
}
}
...
return nil
}
func (pf *PortForwarder) getListener(protocol string, hostname string, port *ForwardedPort) (net.Listener, error) {
listener, err := net.Listen(protocol, net.JoinHostPort(hostname, strconv.Itoa(int(port.Local))))
if err != nil {
return nil, fmt.Errorf("unable to create listener: Error %s", err)
}
...
return listener, nil
}
handleConnection
waitForConnection
通过监听端口获取Tcp连接, 对每个连接开个go程进行处理;
handleConnection
对每个Tcp连接创建新的Stream流, 进行Tcp连接和Stream流的交互;
func (pf *PortForwarder) waitForConnection(listener net.Listener, port ForwardedPort) {
for {
select {
case <-pf.streamConn.CloseChan():
return
default:
conn, err := listener.Accept()
if err != nil {
// TODO consider using something like https://github.com/hydrogen18/stoppableListener?
if !strings.Contains(strings.ToLower(err.Error()), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error accepting connection on port %d: %v", port.Local, err))
}
return
}
go pf.handleConnection(conn, port)
}
}
}
func (pf *PortForwarder) handleConnection(conn net.Conn, port ForwardedPort) {
...
// create data stream
headers.Set(v1.StreamType, v1.StreamTypeData)
dataStream, err := pf.streamConn.CreateStream(headers)
if err != nil {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
defer pf.streamConn.RemoveStreams(dataStream)
localError := make(chan struct{})
remoteDone := make(chan struct{})
go func() {
// Copy from the remote side to the local port.
if _, err := io.Copy(conn, dataStream); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error copying from remote stream to local connection: %v", err))
}
// inform the select below that the remote copy is done
close(remoteDone)
}()
go func() {
// inform server we're not sending any more data after copy unblocks
defer dataStream.Close()
// Copy from the local port to the remote side.
if _, err := io.Copy(dataStream, conn); err != nil && !strings.Contains(err.Error(), "use of closed network connection") {
runtime.HandleError(fmt.Errorf("error copying from local connection to remote stream: %v", err))
// break out of the select below without waiting for the other copy to finish
close(localError)
}
}()
...
}
总结
看代码得知原理, 数据链路为 userClient -> serverListen -> pod;
知道链路了, 就自然能得知它最适合的场景, 就是大量的持续的新建Tcp请求, 比如Api/Oss等服务, 但是对于我的需求场景: 偶尔一次的连接就不太合适了;
所以我们能不能跳过ServerListen
这层中转, 直接让userClient
和Pod
进行交互呢? 答案是可以的;
解决方案
回归我们的需求本身: 我们有大量用户和大量的pod
, 每个pod
也只会有少量用户会访问, 所以没必要用serverListen
中转, 直接用户连pod
就可以了, 这样就省了ServerListen
的两个端口!
代码也很简单, 只需要把 handleConnection
的代码沾出来, 将用户的连接跟pod
的连接做交互就好了;
实现代码
简单贴一下实现代码, 自己在handle func(dataStream httpstream.Stream)
中与net.conn
做交互就可以了;
func createSPDYConnection(namespace, podName string, podPort int, handle func(dataStream httpstream.Stream)) error {
req := clientset.CoreV1().RESTClient().
Post().
Resource("pods").
Namespace(namespace).
Name(podName).
SubResource("portforward").
Param("ports", fmt.Sprintf("%d", podPort))
// 创建 SPDY Transport 和 Dialer
transport, upgrader, err := spdy.RoundTripperFor(config)
if err != nil {
return fmt.Errorf("failed to create round tripper: %v", err)
}
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL())
// 建立连接到 Pod 的端口
streamConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("failed to dial port forward: %v", err)
}
defer streamConn.Close()
handleStreamConnection(streamConn, portforward.ForwardedPort{
Local: 0,
Remote: uint16(podPort),
}, handle)
return nil
}
// handleStreamConnection copies data between the local connection and the stream to
// the remote server.
func handleStreamConnection(streamConn httpstream.Connection, port portforward.ForwardedPort, handle func(dataStream httpstream.Stream)) {
requestID := time.Now().UnixNano()
// create error stream
headers := http.Header{}
headers.Set(v1.StreamType, v1.StreamTypeError)
headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
headers.Set(v1.PortForwardRequestIDHeader, strconv.FormatInt(requestID, 10))
errorStream, err := streamConn.CreateStream(headers)
if err != nil {
runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
// we're not writing to this stream
errorStream.Close()
go func() {
message, err := io.ReadAll(errorStream)
switch {
case err != nil:
log.Printf("error reading error stream: %v\n", err)
case len(message) > 0:
log.Printf("error reading error stream: %v\n", string(message))
}
}()
// create data stream
headers.Set(v1.StreamType, v1.StreamTypeData)
dataStream, err := streamConn.CreateStream(headers)
if err != nil {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
handle(dataStream)
_ = dataStream.Close()
_ = streamConn.Close()
}
Kubelet
并且在k8s源码中也有相同的使用, 虽然是个test;
kubernetes/pkg/kubelet/server/server_test.go at master · kubernetes/kubernetes
func TestServePortForward(t *testing.T) {
tests := map[string]struct {
port string
uid bool
clientData string
containerData string
shouldError bool
}{
"no port": {port: "", shouldError: true},
"none number port": {port: "abc", shouldError: true},
"negative port": {port: "-1", shouldError: true},
"too large port": {port: "65536", shouldError: true},
"0 port": {port: "0", shouldError: true},
"min port": {port: "1", shouldError: false},
"normal port": {port: "8000", shouldError: false},
"normal port with data forward": {port: "8000", clientData: "client data", containerData: "container data", shouldError: false},
"max port": {port: "65535", shouldError: false},
"normal port with uid": {port: "8000", uid: true, shouldError: false},
}
podNamespace := "other"
podName := "foo"
for desc := range tests {
test := tests[desc]
t.Run(desc, func(t *testing.T) {
ss, err := newTestStreamingServer(0)
require.NoError(t, err)
defer ss.testHTTPServer.Close()
fw := newServerTestWithDebug(true, ss)
defer fw.testHTTPServer.Close()
portForwardFuncDone := make(chan struct{})
fw.fakeKubelet.getPortForwardCheck = func(name, namespace string, uid types.UID, opts portforward.V4Options) {
assert.Equal(t, podName, name, "pod name")
assert.Equal(t, podNamespace, namespace, "pod namespace")
if test.uid {
assert.Equal(t, testUID, string(uid), "uid")
}
}
ss.fakeRuntime.portForwardFunc = func(podSandboxID string, port int32, stream io.ReadWriteCloser) error {
defer close(portForwardFuncDone)
assert.Equal(t, testPodSandboxID, podSandboxID, "pod sandbox id")
// The port should be valid if it reaches here.
testPort, err := strconv.ParseInt(test.port, 10, 32)
require.NoError(t, err, "parse port")
assert.Equal(t, int32(testPort), port, "port")
if test.clientData != "" {
fromClient := make([]byte, 32)
n, err := stream.Read(fromClient)
assert.NoError(t, err, "reading client data")
assert.Equal(t, test.clientData, string(fromClient[0:n]), "client data")
}
if test.containerData != "" {
_, err := stream.Write([]byte(test.containerData))
assert.NoError(t, err, "writing container data")
}
return nil
}
var url string
if test.uid {
url = fmt.Sprintf("%s/portForward/%s/%s/%s", fw.testHTTPServer.URL, podNamespace, podName, testUID)
} else {
url = fmt.Sprintf("%s/portForward/%s/%s", fw.testHTTPServer.URL, podNamespace, podName)
}
var (
upgradeRoundTripper httpstream.UpgradeRoundTripper
c *http.Client
)
upgradeRoundTripper, err = spdy.NewRoundTripper(&tls.Config{})
if err != nil {
t.Fatalf("Error creating SpdyRoundTripper: %v", err)
}
c = &http.Client{Transport: upgradeRoundTripper}
req := makeReq(t, "POST", url, "portforward.k8s.io")
resp, err := c.Do(req)
require.NoError(t, err, "POSTing")
defer resp.Body.Close()
assert.Equal(t, http.StatusSwitchingProtocols, resp.StatusCode, "status code")
conn, err := upgradeRoundTripper.NewConnection(resp)
require.NoError(t, err, "creating streaming connection")
defer conn.Close()
headers := http.Header{}
headers.Set("streamType", "error")
headers.Set("port", test.port)
_, err = conn.CreateStream(headers)
assert.Equal(t, test.shouldError, err != nil, "expect error")
if test.shouldError {
return
}
headers.Set("streamType", "data")
headers.Set("port", test.port)
dataStream, err := conn.CreateStream(headers)
require.NoError(t, err, "create stream")
if test.clientData != "" {
_, err := dataStream.Write([]byte(test.clientData))
assert.NoError(t, err, "writing client data")
}
if test.containerData != "" {
fromContainer := make([]byte, 32)
n, err := dataStream.Read(fromContainer)
assert.NoError(t, err, "reading container data")
assert.Equal(t, test.containerData, string(fromContainer[0:n]), "container data")
}
<-portForwardFuncDone
})
}
}
搞个demo
最后再放一个最近做的东西, 是一个连接k8s``pod
的SSH服务, 用户通过连接SSH服务, 转而连接到pod
, 中间可以在SSH握手后进行一些特殊处理, 比如身份校验, 日志记录等;
package main
import (
"context"
"fmt"
"golang.org/x/crypto/ssh"
gossh "golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
"io"
v1 "k8s.io/api/core/v1"
"k8s.io/apimachinery/pkg/util/httpstream"
"k8s.io/apimachinery/pkg/util/runtime"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/rest"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/client-go/tools/portforward"
"k8s.io/client-go/transport/spdy"
"log"
"net"
"net/http"
"os"
"strconv"
"strings"
"time"
)
var (
podName = ""
podNamespace = ""
localSSHPort = ":2225"
kubeConfigPath = "/home/fly/.kube/config"
config *rest.Config
clientset *kubernetes.Clientset
authorizedKey, _ = os.ReadFile("/home/fly/.ssh/id_rsa")
privateKey, _ = gossh.ParsePrivateKey(authorizedKey)
err error
)
func init() {
config, err = clientcmd.BuildConfigFromFlags("", kubeConfigPath)
if err != nil {
log.Fatalf("k8s config err: %v \n", err)
}
clientset, err = kubernetes.NewForConfig(config)
if err != nil {
log.Fatalf("k8s client err: %v \n", err)
}
}
func main() {
listener, err := net.Listen("tcp", localSSHPort)
if err != nil {
log.Fatalf("unable to listen on port %s: %v", localSSHPort, err)
}
defer listener.Close()
log.Printf("the proxy service is listening on the port %s", localSSHPort)
for {
clientConn, err := listener.Accept()
if err != nil {
log.Printf("failed to accept connection: %v", err)
continue
}
go handleConnection(clientConn)
}
}
type NetHandle struct {
ctx context.Context
sshConn *ssh.ServerConn
chans <-chan ssh.NewChannel
reqs <-chan *ssh.Request
dataStream httpstream.Stream
}
func handleConnection(conn net.Conn) {
ctx, cancel := context.WithTimeout(context.Background(), 7*time.Hour)
defer cancel()
// 创建一个新的 SSH 服务
serverConfig := &ssh.ServerConfig{
NoClientAuth: true,
}
serverConfig.AddHostKey(privateKey)
// 接收客户端连接的 SSH 握手
sshConn, chans, reqs, err := ssh.NewServerConn(conn, serverConfig)
if err != nil {
log.Printf("failed to receive ssh connection: %v", err)
conn.Close()
return
}
defer sshConn.Close()
username := sshConn.User()
log.Printf("ssh connection to users: %s", username)
h := &NetHandle{
ctx: ctx,
sshConn: sshConn,
chans: chans,
reqs: reqs,
dataStream: nil,
}
handle := func(dataStream httpstream.Stream) {
clientConf := &ssh.ClientConfig{
User: "ubuntu",
Auth: []ssh.AuthMethod{ssh.PublicKeys(privateKey)},
Timeout: 5 * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
streamConn := NewStreamConn(dataStream)
log.Println("Encapsulate stream as net.conn, start forwarding")
clientConn, clientChans, clientReqs, err := ssh.NewClientConn(streamConn, "vm:22", clientConf)
if err != nil {
log.Printf("new ssh client err: %v\n", err)
return
}
defer clientConn.Close()
go forwardConnReqs(h.sshConn, clientReqs)
go forwardConnReqs(clientConn, h.reqs)
go forwardChans(h.ctx, h.sshConn, clientChans)
go forwardChans(h.ctx, clientConn, h.chans)
waitCtx, waitCancel := context.WithCancel(h.ctx)
go func() {
_ = h.sshConn.Wait()
waitCancel()
}()
go func() {
_ = clientConn.Wait()
waitCancel()
}()
<-waitCtx.Done()
}
createSPDYConnection(podNamespace, podName, 22, handle)
}
type ChannelOpener interface {
OpenChannel(name string, data []byte) (ssh.Channel, <-chan *ssh.Request, error)
}
func forwardChans(ctx context.Context, dst ChannelOpener, chans <-chan ssh.NewChannel) {
for newChan := range chans {
go forwardChan(ctx, dst, newChan)
}
}
func forwardChan(ctx context.Context, dst ChannelOpener, newChan ssh.NewChannel) {
dstChan, dstReqs, err := dst.OpenChannel(newChan.ChannelType(), newChan.ExtraData())
if err != nil {
_ = newChan.Reject(ssh.Prohibited, err.Error())
return
}
defer dstChan.Close()
srcChan, srcReqs, err := newChan.Accept()
if err != nil {
return
}
defer srcChan.Close()
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
return copyWithReqs(ctx, srcChan, dstChan, dstReqs, "out")
})
g.Go(func() error {
return copyWithReqs(ctx, dstChan, srcChan, srcReqs, "in")
})
g.Wait()
}
func copyWithReqs(ctx context.Context, dst, src ssh.Channel, srcReqs <-chan *ssh.Request, _ string) error {
// According to https://github.com/golang/go/issues/29733
// Before we close the channel, We have to wait until exit- prefixed request forwarded.
// forwardChannelReqs should notify when it after forward exit- prefixed request.
// io.Copy may encounter error and exit early (do not consume the channel), so we have to leave a slot in it.
exitRequestForwarded := make(chan struct{}, 1)
g, ctx := errgroup.WithContext(ctx)
go func() { <-ctx.Done(); dst.Close() }()
g.Go(func() error { return forwardChannelReqs(ctx, dst, srcReqs, exitRequestForwarded) })
g.Go(func() error {
_, err := io.Copy(dst.Stderr(), src.Stderr())
return err
})
g.Go(func() error {
// TODO if need audit. we need copy bytes to audit writer
_, err := io.Copy(dst, src)
switch err {
case nil:
// When receiving EOF (which means io.Copy returns nil), wait exit- prefixed request forwarded before we close channel.
// For more detail, see https://github.com/golang/go/issues/29733
t := time.NewTimer(time.Second)
defer t.Stop()
select {
case <-t.C:
// We can't wait forever, exit anyway.
case <-exitRequestForwarded:
// Already forwarded
}
default:
// Encounter error, Don't need to wait anything, Close immediately.
}
dst.CloseWrite()
return err
})
return g.Wait()
}
func forwardConnReqs(dst ssh.Conn, src <-chan *ssh.Request) {
for r := range src {
ok, data, err := dst.SendRequest(r.Type, r.WantReply, r.Payload)
if err != nil {
return
}
if r.WantReply {
if err := r.Reply(ok, data); err != nil {
return
}
}
}
return
}
func forwardChannelReqs(_ context.Context, dst ssh.Channel, src <-chan *ssh.Request, exitRequestForwarded chan<- struct{}) error {
var isExitReq bool
defer func() {
if isExitReq {
// According to https://github.com/golang/go/issues/29733
// Send a signal when exit- prefix request already forwarded.
// Send signal in non-blocking manner to prevent unexpected blocking.
select {
case exitRequestForwarded <- struct{}{}:
default:
}
}
}()
for r := range src {
if strings.HasPrefix(r.Type, "exit-") {
isExitReq = true
}
ok, err := dst.SendRequest(r.Type, r.WantReply, r.Payload)
if err != nil {
return err
}
if r.WantReply {
err := r.Reply(ok, nil)
if err != nil {
return err
}
}
}
return nil
}
func createSPDYConnection(namespace, podName string, podPort int, handle func(dataStream httpstream.Stream)) error {
req := clientset.CoreV1().RESTClient().
Post().
Resource("pods").
Namespace(namespace).
Name(podName).
SubResource("portforward").
Param("ports", fmt.Sprintf("%d", podPort))
// 创建 SPDY Transport 和 Dialer
transport, upgrader, err := spdy.RoundTripperFor(config)
if err != nil {
return fmt.Errorf("failed to create round tripper: %v", err)
}
dialer := spdy.NewDialer(upgrader, &http.Client{Transport: transport}, "POST", req.URL())
// 建立连接到 Pod 的端口
streamConn, _, err := dialer.Dial(portforward.PortForwardProtocolV1Name)
if err != nil {
return fmt.Errorf("failed to dial port forward: %v", err)
}
defer streamConn.Close()
handleStreamConnection(streamConn, portforward.ForwardedPort{
Local: 0,
Remote: uint16(podPort),
}, handle)
return nil
}
// handleStreamConnection copies data between the local connection and the stream to
// the remote server.
func handleStreamConnection(streamConn httpstream.Connection, port portforward.ForwardedPort, handle func(dataStream httpstream.Stream)) {
requestID := time.Now().UnixNano()
// create error stream
headers := http.Header{}
headers.Set(v1.StreamType, v1.StreamTypeError)
headers.Set(v1.PortHeader, fmt.Sprintf("%d", port.Remote))
headers.Set(v1.PortForwardRequestIDHeader, strconv.FormatInt(requestID, 10))
errorStream, err := streamConn.CreateStream(headers)
if err != nil {
runtime.HandleError(fmt.Errorf("error creating error stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
// we're not writing to this stream
errorStream.Close()
go func() {
message, err := io.ReadAll(errorStream)
switch {
case err != nil:
log.Printf("error reading error stream: %v\n", err)
case len(message) > 0:
log.Printf("error reading error stream: %v\n", string(message))
}
}()
// create data stream
headers.Set(v1.StreamType, v1.StreamTypeData)
dataStream, err := streamConn.CreateStream(headers)
if err != nil {
runtime.HandleError(fmt.Errorf("error creating forwarding stream for port %d -> %d: %v", port.Local, port.Remote, err))
return
}
handle(dataStream)
_ = dataStream.Close()
_ = streamConn.Close()
}
// streamNetConn 是封装 httpstream.Stream 实现 net.Conn 接口
type streamNetConn struct {
stream httpstream.Stream
}
// Read 实现 net.Conn 接口的 Read 方法
func (c *streamNetConn) Read(b []byte) (n int, err error) {
// 从 httpstream.Stream 中读取数据
return c.stream.Read(b)
}
// Write 实现 net.Conn 接口的 Write 方法
func (c *streamNetConn) Write(b []byte) (n int, err error) {
// 将数据写入 httpstream.Stream
return c.stream.Write(b)
}
// Close 实现 net.Conn 接口的 Close 方法
func (c *streamNetConn) Close() error {
// 关闭 httpstream.Stream
return c.stream.Close()
}
// LocalAddr 实现 net.Conn 接口的 LocalAddr 方法
func (c *streamNetConn) LocalAddr() net.Addr {
// 可以返回一个 nil 或者实现一个自定义的 LocalAddr
return nil
}
// RemoteAddr 实现 net.Conn 接口的 RemoteAddr 方法
func (c *streamNetConn) RemoteAddr() net.Addr {
// 可以返回一个 nil 或者实现一个自定义的 RemoteAddr
return nil
}
// SetDeadline 实现 net.Conn 接口的 SetDeadline 方法
func (c *streamNetConn) SetDeadline(t time.Time) error {
// 如果需要设置超时,可以在这里实现
return nil
}
// SetReadDeadline 实现 net.Conn 接口的 SetReadDeadline 方法
func (c *streamNetConn) SetReadDeadline(t time.Time) error {
// 如果需要设置读取超时,可以在这里实现
return nil
}
// SetWriteDeadline 实现 net.Conn 接口的 SetWriteDeadline 方法
func (c *streamNetConn) SetWriteDeadline(t time.Time) error {
// 如果需要设置写入超时,可以在这里实现
return nil
}
func NewStreamConn(stream httpstream.Stream) *streamNetConn {
return &streamNetConn{
stream: stream,
}
}