From 94ba285af071de8d475b681ef81f53e522e34b0f Mon Sep 17 00:00:00 2001 From: Chris Koch Date: Tue, 22 Aug 2023 05:43:22 -0700 Subject: [PATCH] Count entire request handler goroutine as pending Signed-off-by: Chris Koch --- .../composefs/composefs_integration_test.go | 6 +- fsimpl/localfs/localfs_integration_test.go | 18 +++--- fsimpl/staticfs/staticfs_integration_test.go | 6 +- fsimpl/test/rwvmtests/rw_linux_test.go | 12 +++- p9/server.go | 61 ++++++++++++++++--- 5 files changed, 77 insertions(+), 26 deletions(-) diff --git a/fsimpl/composefs/composefs_integration_test.go b/fsimpl/composefs/composefs_integration_test.go index 38bdf34..1534217 100644 --- a/fsimpl/composefs/composefs_integration_test.go +++ b/fsimpl/composefs/composefs_integration_test.go @@ -4,6 +4,7 @@ package composefs import ( + "context" "encoding/json" "fmt" "net" @@ -29,7 +30,6 @@ func TestLinuxClient(t *testing.T) { if err != nil { t.Fatalf("err binding: %v", err) } - defer serverSocket.Close() serverPort := serverSocket.Addr().(*net.TCPAddr).Port localfsTmp := t.TempDir() @@ -69,7 +69,6 @@ func TestLinuxClient(t *testing.T) { // Run the server. s := p9.NewServer(attacher, p9.WithServerLogger(ulogtest.Logger{TB: t})) - go s.Serve(serverSocket) // Run the read tests from fsimpl/test/rovmtests. vmtest.RunGoTestsInVM(t, []string{"github.com/hugelgupf/p9/fsimpl/test/rovmtests"}, &vmtest.UrootFSOptions{ @@ -90,6 +89,9 @@ func TestLinuxClient(t *testing.T) { Mask: net.CIDRMask(24, 32), }), qemu.WithVMTimeout(30 * time.Second), + qemu.WithTask(func(ctx context.Context, n *qemu.Notifications) error { + return s.ServeContext(ctx, serverSocket) + }), }, }, }) diff --git a/fsimpl/localfs/localfs_integration_test.go b/fsimpl/localfs/localfs_integration_test.go index 46dbee9..3a0431c 100644 --- a/fsimpl/localfs/localfs_integration_test.go +++ b/fsimpl/localfs/localfs_integration_test.go @@ -4,6 +4,7 @@ package localfs import ( + "context" "fmt" "io/ioutil" "net" @@ -20,22 +21,15 @@ import ( ) func TestIntegration(t *testing.T) { - tempDir, err := ioutil.TempDir("", "localfs-") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tempDir) - serverSocket, err := net.Listen("tcp", ":0") if err != nil { t.Fatalf("err binding: %v", err) } - defer serverSocket.Close() serverPort := serverSocket.Addr().(*net.TCPAddr).Port // Run the server. + tempDir := t.TempDir() s := p9.NewServer(Attacher(tempDir), p9.WithServerLogger(ulogtest.Logger{TB: t})) - go s.Serve(serverSocket) // Run the read-write tests from fsimpl/test/rwvm. vmtest.RunGoTestsInVM(t, []string{"github.com/hugelgupf/p9/fsimpl/test/rwvmtests"}, &vmtest.UrootFSOptions{ @@ -57,6 +51,9 @@ func TestIntegration(t *testing.T) { Mask: net.CIDRMask(24, 32), }), qemu.WithVMTimeout(30 * time.Second), + qemu.WithTask(func(ctx context.Context, n *qemu.Notifications) error { + return s.ServeContext(ctx, serverSocket) + }), }, }, }) @@ -74,12 +71,10 @@ func TestBenchmark(t *testing.T) { if err != nil { t.Fatalf("err binding: %v", err) } - defer serverSocket.Close() serverPort := serverSocket.Addr().(*net.TCPAddr).Port // Run the server. No logger -- slows down the benchmark. s := p9.NewServer(Attacher(tempDir)) //, p9.WithServerLogger(ulogtest.Logger{TB: t})) - go s.Serve(serverSocket) // Run the read-write tests from fsimpl/test/rwvm. vmtest.RunGoTestsInVM(t, []string{"github.com/hugelgupf/p9/fsimpl/test/benchmark"}, &vmtest.UrootFSOptions{ @@ -101,6 +96,9 @@ func TestBenchmark(t *testing.T) { Mask: net.CIDRMask(24, 32), }), qemu.WithVMTimeout(30 * time.Second), + qemu.WithTask(func(ctx context.Context, n *qemu.Notifications) error { + return s.ServeContext(ctx, serverSocket) + }), }, }, }) diff --git a/fsimpl/staticfs/staticfs_integration_test.go b/fsimpl/staticfs/staticfs_integration_test.go index b3ee921..5064134 100644 --- a/fsimpl/staticfs/staticfs_integration_test.go +++ b/fsimpl/staticfs/staticfs_integration_test.go @@ -4,6 +4,7 @@ package staticfs import ( + "context" "encoding/json" "fmt" "net" @@ -28,7 +29,6 @@ func TestLinuxClient(t *testing.T) { if err != nil { t.Fatalf("err binding: %v", err) } - defer serverSocket.Close() serverPort := serverSocket.Addr().(*net.TCPAddr).Port wantRoot := []string{ @@ -73,7 +73,6 @@ func TestLinuxClient(t *testing.T) { // Run the server. s := p9.NewServer(attacher, p9.WithServerLogger(ulogtest.Logger{TB: t})) - go s.Serve(serverSocket) // Run the read tests from fsimpl/test/rovmtests. vmtest.RunGoTestsInVM(t, []string{"github.com/hugelgupf/p9/fsimpl/test/rovmtests"}, &vmtest.UrootFSOptions{ @@ -94,6 +93,9 @@ func TestLinuxClient(t *testing.T) { Mask: net.CIDRMask(24, 32), }), qemu.WithVMTimeout(30 * time.Second), + qemu.WithTask(func(ctx context.Context, n *qemu.Notifications) error { + return s.ServeContext(ctx, serverSocket) + }), }, }, }) diff --git a/fsimpl/test/rwvmtests/rw_linux_test.go b/fsimpl/test/rwvmtests/rw_linux_test.go index 329b556..c3703f2 100644 --- a/fsimpl/test/rwvmtests/rw_linux_test.go +++ b/fsimpl/test/rwvmtests/rw_linux_test.go @@ -13,6 +13,7 @@ import ( "path/filepath" "reflect" "sort" + "sync" "testing" "github.com/hugelgupf/p9/fsimpl/localfs" @@ -209,12 +210,18 @@ func TestGuestServer(t *testing.T) { if err != nil { t.Fatalf("err binding: %v", err) } - defer serverSocket.Close() serverPort := serverSocket.Addr().(*net.TCPAddr).Port // Run the server. s := p9.NewServer(localfs.Attacher(tmp), p9.WithServerLogger(ulogtest.Logger{TB: t})) - go s.Serve(serverSocket) + var wg sync.WaitGroup + wg.Add(1) + go func() { + s.Serve(serverSocket) + wg.Done() + }() + defer wg.Wait() + defer serverSocket.Close() targetDir := "/guesttarget" if err := os.MkdirAll(targetDir, 0755); err != nil { @@ -278,5 +285,4 @@ func TestGuestServer(t *testing.T) { t.Errorf("Listxattr = %v, want %v", xattrs, attrs) } }) - } diff --git a/p9/server.go b/p9/server.go index 78a2955..3ae8ec3 100644 --- a/p9/server.go +++ b/p9/server.go @@ -15,11 +15,13 @@ package p9 import ( + "context" "errors" "fmt" "io" "net" "runtime/debug" + "strings" "sync" "sync/atomic" @@ -474,6 +476,9 @@ func (cs *connState) WaitTag(t tag) { // The recvDone channel is signaled when recv is done (with a error if // necessary). The sendDone channel is signaled with the result of the send. func (cs *connState) handleRequest() bool { + cs.pendingWg.Add(1) + defer cs.pendingWg.Done() + // Obtain the right to receive a message from cs.t. atomic.AddInt32(&cs.recvIdle, 1) cs.recvMu.Lock() @@ -495,8 +500,10 @@ func (cs *connState) handleRequest() bool { // Receive a message. tag, m, err := recv(cs.server.log, cs.t, messageSize, msgDotLRegistry.get) if errSocket, ok := err.(ConnError); ok { - // Connection problem; stop serving. - cs.server.log.Printf("p9.recv: %v", errSocket.error) + if errSocket.error != io.EOF { + // Connection problem; stop serving. + cs.server.log.Printf("p9.recv: %v", errSocket.error) + } cs.recvShutdown = true cs.recvMu.Unlock() return false @@ -523,6 +530,7 @@ func (cs *connState) handleRequest() bool { // Try to start the tag. if !cs.StartTag(tag) { + cs.server.log.Printf("no valid tag [%05d]", tag) // Nothing we can do at this point; client is bogus. return true } @@ -549,9 +557,7 @@ func (cs *connState) handleRequest() bool { } func (cs *connState) handle(m message) (r message) { - cs.pendingWg.Add(1) defer func() { - cs.pendingWg.Done() if r == nil { // Don't allow a panic to propagate. err := recover() @@ -585,9 +591,9 @@ func (cs *connState) handleRequests() { } func (cs *connState) stop() { - // Wait for completion of all inflight requests. If a request is stuck, - // something has the opportunity to kill us with SIGABRT to get a stack - // dump of the offending handler. + // Wait for completion of all inflight request goroutines.. If a + // request is stuck, something has the opportunity to kill us with + // SIGABRT to get a stack dump of the offending handler. cs.pendingWg.Wait() // Ensure the connection is closed. @@ -619,19 +625,56 @@ func (s *Server) Handle(t io.ReadCloser, r io.WriteCloser) error { return nil } +func isErrClosing(err error) bool { + return strings.Contains(err.Error(), "use of closed network connection") +} + // Serve handles requests from the bound socket. // // The passed serverSocket _must_ be created in packet mode. func (s *Server) Serve(serverSocket net.Listener) error { + return s.ServeContext(nil, serverSocket) +} + +var errAlreadyClosed = errors.New("already closed") + +// ServeContext handles requests from the bound socket. +// +// The passed serverSocket _must_ be created in packet mode. +// +// When the context is done, the listener is closed and serve returns once +// every request has been handled. +func (s *Server) ServeContext(ctx context.Context, serverSocket net.Listener) error { var wg sync.WaitGroup defer wg.Wait() + var cancelCause context.CancelCauseFunc + if ctx != nil { + ctx, cancelCause = context.WithCancelCause(ctx) + + wg.Add(1) + go func() { + defer wg.Done() + <-ctx.Done() + + // Only close the server socket if it wasn't already closed. + if err := ctx.Err(); errors.Is(err, errAlreadyClosed) { + return + } + serverSocket.Close() + }() + } + for { conn, err := serverSocket.Accept() if err != nil { + if cancelCause != nil { + cancelCause(errAlreadyClosed) + } + if isErrClosing(err) { + return nil + } // Something went wrong. - // - // Socket closed? return err }