diff --git a/modal-go/client.go b/modal-go/client.go index 3a18c5ac..7a96bee0 100644 --- a/modal-go/client.go +++ b/modal-go/client.go @@ -275,6 +275,7 @@ type retryCallOption struct { const ( apiEndpoint = "api.modal.com:443" maxMessageSize = 100 * 1024 * 1024 // 100 MB + windowSize = 64 * 1024 * 1024 // 64 MiB defaultRetryAttempts = 3 defaultRetryBaseDelay = 100 * time.Millisecond defaultRetryMaxDelay = 1 * time.Second @@ -334,6 +335,8 @@ func newClient(ctx context.Context, profile Profile, c *Client, customUnaryInter conn, err := grpc.NewClient( target, grpc.WithTransportCredentials(creds), + grpc.WithInitialWindowSize(windowSize), + grpc.WithInitialConnWindowSize(windowSize), grpc.WithDefaultCallOptions( grpc.MaxCallRecvMsgSize(maxMessageSize), grpc.MaxCallSendMsgSize(maxMessageSize), diff --git a/modal-go/config.go b/modal-go/config.go index 08b62b99..abcdf90b 100644 --- a/modal-go/config.go +++ b/modal-go/config.go @@ -8,29 +8,32 @@ import ( "fmt" "os" "path/filepath" + "strings" "github.com/pelletier/go-toml/v2" ) // Profile holds a fully-resolved configuration ready for use by the client. type Profile struct { - ServerURL string - TokenID string - TokenSecret string - Environment string - ImageBuilderVersion string - LogLevel string + ServerURL string + TokenID string + TokenSecret string + Environment string + ImageBuilderVersion string + LogLevel string + TaskCommandRouterInsecure bool } // rawProfile mirrors the TOML structure on disk. type rawProfile struct { - ServerURL string `toml:"server_url"` - TokenID string `toml:"token_id"` - TokenSecret string `toml:"token_secret"` - Environment string `toml:"environment"` - ImageBuilderVersion string `toml:"image_builder_version"` - LogLevel string `toml:"loglevel"` - Active bool `toml:"active"` + ServerURL string `toml:"server_url"` + TokenID string `toml:"token_id"` + TokenSecret string `toml:"token_secret"` + Environment string `toml:"environment"` + ImageBuilderVersion string `toml:"image_builder_version"` + LogLevel string `toml:"loglevel"` + Active bool `toml:"active"` + TaskCommandRouterInsecure bool `toml:"task_command_router_insecure"` } type config map[string]rawProfile @@ -96,13 +99,19 @@ func getProfile(name string, cfg config) Profile { imageBuilderVersion := firstNonEmpty(os.Getenv("MODAL_IMAGE_BUILDER_VERSION"), raw.ImageBuilderVersion) logLevel := firstNonEmpty(os.Getenv("MODAL_LOGLEVEL"), raw.LogLevel) + taskCommandRouterInsecure := raw.TaskCommandRouterInsecure + if envVal := os.Getenv("MODAL_TASK_COMMAND_ROUTER_INSECURE"); envVal != "" { + taskCommandRouterInsecure = strings.ToLower(envVal) == "true" || envVal == "1" + } + return Profile{ - ServerURL: serverURL, - TokenID: tokenID, - TokenSecret: tokenSecret, - Environment: environment, - ImageBuilderVersion: imageBuilderVersion, - LogLevel: logLevel, + ServerURL: serverURL, + TokenID: tokenID, + TokenSecret: tokenSecret, + Environment: environment, + ImageBuilderVersion: imageBuilderVersion, + LogLevel: logLevel, + TaskCommandRouterInsecure: taskCommandRouterInsecure, } } diff --git a/modal-go/errors.go b/modal-go/errors.go index a63c6bf7..54a679b5 100644 --- a/modal-go/errors.go +++ b/modal-go/errors.go @@ -100,3 +100,12 @@ type SandboxTimeoutError struct { func (e SandboxTimeoutError) Error() string { return "SandboxTimeoutError: " + e.Exception } + +// ExecTimeoutError is returned when a container exec exceeds its execution duration limit. +type ExecTimeoutError struct { + Exception string +} + +func (e ExecTimeoutError) Error() string { + return "ExecTimeoutError: " + e.Exception +} diff --git a/modal-go/sandbox.go b/modal-go/sandbox.go index ab5bc7ae..d0563979 100644 --- a/modal-go/sandbox.go +++ b/modal-go/sandbox.go @@ -13,6 +13,7 @@ import ( "github.com/djherbis/buffer" "github.com/djherbis/nio/v3" + "github.com/google/uuid" pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" @@ -366,10 +367,13 @@ type Sandbox struct { Stdout io.ReadCloser Stderr io.ReadCloser - taskID string - tunnels map[int]*Tunnel + taskID string + tunnels map[int]*Tunnel + commandRouterClient *TaskCommandRouterClient client *Client + + mu sync.Mutex // protects commandRouterClient } func defaultSandboxPTYInfo() *pb.PTYInfo { @@ -463,12 +467,41 @@ type SandboxExecParams struct { PTY bool } -// buildContainerExecRequestProto builds a ContainerExecRequest proto from command and options. -func buildContainerExecRequestProto(taskID string, command []string, params SandboxExecParams) (*pb.ContainerExecRequest, error) { - var workdir *string - if params.Workdir != "" { - workdir = ¶ms.Workdir +// ValidateExecArgs checks if command arguments exceed ARG_MAX. +func ValidateExecArgs(args []string) error { + // The maximum number of bytes that can be passed to an exec on Linux. + // Though this is technically a 'server side' limit, it is unlikely to change. + // getconf ARG_MAX will show this value on a host. + // + // By probing in production, the limit is 131072 bytes (2**17). + // We need some bytes of overhead for the rest of the command line besides the args, + // e.g. 'runsc exec ...'. So we use 2**16 as the limit. + + argMaxBytes := 1 << 16 + + // Avoid "[Errno 7] Argument list too long" errors. + totalLen := 0 + for _, arg := range args { + totalLen += len(arg) + } + if totalLen > argMaxBytes { + return InvalidError{Exception: fmt.Sprintf( + "Total length of CMD arguments must be less than %d bytes. Got %d bytes.", + argMaxBytes, totalLen, + )} } + return nil +} + +// buildTaskExecStartRequestProto builds a TaskExecStartRequest proto from command and options. +func buildTaskExecStartRequestProto(taskID, execID string, command []string, params SandboxExecParams) (*pb.TaskExecStartRequest, error) { + if params.Timeout < 0 { + return nil, fmt.Errorf("timeout must be non-negative, got %v", params.Timeout) + } + if params.Timeout != 0 && params.Timeout%time.Second != 0 { + return nil, fmt.Errorf("timeout must be a whole number of seconds, got %v", params.Timeout) + } + secretIds := []string{} for _, secret := range params.Secrets { if secret != nil { @@ -476,27 +509,53 @@ func buildContainerExecRequestProto(taskID string, command []string, params Sand } } + var stdoutConfig pb.TaskExecStdoutConfig + switch params.Stdout { + case Pipe, "": + stdoutConfig = pb.TaskExecStdoutConfig_TASK_EXEC_STDOUT_CONFIG_PIPE + case Ignore: + stdoutConfig = pb.TaskExecStdoutConfig_TASK_EXEC_STDOUT_CONFIG_DEVNULL + default: + return nil, fmt.Errorf("unsupported stdout behavior: %s", params.Stdout) + } + + var stderrConfig pb.TaskExecStderrConfig + switch params.Stderr { + case Pipe, "": + stderrConfig = pb.TaskExecStderrConfig_TASK_EXEC_STDERR_CONFIG_PIPE + case Ignore: + stderrConfig = pb.TaskExecStderrConfig_TASK_EXEC_STDERR_CONFIG_DEVNULL + default: + return nil, fmt.Errorf("unsupported stderr behavior: %s", params.Stderr) + } + var ptyInfo *pb.PTYInfo if params.PTY { ptyInfo = defaultSandboxPTYInfo() } - if params.Timeout < 0 { - return nil, fmt.Errorf("timeout must be non-negative, got %v", params.Timeout) + builder := pb.TaskExecStartRequest_builder{ + TaskId: taskID, + ExecId: execID, + CommandArgs: command, + StdoutConfig: stdoutConfig, + StderrConfig: stderrConfig, + Workdir: nil, + SecretIds: secretIds, + PtyInfo: ptyInfo, + RuntimeDebug: false, } - if params.Timeout%time.Second != 0 { - return nil, fmt.Errorf("timeout must be a whole number of seconds, got %v", params.Timeout) + + if params.Workdir != "" { + builder.Workdir = ¶ms.Workdir } - timeoutSecs := uint32(params.Timeout / time.Second) - return pb.ContainerExecRequest_builder{ - TaskId: taskID, - Command: command, - Workdir: workdir, - TimeoutSecs: timeoutSecs, - SecretIds: secretIds, - PtyInfo: ptyInfo, - }.Build(), nil + if params.Timeout > 0 { + timeoutSecs := uint32(params.Timeout / time.Second) + builder.TimeoutSecs = &timeoutSecs + } + + return builder.Build(), nil } // Exec runs a command in the Sandbox and returns text streams. @@ -505,6 +564,10 @@ func (sb *Sandbox) Exec(ctx context.Context, command []string, params *SandboxEx params = &SandboxExecParams{} } + if err := ValidateExecArgs(command); err != nil { + return nil, err + } + if err := sb.ensureTaskID(ctx); err != nil { return nil, err } @@ -518,19 +581,34 @@ func (sb *Sandbox) Exec(ctx context.Context, command []string, params *SandboxEx mergedParams.Secrets = mergedSecrets mergedParams.Env = nil // nil'ing Env just to clarify it's not needed anymore - req, err := buildContainerExecRequestProto(sb.taskID, command, mergedParams) + commandRouterClient, err := sb.getOrCreateCommandRouterClient(ctx, sb.taskID) + if err != nil { + return nil, err + } + + execID := uuid.New().String() + req, err := buildTaskExecStartRequestProto(sb.taskID, execID, command, mergedParams) if err != nil { return nil, err } - resp, err := sb.client.cpClient.ContainerExec(ctx, req) + + _, err = commandRouterClient.ExecStart(ctx, req) if err != nil { return nil, err } + sb.client.logger.DebugContext(ctx, "Created ContainerProcess", - "exec_id", resp.GetExecId(), + "exec_id", execID, "sandbox_id", sb.SandboxID, "command", command) - return newContainerProcess(sb.client.cpClient, sb.client.logger, resp.GetExecId(), *params), nil + + var deadline *time.Time + if mergedParams.Timeout > 0 { + d := time.Now().Add(mergedParams.Timeout) + deadline = &d + } + + return newContainerProcess(commandRouterClient, sb.client.logger, sb.taskID, execID, mergedParams, deadline), nil } // SandboxCreateConnectTokenParams are optional parameters for CreateConnectToken. @@ -606,8 +684,50 @@ func (sb *Sandbox) ensureTaskID(ctx context.Context) error { return nil } +func (sb *Sandbox) getOrCreateCommandRouterClient(ctx context.Context, taskID string) (*TaskCommandRouterClient, error) { + sb.mu.Lock() + defer sb.mu.Unlock() + + if sb.commandRouterClient == nil { + client, err := TryInitTaskCommandRouterClient( + ctx, + sb.client.cpClient, + taskID, + sb.client.logger, + sb.client.profile, + ) + if err != nil { + return nil, err + } + if client == nil { + return nil, fmt.Errorf("command router access is not available for this sandbox") + } + sb.commandRouterClient = client + } + return sb.commandRouterClient, nil +} + +// Close task command router client +func (sb *Sandbox) closeTaskCommandRouterClient() error { + sb.mu.Lock() + defer sb.mu.Unlock() + if sb.commandRouterClient != nil { + err := sb.commandRouterClient.Close() + if err != nil { + return err + } + sb.commandRouterClient = nil + } + return nil +} + // Terminate stops the Sandbox. +// The Stdin, Stdout, Stderr streams are not closed. func (sb *Sandbox) Terminate(ctx context.Context) error { + if err := sb.closeTaskCommandRouterClient(); err != nil { + return err + } + _, err := sb.client.cpClient.SandboxTerminate(ctx, pb.SandboxTerminateRequest_builder{ SandboxId: sb.SandboxID, }.Build()) @@ -827,11 +947,13 @@ type ContainerProcess struct { Stdout io.ReadCloser Stderr io.ReadCloser - execID string - cpClient pb.ModalClientClient + taskID string + execID string + commandRouterClient *TaskCommandRouterClient + deadline *time.Time } -func newContainerProcess(cpClient pb.ModalClientClient, logger *slog.Logger, execID string, params SandboxExecParams) *ContainerProcess { +func newContainerProcess(commandRouterClient *TaskCommandRouterClient, logger *slog.Logger, taskID, execID string, params SandboxExecParams, deadline *time.Time) *ContainerProcess { stdoutBehavior := Pipe stderrBehavior := Pipe if params.Stdout != "" { @@ -841,15 +963,20 @@ func newContainerProcess(cpClient pb.ModalClientClient, logger *slog.Logger, exe stderrBehavior = params.Stderr } - cp := &ContainerProcess{execID: execID, cpClient: cpClient} - cp.Stdin = inputStreamCp(cpClient, execID) + cp := &ContainerProcess{ + taskID: taskID, + execID: execID, + commandRouterClient: commandRouterClient, + deadline: deadline, + } + cp.Stdin = inputStreamCp(commandRouterClient, taskID, execID) if stdoutBehavior == Ignore { cp.Stdout = io.NopCloser(bytes.NewReader(nil)) } else { cp.Stdout = &lazyStreamReader{ initFunc: func() io.ReadCloser { - return outputStreamCp(cpClient, logger, execID, pb.FileDescriptor_FILE_DESCRIPTOR_STDOUT) + return outputStreamCp(commandRouterClient, logger, taskID, execID, pb.FileDescriptor_FILE_DESCRIPTOR_STDOUT, deadline) }, } } @@ -858,7 +985,7 @@ func newContainerProcess(cpClient pb.ModalClientClient, logger *slog.Logger, exe } else { cp.Stderr = &lazyStreamReader{ initFunc: func() io.ReadCloser { - return outputStreamCp(cpClient, logger, execID, pb.FileDescriptor_FILE_DESCRIPTOR_STDERR) + return outputStreamCp(commandRouterClient, logger, taskID, execID, pb.FileDescriptor_FILE_DESCRIPTOR_STDERR, deadline) }, } } @@ -868,22 +995,11 @@ func newContainerProcess(cpClient pb.ModalClientClient, logger *slog.Logger, exe // Wait blocks until the container process exits and returns its exit code. func (cp *ContainerProcess) Wait(ctx context.Context) (int, error) { - for { - if err := ctx.Err(); err != nil { - return 0, err - } - - resp, err := cp.cpClient.ContainerExecWait(ctx, pb.ContainerExecWaitRequest_builder{ - ExecId: cp.execID, - Timeout: 55, - }.Build()) - if err != nil { - return 0, err - } - if resp.GetCompleted() { - return int(resp.GetExitCode()), nil - } + resp, err := cp.commandRouterClient.ExecWait(ctx, cp.taskID, cp.execID, cp.deadline) + if err != nil { + return 0, err } + return int(resp.GetCode()), nil } func inputStreamSb(cpClient pb.ModalClientClient, sandboxID string) io.WriteCloser { @@ -922,43 +1038,34 @@ func (sbs *sbStdin) Close() error { Index: sbs.index, Eof: true, }.Build()) + if st, ok := status.FromError(err); ok && st.Code() == codes.FailedPrecondition { + return nil + } return err } -func inputStreamCp(cpClient pb.ModalClientClient, execID string) io.WriteCloser { - return &cpStdin{execID: execID, messageIndex: 1, cpClient: cpClient} +func inputStreamCp(commandRouterClient *TaskCommandRouterClient, taskID, execID string) io.WriteCloser { + return &cpStdin{taskID: taskID, execID: execID, offset: 0, commandRouterClient: commandRouterClient} } type cpStdin struct { - execID string - messageIndex uint64 - cpClient pb.ModalClientClient + taskID string + execID string + commandRouterClient *TaskCommandRouterClient + offset uint64 } func (cps *cpStdin) Write(p []byte) (int, error) { - _, err := cps.cpClient.ContainerExecPutInput(context.Background(), pb.ContainerExecPutInputRequest_builder{ - ExecId: cps.execID, - Input: pb.RuntimeInputMessage_builder{ - Message: p, - MessageIndex: cps.messageIndex, - }.Build(), - }.Build()) + err := cps.commandRouterClient.ExecStdinWrite(context.Background(), cps.taskID, cps.execID, cps.offset, p, false) if err != nil { return 0, err } - cps.messageIndex++ + cps.offset += uint64(len(p)) return len(p), nil } func (cps *cpStdin) Close() error { - _, err := cps.cpClient.ContainerExecPutInput(context.Background(), pb.ContainerExecPutInputRequest_builder{ - ExecId: cps.execID, - Input: pb.RuntimeInputMessage_builder{ - MessageIndex: cps.messageIndex, - Eof: true, - }.Build(), - }.Build()) - return err + return cps.commandRouterClient.ExecStdinWrite(context.Background(), cps.taskID, cps.execID, cps.offset, nil, true) } // cancelOnCloseReader is used to cancel background goroutines when the stream is closed. @@ -1068,7 +1175,7 @@ func outputStreamSb(cpClient pb.ModalClientClient, logger *slog.Logger, sandboxI return &cancelOnCloseReader{ReadCloser: pr, cancel: cancel} } -func outputStreamCp(cpClient pb.ModalClientClient, logger *slog.Logger, execID string, fd pb.FileDescriptor) io.ReadCloser { +func outputStreamCp(commandRouterClient *TaskCommandRouterClient, logger *slog.Logger, taskID, execID string, fd pb.FileDescriptor, deadline *time.Time) io.ReadCloser { pr, pw := nio.Pipe(buffer.New(64 * 1024)) ctx, cancel := context.WithCancel(context.Background()) go func() { @@ -1078,61 +1185,21 @@ func outputStreamCp(cpClient pb.ModalClientClient, logger *slog.Logger, execID s } }() defer cancel() - var lastIndex uint64 - completed := false - retries := 10 - for !completed { - stream, err := cpClient.ContainerExecGetOutput(ctx, pb.ContainerExecGetOutputRequest_builder{ - ExecId: execID, - FileDescriptor: fd, - Timeout: 55, - GetRawBytes: true, - LastBatchIndex: lastIndex, - }.Build()) - if err != nil { + + resultCh := commandRouterClient.ExecStdioRead(ctx, taskID, execID, fd, deadline) + for result := range resultCh { + if result.Err != nil { if ctx.Err() != nil { return } - if isRetryableGrpc(err) && retries > 0 { - retries-- - continue - } - streamErr := fmt.Errorf("error getting output stream: %w", err) + streamErr := fmt.Errorf("error getting output stream: %w", result.Err) if closeErr := pw.CloseWithError(streamErr); closeErr != nil { logger.DebugContext(ctx, "failed to close pipe writer with error", "error", closeErr.Error(), "stream_error", streamErr.Error()) } return } - for { - batch, err := stream.Recv() - if err != nil { - if ctx.Err() != nil { - return - } - if err != io.EOF { - if isRetryableGrpc(err) && retries > 0 { - retries-- - } else { - streamErr := fmt.Errorf("error getting output stream: %w", err) - if closeErr := pw.CloseWithError(streamErr); closeErr != nil { - logger.DebugContext(ctx, "failed to close pipe writer with error", "error", closeErr.Error(), "stream_error", streamErr.Error()) - } - return - } - } - break // we need to retry, either from an EOF or gRPC error - } - lastIndex = batch.GetBatchIndex() - for _, item := range batch.GetItems() { - // On error, writer has been closed. Still consume the rest of the channel. - if _, err := pw.Write(item.GetMessageBytes()); err != nil { - logger.DebugContext(ctx, "failed to write to pipe", "error", err.Error()) - } - } - if batch.HasExitCode() { - completed = true - break - } + if _, err := pw.Write(result.Response.GetData()); err != nil { + logger.DebugContext(ctx, "failed to write to pipe", "error", err.Error()) } } }() diff --git a/modal-go/sandbox_test.go b/modal-go/sandbox_test.go index 75b8d942..18d418c7 100644 --- a/modal-go/sandbox_test.go +++ b/modal-go/sandbox_test.go @@ -1,7 +1,9 @@ package modal import ( + "bytes" "testing" + "time" pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto" "github.com/onsi/gomega" @@ -38,18 +40,18 @@ func TestSandboxCreateRequestProto_WithPTY(t *testing.T) { g.Expect(ptyInfo.GetPtyType()).To(gomega.Equal(pb.PTYInfo_PTY_TYPE_SHELL)) } -func TestContainerExecProto_WithoutPTY(t *testing.T) { +func TestTaskExecStartProto_WithoutPTY(t *testing.T) { g := gomega.NewWithT(t) - req, err := buildContainerExecRequestProto("task-123", []string{"bash"}, SandboxExecParams{}) + req, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"bash"}, SandboxExecParams{}) g.Expect(err).ShouldNot(gomega.HaveOccurred()) ptyInfo := req.GetPtyInfo() g.Expect(ptyInfo).Should(gomega.BeNil()) } -func TestContainerExecProto_WithPTY(t *testing.T) { +func TestTaskExecStartProto_WithPTY(t *testing.T) { g := gomega.NewWithT(t) - req, err := buildContainerExecRequestProto("task-123", []string{"bash"}, SandboxExecParams{ + req, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"bash"}, SandboxExecParams{ PTY: true, }) g.Expect(err).ShouldNot(gomega.HaveOccurred()) @@ -65,16 +67,103 @@ func TestContainerExecProto_WithPTY(t *testing.T) { g.Expect(ptyInfo.GetNoTerminateOnIdleStdin()).To(gomega.BeTrue()) } -func TestContainerExecRequestProto_DefaultValues(t *testing.T) { +func TestTaskExecStartRequestProto_DefaultValues(t *testing.T) { g := gomega.NewWithT(t) - req, err := buildContainerExecRequestProto("task-123", []string{"bash"}, SandboxExecParams{}) - g.Expect(err).ShouldNot(gomega.HaveOccurred()) + req, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"bash"}, SandboxExecParams{}) + g.Expect(err).ToNot(gomega.HaveOccurred()) g.Expect(req.GetWorkdir()).To(gomega.BeEmpty()) - g.Expect(req.GetTimeoutSecs()).To(gomega.Equal(uint32(0))) + g.Expect(req.HasTimeoutSecs()).To(gomega.BeFalse()) g.Expect(req.GetSecretIds()).To(gomega.BeEmpty()) g.Expect(req.GetPtyInfo()).To(gomega.BeNil()) + g.Expect(req.GetStdoutConfig()).To(gomega.Equal(pb.TaskExecStdoutConfig_TASK_EXEC_STDOUT_CONFIG_PIPE)) + g.Expect(req.GetStderrConfig()).To(gomega.Equal(pb.TaskExecStderrConfig_TASK_EXEC_STDERR_CONFIG_PIPE)) +} + +func TestTaskExecStartRequestProto_WithStdoutIgnore(t *testing.T) { + g := gomega.NewWithT(t) + + req, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"bash"}, SandboxExecParams{ + Stdout: Ignore, + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + g.Expect(req.GetStdoutConfig()).To(gomega.Equal(pb.TaskExecStdoutConfig_TASK_EXEC_STDOUT_CONFIG_DEVNULL)) + g.Expect(req.GetStderrConfig()).To(gomega.Equal(pb.TaskExecStderrConfig_TASK_EXEC_STDERR_CONFIG_PIPE)) +} + +func TestTaskExecStartRequestProto_WithStderrIgnore(t *testing.T) { + g := gomega.NewWithT(t) + + req, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"bash"}, SandboxExecParams{ + Stderr: Ignore, + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + g.Expect(req.GetStdoutConfig()).To(gomega.Equal(pb.TaskExecStdoutConfig_TASK_EXEC_STDOUT_CONFIG_PIPE)) + g.Expect(req.GetStderrConfig()).To(gomega.Equal(pb.TaskExecStderrConfig_TASK_EXEC_STDERR_CONFIG_DEVNULL)) +} + +func TestTaskExecStartRequestProto_WithWorkdir(t *testing.T) { + g := gomega.NewWithT(t) + + req, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"pwd"}, SandboxExecParams{ + Workdir: "/tmp", + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + g.Expect(req.GetWorkdir()).To(gomega.Equal("/tmp")) +} + +func TestTaskExecStartRequestProto_WithTimeout(t *testing.T) { + g := gomega.NewWithT(t) + timeout := 30 * time.Second + + req, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"sleep", "10"}, SandboxExecParams{ + Timeout: timeout, + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + g.Expect(req.HasTimeoutSecs()).To(gomega.BeTrue()) + g.Expect(req.GetTimeoutSecs()).To(gomega.Equal(uint32(30))) +} + +func TestTaskExecStartRequestProto_InvalidTimeoutNegative(t *testing.T) { + g := gomega.NewWithT(t) + + _, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"echo", "hi"}, SandboxExecParams{ + Timeout: -1 * time.Second, + }) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("must be non-negative")) +} + +func TestTaskExecStartRequestProto_InvalidTimeoutNotWholeSeconds(t *testing.T) { + g := gomega.NewWithT(t) + + _, err := buildTaskExecStartRequestProto("task-123", "exec-456", []string{"echo", "hi"}, SandboxExecParams{ + Timeout: 1500 * time.Millisecond, + }) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("whole number of seconds")) +} + +func TestValidateExecArgsWithArgsWithinLimit(t *testing.T) { + g := gomega.NewWithT(t) + + err := ValidateExecArgs([]string{"echo", "hello"}) + g.Expect(err).ToNot(gomega.HaveOccurred()) +} + +func TestValidateExecArgsWithArgsExceedingArgMax(t *testing.T) { + g := gomega.NewWithT(t) + + largeArg := bytes.Repeat([]byte{'a'}, 1<<16+1) + + err := ValidateExecArgs([]string{string(largeArg)}) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("Total length of CMD arguments must be less than")) } func TestSandboxCreateRequestProto_WithCPUAndCPULimit(t *testing.T) { diff --git a/modal-go/task_command_router_client.go b/modal-go/task_command_router_client.go new file mode 100644 index 00000000..4d227ab8 --- /dev/null +++ b/modal-go/task_command_router_client.go @@ -0,0 +1,650 @@ +package modal + +import ( + "context" + "crypto/tls" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "io" + "log/slog" + "net" + "net/url" + "strings" + "sync" + "sync/atomic" + "time" + + pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto" + "golang.org/x/sync/singleflight" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/metadata" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/emptypb" +) + +// tlsCredsNoALPN is a TLS credential that skips ALPN enforcement, implementing +// the google.golang.org/grpc/credentials#TransportCredentials interface. +// +// Starting in grpc-go v1.67, ALPN is enforced by default for TLS connections. +// However, the task command router server doesn't negotiate ALPN. +// This performs the TLS handshake without that check. +// See: https://github.com/grpc/grpc-go/issues/434 +type tlsCredsNoALPN struct { + insecureSkipVerify bool +} + +func (c *tlsCredsNoALPN) ClientHandshake(ctx context.Context, authority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) { + serverName, _, err := net.SplitHostPort(authority) + if err != nil { + serverName = authority + } + cfg := &tls.Config{ + ServerName: serverName, + NextProtos: []string{"h2"}, + InsecureSkipVerify: c.insecureSkipVerify, + } + + conn := tls.Client(rawConn, cfg) + if err := conn.HandshakeContext(ctx); err != nil { + _ = conn.Close() + return nil, nil, err + } + + return conn, credentials.TLSInfo{ + State: conn.ConnectionState(), + CommonAuthInfo: credentials.CommonAuthInfo{ + SecurityLevel: credentials.PrivacyAndIntegrity, + }, + }, nil +} + +func (c *tlsCredsNoALPN) ServerHandshake(net.Conn) (net.Conn, credentials.AuthInfo, error) { + return nil, nil, errors.New("tlsCredsNoALPN: server-side not supported") +} + +func (c *tlsCredsNoALPN) Info() credentials.ProtocolInfo { + return credentials.ProtocolInfo{SecurityProtocol: "tls", SecurityVersion: "1.2"} +} + +func (c *tlsCredsNoALPN) Clone() credentials.TransportCredentials { + return &tlsCredsNoALPN{insecureSkipVerify: c.insecureSkipVerify} +} + +func (c *tlsCredsNoALPN) OverrideServerName(string) error { + return errors.New("tlsCredsNoALPN: OverrideServerName not supported") +} + +// retryOptions configures retry behavior for callWithRetriesOnTransientErrors. +type retryOptions struct { + BaseDelay time.Duration + DelayFactor float64 + MaxRetries *int // nil means retry forever + Deadline *time.Time +} + +// defaultRetryOptions returns the default retry options. +func defaultRetryOptions() retryOptions { + maxRetries := 10 + return retryOptions{ + BaseDelay: 10 * time.Millisecond, + DelayFactor: 2.0, + MaxRetries: &maxRetries, + Deadline: nil, + } +} + +var commandRouterRetryableCodes = map[codes.Code]struct{}{ + codes.DeadlineExceeded: {}, + codes.Unavailable: {}, + codes.Canceled: {}, + codes.Internal: {}, + codes.Unknown: {}, +} + +// parseJwtExpiration extracts the expiration time from a JWT token. +// Returns (nil, nil) if the token has no exp claim. +// Returns an error if the token is malformed. +func parseJwtExpiration(jwt string) (*int64, error) { + parts := strings.Split(jwt, ".") + if len(parts) != 3 { + return nil, fmt.Errorf("malformed JWT: expected 3 parts, got %d", len(parts)) + } + + payloadB64 := parts[1] + switch len(payloadB64) % 4 { + case 2: + payloadB64 += "==" + case 3: + payloadB64 += "=" + } + + payloadJSON, err := base64.URLEncoding.DecodeString(payloadB64) + if err != nil { + return nil, fmt.Errorf("malformed JWT: base64 decode: %w", err) + } + + var payload struct { + Exp json.Number `json:"exp"` + } + if err := json.Unmarshal(payloadJSON, &payload); err != nil { + return nil, fmt.Errorf("malformed JWT: json unmarshal: %w", err) + } + + if payload.Exp == "" { + return nil, nil + } + + exp, err := payload.Exp.Int64() + if err != nil { + return nil, fmt.Errorf("malformed JWT: exp not an integer: %w", err) + } + + return &exp, nil +} + +var errDeadlineExceeded = errors.New("deadline exceeded") + +// callWithRetriesOnTransientErrors retries the given function on transient gRPC errors. +func callWithRetriesOnTransientErrors[T any]( + ctx context.Context, + fn func() (*T, error), + opts retryOptions, +) (*T, error) { + delay := opts.BaseDelay + numRetries := 0 + + for { + if opts.Deadline != nil && time.Now().After(*opts.Deadline) { + return nil, errDeadlineExceeded + } + + result, err := fn() + if err == nil { + return result, nil + } + + st, ok := status.FromError(err) + if !ok { + return nil, err + } + + if _, retryable := commandRouterRetryableCodes[st.Code()]; !retryable { + return nil, err + } + + if opts.MaxRetries != nil && numRetries >= *opts.MaxRetries { + return nil, err + } + + if opts.Deadline != nil && time.Now().Add(delay).After(*opts.Deadline) { + return nil, errDeadlineExceeded + } + + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + + delay = time.Duration(float64(delay) * opts.DelayFactor) + numRetries++ + } +} + +// TaskCommandRouterClient provides a client for the TaskCommandRouter gRPC service. +type TaskCommandRouterClient struct { + stub pb.TaskCommandRouterClient + conn *grpc.ClientConn + serverClient pb.ModalClientClient + taskID string + serverURL string + jwt atomic.Pointer[string] + jwtExp atomic.Pointer[int64] + logger *slog.Logger + closed atomic.Bool + + // done is closed when Close() is called, signaling all goroutines to stop. + done chan struct{} + closeOnce sync.Once + + refreshJwtGroup singleflight.Group +} + +// TryInitTaskCommandRouterClient attempts to initialize a TaskCommandRouterClient. +// Returns nil if command router access is not available for this task. +func TryInitTaskCommandRouterClient( + ctx context.Context, + serverClient pb.ModalClientClient, + taskID string, + logger *slog.Logger, + profile Profile, +) (*TaskCommandRouterClient, error) { + resp, err := serverClient.TaskGetCommandRouterAccess(ctx, pb.TaskGetCommandRouterAccessRequest_builder{ + TaskId: taskID, + }.Build()) + if err != nil { + if st, ok := status.FromError(err); ok && st.Code() == codes.FailedPrecondition { + logger.DebugContext(ctx, "Command router access is not enabled for task", "task_id", taskID) + return nil, nil + } + return nil, err + } + + logger.DebugContext(ctx, "Using command router access for task", "task_id", taskID, "url", resp.GetUrl()) + + jwt := resp.GetJwt() + jwtExp, err := parseJwtExpiration(jwt) + if err != nil { + return nil, fmt.Errorf("parseJwtExpiration: %w", err) + } + + url, err := url.Parse(resp.GetUrl()) + if err != nil { + return nil, fmt.Errorf("failed to parse task router URL: %w", err) + } + + if url.Scheme != "https" { + return nil, fmt.Errorf("task router URL must be https, got: %s", resp.GetUrl()) + } + + host := url.Hostname() + port := url.Port() + if port == "" { + port = "443" + } + target := fmt.Sprintf("%s:%s", host, port) + + // Use custom TLS credentials that skip ALPN enforcement. + // The command router server may not negotiate ALPN, which causes + // grpc-go v1.67+ to fail the handshake. + creds := &tlsCredsNoALPN{insecureSkipVerify: profile.TaskCommandRouterInsecure} + if profile.TaskCommandRouterInsecure { + logger.WarnContext(ctx, "Using insecure TLS (skip certificate verification) for task command router due to MODAL_TASK_COMMAND_ROUTER_INSECURE") + } + + conn, err := grpc.NewClient( + target, + grpc.WithTransportCredentials(creds), + grpc.WithInitialWindowSize(windowSize), + grpc.WithInitialConnWindowSize(windowSize), + grpc.WithDefaultCallOptions( + grpc.MaxCallRecvMsgSize(maxMessageSize), + grpc.MaxCallSendMsgSize(maxMessageSize), + ), + ) + if err != nil { + return nil, fmt.Errorf("failed to create task command router connection: %w", err) + } + + client := &TaskCommandRouterClient{ + stub: pb.NewTaskCommandRouterClient(conn), + conn: conn, + serverClient: serverClient, + taskID: taskID, + serverURL: resp.GetUrl(), + logger: logger, + done: make(chan struct{}), + } + client.jwt.Store(&jwt) + client.jwtExp.Store(jwtExp) + + logger.DebugContext(ctx, "Successfully initialized command router client", "task_id", taskID) + return client, nil +} + +// Close closes the gRPC connection and cancels all in-flight operations. +func (c *TaskCommandRouterClient) Close() error { + if !c.closed.CompareAndSwap(false, true) { + return nil + } + c.closeOnce.Do(func() { + close(c.done) + }) + if c.conn != nil { + return c.conn.Close() + } + return nil +} + +func (c *TaskCommandRouterClient) authContext(ctx context.Context) context.Context { + return metadata.AppendToOutgoingContext(ctx, "authorization", "Bearer "+*c.jwt.Load()) +} + +func (c *TaskCommandRouterClient) refreshJwt(ctx context.Context) error { + const jwtRefreshBufferSeconds = 30 + + if c.closed.Load() { + return errors.New("client is closed") + } + + // If the current JWT expiration is already far enough in the future, don't refresh. + if exp := c.jwtExp.Load(); exp != nil && *exp-time.Now().Unix() > jwtRefreshBufferSeconds { + c.logger.DebugContext(ctx, "Skipping JWT refresh because expiration is far enough in the future", "task_id", c.taskID) + return nil + } + + _, err, _ := c.refreshJwtGroup.Do("refresh", func() (any, error) { + if exp := c.jwtExp.Load(); exp != nil && *exp-time.Now().Unix() > jwtRefreshBufferSeconds { + return nil, nil + } + + resp, err := c.serverClient.TaskGetCommandRouterAccess(ctx, pb.TaskGetCommandRouterAccessRequest_builder{ + TaskId: c.taskID, + }.Build()) + if err != nil { + return nil, fmt.Errorf("failed to refresh JWT: %w", err) + } + + if resp.GetUrl() != c.serverURL { + return nil, errors.New("task router URL changed during session") + } + + jwt := resp.GetJwt() + c.jwt.Store(&jwt) + jwtExp, err := parseJwtExpiration(jwt) + if err != nil { + // Log warning but continue - we'll refresh on every auth failure instead of proactively. + c.logger.WarnContext(ctx, "parseJwtExpiration during refresh", "error", err) + } + c.jwtExp.Store(jwtExp) + return nil, nil + }) + return err +} + +type RetryableClient interface { + authContext(ctx context.Context) context.Context + refreshJwt(ctx context.Context) error +} + +func callWithAuthRetry[T any](ctx context.Context, c RetryableClient, fn func(context.Context) (*T, error)) (*T, error) { + resp, err := fn(c.authContext(ctx)) + if err != nil { + if st, ok := status.FromError(err); ok && st.Code() == codes.Unauthenticated { + if refreshErr := c.refreshJwt(ctx); refreshErr != nil { + return nil, refreshErr + } + return fn(c.authContext(ctx)) + } + } + return resp, err +} + +// ExecStart starts a command execution. +func (c *TaskCommandRouterClient) ExecStart(ctx context.Context, request *pb.TaskExecStartRequest) (*pb.TaskExecStartResponse, error) { + return callWithRetriesOnTransientErrors(ctx, func() (*pb.TaskExecStartResponse, error) { + return callWithAuthRetry(ctx, c, func(authCtx context.Context) (*pb.TaskExecStartResponse, error) { + return c.stub.TaskExecStart(authCtx, request) + }) + }, defaultRetryOptions()) +} + +// ExecStdinWrite writes data to stdin of an exec. +func (c *TaskCommandRouterClient) ExecStdinWrite(ctx context.Context, taskID, execID string, offset uint64, data []byte, eof bool) error { + request := pb.TaskExecStdinWriteRequest_builder{ + TaskId: taskID, + ExecId: execID, + Offset: offset, + Data: data, + Eof: eof, + }.Build() + + _, err := callWithRetriesOnTransientErrors(ctx, func() (*pb.TaskExecStdinWriteResponse, error) { + return callWithAuthRetry(ctx, c, func(authCtx context.Context) (*pb.TaskExecStdinWriteResponse, error) { + return c.stub.TaskExecStdinWrite(authCtx, request) + }) + }, defaultRetryOptions()) + return err +} + +// ExecPoll polls for the exit status of an exec. +func (c *TaskCommandRouterClient) ExecPoll(ctx context.Context, taskID, execID string, deadline *time.Time) (*pb.TaskExecPollResponse, error) { + request := pb.TaskExecPollRequest_builder{ + TaskId: taskID, + ExecId: execID, + }.Build() + + if deadline != nil && time.Now().After(*deadline) { + return nil, ExecTimeoutError{Exception: fmt.Sprintf("deadline exceeded while polling for exec %s", execID)} + } + + opts := defaultRetryOptions() + opts.Deadline = deadline + + resp, err := callWithRetriesOnTransientErrors(ctx, func() (*pb.TaskExecPollResponse, error) { + return callWithAuthRetry(ctx, c, func(authCtx context.Context) (*pb.TaskExecPollResponse, error) { + return c.stub.TaskExecPoll(authCtx, request) + }) + }, opts) + + if err != nil { + st, ok := status.FromError(err) + if (ok && st.Code() == codes.DeadlineExceeded) || errors.Is(err, errDeadlineExceeded) { + return nil, ExecTimeoutError{Exception: fmt.Sprintf("deadline exceeded while polling for exec %s", execID)} + } + } + return resp, err +} + +// ExecWait waits for an exec to complete and returns the exit code. +func (c *TaskCommandRouterClient) ExecWait(ctx context.Context, taskID, execID string, deadline *time.Time) (*pb.TaskExecWaitResponse, error) { + request := pb.TaskExecWaitRequest_builder{ + TaskId: taskID, + ExecId: execID, + }.Build() + + if deadline != nil && time.Now().After(*deadline) { + return nil, ExecTimeoutError{Exception: fmt.Sprintf("deadline exceeded while waiting for exec %s", execID)} + } + + opts := retryOptions{ + BaseDelay: 1 * time.Second, // Retry after 1s since total time is expected to be long. + DelayFactor: 1, // Fixed delay. + MaxRetries: nil, // Retry forever. + Deadline: deadline, + } + + resp, err := callWithRetriesOnTransientErrors(ctx, func() (*pb.TaskExecWaitResponse, error) { + return callWithAuthRetry(ctx, c, func(authCtx context.Context) (*pb.TaskExecWaitResponse, error) { + // Set a per-call timeout of 60 seconds + callCtx, cancel := context.WithTimeout(authCtx, 60*time.Second) + defer cancel() + return c.stub.TaskExecWait(callCtx, request) + }) + }, opts) + + if err != nil { + st, ok := status.FromError(err) + if (ok && st.Code() == codes.DeadlineExceeded) || errors.Is(err, errDeadlineExceeded) { + return nil, ExecTimeoutError{Exception: fmt.Sprintf("deadline exceeded while waiting for exec %s", execID)} + } + } + return resp, err +} + +// stdioReadResult represents a result from the stdio read stream. +type stdioReadResult struct { + Response *pb.TaskExecStdioReadResponse + Err error +} + +// ExecStdioRead reads stdout or stderr from an exec. +// The returned channel will be closed when the stream ends or an error occurs. +func (c *TaskCommandRouterClient) ExecStdioRead( + ctx context.Context, + taskID, execID string, + fd pb.FileDescriptor, + deadline *time.Time, +) <-chan stdioReadResult { + resultCh := make(chan stdioReadResult) + + go func() { + defer close(resultCh) + + var srFd pb.TaskExecStdioFileDescriptor + switch fd { + case pb.FileDescriptor_FILE_DESCRIPTOR_STDOUT: + srFd = pb.TaskExecStdioFileDescriptor_TASK_EXEC_STDIO_FILE_DESCRIPTOR_STDOUT + case pb.FileDescriptor_FILE_DESCRIPTOR_STDERR: + srFd = pb.TaskExecStdioFileDescriptor_TASK_EXEC_STDIO_FILE_DESCRIPTOR_STDERR + case pb.FileDescriptor_FILE_DESCRIPTOR_INFO, pb.FileDescriptor_FILE_DESCRIPTOR_UNSPECIFIED: + resultCh <- stdioReadResult{Err: fmt.Errorf("unsupported file descriptor: %v", fd)} + return + default: + resultCh <- stdioReadResult{Err: fmt.Errorf("invalid file descriptor: %v", fd)} + return + } + + // Create a context that cancels when either the caller's ctx is done or Close() is called. + // This ensures goroutines exit promptly when Close() is called. + ctx, cancel := context.WithCancel(ctx) + defer cancel() + go func() { + select { + case <-c.done: + cancel() + case <-ctx.Done(): + } + }() + + if deadline != nil { + var deadlineCancel context.CancelFunc + ctx, deadlineCancel = context.WithDeadline(ctx, *deadline) + defer deadlineCancel() + } + c.streamStdio(ctx, resultCh, taskID, execID, srFd) + }() + + return resultCh +} + +// MountDirectory mounts an image at a directory in the container. +func (c *TaskCommandRouterClient) MountDirectory(ctx context.Context, request *pb.TaskMountDirectoryRequest) error { + _, err := callWithRetriesOnTransientErrors(ctx, func() (*emptypb.Empty, error) { + return callWithAuthRetry(ctx, c, func(authCtx context.Context) (*emptypb.Empty, error) { + return c.stub.TaskMountDirectory(authCtx, request) + }) + }, defaultRetryOptions()) + return err +} + +// SnapshotDirectory snapshots a directory into a new image. +func (c *TaskCommandRouterClient) SnapshotDirectory(ctx context.Context, request *pb.TaskSnapshotDirectoryRequest) (*pb.TaskSnapshotDirectoryResponse, error) { + return callWithRetriesOnTransientErrors(ctx, func() (*pb.TaskSnapshotDirectoryResponse, error) { + return callWithAuthRetry(ctx, c, func(authCtx context.Context) (*pb.TaskSnapshotDirectoryResponse, error) { + return c.stub.TaskSnapshotDirectory(authCtx, request) + }) + }, defaultRetryOptions()) +} + +func (c *TaskCommandRouterClient) streamStdio( + ctx context.Context, + resultCh chan<- stdioReadResult, + taskID, execID string, + fd pb.TaskExecStdioFileDescriptor, +) { + deadline, hasDeadline := ctx.Deadline() + + var offset int64 + delay := 10 * time.Millisecond + delayFactor := 2.0 + numRetriesRemaining := 10 + didAuthRetry := false + + for { + if ctx.Err() != nil { + if hasDeadline && ctx.Err() == context.DeadlineExceeded { + resultCh <- stdioReadResult{Err: ExecTimeoutError{Exception: fmt.Sprintf("deadline exceeded while streaming stdio for exec %s", execID)}} + } else { + resultCh <- stdioReadResult{Err: ctx.Err()} + } + return + } + + callCtx := c.authContext(ctx) + + request := pb.TaskExecStdioReadRequest_builder{ + TaskId: taskID, + ExecId: execID, + Offset: uint64(offset), + FileDescriptor: fd, + }.Build() + + stream, err := c.stub.TaskExecStdioRead(callCtx, request) + if err != nil { + if st, ok := status.FromError(err); ok && st.Code() == codes.Unauthenticated && !didAuthRetry { + if refreshErr := c.refreshJwt(ctx); refreshErr != nil { + resultCh <- stdioReadResult{Err: refreshErr} + return + } + didAuthRetry = true + continue + } + if _, retryable := commandRouterRetryableCodes[status.Code(err)]; retryable && numRetriesRemaining > 0 { + if hasDeadline && time.Until(deadline) <= delay { + resultCh <- stdioReadResult{Err: ExecTimeoutError{Exception: fmt.Sprintf("deadline exceeded while streaming stdio for exec %s", execID)}} + return + } + c.logger.DebugContext(ctx, "Retrying stdio read with delay", "delay", delay, "error", err) + select { + case <-ctx.Done(): + resultCh <- stdioReadResult{Err: ctx.Err()} + return + case <-time.After(delay): + } + delay = time.Duration(float64(delay) * delayFactor) + numRetriesRemaining-- + continue + } + resultCh <- stdioReadResult{Err: err} + return + } + + for { + item, err := stream.Recv() + if err == io.EOF { + return + } + if err != nil { + if st, ok := status.FromError(err); ok && st.Code() == codes.Unauthenticated && !didAuthRetry { + if refreshErr := c.refreshJwt(ctx); refreshErr != nil { + resultCh <- stdioReadResult{Err: refreshErr} + return + } + didAuthRetry = true + break + } + if _, retryable := commandRouterRetryableCodes[status.Code(err)]; retryable && numRetriesRemaining > 0 { + if hasDeadline && time.Until(deadline) <= delay { + resultCh <- stdioReadResult{Err: ExecTimeoutError{Exception: fmt.Sprintf("deadline exceeded while streaming stdio for exec %s", execID)}} + return + } + c.logger.DebugContext(ctx, "Retrying stdio read with delay", "delay", delay, "error", err) + select { + case <-ctx.Done(): + resultCh <- stdioReadResult{Err: ctx.Err()} + return + case <-time.After(delay): + } + delay = time.Duration(float64(delay) * delayFactor) + numRetriesRemaining-- + break + } + resultCh <- stdioReadResult{Err: err} + return + } + + if didAuthRetry { + didAuthRetry = false + } + delay = 10 * time.Millisecond + offset += int64(len(item.GetData())) + + resultCh <- stdioReadResult{Response: item} + } + } +} diff --git a/modal-go/task_command_router_client_test.go b/modal-go/task_command_router_client_test.go new file mode 100644 index 00000000..6fd158e7 --- /dev/null +++ b/modal-go/task_command_router_client_test.go @@ -0,0 +1,268 @@ +package modal + +import ( + "context" + "encoding/base64" + "encoding/json" + "testing" + "time" + + "github.com/onsi/gomega" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func mockJWT(exp any) string { + header := base64.RawURLEncoding.EncodeToString([]byte(`{"alg":"HS256","typ":"JWT"}`)) + var payloadJSON []byte + if exp != nil { + payloadJSON, _ = json.Marshal(map[string]any{"exp": exp}) + } else { + payloadJSON, _ = json.Marshal(map[string]any{}) + } + payload := base64.RawURLEncoding.EncodeToString(payloadJSON) + signature := "fake-signature" + return header + "." + payload + "." + signature +} + +func TestParseJwtExpirationWithValidJWT(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + exp := time.Now().Unix() + 3600 + jwt := mockJWT(exp) + result, err := parseJwtExpiration(jwt) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(result).ToNot(gomega.BeNil()) + g.Expect(*result).To(gomega.Equal(exp)) +} + +func TestParseJwtExpirationWithoutExpClaim(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + jwt := mockJWT(nil) + result, err := parseJwtExpiration(jwt) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(result).To(gomega.BeNil()) +} + +func TestParseJwtExpirationWithMalformedJWT(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + jwt := "only.two" + result, err := parseJwtExpiration(jwt) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(result).To(gomega.BeNil()) +} + +func TestParseJwtExpirationWithInvalidBase64(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + jwt := "invalid.!!!invalid!!!.signature" + result, err := parseJwtExpiration(jwt) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(result).To(gomega.BeNil()) +} + +func TestParseJwtExpirationWithNonNumericExp(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + jwt := mockJWT("not-a-number") + result, err := parseJwtExpiration(jwt) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(result).To(gomega.BeNil()) +} + +func TestCallWithRetriesOnTransientErrorsSuccessOnFirstAttempt(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + callCount := 0 + result, err := callWithRetriesOnTransientErrors(ctx, func() (*string, error) { + callCount++ + output := "success" + return &output, nil + }, defaultRetryOptions()) + + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(*result).To(gomega.Equal("success")) + g.Expect(callCount).To(gomega.Equal(1)) +} + +func TestCallWithRetriesOnTransientErrorsRetriesOnTransientCodes(t *testing.T) { + t.Parallel() + + testCases := []struct { + name string + code codes.Code + message string + }{ + {"DeadlineExceeded", codes.DeadlineExceeded, "timeout"}, + {"Unavailable", codes.Unavailable, "unavailable"}, + {"Canceled", codes.Canceled, "cancelled"}, + {"Internal", codes.Internal, "internal error"}, + {"Unknown", codes.Unknown, "unknown error"}, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + callCount := 0 + result, err := callWithRetriesOnTransientErrors(ctx, func() (*string, error) { + callCount++ + var output string + if callCount == 1 { + output = "" + return &output, status.Error(tc.code, tc.message) + } + output = "success" + return &output, nil + }, retryOptions{BaseDelay: time.Millisecond, DelayFactor: 1, MaxRetries: intPtr(10)}) + + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(*result).To(gomega.Equal("success")) + g.Expect(callCount).To(gomega.Equal(2)) + }) + } +} + +func TestCallWithRetriesOnTransientErrorsNonRetryableError(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + callCount := 0 + _, err := callWithRetriesOnTransientErrors(ctx, func() (*string, error) { + callCount++ + return nil, status.Error(codes.InvalidArgument, "invalid") + }, retryOptions{BaseDelay: time.Millisecond, DelayFactor: 1, MaxRetries: intPtr(10)}) + + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(callCount).To(gomega.Equal(1)) +} + +func TestCallWithRetriesOnTransientErrorsMaxRetriesExceeded(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + callCount := 0 + maxRetries := 3 + _, err := callWithRetriesOnTransientErrors(ctx, func() (*string, error) { + callCount++ + return nil, status.Error(codes.Unavailable, "unavailable") + }, retryOptions{BaseDelay: time.Millisecond, DelayFactor: 1, MaxRetries: &maxRetries}) + + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(callCount).To(gomega.Equal(maxRetries + 1)) +} + +func TestCallWithRetriesOnTransientErrorsDeadlineExceeded(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + callCount := 0 + deadline := time.Now().Add(50 * time.Millisecond) + _, err := callWithRetriesOnTransientErrors(ctx, func() (*string, error) { + callCount++ + return nil, status.Error(codes.Unavailable, "unavailable") + }, retryOptions{BaseDelay: 100 * time.Millisecond, DelayFactor: 1, MaxRetries: nil, Deadline: &deadline}) + + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.Equal("deadline exceeded")) +} + +func intPtr(i int) *int { + return &i +} + +type mockRetryableClient struct { + refreshJwtCallCount int + authContextCallCount int +} + +func (m *mockRetryableClient) authContext(ctx context.Context) context.Context { + m.authContextCallCount += 1 + return ctx +} + +func (m *mockRetryableClient) refreshJwt(ctx context.Context) error { + m.refreshJwtCallCount += 1 + return nil +} + +func newMockRetryableClient() *mockRetryableClient { + return &mockRetryableClient{refreshJwtCallCount: 0, authContextCallCount: 0} +} + +func TestCallWithAuthRetrySuccessFirstAttempt(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + + c := newMockRetryableClient() + result, err := callWithAuthRetry(ctx, c, func(authCtx context.Context) (*int, error) { + return intPtr(3), nil + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + g.Expect(c.authContextCallCount).To(gomega.Equal(1)) + g.Expect(c.refreshJwtCallCount).To(gomega.Equal(0)) + + g.Expect(result).ToNot(gomega.BeNil()) + g.Expect(*result).To(gomega.Equal(3)) +} + +func TestCallWithAuthRetryOnUNAUTHENTICATED(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + + callCount := 0 + + c := newMockRetryableClient() + result, err := callWithAuthRetry(ctx, c, func(authCtx context.Context) (*int, error) { + if callCount == 0 { + callCount += 1 + return nil, status.Error(codes.Unauthenticated, "Not authenticated") + } + return intPtr(3), nil + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + g.Expect(c.authContextCallCount).To(gomega.Equal(2)) + g.Expect(c.refreshJwtCallCount).To(gomega.Equal(1)) + + g.Expect(result).ToNot(gomega.BeNil()) + g.Expect(*result).To(gomega.Equal(3)) + +} + +func TestCallWithAuthRetryDoesNotRetryOnNonUNAUTHENTICATED(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + + c := newMockRetryableClient() + _, err := callWithAuthRetry(ctx, c, func(authCtx context.Context) (*int, error) { + return nil, status.Error(codes.InvalidArgument, "Invalid argument") + }) + g.Expect(err).To(gomega.HaveOccurred()) + + g.Expect(c.authContextCallCount).To(gomega.Equal(1)) + g.Expect(c.refreshJwtCallCount).To(gomega.Equal(0)) +} + +func TestCallWithAuthRetryDoesNotRetryErrorIfUNAUTHENTICATEDAfterRetry(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + + c := newMockRetryableClient() + _, err := callWithAuthRetry(ctx, c, func(authCtx context.Context) (*int, error) { + return nil, status.Error(codes.Unauthenticated, "Not authenticated") + }) + g.Expect(err).To(gomega.HaveOccurred()) + + g.Expect(c.authContextCallCount).To(gomega.Equal(2)) + g.Expect(c.refreshJwtCallCount).To(gomega.Equal(1)) +} diff --git a/modal-go/test/sandbox_filesystem_snapshot_test.go b/modal-go/test/sandbox_filesystem_snapshot_test.go index 18c79275..eb80c6e7 100644 --- a/modal-go/test/sandbox_filesystem_snapshot_test.go +++ b/modal-go/test/sandbox_filesystem_snapshot_test.go @@ -25,11 +25,17 @@ func TestSnapshotFilesystem(t *testing.T) { g.Expect(err).ShouldNot(gomega.HaveOccurred()) defer terminateSandbox(g, sb) - _, err = sb.Exec(ctx, []string{"sh", "-c", "echo -n 'test content' > /tmp/test.txt"}, nil) + p1, err := sb.Exec(ctx, []string{"sh", "-c", "echo -n 'test content' > /tmp/test.txt"}, nil) g.Expect(err).ShouldNot(gomega.HaveOccurred()) + rc, err := p1.Wait(ctx) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(rc).To(gomega.Equal(0)) - _, err = sb.Exec(ctx, []string{"mkdir", "-p", "/tmp/testdir"}, nil) + p2, err := sb.Exec(ctx, []string{"mkdir", "-p", "/tmp/testdir"}, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + rc2, err := p2.Wait(ctx) g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(rc2).To(gomega.Equal(0)) snapshotImage, err := sb.SnapshotFilesystem(ctx, 55*time.Second) g.Expect(err).ShouldNot(gomega.HaveOccurred()) diff --git a/modal-go/test/sandbox_test.go b/modal-go/test/sandbox_test.go index bffd08c5..6c92ab18 100644 --- a/modal-go/test/sandbox_test.go +++ b/modal-go/test/sandbox_test.go @@ -832,3 +832,316 @@ func TestSandboxExperimentalDockerMock(t *testing.T) { g.Expect(mock.AssertExhausted()).ShouldNot(gomega.HaveOccurred()) } + +func TestSandboxExecStdinStdout(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer terminateSandbox(g, sb) + + p, err := sb.Exec(ctx, []string{"sh", "-c", "while read line; do echo $line; done"}, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + _, err = p.Stdin.Write([]byte("foo\n")) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + _, err = p.Stdin.Write([]byte("bar\n")) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + err = p.Stdin.Close() + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + output, err := io.ReadAll(p.Stdout) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(string(output)).To(gomega.Equal("foo\nbar\n")) +} + +func TestSandboxExecWaitExitCode(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer terminateSandbox(g, sb) + + p, err := sb.Exec(ctx, []string{"sh", "-c", "exit 42"}, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + exitCode, err := p.Wait(ctx) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(exitCode).To(gomega.Equal(42)) +} + +func TestSandboxExecDoubleRead(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer terminateSandbox(g, sb) + + p, err := sb.Exec(ctx, []string{"echo", "hello"}, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + output1, err := io.ReadAll(p.Stdout) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(string(output1)).To(gomega.Equal("hello\n")) + + output2, err := io.ReadAll(p.Stdout) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(string(output2)).To(gomega.Equal("")) + + exitCode, err := p.Wait(ctx) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(exitCode).To(gomega.Equal(0)) +} + +func TestSandboxExecBinaryMode(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer terminateSandbox(g, sb) + + p, err := sb.Exec(ctx, []string{"printf", "\\x01\\x02\\x03"}, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + output, err := io.ReadAll(p.Stdout) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(output).To(gomega.Equal([]byte{0x01, 0x02, 0x03})) + + exitCode, err := p.Wait(ctx) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(exitCode).To(gomega.Equal(0)) +} + +func TestSandboxExecWithPty(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer terminateSandbox(g, sb) + + p, err := sb.Exec(ctx, []string{"echo", "hello"}, &modal.SandboxExecParams{PTY: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + exitCode, err := p.Wait(ctx) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(exitCode).To(gomega.Equal(0)) +} + +func TestSandboxExecWaitTimeout(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer terminateSandbox(g, sb) + + p, err := sb.Exec(ctx, []string{"sleep", "999"}, &modal.SandboxExecParams{Timeout: 1 * time.Second}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + t0 := time.Now() + exitCode, err := p.Wait(ctx) + elapsed := time.Since(t0) + + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + g.Expect(elapsed).To(gomega.BeNumerically(">", 800*time.Millisecond)) + g.Expect(elapsed).To(gomega.BeNumerically("<", 10*time.Second)) + g.Expect(exitCode).To(gomega.Equal(0)) +} + +func TestSandboxExecOutputTimeout(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + defer terminateSandbox(g, sb) + + t0 := time.Now() + p, err := sb.Exec(ctx, []string{"sh", "-c", "echo hi; sleep 999"}, &modal.SandboxExecParams{Timeout: 1 * time.Second}) + g.Expect(err).ShouldNot(gomega.HaveOccurred()) + + output, readErr := io.ReadAll(p.Stdout) + elapsed := time.Since(t0) + + if readErr != nil { + g.Expect(readErr.Error()).To(gomega.ContainSubstring("deadline exceeded")) + } else { + g.Expect(string(output)).To(gomega.Equal("hi\n")) + + exitCode, waitErr := p.Wait(ctx) + if waitErr != nil { + // Deadline may have passed between stdout read completing and Wait() being called + g.Expect(waitErr.Error()).To(gomega.ContainSubstring("deadline exceeded")) + } else { + g.Expect(exitCode).To(gomega.Equal(0)) + } + } + + g.Expect(elapsed).To(gomega.BeNumerically(">", 1*time.Second)) + g.Expect(elapsed).To(gomega.BeNumerically("<", 15*time.Second)) +} + +func TestSandboxDoubleTerminate(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + err = sb.Terminate(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + err = sb.Terminate(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) +} + +func TestSandboxExecAfterTerminate(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + err = sb.Terminate(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + _, err = sb.Exec(ctx, []string{"echo", "hello"}, nil) + g.Expect(err).To(gomega.HaveOccurred()) +} + +func TestSandboxReadStdoutAfterTerminate(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, &modal.SandboxCreateParams{ + Command: []string{"sh", "-c", "echo hello-stdout; echo hello-stderr >&2"}, + }) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + _, err = sb.Wait(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + err = sb.Terminate(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + stdout, err := io.ReadAll(sb.Stdout) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(string(stdout)).To(gomega.Equal("hello-stdout\n")) + + stderr, err := io.ReadAll(sb.Stderr) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(string(stderr)).To(gomega.Equal("hello-stderr\n")) +} + +func TestContainerProcessReadStdoutAfterSandboxTerminate(t *testing.T) { + t.Parallel() + g := gomega.NewWithT(t) + ctx := context.Background() + tc := newTestClient(t) + + app, err := tc.Apps.FromName(ctx, "libmodal-test", &modal.AppFromNameParams{CreateIfMissing: true}) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + image := tc.Images.FromRegistry("alpine:3.21", nil) + + sb, err := tc.Sandboxes.Create(ctx, app, image, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + p, err := sb.Exec(ctx, []string{"sh", "-c", "echo exec-stdout; echo exec-stderr >&2"}, nil) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + exitCode, err := p.Wait(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + g.Expect(exitCode).To(gomega.Equal(0)) + + err = sb.Terminate(ctx) + g.Expect(err).ToNot(gomega.HaveOccurred()) + + // Behavior for reading from stdout & stderr on ContainerProcess is inconsistent between modal-go + // and modal-js when the sandbox is terminated: + // - modal-js: Reading stdout/stderr continues to work after the sandbox is terminated. It'll + // return an empty string. + // - modal-go: Reading stdout/stderr stops working because the go-routines are all canceled. + _, err = io.ReadAll(p.Stdout) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("context canceled")) + + _, err = io.ReadAll(p.Stderr) + g.Expect(err).To(gomega.HaveOccurred()) + g.Expect(err.Error()).To(gomega.ContainSubstring("context canceled")) +} diff --git a/modal-js/src/auth_token_manager.ts b/modal-js/src/auth_token_manager.ts index 70eacabf..0b2fcc1a 100644 --- a/modal-js/src/auth_token_manager.ts +++ b/modal-js/src/auth_token_manager.ts @@ -10,7 +10,6 @@ export class AuthTokenManager { private logger: Logger; private currentToken: string = ""; private tokenExpiry: number = 0; - private timeoutId: NodeJS.Timeout | null = null; private running: boolean = false; private fetchPromise: Promise | null = null; @@ -75,36 +74,6 @@ export class AuthTokenManager { ); } - /** - * Background loop that refreshes tokens REFRESH_WINDOW seconds before they expire. - */ - private async backgroundRefresh(): Promise { - while (this.running) { - const now = Math.floor(Date.now() / 1000); - const refreshTime = this.tokenExpiry - REFRESH_WINDOW; - const delay = Math.max(0, refreshTime - now) * 1000; - - // Sleep until it's time to refresh - await new Promise((resolve) => { - this.timeoutId = setTimeout(resolve, delay); - this.timeoutId.unref(); - }); - - if (!this.running) { - return; - } - - // Fetch new token - try { - await this.runFetch(); - } catch (error) { - this.logger.error("Failed to refresh auth token", "error", error); - // Sleep for 5 seconds before trying again on failure - await new Promise((resolve) => setTimeout(resolve, 5000)); - } - } - } - /** * Fetches the initial token and starts the refresh loop. * Throws an error if the initial token fetch fails. @@ -121,9 +90,6 @@ export class AuthTokenManager { this.running = false; throw error; } - - // Start background refresh loop, do not await - this.backgroundRefresh(); } /** @@ -131,10 +97,6 @@ export class AuthTokenManager { */ stop(): void { this.running = false; - if (this.timeoutId) { - clearTimeout(this.timeoutId); - this.timeoutId = null; - } } /** diff --git a/modal-js/src/errors.ts b/modal-js/src/errors.ts index 1aa8574a..35fc4025 100644 --- a/modal-js/src/errors.ts +++ b/modal-js/src/errors.ts @@ -77,3 +77,11 @@ export class SandboxTimeoutError extends Error { this.name = "SandboxTimeoutError"; } } + +/** Exec operations that exceed the allowed time limit. */ +export class ExecTimeoutError extends Error { + constructor(message: string = "Exec operation timed out") { + super(message); + this.name = "ExecTimeoutError"; + } +} diff --git a/modal-js/src/index.ts b/modal-js/src/index.ts index 80a93702..0620646b 100644 --- a/modal-js/src/index.ts +++ b/modal-js/src/index.ts @@ -26,6 +26,7 @@ export { QueueEmptyError, QueueFullError, SandboxTimeoutError, + ExecTimeoutError, } from "./errors"; export { Function_, diff --git a/modal-js/src/sandbox.ts b/modal-js/src/sandbox.ts index fb8bd32c..3d69b9f6 100644 --- a/modal-js/src/sandbox.ts +++ b/modal-js/src/sandbox.ts @@ -639,7 +639,7 @@ export function validateExecArgs(args: string[]): void { const totalArgLen = args.reduce((sum, arg) => sum + arg.length, 0); if (totalArgLen > ARG_MAX_BYTES) { throw new InvalidError( - `Total length of CMD arguments must be less than ${ARG_MAX_BYTES} bytes (ARG_MAX). ` + + `Total length of CMD arguments must be less than ${ARG_MAX_BYTES} bytes. ` + `Got ${totalArgLen} bytes.`, ); } @@ -940,7 +940,16 @@ export class Sandbox { return { url: resp.url, token: resp.token }; } + /** + * Terminate the Sandbox. + * The stdin, stdout, stderr streams are not closed. + */ async terminate(): Promise { + if (this.#commandRouterClient) { + this.#commandRouterClient.close(); + this.#commandRouterClient = undefined; + } + await this.#client.cpClient.sandboxTerminate({ sandboxId: this.sandboxId }); this.#taskId = undefined; // Reset task ID after termination } diff --git a/modal-js/src/task_command_router_client.ts b/modal-js/src/task_command_router_client.ts index 4588ae13..d016c278 100644 --- a/modal-js/src/task_command_router_client.ts +++ b/modal-js/src/task_command_router_client.ts @@ -35,9 +35,18 @@ import type { ModalGrpcClient } from "./client"; import { timeoutMiddleware, type TimeoutOptions } from "./client"; import type { Logger } from "./logger"; import type { Profile } from "./config"; +import { ExecTimeoutError } from "./errors"; type TaskCommandRouterClient = Client; +const RETRYABLE_STATUS_CODES = new Set([ + Status.DEADLINE_EXCEEDED, + Status.UNAVAILABLE, + Status.CANCELLED, + Status.INTERNAL, + Status.UNKNOWN, +]); + export function parseJwtExpiration( jwtToken: string, logger: Logger, @@ -74,14 +83,6 @@ export async function callWithRetriesOnTransientErrors( let delayMs = baseDelayMs; let numRetries = 0; - const retryableStatusCodes = new Set([ - Status.DEADLINE_EXCEEDED, - Status.UNAVAILABLE, - Status.CANCELLED, - Status.INTERNAL, - Status.UNKNOWN, - ]); - while (true) { if (deadlineMs !== null && Date.now() >= deadlineMs) { throw new Error("Deadline exceeded"); @@ -92,7 +93,7 @@ export async function callWithRetriesOnTransientErrors( } catch (err) { if ( err instanceof ClientError && - retryableStatusCodes.has(err.code) && + RETRYABLE_STATUS_CODES.has(err.code) && (maxRetries === null || numRetries < maxRetries) ) { if (deadlineMs !== null && Date.now() + delayMs >= deadlineMs) { @@ -109,6 +110,28 @@ export async function callWithRetriesOnTransientErrors( } } +/** + * Calls a function and retries once on UNAUTHENTICATED errors after invoking an auth refresh handler. + * This is exported for testing purposes. + * @param func The function to call + * @param onAuthError Handler to call when an UNAUTHENTICATED error occurs (e.g., to refresh JWT) + * @returns The result of the function call + */ +export async function callWithAuthRetry( + func: () => Promise, + onAuthError: () => Promise, +): Promise { + try { + return await func(); + } catch (err) { + if (err instanceof ClientError && err.code === Status.UNAUTHENTICATED) { + await onAuthError(); + return await func(); + } + throw err; + } +} + /** @ignore */ export class TaskCommandRouterClientImpl { private stub: TaskCommandRouterClient; @@ -165,14 +188,28 @@ export class TaskCommandRouterClientImpl { const port = url.port ? parseInt(url.port) : 443; const serverUrl = `${host}:${port}`; + const channelOptions = { + "grpc.max_receive_message_length": 100 * 1024 * 1024, + "grpc.max_send_message_length": 100 * 1024 * 1024, + "grpc-node.flow_control_window": 64 * 1024 * 1024, + }; + let channel; if (profile.taskCommandRouterInsecure) { logger.warn( "Using insecure TLS for task command router due to MODAL_TASK_COMMAND_ROUTER_INSECURE", ); - channel = createChannel(serverUrl, ChannelCredentials.createInsecure()); + channel = createChannel( + serverUrl, + ChannelCredentials.createInsecure(), + channelOptions, + ); } else { - channel = createChannel(serverUrl, ChannelCredentials.createSsl()); + channel = createChannel( + serverUrl, + ChannelCredentials.createSsl(), + channelOptions, + ); } const client = new TaskCommandRouterClientImpl( @@ -294,7 +331,9 @@ export class TaskCommandRouterClientImpl { // The timeout here is really a backstop in the event of a hang contacting // the command router. Poll should usually be instantaneous. if (deadline && deadline <= Date.now()) { - throw new Error(`Deadline exceeded while polling for exec ${execId}`); + throw new ExecTimeoutError( + `Deadline exceeded while polling for exec ${execId}`, + ); } try { @@ -307,7 +346,9 @@ export class TaskCommandRouterClientImpl { ); } catch (err) { if (err instanceof ClientError && err.code === Status.DEADLINE_EXCEEDED) { - throw new Error(`Deadline exceeded while polling for exec ${execId}`); + throw new ExecTimeoutError( + `Deadline exceeded while polling for exec ${execId}`, + ); } throw err; } @@ -321,7 +362,9 @@ export class TaskCommandRouterClientImpl { const request = TaskExecWaitRequest.create({ taskId, execId }); if (deadline && deadline <= Date.now()) { - throw new Error(`Deadline exceeded while waiting for exec ${execId}`); + throw new ExecTimeoutError( + `Deadline exceeded while waiting for exec ${execId}`, + ); } try { @@ -339,7 +382,9 @@ export class TaskCommandRouterClientImpl { ); } catch (err) { if (err instanceof ClientError && err.code === Status.DEADLINE_EXCEEDED) { - throw new Error(`Deadline exceeded while waiting for exec ${execId}`); + throw new ExecTimeoutError( + `Deadline exceeded while waiting for exec ${execId}`, + ); } throw err; } @@ -394,15 +439,7 @@ export class TaskCommandRouterClientImpl { } private async callWithAuthRetry(func: () => Promise): Promise { - try { - return await func(); - } catch (err) { - if (err instanceof ClientError && err.code === Status.UNAUTHENTICATED) { - await this.refreshJwt(); - return await func(); - } - throw err; - } + return await callWithAuthRetry(func, this.refreshJwt); } private async *streamStdio( @@ -419,14 +456,6 @@ export class TaskCommandRouterClientImpl { // refresh yields an invalid JWT somehow or that the JWT is otherwise invalid. let didAuthRetry = false; - const retryableStatusCodes = new Set([ - Status.DEADLINE_EXCEEDED, - Status.UNAVAILABLE, - Status.CANCELLED, - Status.INTERNAL, - Status.UNKNOWN, - ]); - while (true) { try { const timeoutMs = @@ -471,11 +500,11 @@ export class TaskCommandRouterClientImpl { } catch (err) { if ( err instanceof ClientError && - retryableStatusCodes.has(err.code) && + RETRYABLE_STATUS_CODES.has(err.code) && numRetriesRemaining > 0 ) { if (deadline && deadline - Date.now() <= delayMs) { - throw new Error( + throw new ExecTimeoutError( `Deadline exceeded while streaming stdio for exec ${execId}`, ); } diff --git a/modal-js/test/auth_token_manager.test.ts b/modal-js/test/auth_token_manager.test.ts index 48b9464e..860a5dfb 100644 --- a/modal-js/test/auth_token_manager.test.ts +++ b/modal-js/test/auth_token_manager.test.ts @@ -197,49 +197,6 @@ describe("AuthTokenManager", () => { await new Promise((resolve) => setTimeout(resolve, 100)); }); - - test("TestAuthToken_HandlesEventLoopFreeze", async () => { - // Use fake timers so we can control both timers and system time - vi.useFakeTimers(); - try { - const baseTime = new Date("2025-01-01T00:00:00Z"); - vi.setSystemTime(baseTime); - const baseTimeSeconds = Math.floor(baseTime.getTime() / 1000); - const firstRefreshDelaySeconds = 5; - - // Want it to trigger refresh firstRefreshDelaySeconds from "now". - const tokenOneExpirySeconds = - baseTimeSeconds + REFRESH_WINDOW + firstRefreshDelaySeconds; - - // First fetch happens when the manager starts. - const tokenOne = createTestJWT(tokenOneExpirySeconds); - mockClient.setAuthToken(tokenOne); - await manager.start(); - expect(mockClient.authTokenGet).toHaveBeenCalledTimes(1); - - // Simulate an event-loop "freeze" where time moves forward but timeouts don't fire. - // getToken() should see tokenOne expired, and fetch tokenTwo. - const tokenTwo = createTestJWT(tokenOneExpirySeconds + 3600); - mockClient.setAuthToken(tokenTwo); - vi.setSystemTime(new Date((tokenOneExpirySeconds + 1) * 1000)); - await expect(manager.getToken()).resolves.toBe(tokenTwo); - expect(mockClient.authTokenGet).toHaveBeenCalledTimes(2); - - // Advance timers and check that the background refresh fetches tokenThree. - const tokenThree = createTestJWT(tokenOneExpirySeconds + 2 * 3600); - mockClient.setAuthToken(tokenThree); - await vi.advanceTimersByTimeAsync(firstRefreshDelaySeconds * 1000); - await eventually( - () => - manager.getCurrentToken() === tokenThree && - mockClient.authTokenGet.mock.calls.length === 3, - 2000, - 10, - ); - } finally { - vi.useRealTimers(); - } - }); }); describe("ModalClient with AuthTokenManager", () => { diff --git a/modal-js/test/sandbox.test.ts b/modal-js/test/sandbox.test.ts index 8712f2dd..1f104f91 100644 --- a/modal-js/test/sandbox.test.ts +++ b/modal-js/test/sandbox.test.ts @@ -992,3 +992,61 @@ test("SandboxExecOutputTimeout", async () => { expect(elapsed).toBeLessThan(4000); } }); + +test("SandboxDoubleTerminate", async () => { + const app = await tc.apps.fromName("libmodal-test", { + createIfMissing: true, + }); + const image = tc.images.fromRegistry("alpine:3.21"); + + const sb = await tc.sandboxes.create(app, image); + + await sb.terminate(); + await sb.terminate(); +}); + +test("SandboxExecAfterTerminate", async () => { + const app = await tc.apps.fromName("libmodal-test", { + createIfMissing: true, + }); + const image = tc.images.fromRegistry("alpine:3.21"); + + const sb = await tc.sandboxes.create(app, image); + + await sb.terminate(); + + await expect(sb.exec(["echo", "hello"])).rejects.toThrow(); +}); + +test("ContainerProcessReadStdoutAfterSandboxTerminate", async () => { + const app = await tc.apps.fromName("libmodal-test", { + createIfMissing: true, + }); + const image = tc.images.fromRegistry("alpine:3.21"); + const sb = await tc.sandboxes.create(app, image); + + const p = await sb.exec([ + "sh", + "-c", + "echo exec-stdout; echo exec-stderr >&2", + ]); + await p.wait(); + await sb.terminate(); + + // Behavior for reading from stdout & stderr on ContainerProcess is inconsistent between modal-go + // and modal-js when the sandbox is terminated: + // - modal-js: Reading stdout/stderr continues to work after the sandbox is terminated. It'll + // return an empty string. + // - modal-go: Reading stdout/stderr stops working because the go-routines are all canceled. + const stdout1 = await p.stdout.readText(); + expect(stdout1).equal("exec-stdout\n"); + + const stdout2 = await p.stdout.readText(); + expect(stdout2).equal(""); + + const stderr1 = await p.stderr.readText(); + expect(stderr1).equal("exec-stderr\n"); + + const stderr2 = await p.stderr.readText(); + expect(stderr2).equal(""); +}); diff --git a/modal-js/test/task_command_router_client.test.ts b/modal-js/test/task_command_router_client.test.ts index bc9cf2a4..e4f27e9b 100644 --- a/modal-js/test/task_command_router_client.test.ts +++ b/modal-js/test/task_command_router_client.test.ts @@ -2,6 +2,7 @@ import { expect, test, vi } from "vitest"; import { parseJwtExpiration, callWithRetriesOnTransientErrors, + callWithAuthRetry, } from "../src/task_command_router_client"; import { ClientError, Status } from "nice-grpc"; @@ -105,3 +106,48 @@ test("callWithRetriesOnTransientErrors deadline exceeded", async () => { callWithRetriesOnTransientErrors(func, 100, 2, null, deadline), ).rejects.toThrow("Deadline exceeded"); }); + +test("callWithAuthRetry success on first attempt", async () => { + const func = vi.fn().mockResolvedValue("success"); + const onAuthError = vi.fn(); + const result = await callWithAuthRetry(func, onAuthError); + expect(result).toBe("success"); + expect(func).toHaveBeenCalledTimes(1); + expect(onAuthError).not.toHaveBeenCalled(); +}); + +test("callWithAuthRetry retries on UNAUTHENTICATED error", async () => { + const func = vi + .fn() + .mockRejectedValueOnce( + new ClientError("/test", Status.UNAUTHENTICATED, "unauthenticated"), + ) + .mockResolvedValue("success"); + const onAuthError = vi.fn().mockResolvedValue(undefined); + const result = await callWithAuthRetry(func, onAuthError); + expect(result).toBe("success"); + expect(func).toHaveBeenCalledTimes(2); + expect(onAuthError).toHaveBeenCalledTimes(1); +}); + +test("callWithAuthRetry does not retry on non-UNAUTHENTICATED errors", async () => { + const error = new ClientError("/test", Status.INVALID_ARGUMENT, "invalid"); + const func = vi.fn().mockRejectedValue(error); + const onAuthError = vi.fn(); + await expect(callWithAuthRetry(func, onAuthError)).rejects.toThrow(error); + expect(func).toHaveBeenCalledTimes(1); + expect(onAuthError).not.toHaveBeenCalled(); +}); + +test("callWithAuthRetry throws if still UNAUTHENTICATED after retry", async () => { + const error = new ClientError( + "/test", + Status.UNAUTHENTICATED, + "still unauthenticated", + ); + const func = vi.fn().mockRejectedValue(error); + const onAuthError = vi.fn().mockResolvedValue(undefined); + await expect(callWithAuthRetry(func, onAuthError)).rejects.toThrow(error); + expect(func).toHaveBeenCalledTimes(2); + expect(onAuthError).toHaveBeenCalledTimes(1); +});