Compare commits

...

10 commits

Author SHA1 Message Date
Joachim Bauch d8f2f265ab
Merge pull request #736 from strukturag/log-mcu-proxy-client-closed
Log something if mcu publisher / subscriber was closed.
2024-05-16 10:37:08 +02:00
Joachim Bauch ddbf1065f6
Merge pull request #707 from strukturag/validate-received-sdp
Validate received SDP earlier.
2024-05-16 10:19:15 +02:00
Joachim Bauch bad52af35a
Validate received SDP earlier. 2024-05-16 10:04:57 +02:00
Joachim Bauch c58564c0e8
Log something if mcu publisher / subscriber was closed. 2024-05-16 09:44:47 +02:00
Joachim Bauch 0b259a8171
Merge pull request #732 from strukturag/close-context
Add Context to clients / sessions.
2024-05-16 09:36:34 +02:00
Joachim Bauch 3fc5f5253d
Merge pull request #735 from strukturag/read-error-after-close
Don't log read error after we closed the connection.
2024-05-16 09:36:07 +02:00
Joachim Bauch 3e92664edc
Don't log read error after we closed the connection. 2024-05-16 09:23:32 +02:00
Joachim Bauch 0ee976d377
Add Context to clients / sessions.
The Context will be closed when the client disconnects / the session is removed,
so any pending requests can be cancelled.
2024-05-16 09:07:59 +02:00
Joachim Bauch 552474f6f0
Merge pull request #734 from strukturag/dependabot/go_modules/google.golang.org/grpc-1.64.0
build(deps): Bump google.golang.org/grpc from 1.63.2 to 1.64.0
2024-05-16 08:51:38 +02:00
dependabot[bot] 09e010ee14
build(deps): Bump google.golang.org/grpc from 1.63.2 to 1.64.0
Bumps [google.golang.org/grpc](https://github.com/grpc/grpc-go) from 1.63.2 to 1.64.0.
- [Release notes](https://github.com/grpc/grpc-go/releases)
- [Commits](https://github.com/grpc/grpc-go/compare/v1.63.2...v1.64.0)

---
updated-dependencies:
- dependency-name: google.golang.org/grpc
  dependency-type: direct:production
  update-type: version-update:semver-minor
...

Signed-off-by: dependabot[bot] <support@github.com>
2024-05-15 20:58:48 +00:00
15 changed files with 190 additions and 135 deletions

View file

@ -32,6 +32,7 @@ import (
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/pion/sdp/v3"
)
const (
@ -42,6 +43,11 @@ const (
HelloVersionV2 = "2.0"
)
var (
ErrNoSdp = NewError("no_sdp", "Payload does not contain a SDP.")
ErrInvalidSdp = NewError("invalid_sdp", "Payload does not contain a valid SDP.")
)
// ClientMessage is a message that is sent from a client to the server.
type ClientMessage struct {
json.Marshaler
@ -563,12 +569,39 @@ type MessageClientMessageData struct {
RoomType string `json:"roomType"`
Bitrate int `json:"bitrate,omitempty"`
Payload map[string]interface{} `json:"payload"`
offerSdp *sdp.SessionDescription // Only set if Type == "offer"
answerSdp *sdp.SessionDescription // Only set if Type == "answer"
}
func (m *MessageClientMessageData) CheckValid() error {
if !IsValidStreamType(m.RoomType) {
return fmt.Errorf("invalid room type: %s", m.RoomType)
}
if m.Type == "offer" || m.Type == "answer" {
sdpValue, found := m.Payload["sdp"]
if !found {
return ErrNoSdp
}
sdpText, ok := sdpValue.(string)
if !ok {
return ErrInvalidSdp
}
var sdp sdp.SessionDescription
if err := sdp.Unmarshal([]byte(sdpText)); err != nil {
return NewErrorDetail("invalid_sdp", "Error parsing SDP from payload.", map[string]interface{}{
"error": err.Error(),
})
}
switch m.Type {
case "offer":
m.offerSdp = &sdp
case "answer":
m.answerSdp = &sdp
}
}
return nil
}

View file

@ -23,8 +23,11 @@ package signaling
import (
"bytes"
"context"
"encoding/json"
"errors"
"log"
"net"
"strconv"
"strings"
"sync"
@ -93,6 +96,7 @@ type WritableClientMessage interface {
}
type HandlerClient interface {
Context() context.Context
RemoteAddr() string
Country() string
UserAgent() string
@ -121,6 +125,7 @@ type ClientGeoIpHandler interface {
}
type Client struct {
ctx context.Context
conn *websocket.Conn
addr string
agent string
@ -142,7 +147,7 @@ type Client struct {
messageChan chan *bytes.Buffer
}
func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) {
func NewClient(ctx context.Context, conn *websocket.Conn, remoteAddress string, agent string, handler ClientHandler) (*Client, error) {
remoteAddress = strings.TrimSpace(remoteAddress)
if remoteAddress == "" {
remoteAddress = "unknown remote address"
@ -153,6 +158,7 @@ func NewClient(conn *websocket.Conn, remoteAddress string, agent string, handler
}
client := &Client{
ctx: ctx,
agent: agent,
logRTT: true,
}
@ -181,6 +187,10 @@ func (c *Client) getHandler() ClientHandler {
return c.handler
}
func (c *Client) Context() context.Context {
return c.ctx
}
func (c *Client) IsConnected() bool {
return c.closed.Load() == 0
}
@ -354,7 +364,10 @@ func (c *Client) ReadPump() {
conn.SetReadDeadline(time.Now().Add(pongWait)) // nolint
messageType, reader, err := conn.NextReader()
if err != nil {
if _, ok := err.(*websocket.CloseError); !ok || websocket.IsUnexpectedCloseError(err,
// Gorilla websocket hides the original net.Error, so also compare error messages
if errors.Is(err, net.ErrClosed) || strings.Contains(err.Error(), net.ErrClosed.Error()) {
break
} else if _, ok := err.(*websocket.CloseError); !ok || websocket.IsUnexpectedCloseError(err,
websocket.CloseNormalClosure,
websocket.CloseGoingAway,
websocket.CloseNoStatusReceived) {

View file

@ -51,6 +51,8 @@ type ClientSession struct {
privateId string
publicId string
data *SessionIdData
ctx context.Context
closeFunc context.CancelFunc
clientType string
features []string
@ -91,12 +93,15 @@ type ClientSession struct {
}
func NewClientSession(hub *Hub, privateId string, publicId string, data *SessionIdData, backend *Backend, hello *HelloClientMessage, auth *BackendClientAuthResponse) (*ClientSession, error) {
ctx, closeFunc := context.WithCancel(context.Background())
s := &ClientSession{
hub: hub,
events: hub.events,
privateId: privateId,
publicId: publicId,
data: data,
ctx: ctx,
closeFunc: closeFunc,
clientType: hello.Auth.Type,
features: hello.Features,
@ -140,6 +145,10 @@ func NewClientSession(hub *Hub, privateId string, publicId string, data *Session
return s, nil
}
func (s *ClientSession) Context() context.Context {
return s.ctx
}
func (s *ClientSession) PrivateId() string {
return s.privateId
}
@ -337,7 +346,7 @@ func (s *ClientSession) getRoomJoinTime() time.Time {
func (s *ClientSession) releaseMcuObjects() {
if len(s.publishers) > 0 {
go func(publishers map[StreamType]McuPublisher) {
ctx := context.TODO()
ctx := context.Background()
for _, publisher := range publishers {
publisher.Close(ctx)
}
@ -346,7 +355,7 @@ func (s *ClientSession) releaseMcuObjects() {
}
if len(s.subscribers) > 0 {
go func(subscribers map[string]McuSubscriber) {
ctx := context.TODO()
ctx := context.Background()
for _, subscriber := range subscribers {
subscriber.Close(ctx)
}
@ -360,6 +369,7 @@ func (s *ClientSession) Close() {
}
func (s *ClientSession) closeAndWait(wait bool) {
s.closeFunc()
s.hub.removeSession(s)
s.mu.Lock()
@ -720,23 +730,6 @@ func (s *ClientSession) SubscriberClosed(subscriber McuSubscriber) {
}
}
type SdpError struct {
message string
}
func (e *SdpError) Error() string {
return e.message
}
type WrappedSdpError struct {
SdpError
err error
}
func (e *WrappedSdpError) Unwrap() error {
return e.err
}
type PermissionError struct {
permission Permission
}
@ -749,23 +742,10 @@ func (e *PermissionError) Error() string {
return fmt.Sprintf("permission \"%s\" not found", e.permission)
}
func (s *ClientSession) isSdpAllowedToSendLocked(payload map[string]interface{}) (MediaType, error) {
sdpValue, found := payload["sdp"]
if !found {
return 0, &SdpError{"payload does not contain a sdp"}
}
sdpText, ok := sdpValue.(string)
if !ok {
return 0, &SdpError{"payload does not contain a valid sdp"}
}
var sdp sdp.SessionDescription
if err := sdp.Unmarshal([]byte(sdpText)); err != nil {
return 0, &WrappedSdpError{
SdpError: SdpError{
message: fmt.Sprintf("could not parse sdp: %s", err),
},
err: err,
}
func (s *ClientSession) isSdpAllowedToSendLocked(sdp *sdp.SessionDescription) (MediaType, error) {
if sdp == nil {
// Should have already been checked when data was validated.
return 0, ErrNoSdp
}
var mediaTypes MediaType
@ -803,8 +783,8 @@ func (s *ClientSession) IsAllowedToSend(data *MessageClientMessageData) error {
// Client is allowed to publish any media (audio / video).
return nil
} else if data != nil && data.Type == "offer" {
// Parse SDP to check what user is trying to publish and check permissions accordingly.
if _, err := s.isSdpAllowedToSendLocked(data.Payload); err != nil {
// Check what user is trying to publish and check permissions accordingly.
if _, err := s.isSdpAllowedToSendLocked(data.offerSdp); err != nil {
return err
}
@ -834,7 +814,7 @@ func (s *ClientSession) checkOfferTypeLocked(streamType StreamType, data *Messag
return MediaTypeScreen, nil
} else if data != nil && data.Type == "offer" {
mediaTypes, err := s.isSdpAllowedToSendLocked(data.Payload)
mediaTypes, err := s.isSdpAllowedToSendLocked(data.offerSdp)
if err != nil {
return 0, err
}
@ -885,7 +865,7 @@ func (s *ClientSession) GetOrCreatePublisher(ctx context.Context, mcu Mcu, strea
if prev, found := s.publishers[streamType]; found {
// Another thread created the publisher while we were waiting.
go func(pub McuPublisher) {
closeCtx := context.TODO()
closeCtx := context.Background()
pub.Close(closeCtx)
}(publisher)
publisher = prev
@ -962,7 +942,7 @@ func (s *ClientSession) GetOrCreateSubscriber(ctx context.Context, mcu Mcu, id s
if prev, found := s.subscribers[getStreamId(id, streamType)]; found {
// Another thread created the subscriber while we were waiting.
go func(sub McuSubscriber) {
closeCtx := context.TODO()
closeCtx := context.Background()
sub.Close(closeCtx)
}(subscriber)
subscriber = prev
@ -1036,7 +1016,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
case "sendoffer":
// Process asynchronously to not block other messages received.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), s.hub.mcuTimeout)
ctx, cancel := context.WithTimeout(s.Context(), s.hub.mcuTimeout)
defer cancel()
mc, err := s.GetOrCreateSubscriber(ctx, s.hub.mcu, message.SendOffer.SessionId, StreamType(message.SendOffer.Data.RoomType))
@ -1068,7 +1048,7 @@ func (s *ClientSession) processAsyncMessage(message *AsyncMessage) {
return
}
mc.SendMessage(context.TODO(), nil, message.SendOffer.Data, func(err error, response map[string]interface{}) {
mc.SendMessage(s.Context(), nil, message.SendOffer.Data, func(err error, response map[string]interface{}) {
if err != nil {
log.Printf("Could not send MCU message %+v for session %s to %s: %s", message.SendOffer.Data, message.SendOffer.SessionId, s.PublicId(), err)
if err := s.events.PublishSessionMessage(message.SendOffer.SessionId, s.backend, &AsyncMessage{

6
go.mod
View file

@ -22,7 +22,7 @@ require (
go.etcd.io/etcd/client/v3 v3.5.12
go.etcd.io/etcd/server/v3 v3.5.12
go.uber.org/zap v1.27.0
google.golang.org/grpc v1.63.2
google.golang.org/grpc v1.64.0
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0
google.golang.org/protobuf v1.34.1
)
@ -82,8 +82,8 @@ require (
golang.org/x/text v0.14.0 // indirect
golang.org/x/time v0.5.0 // indirect
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de // indirect
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 // indirect
gopkg.in/natefinch/lumberjack.v2 v2.0.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
sigs.k8s.io/yaml v1.2.0 // indirect

18
go.sum
View file

@ -1,7 +1,7 @@
cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.34.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw=
cloud.google.com/go v0.112.0 h1:tpFCD7hpHFlQ8yPwT3x+QeXqc2T6+n6T+hmABHfDUSM=
cloud.google.com/go/compute v1.24.0 h1:phWcR2eWzRJaL/kOiJwfFsPs4BaKq1j6vnpZrc1YlVg=
cloud.google.com/go/compute v1.25.1 h1:ZRpHJedLtTpKgr3RV1Fx23NuaAEN1Zfx9hw1u4aJdjU=
cloud.google.com/go/compute/metadata v0.2.3 h1:mg4jlk7mCAj6xXp9UJ4fjI9VUI5rubuGBW5aJ7UnBMY=
github.com/BurntSushi/toml v0.3.1 h1:WXkYYl6Yr3qBf1K79EBnL4mak0OimBfB0XUf9Vl28OQ=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
@ -15,7 +15,7 @@ github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj
github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw=
github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc=
github.com/cncf/xds/go v0.0.0-20231128003011-0fa0005c9caa h1:jQCWAUqqlij9Pgj2i/PB79y4KOPYVyFYdROxgaCwdTQ=
github.com/cncf/xds/go v0.0.0-20240318125728-8a4994d93e50 h1:DBmgJDC9dTfkVyGgipamEh2BpGYxScCH1TOF1LL1cXc=
github.com/cockroachdb/datadriven v1.0.2 h1:H9MtNqVoVhvd9nCBwOyDjUEdZCREqbIdCJD93PBm/jA=
github.com/coreos/go-semver v0.3.0 h1:wkHLiw0WNATZnSG7epLsujiMCgPAc9xhjJ4tgnAxmfM=
github.com/coreos/go-semver v0.3.0/go.mod h1:nnelYz7RCh+5ahJtPPxZlU+153eP4D4r3EedlOD2RNk=
@ -229,7 +229,7 @@ golang.org/x/net v0.23.0 h1:7EYJ93RZ9vYSZAIb2x3lnuvqO5zneoD6IvWjuhfxjTs=
golang.org/x/net v0.23.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U=
golang.org/x/oauth2 v0.0.0-20200107190931-bf48bf16ab8d/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw=
golang.org/x/oauth2 v0.17.0 h1:6m3ZPmLEFdVxKKWnKq4VqZ60gutO35zm+zrAHVmHyDQ=
golang.org/x/oauth2 v0.18.0 h1:09qnuIAgzdx1XplqJvW6CQqMCtGZykZWcXzPMPUusvI=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181221193216-37e7f081c4d4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
@ -273,18 +273,18 @@ google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfG
google.golang.org/genproto v0.0.0-20200513103714-09dca8ec2884/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c=
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de h1:F6qOa9AZTYJXOUEr4jDysRDLrm4PHePlge4v4TGAlxY=
google.golang.org/genproto v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:VUhTRKeHn9wwcdrk73nvdC9gF178Tzhmt/qyaFcPLSo=
google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de h1:jFNzHPIeuzhdRwVhbZdiym9q0ory/xY3sA+v2wPg8I0=
google.golang.org/genproto/googleapis/api v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:5iCWqnniDlqZHrd3neWVTOwvh/v6s3232omMecelax8=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de h1:cZGRis4/ot9uVm639a+rHCUaG0JJHEsdyzSQTMX+suY=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240227224415-6ceb2ff114de/go.mod h1:H4O17MA/PE9BsGx3w+a+W2VOLLD1Qf7oJneAoU6WktY=
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237 h1:RFiFrvy37/mpSpdySBDrUdipW/dHwsRwh3J3+A9VgT4=
google.golang.org/genproto/googleapis/api v0.0.0-20240318140521-94a12d6c2237/go.mod h1:Z5Iiy3jtmioajWHDGFk7CeugTyHtPvMHA4UTmUkyalE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237 h1:NnYq6UN9ReLM9/Y01KWNOWyI5xQ9kbIms5GGJVwS/Yc=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240318140521-94a12d6c2237/go.mod h1:WtryC6hu0hhx87FDGxWCDptyssuo68sk10vYjF+T9fY=
google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c=
google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg=
google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY=
google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk=
google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk=
google.golang.org/grpc v1.33.1/go.mod h1:fr5YgcSWrqhRRxogOsw7RzIpsmvOZ6IcH4kBYTpR3n0=
google.golang.org/grpc v1.63.2 h1:MUeiw1B2maTVZthpU5xvASfTh3LDbxHd6IJ6QQVU+xM=
google.golang.org/grpc v1.63.2/go.mod h1:WAX/8DgncnokcFUldAxq7GeB5DXHDbMF+lLvDomNkRA=
google.golang.org/grpc v1.64.0 h1:KH3VH9y/MgNQg1dE7b3XfVK0GsPSIzJwdF617gUSbvY=
google.golang.org/grpc v1.64.0/go.mod h1:oxjF8E3FBnjp+/gVFYdWacaLDx9na1aqy9oovLpxQYg=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0 h1:rNBFJjBCOgVr9pWD7rs/knKL4FRTKgpZmsRfV214zcA=
google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.3.0/go.mod h1:Dk1tviKTvMCz5tvh7t+fh94dhmQVHuCt2OzJB3CTW9Y=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=

View file

@ -115,6 +115,10 @@ func (c *remoteGrpcClient) readPump() {
}
}
func (c *remoteGrpcClient) Context() context.Context {
return c.client.Context()
}
func (c *remoteGrpcClient) RemoteAddr() string {
return c.remoteAddr
}

118
hub.go
View file

@ -850,7 +850,7 @@ func (h *Hub) processRegister(c HandlerClient, message *ClientMessage, backend *
var totalCount atomic.Uint32
totalCount.Add(uint32(backend.Len()))
var wg sync.WaitGroup
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(client.Context(), time.Second)
defer cancel()
for _, client := range h.rpcClients.GetClients() {
wg.Add(1)
@ -983,15 +983,15 @@ func (h *Hub) processMessage(client HandlerClient, data []byte) {
switch message.Type {
case "room":
h.processRoom(client, &message)
h.processRoom(session, &message)
case "message":
h.processMessageMsg(client, &message)
h.processMessageMsg(session, &message)
case "control":
h.processControlMsg(client, &message)
h.processControlMsg(session, &message)
case "internal":
h.processInternalMsg(client, &message)
h.processInternalMsg(session, &message)
case "transient":
h.processTransientMsg(client, &message)
h.processTransientMsg(session, &message)
case "bye":
h.processByeMsg(client, &message)
case "hello":
@ -1035,7 +1035,7 @@ func (h *Hub) tryProxyResume(c HandlerClient, resumeId string, message *ClientMe
return false
}
rpcCtx, rpcCancel := context.WithTimeout(context.Background(), 5*time.Second)
rpcCtx, rpcCancel := context.WithTimeout(c.Context(), 5*time.Second)
defer rpcCancel()
var wg sync.WaitGroup
@ -1174,7 +1174,7 @@ func (h *Hub) processHello(client HandlerClient, message *ClientMessage) {
}
}
func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
func (h *Hub) processHelloV1(ctx context.Context, client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
url := message.Hello.Auth.parsedUrl
backend := h.backend.GetBackend(url)
if backend == nil {
@ -1182,7 +1182,7 @@ func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Bac
}
// Run in timeout context to prevent blocking too long.
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
ctx, cancel := context.WithTimeout(ctx, h.backendTimeout)
defer cancel()
var auth BackendClientResponse
@ -1196,7 +1196,7 @@ func (h *Hub) processHelloV1(client HandlerClient, message *ClientMessage) (*Bac
return backend, &auth, nil
}
func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
func (h *Hub) processHelloV2(ctx context.Context, client HandlerClient, message *ClientMessage) (*Backend, *BackendClientResponse, error) {
url := message.Hello.Auth.parsedUrl
backend := h.backend.GetBackend(url)
if backend == nil {
@ -1243,16 +1243,16 @@ func (h *Hub) processHelloV2(client HandlerClient, message *ClientMessage) (*Bac
}
// Run in timeout context to prevent blocking too long.
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
backendCtx, cancel := context.WithTimeout(ctx, h.backendTimeout)
defer cancel()
keyData, cached, found := h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
keyData, cached, found := h.backend.capabilities.GetStringConfig(backendCtx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
if !found {
if cached {
// The Nextcloud instance might just have enabled JWT but we probably use
// the cached capabilities without the public key. Make sure to re-fetch.
h.backend.capabilities.InvalidateCapabilities(url)
keyData, _, found = h.backend.capabilities.GetStringConfig(ctx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
keyData, _, found = h.backend.capabilities.GetStringConfig(backendCtx, url, ConfigGroupSignaling, ConfigKeyHelloV2TokenKey)
}
if !found {
return nil, fmt.Errorf("No key found for issuer")
@ -1306,7 +1306,7 @@ func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) {
// Make sure the client must send another "hello" in case of errors.
defer h.startExpectHello(client)
var authFunc func(HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error)
var authFunc func(context.Context, HandlerClient, *ClientMessage) (*Backend, *BackendClientResponse, error)
switch message.Hello.Version {
case HelloVersionV1:
// Auth information contains a ticket that must be validated against the
@ -1320,7 +1320,7 @@ func (h *Hub) processHelloClient(client HandlerClient, message *ClientMessage) {
return
}
backend, auth, err := authFunc(client, message)
backend, auth, err := authFunc(client.Context(), client, message)
if err != nil {
if e, ok := err.(*Error); ok {
client.SendMessage(message.NewErrorServerMessage(e))
@ -1422,18 +1422,14 @@ func (h *Hub) sendRoom(session *ClientSession, message *ClientMessage, room *Roo
return session.SendMessage(response)
}
func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
session, ok := client.GetSession().(*ClientSession)
func (h *Hub) processRoom(sess Session, message *ClientMessage) {
session, ok := sess.(*ClientSession)
if !ok {
return
}
roomId := message.Room.RoomId
if roomId == "" {
if session == nil {
return
}
// We can handle leaving a room directly.
if session.LeaveRoom(true) != nil {
// User was in a room before, so need to notify about leaving it.
@ -1446,13 +1442,6 @@ func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
return
}
if session == nil {
session.SendMessage(message.NewErrorServerMessage(
NewError("not_authenticated", "Need to authenticate before joining rooms."),
))
return
}
if room := h.getRoomForBackend(roomId, session.Backend()); room != nil && room.HasSession(session) {
// Session already is in that room, no action needed.
roomSessionId := message.Room.SessionId
@ -1487,7 +1476,7 @@ func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
}
} else {
// Run in timeout context to prevent blocking too long.
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
ctx, cancel := context.WithTimeout(session.Context(), h.backendTimeout)
defer cancel()
sessionId := message.Room.SessionId
@ -1507,7 +1496,7 @@ func (h *Hub) processRoom(client HandlerClient, message *ClientMessage) {
if message.Room.SessionId != "" {
// There can only be one connection per Nextcloud Talk session,
// disconnect any other connections without sending a "leave" event.
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
ctx, cancel := context.WithTimeout(session.Context(), time.Second)
defer cancel()
h.disconnectByRoomSessionId(ctx, message.Room.SessionId, session.Backend())
@ -1600,9 +1589,9 @@ func (h *Hub) processJoinRoom(session *ClientSession, message *ClientMessage, ro
r.AddSession(session, room.Room.Session)
}
func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
session, ok := client.GetSession().(*ClientSession)
if session == nil || !ok {
func (h *Hub) processMessageMsg(sess Session, message *ClientMessage) {
session, ok := sess.(*ClientSession)
if !ok {
// Client is not connected yet.
return
}
@ -1654,10 +1643,13 @@ func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
// User is stopping to share his screen. Firefox doesn't properly clean
// up the peer connections in all cases, so make sure to stop publishing
// in the MCU.
go func(c HandlerClient) {
time.Sleep(cleanupScreenPublisherDelay)
session, ok := c.GetSession().(*ClientSession)
if session == nil || !ok {
go func(session *ClientSession) {
sleepCtx, cancel := context.WithTimeout(session.Context(), cleanupScreenPublisherDelay)
defer cancel()
<-sleepCtx.Done()
if session.Context().Err() != nil {
// Session was closed while waiting.
return
}
@ -1670,7 +1662,7 @@ func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
defer cancel()
publisher.Close(ctx)
}(client)
}(session)
}
}
}
@ -1778,7 +1770,7 @@ func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
// client) to start his stream, so we must not block the active
// goroutine.
go func() {
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
ctx, cancel := context.WithTimeout(session.Context(), h.mcuTimeout)
defer cancel()
mc, err := recipient.GetOrCreateSubscriber(ctx, h.mcu, session.PublicId(), StreamType(clientData.RoomType))
@ -1792,7 +1784,7 @@ func (h *Hub) processMessageMsg(client HandlerClient, message *ClientMessage) {
return
}
mc.SendMessage(context.TODO(), msg, clientData, func(err error, response map[string]interface{}) {
mc.SendMessage(session.Context(), msg, clientData, func(err error, response map[string]interface{}) {
if err != nil {
log.Printf("Could not send MCU message %+v for session %s to %s: %s", clientData, session.PublicId(), recipient.PublicId(), err)
sendMcuProcessingFailed(session, message)
@ -1870,13 +1862,9 @@ func isAllowedToControl(session Session) bool {
return false
}
func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) {
func (h *Hub) processControlMsg(session Session, message *ClientMessage) {
msg := message.Control
session := client.GetSession()
if session == nil {
// Client is not connected yet.
return
} else if !isAllowedToControl(session) {
if !isAllowedToControl(session) {
log.Printf("Ignore control message %+v from %s", msg, session.PublicId())
return
}
@ -1983,10 +1971,10 @@ func (h *Hub) processControlMsg(client HandlerClient, message *ClientMessage) {
}
}
func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) {
func (h *Hub) processInternalMsg(sess Session, message *ClientMessage) {
msg := message.Internal
session, ok := client.GetSession().(*ClientSession)
if session == nil || !ok {
session, ok := sess.(*ClientSession)
if !ok {
// Client is not connected yet.
return
} else if session.ClientType() != HelloClientTypeInternal {
@ -2019,7 +2007,7 @@ func (h *Hub) processInternalMsg(client HandlerClient, message *ClientMessage) {
return
}
ctx, cancel := context.WithTimeout(context.Background(), h.backendTimeout)
ctx, cancel := context.WithTimeout(session.Context(), h.backendTimeout)
defer cancel()
virtualSessionId := GetVirtualSessionId(session, msg.SessionId)
@ -2200,14 +2188,7 @@ func isAllowedToUpdateTransientData(session Session) bool {
return false
}
func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage) {
msg := message.TransientData
session := client.GetSession()
if session == nil {
// Client is not connected yet.
return
}
func (h *Hub) processTransientMsg(session Session, message *ClientMessage) {
room := session.GetRoom()
if room == nil {
response := message.NewErrorServerMessage(NewError("not_in_room", "No room joined yet."))
@ -2215,6 +2196,7 @@ func (h *Hub) processTransientMsg(client HandlerClient, message *ClientMessage)
return
}
msg := message.TransientData
switch msg.Type {
case "set":
if !isAllowedToUpdateTransientData(session) {
@ -2318,7 +2300,7 @@ func (h *Hub) isInSameCall(ctx context.Context, senderSession *ClientSession, re
}
func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMessage, message *MessageClientMessage, data *MessageClientMessageData) {
ctx, cancel := context.WithTimeout(context.Background(), h.mcuTimeout)
ctx, cancel := context.WithTimeout(session.Context(), h.mcuTimeout)
defer cancel()
var mc McuClient
@ -2352,11 +2334,6 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe
sendNotAllowed(session, client_message, "Not allowed to publish.")
return
}
if err, ok := err.(*SdpError); ok {
log.Printf("Session %s sent unsupported offer %s, ignoring (%s)", session.PublicId(), data.RoomType, err)
sendNotAllowed(session, client_message, "Not allowed to publish.")
return
}
case "selectStream":
if session.PublicId() == message.Recipient.SessionId {
log.Printf("Not selecting substream for own %s stream in session %s", data.RoomType, session.PublicId())
@ -2390,7 +2367,7 @@ func (h *Hub) processMcuMessage(session *ClientSession, client_message *ClientMe
return
}
mc.SendMessage(context.TODO(), message, data, func(err error, response map[string]interface{}) {
mc.SendMessage(session.Context(), message, data, func(err error, response map[string]interface{}) {
if err != nil {
log.Printf("Could not send MCU message %+v for session %s to %s: %s", data, session.PublicId(), message.Recipient.SessionId, err)
sendMcuProcessingFailed(session, client_message)
@ -2563,7 +2540,7 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
return
}
client, err := NewClient(conn, addr, agent, h)
client, err := NewClient(r.Context(), conn, addr, agent, h)
if err != nil {
log.Printf("Could not create client for %s: %s", addr, err)
return
@ -2575,11 +2552,10 @@ func (h *Hub) serveWs(w http.ResponseWriter, r *http.Request) {
defer h.writePumpActive.Add(-1)
client.WritePump()
}(h)
go func(h *Hub) {
h.readPumpActive.Add(1)
defer h.readPumpActive.Add(-1)
client.ReadPump()
}(h)
h.readPumpActive.Add(1)
defer h.readPumpActive.Add(-1)
client.ReadPump()
}
func (h *Hub) OnLookupCountry(client HandlerClient) string {

View file

@ -4697,6 +4697,30 @@ func TestClientRequestOfferNotInRoom(t *testing.T) {
if err := client2.RunUntilOffer(ctx, MockSdpOfferAudioAndVideo); err != nil {
t.Fatal(err)
}
if err := client2.SendMessage(MessageClientMessageRecipient{
Type: "session",
SessionId: hello1.Hello.SessionId,
}, MessageClientMessageData{
Type: "answer",
Sid: "12345",
RoomType: "screen",
Payload: map[string]interface{}{
"sdp": MockSdpAnswerAudioAndVideo,
},
}); err != nil {
t.Fatal(err)
}
// The sender won't get a reply...
ctx2, cancel2 := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel2()
if message, err := client2.RunUntilMessage(ctx2); err != nil && err != ErrNoMessageReceived && err != context.DeadlineExceeded {
t.Error(err)
} else if message != nil {
t.Errorf("Expected no message, got %+v", message)
}
})
}
}

View file

@ -162,6 +162,7 @@ func (p *mcuProxyPublisher) SetMedia(mt MediaType) {
}
func (p *mcuProxyPublisher) NotifyClosed() {
log.Printf("Publisher %s at %s was closed", p.proxyId, p.conn)
p.listener.PublisherClosed(p)
p.conn.removePublisher(p)
}
@ -185,7 +186,7 @@ func (p *mcuProxyPublisher) Close(ctx context.Context) {
return
}
log.Printf("Delete publisher %s at %s", p.proxyId, p.conn)
log.Printf("Deleted publisher %s at %s", p.proxyId, p.conn)
}
func (p *mcuProxyPublisher) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) {
@ -243,6 +244,7 @@ func (s *mcuProxySubscriber) Publisher() string {
}
func (s *mcuProxySubscriber) NotifyClosed() {
log.Printf("Subscriber %s at %s was closed", s.proxyId, s.conn)
s.listener.SubscriberClosed(s)
s.conn.removeSubscriber(s)
}
@ -266,7 +268,7 @@ func (s *mcuProxySubscriber) Close(ctx context.Context) {
return
}
log.Printf("Delete subscriber %s at %s", s.proxyId, s.conn)
log.Printf("Deleted subscriber %s at %s", s.proxyId, s.conn)
}
func (s *mcuProxySubscriber) SendMessage(ctx context.Context, message *MessageClientMessage, data *MessageClientMessageData, callback func(error, map[string]interface{})) {

View file

@ -253,6 +253,8 @@ func (s *TestMCUSubscriber) SendMessage(ctx context.Context, message *MessageCli
"type": "offer",
"sdp": sdp,
})
case "answer":
callback(nil, nil)
default:
callback(fmt.Errorf("Message type %s is not implemented", data.Type), nil)
}

View file

@ -777,9 +777,10 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s
fallthrough
case "candidate":
mcuData = &signaling.MessageClientMessageData{
Type: payload.Type,
Sid: payload.Sid,
Payload: payload.Payload,
RoomType: string(mcuClient.StreamType()),
Type: payload.Type,
Sid: payload.Sid,
Payload: payload.Payload,
}
case "endOfCandidates":
// Ignore but confirm, not passed along to Janus anyway.
@ -796,14 +797,21 @@ func (s *ProxyServer) processPayload(ctx context.Context, client *ProxyClient, s
fallthrough
case "sendoffer":
mcuData = &signaling.MessageClientMessageData{
Type: payload.Type,
Sid: payload.Sid,
RoomType: string(mcuClient.StreamType()),
Type: payload.Type,
Sid: payload.Sid,
}
default:
session.sendMessage(message.NewErrorServerMessage(UnsupportedPayload))
return
}
if err := mcuData.CheckValid(); err != nil {
log.Printf("Received invalid payload %+v for %s client %s: %s", mcuData, mcuClient.StreamType(), payload.ClientId, err)
session.sendMessage(message.NewErrorServerMessage(UnsupportedPayload))
return
}
mcuClient.SendMessage(ctx, nil, mcuData, func(err error, response map[string]interface{}) {
var responseMsg *signaling.ProxyServerMessage
if err != nil {

View file

@ -51,6 +51,8 @@ func NewRemoteSession(hub *Hub, client *Client, remoteClient *GrpcClient, sessio
client.SetSessionId(sessionId)
client.SetHandler(remoteSession)
// Don't use "client.Context()" here as it could close the proxy connection
// before any final messages are forwarded to the remote end.
proxy, err := remoteClient.ProxySession(context.Background(), sessionId, remoteSession)
if err != nil {
return nil, err

View file

@ -22,6 +22,7 @@
package signaling
import (
"context"
"encoding/json"
"errors"
"net/url"
@ -32,6 +33,10 @@ type DummySession struct {
publicId string
}
func (s *DummySession) Context() context.Context {
return context.Background()
}
func (s *DummySession) PrivateId() string {
return ""
}

View file

@ -22,6 +22,7 @@
package signaling
import (
"context"
"encoding/json"
"net/url"
"time"
@ -53,6 +54,7 @@ type SessionIdData struct {
}
type Session interface {
Context() context.Context
PrivateId() string
PublicId() string
ClientType() string

View file

@ -85,6 +85,10 @@ func NewVirtualSession(session *ClientSession, privateId string, publicId string
return result, nil
}
func (s *VirtualSession) Context() context.Context {
return s.session.Context()
}
func (s *VirtualSession) PrivateId() string {
return s.privateId
}