Skip to content

Commit

Permalink
Count entire request handler goroutine as pending
Browse files Browse the repository at this point in the history
Signed-off-by: Chris Koch <chrisko@google.com>
  • Loading branch information
hugelgupf committed Aug 22, 2023
1 parent 74365fc commit 94ba285
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 26 deletions.
6 changes: 4 additions & 2 deletions fsimpl/composefs/composefs_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package composefs

import (
"context"
"encoding/json"
"fmt"
"net"
Expand All @@ -29,7 +30,6 @@ func TestLinuxClient(t *testing.T) {
if err != nil {
t.Fatalf("err binding: %v", err)
}
defer serverSocket.Close()
serverPort := serverSocket.Addr().(*net.TCPAddr).Port

localfsTmp := t.TempDir()
Expand Down Expand Up @@ -69,7 +69,6 @@ func TestLinuxClient(t *testing.T) {

// Run the server.
s := p9.NewServer(attacher, p9.WithServerLogger(ulogtest.Logger{TB: t}))
go s.Serve(serverSocket)

// Run the read tests from fsimpl/test/rovmtests.
vmtest.RunGoTestsInVM(t, []string{"github.com/hugelgupf/p9/fsimpl/test/rovmtests"}, &vmtest.UrootFSOptions{
Expand All @@ -90,6 +89,9 @@ func TestLinuxClient(t *testing.T) {
Mask: net.CIDRMask(24, 32),
}),
qemu.WithVMTimeout(30 * time.Second),
qemu.WithTask(func(ctx context.Context, n *qemu.Notifications) error {
return s.ServeContext(ctx, serverSocket)
}),
},
},
})
Expand Down
18 changes: 8 additions & 10 deletions fsimpl/localfs/localfs_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package localfs

import (
"context"
"fmt"
"io/ioutil"
"net"
Expand All @@ -20,22 +21,15 @@ import (
)

func TestIntegration(t *testing.T) {
tempDir, err := ioutil.TempDir("", "localfs-")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempDir)

serverSocket, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatalf("err binding: %v", err)
}
defer serverSocket.Close()
serverPort := serverSocket.Addr().(*net.TCPAddr).Port

// Run the server.
tempDir := t.TempDir()
s := p9.NewServer(Attacher(tempDir), p9.WithServerLogger(ulogtest.Logger{TB: t}))
go s.Serve(serverSocket)

// Run the read-write tests from fsimpl/test/rwvm.
vmtest.RunGoTestsInVM(t, []string{"github.com/hugelgupf/p9/fsimpl/test/rwvmtests"}, &vmtest.UrootFSOptions{
Expand All @@ -57,6 +51,9 @@ func TestIntegration(t *testing.T) {
Mask: net.CIDRMask(24, 32),
}),
qemu.WithVMTimeout(30 * time.Second),
qemu.WithTask(func(ctx context.Context, n *qemu.Notifications) error {
return s.ServeContext(ctx, serverSocket)
}),
},
},
})
Expand All @@ -74,12 +71,10 @@ func TestBenchmark(t *testing.T) {
if err != nil {
t.Fatalf("err binding: %v", err)
}
defer serverSocket.Close()
serverPort := serverSocket.Addr().(*net.TCPAddr).Port

// Run the server. No logger -- slows down the benchmark.
s := p9.NewServer(Attacher(tempDir)) //, p9.WithServerLogger(ulogtest.Logger{TB: t}))
go s.Serve(serverSocket)

// Run the read-write tests from fsimpl/test/rwvm.
vmtest.RunGoTestsInVM(t, []string{"github.com/hugelgupf/p9/fsimpl/test/benchmark"}, &vmtest.UrootFSOptions{
Expand All @@ -101,6 +96,9 @@ func TestBenchmark(t *testing.T) {
Mask: net.CIDRMask(24, 32),
}),
qemu.WithVMTimeout(30 * time.Second),
qemu.WithTask(func(ctx context.Context, n *qemu.Notifications) error {
return s.ServeContext(ctx, serverSocket)
}),
},
},
})
Expand Down
6 changes: 4 additions & 2 deletions fsimpl/staticfs/staticfs_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
package staticfs

import (
"context"
"encoding/json"
"fmt"
"net"
Expand All @@ -28,7 +29,6 @@ func TestLinuxClient(t *testing.T) {
if err != nil {
t.Fatalf("err binding: %v", err)
}
defer serverSocket.Close()
serverPort := serverSocket.Addr().(*net.TCPAddr).Port

wantRoot := []string{
Expand Down Expand Up @@ -73,7 +73,6 @@ func TestLinuxClient(t *testing.T) {

// Run the server.
s := p9.NewServer(attacher, p9.WithServerLogger(ulogtest.Logger{TB: t}))
go s.Serve(serverSocket)

// Run the read tests from fsimpl/test/rovmtests.
vmtest.RunGoTestsInVM(t, []string{"github.com/hugelgupf/p9/fsimpl/test/rovmtests"}, &vmtest.UrootFSOptions{
Expand All @@ -94,6 +93,9 @@ func TestLinuxClient(t *testing.T) {
Mask: net.CIDRMask(24, 32),
}),
qemu.WithVMTimeout(30 * time.Second),
qemu.WithTask(func(ctx context.Context, n *qemu.Notifications) error {
return s.ServeContext(ctx, serverSocket)
}),
},
},
})
Expand Down
12 changes: 9 additions & 3 deletions fsimpl/test/rwvmtests/rw_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"path/filepath"
"reflect"
"sort"
"sync"
"testing"

"github.com/hugelgupf/p9/fsimpl/localfs"
Expand Down Expand Up @@ -209,12 +210,18 @@ func TestGuestServer(t *testing.T) {
if err != nil {
t.Fatalf("err binding: %v", err)
}
defer serverSocket.Close()
serverPort := serverSocket.Addr().(*net.TCPAddr).Port

// Run the server.
s := p9.NewServer(localfs.Attacher(tmp), p9.WithServerLogger(ulogtest.Logger{TB: t}))
go s.Serve(serverSocket)
var wg sync.WaitGroup
wg.Add(1)
go func() {
s.Serve(serverSocket)
wg.Done()
}()
defer wg.Wait()
defer serverSocket.Close()

targetDir := "/guesttarget"
if err := os.MkdirAll(targetDir, 0755); err != nil {
Expand Down Expand Up @@ -278,5 +285,4 @@ func TestGuestServer(t *testing.T) {
t.Errorf("Listxattr = %v, want %v", xattrs, attrs)
}
})

}
61 changes: 52 additions & 9 deletions p9/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
package p9

import (
"context"
"errors"
"fmt"
"io"
"net"
"runtime/debug"
"strings"
"sync"
"sync/atomic"

Expand Down Expand Up @@ -474,6 +476,9 @@ func (cs *connState) WaitTag(t tag) {
// The recvDone channel is signaled when recv is done (with a error if
// necessary). The sendDone channel is signaled with the result of the send.
func (cs *connState) handleRequest() bool {
cs.pendingWg.Add(1)
defer cs.pendingWg.Done()

// Obtain the right to receive a message from cs.t.
atomic.AddInt32(&cs.recvIdle, 1)
cs.recvMu.Lock()
Expand All @@ -495,8 +500,10 @@ func (cs *connState) handleRequest() bool {
// Receive a message.
tag, m, err := recv(cs.server.log, cs.t, messageSize, msgDotLRegistry.get)
if errSocket, ok := err.(ConnError); ok {
// Connection problem; stop serving.
cs.server.log.Printf("p9.recv: %v", errSocket.error)
if errSocket.error != io.EOF {
// Connection problem; stop serving.
cs.server.log.Printf("p9.recv: %v", errSocket.error)
}
cs.recvShutdown = true
cs.recvMu.Unlock()
return false
Expand All @@ -523,6 +530,7 @@ func (cs *connState) handleRequest() bool {

// Try to start the tag.
if !cs.StartTag(tag) {
cs.server.log.Printf("no valid tag [%05d]", tag)
// Nothing we can do at this point; client is bogus.
return true
}
Expand All @@ -549,9 +557,7 @@ func (cs *connState) handleRequest() bool {
}

func (cs *connState) handle(m message) (r message) {
cs.pendingWg.Add(1)
defer func() {
cs.pendingWg.Done()
if r == nil {
// Don't allow a panic to propagate.
err := recover()
Expand Down Expand Up @@ -585,9 +591,9 @@ func (cs *connState) handleRequests() {
}

func (cs *connState) stop() {
// Wait for completion of all inflight requests. If a request is stuck,
// something has the opportunity to kill us with SIGABRT to get a stack
// dump of the offending handler.
// Wait for completion of all inflight request goroutines.. If a
// request is stuck, something has the opportunity to kill us with
// SIGABRT to get a stack dump of the offending handler.
cs.pendingWg.Wait()

// Ensure the connection is closed.
Expand Down Expand Up @@ -619,19 +625,56 @@ func (s *Server) Handle(t io.ReadCloser, r io.WriteCloser) error {
return nil
}

func isErrClosing(err error) bool {
return strings.Contains(err.Error(), "use of closed network connection")
}

// Serve handles requests from the bound socket.
//
// The passed serverSocket _must_ be created in packet mode.
func (s *Server) Serve(serverSocket net.Listener) error {
return s.ServeContext(nil, serverSocket)
}

var errAlreadyClosed = errors.New("already closed")

// ServeContext handles requests from the bound socket.
//
// The passed serverSocket _must_ be created in packet mode.
//
// When the context is done, the listener is closed and serve returns once
// every request has been handled.
func (s *Server) ServeContext(ctx context.Context, serverSocket net.Listener) error {
var wg sync.WaitGroup
defer wg.Wait()

var cancelCause context.CancelCauseFunc
if ctx != nil {
ctx, cancelCause = context.WithCancelCause(ctx)

wg.Add(1)
go func() {
defer wg.Done()
<-ctx.Done()

// Only close the server socket if it wasn't already closed.
if err := ctx.Err(); errors.Is(err, errAlreadyClosed) {
return
}
serverSocket.Close()
}()
}

for {
conn, err := serverSocket.Accept()
if err != nil {
if cancelCause != nil {
cancelCause(errAlreadyClosed)
}
if isErrClosing(err) {
return nil
}
// Something went wrong.
//
// Socket closed?
return err
}

Expand Down

0 comments on commit 94ba285

Please sign in to comment.