act_runner/vendor/connectrpc.com/connect/handler.go
skeris 7434dceaaa
Some checks failed
checks / check and test (push) Has been cancelled
release-nightly / goreleaser (push) Has been cancelled
release-nightly / release-image (push) Has been cancelled
add vendor dir
2024-11-22 16:14:01 +03:00

329 lines
11 KiB
Go

// Copyright 2021-2024 The Connect Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package connect
import (
"context"
"fmt"
"net/http"
)
// A Handler is the server-side implementation of a single RPC defined by a
// service schema.
//
// By default, Handlers support the Connect, gRPC, and gRPC-Web protocols with
// the binary Protobuf and JSON codecs. They support gzip compression using the
// standard library's [compress/gzip].
type Handler struct {
spec Spec
implementation StreamingHandlerFunc
protocolHandlers map[string][]protocolHandler // Method to protocol handlers
allowMethod string // Allow header
acceptPost string // Accept-Post header
}
// NewUnaryHandler constructs a [Handler] for a request-response procedure.
func NewUnaryHandler[Req, Res any](
procedure string,
unary func(context.Context, *Request[Req]) (*Response[Res], error),
options ...HandlerOption,
) *Handler {
// Wrap the strongly-typed implementation so we can apply interceptors.
untyped := UnaryFunc(func(ctx context.Context, request AnyRequest) (AnyResponse, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
typed, ok := request.(*Request[Req])
if !ok {
return nil, errorf(CodeInternal, "unexpected handler request type %T", request)
}
res, err := unary(ctx, typed)
if res == nil && err == nil {
// This is going to panic during serialization. Debugging is much easier
// if we panic here instead, so we can include the procedure name.
panic(fmt.Sprintf("%s returned nil *connect.Response and nil error", procedure)) //nolint: forbidigo
}
return res, err
})
config := newHandlerConfig(procedure, StreamTypeUnary, options)
if interceptor := config.Interceptor; interceptor != nil {
untyped = interceptor.WrapUnary(untyped)
}
// Given a stream, how should we call the unary function?
implementation := func(ctx context.Context, conn StreamingHandlerConn) error {
request, err := receiveUnaryRequest[Req](conn, config.Initializer)
if err != nil {
return err
}
response, err := untyped(ctx, request)
if err != nil {
return err
}
mergeHeaders(conn.ResponseHeader(), response.Header())
mergeHeaders(conn.ResponseTrailer(), response.Trailer())
return conn.Send(response.Any())
}
protocolHandlers := config.newProtocolHandlers()
return &Handler{
spec: config.newSpec(),
implementation: implementation,
protocolHandlers: mappedMethodHandlers(protocolHandlers),
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
}
// NewClientStreamHandler constructs a [Handler] for a client streaming procedure.
func NewClientStreamHandler[Req, Res any](
procedure string,
implementation func(context.Context, *ClientStream[Req]) (*Response[Res], error),
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, StreamTypeClient, options)
return newStreamHandler(
config,
func(ctx context.Context, conn StreamingHandlerConn) error {
stream := &ClientStream[Req]{
conn: conn,
initializer: config.Initializer,
}
res, err := implementation(ctx, stream)
if err != nil {
return err
}
if res == nil {
// This is going to panic during serialization. Debugging is much easier
// if we panic here instead, so we can include the procedure name.
panic(fmt.Sprintf("%s returned nil *connect.Response and nil error", procedure)) //nolint: forbidigo
}
mergeHeaders(conn.ResponseHeader(), res.header)
mergeHeaders(conn.ResponseTrailer(), res.trailer)
return conn.Send(res.Msg)
},
)
}
// NewServerStreamHandler constructs a [Handler] for a server streaming procedure.
func NewServerStreamHandler[Req, Res any](
procedure string,
implementation func(context.Context, *Request[Req], *ServerStream[Res]) error,
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, StreamTypeServer, options)
return newStreamHandler(
config,
func(ctx context.Context, conn StreamingHandlerConn) error {
req, err := receiveUnaryRequest[Req](conn, config.Initializer)
if err != nil {
return err
}
return implementation(ctx, req, &ServerStream[Res]{conn: conn})
},
)
}
// NewBidiStreamHandler constructs a [Handler] for a bidirectional streaming procedure.
func NewBidiStreamHandler[Req, Res any](
procedure string,
implementation func(context.Context, *BidiStream[Req, Res]) error,
options ...HandlerOption,
) *Handler {
config := newHandlerConfig(procedure, StreamTypeBidi, options)
return newStreamHandler(
config,
func(ctx context.Context, conn StreamingHandlerConn) error {
return implementation(
ctx,
&BidiStream[Req, Res]{
conn: conn,
initializer: config.Initializer,
},
)
},
)
}
// ServeHTTP implements [http.Handler].
func (h *Handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) {
// We don't need to defer functions to close the request body or read to
// EOF: the stream we construct later on already does that, and we only
// return early when dealing with misbehaving clients. In those cases, it's
// okay if we can't re-use the connection.
isBidi := (h.spec.StreamType & StreamTypeBidi) == StreamTypeBidi
if isBidi && request.ProtoMajor < 2 {
// Clients coded to expect full-duplex connections may hang if they've
// mistakenly negotiated HTTP/1.1. To unblock them, we must close the
// underlying TCP connection.
responseWriter.Header().Set("Connection", "close")
responseWriter.WriteHeader(http.StatusHTTPVersionNotSupported)
return
}
protocolHandlers := h.protocolHandlers[request.Method]
if len(protocolHandlers) == 0 {
responseWriter.Header().Set("Allow", h.allowMethod)
responseWriter.WriteHeader(http.StatusMethodNotAllowed)
return
}
contentType := canonicalizeContentType(getHeaderCanonical(request.Header, headerContentType))
// Find our implementation of the RPC protocol in use.
var protocolHandler protocolHandler
for _, handler := range protocolHandlers {
if handler.CanHandlePayload(request, contentType) {
protocolHandler = handler
break
}
}
if protocolHandler == nil {
responseWriter.Header().Set("Accept-Post", h.acceptPost)
responseWriter.WriteHeader(http.StatusUnsupportedMediaType)
return
}
if request.Method == http.MethodGet {
// A body must not be present.
hasBody := request.ContentLength > 0
if request.ContentLength < 0 {
// No content-length header.
// Test if body is empty by trying to read a single byte.
var b [1]byte
n, _ := request.Body.Read(b[:])
hasBody = n > 0
}
if hasBody {
responseWriter.WriteHeader(http.StatusUnsupportedMediaType)
return
}
_ = request.Body.Close()
}
// Establish a stream and serve the RPC.
setHeaderCanonical(request.Header, headerContentType, contentType)
setHeaderCanonical(request.Header, headerHost, request.Host)
ctx, cancel, timeoutErr := protocolHandler.SetTimeout(request) //nolint: contextcheck
if timeoutErr != nil {
ctx = request.Context()
}
if cancel != nil {
defer cancel()
}
connCloser, ok := protocolHandler.NewConn(
responseWriter,
request.WithContext(ctx),
)
if !ok {
// Failed to create stream, usually because client used an unknown
// compression algorithm. Nothing further to do.
return
}
if timeoutErr != nil {
_ = connCloser.Close(timeoutErr)
return
}
_ = connCloser.Close(h.implementation(ctx, connCloser))
}
type handlerConfig struct {
CompressionPools map[string]*compressionPool
CompressionNames []string
Codecs map[string]Codec
CompressMinBytes int
Interceptor Interceptor
Procedure string
Schema any
Initializer maybeInitializer
RequireConnectProtocolHeader bool
IdempotencyLevel IdempotencyLevel
BufferPool *bufferPool
ReadMaxBytes int
SendMaxBytes int
StreamType StreamType
}
func newHandlerConfig(procedure string, streamType StreamType, options []HandlerOption) *handlerConfig {
protoPath := extractProtoPath(procedure)
config := handlerConfig{
Procedure: protoPath,
CompressionPools: make(map[string]*compressionPool),
Codecs: make(map[string]Codec),
BufferPool: newBufferPool(),
StreamType: streamType,
}
withProtoBinaryCodec().applyToHandler(&config)
withProtoJSONCodecs().applyToHandler(&config)
withGzip().applyToHandler(&config)
for _, opt := range options {
opt.applyToHandler(&config)
}
return &config
}
func (c *handlerConfig) newSpec() Spec {
return Spec{
Procedure: c.Procedure,
Schema: c.Schema,
StreamType: c.StreamType,
IdempotencyLevel: c.IdempotencyLevel,
}
}
func (c *handlerConfig) newProtocolHandlers() []protocolHandler {
protocols := []protocol{
&protocolConnect{},
&protocolGRPC{web: false},
&protocolGRPC{web: true},
}
handlers := make([]protocolHandler, 0, len(protocols))
codecs := newReadOnlyCodecs(c.Codecs)
compressors := newReadOnlyCompressionPools(
c.CompressionPools,
c.CompressionNames,
)
for _, protocol := range protocols {
handlers = append(handlers, protocol.NewHandler(&protocolHandlerParams{
Spec: c.newSpec(),
Codecs: codecs,
CompressionPools: compressors,
CompressMinBytes: c.CompressMinBytes,
BufferPool: c.BufferPool,
ReadMaxBytes: c.ReadMaxBytes,
SendMaxBytes: c.SendMaxBytes,
RequireConnectProtocolHeader: c.RequireConnectProtocolHeader,
IdempotencyLevel: c.IdempotencyLevel,
}))
}
return handlers
}
func newStreamHandler(
config *handlerConfig,
implementation StreamingHandlerFunc,
) *Handler {
if ic := config.Interceptor; ic != nil {
implementation = ic.WrapStreamingHandler(implementation)
}
protocolHandlers := config.newProtocolHandlers()
return &Handler{
spec: config.newSpec(),
implementation: implementation,
protocolHandlers: mappedMethodHandlers(protocolHandlers),
allowMethod: sortedAllowMethodValue(protocolHandlers),
acceptPost: sortedAcceptPostValue(protocolHandlers),
}
}