diff --git a/cmd/state-svc/internal/resolver/resolver.go b/cmd/state-svc/internal/resolver/resolver.go index 8660c31e35..a58dd7008a 100644 --- a/cmd/state-svc/internal/resolver/resolver.go +++ b/cmd/state-svc/internal/resolver/resolver.go @@ -23,14 +23,16 @@ import ( type Resolver struct { cfg *config.Instance cache *cache.Cache + done chan bool } // var _ genserver.ResolverRoot = &Resolver{} // Must implement ResolverRoot -func New(cfg *config.Instance) *Resolver { +func New(cfg *config.Instance, done chan bool) *Resolver { return &Resolver{ cfg, cache.New(12*time.Hour, time.Hour), + done, } } @@ -38,6 +40,8 @@ func New(cfg *config.Instance) *Resolver { // So far no need for this, so we're pointing back at ourselves.. func (r *Resolver) Query() genserver.QueryResolver { return r } +func (r *Resolver) Subscription() genserver.SubscriptionResolver { return r } + func (r *Resolver) Version(ctx context.Context) (*graph.Version, error) { logging.Debug("Version resolver") return &graph.Version{ @@ -129,3 +133,18 @@ func (r *Resolver) Projects(ctx context.Context) ([]*graph.Project, error) { return projects, nil } + +func (r *Resolver) Quit(ctx context.Context) (<-chan bool, error) { + logging.Debug("Quit resolver") + done := make(chan bool) + + // Each subscriber gets its own copy of a done channel. We wait on the main + // done channel and send a response once that channel is closed or a response + // is sent + go func() { + <-r.done + done <- true + }() + + return done, nil +} diff --git a/cmd/state-svc/internal/server/generated/generated.go b/cmd/state-svc/internal/server/generated/generated.go index 3404af4789..f4ea3c6f7a 100644 --- a/cmd/state-svc/internal/server/generated/generated.go +++ b/cmd/state-svc/internal/server/generated/generated.go @@ -6,6 +6,7 @@ import ( "bytes" "context" "errors" + "io" "strconv" "sync" "sync/atomic" @@ -36,6 +37,7 @@ type Config struct { type ResolverRoot interface { Query() QueryResolver + Subscription() SubscriptionResolver } type DirectiveRoot struct { @@ -76,6 +78,10 @@ type ComplexityRoot struct { Version func(childComplexity int) int } + Subscription struct { + Quit func(childComplexity int) int + } + Version struct { State func(childComplexity int) int } @@ -87,6 +93,9 @@ type QueryResolver interface { Update(ctx context.Context, channel *string, version *string) (*graph.DeferredUpdate, error) Projects(ctx context.Context) ([]*graph.Project, error) } +type SubscriptionResolver interface { + Quit(ctx context.Context) (<-chan bool, error) +} type executableSchema struct { resolvers ResolverRoot @@ -241,6 +250,13 @@ func (e *executableSchema) Complexity(typeName, field string, childComplexity in return e.complexity.StateVersion.Version(childComplexity), true + case "Subscription.quit": + if e.complexity.Subscription.Quit == nil { + break + } + + return e.complexity.Subscription.Quit(childComplexity), true + case "Version.state": if e.complexity.Version.State == nil { break @@ -268,6 +284,23 @@ func (e *executableSchema) Exec(ctx context.Context) graphql.ResponseHandler { var buf bytes.Buffer data.MarshalGQL(&buf) + return &graphql.Response{ + Data: buf.Bytes(), + } + } + case ast.Subscription: + next := ec._Subscription(ctx, rc.Operation.SelectionSet) + + var buf bytes.Buffer + return func(ctx context.Context) *graphql.Response { + buf.Reset() + data := next() + + if data == nil { + return nil + } + data.MarshalGQL(&buf) + return &graphql.Response{ Data: buf.Bytes(), } @@ -335,6 +368,10 @@ type Query { update(channel: String, version: String): DeferredUpdate projects: [Project]! } + +type Subscription { + quit: Boolean! +} `, BuiltIn: false}, } var parsedSchema = gqlparser.MustLoadSchema(sources...) @@ -1154,6 +1191,51 @@ func (ec *executionContext) _StateVersion_date(ctx context.Context, field graphq return ec.marshalNString2string(ctx, field.Selections, res) } +func (ec *executionContext) _Subscription_quit(ctx context.Context, field graphql.CollectedField) (ret func() graphql.Marshaler) { + defer func() { + if r := recover(); r != nil { + ec.Error(ctx, ec.Recover(ctx, r)) + ret = nil + } + }() + fc := &graphql.FieldContext{ + Object: "Subscription", + Field: field, + Args: nil, + IsMethod: true, + IsResolver: true, + } + + ctx = graphql.WithFieldContext(ctx, fc) + resTmp, err := ec.ResolverMiddleware(ctx, func(rctx context.Context) (interface{}, error) { + ctx = rctx // use context from middleware stack in children + return ec.resolvers.Subscription().Quit(rctx) + }) + if err != nil { + ec.Error(ctx, err) + return nil + } + if resTmp == nil { + if !graphql.HasFieldError(ctx, fc) { + ec.Errorf(ctx, "must not be null") + } + return nil + } + return func() graphql.Marshaler { + res, ok := <-resTmp.(<-chan bool) + if !ok { + return nil + } + return graphql.WriterFunc(func(w io.Writer) { + w.Write([]byte{'{'}) + graphql.MarshalString(field.Alias).MarshalGQL(w) + w.Write([]byte{':'}) + ec.marshalNBoolean2bool(ctx, field.Selections, res).MarshalGQL(w) + w.Write([]byte{'}'}) + }) + } +} + func (ec *executionContext) _Version_state(ctx context.Context, field graphql.CollectedField, obj *graph.Version) (ret graphql.Marshaler) { defer func() { if r := recover(); r != nil { @@ -2524,6 +2606,26 @@ func (ec *executionContext) _StateVersion(ctx context.Context, sel ast.Selection return out } +var subscriptionImplementors = []string{"Subscription"} + +func (ec *executionContext) _Subscription(ctx context.Context, sel ast.SelectionSet) func() graphql.Marshaler { + fields := graphql.CollectFields(ec.OperationContext, sel, subscriptionImplementors) + ctx = graphql.WithFieldContext(ctx, &graphql.FieldContext{ + Object: "Subscription", + }) + if len(fields) != 1 { + ec.Errorf(ctx, "must subscribe to exactly one stream") + return nil + } + + switch fields[0].Name { + case "quit": + return ec._Subscription_quit(ctx, fields[0]) + default: + panic("unknown field " + strconv.Quote(fields[0].Name)) + } +} + var versionImplementors = []string{"Version"} func (ec *executionContext) _Version(ctx context.Context, sel ast.SelectionSet, obj *graph.Version) graphql.Marshaler { diff --git a/cmd/state-svc/internal/server/router.go b/cmd/state-svc/internal/server/router.go index b6d16148d3..ff8ab159d4 100644 --- a/cmd/state-svc/internal/server/router.go +++ b/cmd/state-svc/internal/server/router.go @@ -18,8 +18,13 @@ func (s *Server) setupRouting() { return nil }) + s.httpServer.GET("/subscriptions", func(c echo.Context) error { + s.graphServer.ServeHTTP(c.Response(), c.Request()) + return nil + }) + s.httpServer.GET(QuitRoute, func(c echo.Context) error { - s.shutdown() + s.quit() return nil }) } diff --git a/cmd/state-svc/internal/server/server.go b/cmd/state-svc/internal/server/server.go index e3c2c442dd..efbaaf9b1d 100644 --- a/cmd/state-svc/internal/server/server.go +++ b/cmd/state-svc/internal/server/server.go @@ -3,6 +3,7 @@ package server import ( "context" "net" + "net/http" "strconv" "time" @@ -10,6 +11,7 @@ import ( "github.com/99designs/gqlgen/graphql/handler/extension" "github.com/99designs/gqlgen/graphql/handler/lru" "github.com/99designs/gqlgen/graphql/handler/transport" + "github.com/gorilla/websocket" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" @@ -21,21 +23,23 @@ import ( ) type Server struct { - shutdown context.CancelFunc + cancel context.CancelFunc + done chan bool graphServer *handler.Server listener net.Listener httpServer *echo.Echo port int } -func New(cfg *config.Instance, shutdown context.CancelFunc) (*Server, error) { +func New(cfg *config.Instance, cancel context.CancelFunc) (*Server, error) { listener, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { return nil, errs.Wrap(err, "Failed to listen") } - s := &Server{shutdown: shutdown} - s.graphServer = newGraphServer(cfg) + s := &Server{cancel: cancel} + s.done = make(chan bool) + s.graphServer = newGraphServer(cfg, s.done) s.listener = listener s.httpServer = newHTTPServer(listener) @@ -61,6 +65,11 @@ func (s *Server) Start() error { return s.httpServer.Start(s.listener.Addr().String()) } +func (s *Server) quit() { + close(s.done) + s.cancel() +} + func (s *Server) Shutdown() error { logging.Debug("shutting down server") ctx, cancel := context.WithTimeout(context.Background(), time.Second) @@ -72,9 +81,17 @@ func (s *Server) Shutdown() error { return nil } -func newGraphServer(cfg *config.Instance) *handler.Server { - graphServer := handler.NewDefaultServer(genserver.NewExecutableSchema(genserver.Config{Resolvers: resolver.New(cfg)})) - graphServer.AddTransport(&transport.Websocket{}) +func newGraphServer(cfg *config.Instance, done chan bool) *handler.Server { + graphServer := handler.NewDefaultServer(genserver.NewExecutableSchema(genserver.Config{Resolvers: resolver.New(cfg, done)})) + graphServer.AddTransport(&transport.Websocket{ + Upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + // For development. User proper CORS for prod + return true + }, + }, + KeepAlivePingInterval: 10 * time.Second, + }) graphServer.SetQueryCache(lru.New(1000)) graphServer.Use(extension.Introspection{}) graphServer.Use(extension.AutomaticPersistedQuery{ diff --git a/cmd/state-svc/schema/schema.graphqls b/cmd/state-svc/schema/schema.graphqls index 5de8f8df7f..19312c1104 100644 --- a/cmd/state-svc/schema/schema.graphqls +++ b/cmd/state-svc/schema/schema.graphqls @@ -35,3 +35,7 @@ type Query { update(channel: String, version: String): DeferredUpdate projects: [Project]! } + +type Subscription { + quit: Boolean! +} diff --git a/cmd/state-tray/main.go b/cmd/state-tray/main.go index 000dc9d3b7..242f9a75a7 100644 --- a/cmd/state-tray/main.go +++ b/cmd/state-tray/main.go @@ -172,7 +172,10 @@ func run() (rerr error) { systray.AddSeparator() - mQuit := systray.AddMenuItem(locale.Tl("tray_exit", "Exit"), "") + quit, err := setupQuit(model) + if err != nil { + return errs.Wrap(err, "Could not setup quit channel") + } for { select { @@ -242,13 +245,39 @@ func run() (rerr error) { if err := execute(updlgInfo.Exec(), nil); err != nil { return errs.New("Could not execute: %s", updlgInfo.Name()) } - case <-mQuit.ClickedCh: + case <-quit: logging.Debug("Quit event") + model.CloseSubscriptions() systray.Quit() } } } +func setupQuit(model *model.SvcModel) (chan struct{}, error) { + mQuit := systray.AddMenuItem(locale.Tl("tray_exit", "Exit"), "") + quitSub, err := model.Quit(context.Background()) + if err != nil { + return nil, errs.Wrap(err, "Could not subscribe so service quit event") + } + + quit := make(chan struct{}) + + go func() { + for { + select { + case <-mQuit.ClickedCh: + logging.Debug("Quit clicked") + quit <- struct{}{} + case <-quitSub: + logging.Debug("Quit from subscription") + quit <- struct{}{} + } + } + }() + + return quit, nil +} + func onExit() { logging.Debug("systray.OnExit() was called.") cfg, err := config.New() diff --git a/cmd/state-tray/update.go b/cmd/state-tray/update.go index 83f7369993..4bb04d7000 100644 --- a/cmd/state-tray/update.go +++ b/cmd/state-tray/update.go @@ -15,7 +15,7 @@ const ( ) func superviseUpdate(mdl *model.SvcModel, notice *updateNotice) func() { - var done chan struct{} + done := make(chan struct{}) go func() { for { diff --git a/go.mod b/go.mod index b70da79c7b..3729b60ef3 100644 --- a/go.mod +++ b/go.mod @@ -38,12 +38,13 @@ require ( github.com/go-openapi/validate v0.20.2 github.com/gobuffalo/packr v1.10.7 github.com/golang/protobuf v1.4.3 // indirect - github.com/google/uuid v1.1.2 + github.com/google/uuid v1.3.0 github.com/gorilla/websocket v1.4.2 github.com/hashicorp/go-cleanhttp v0.5.1 github.com/hashicorp/go-retryablehttp v0.6.7 github.com/hashicorp/go-version v1.1.0 github.com/hashicorp/golang-lru v0.5.4 // indirect + github.com/hasura/go-graphql-client v0.2.0 github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 github.com/hpcloud/tail v1.0.0 github.com/iafan/cwalk v0.0.0-20191125092548-dd7f505d2f66 @@ -52,6 +53,7 @@ require ( github.com/jessevdk/go-flags v1.4.0 github.com/kami-zh/go-capturer v0.0.0-20171211120116-e492ea43421d github.com/kardianos/osext v0.0.0-20170510131534-ae77be60afb1 + github.com/klauspost/compress v1.13.4 // indirect github.com/labstack/echo/v4 v4.2.1 github.com/machinebox/graphql v0.2.2 github.com/mailru/easyjson v0.7.7 // indirect @@ -87,7 +89,7 @@ require ( github.com/yuin/goldmark v1.3.5 go.mongodb.org/mongo-driver v1.5.3 // indirect golang.org/x/crypto v0.0.0-20200820211705-5c72a883971a - golang.org/x/net v0.0.0-20210614182718-04defd469f4e + golang.org/x/net v0.0.0-20210825183410-e898025ed96a golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c golang.org/x/text v0.3.6 google.golang.org/genproto v0.0.0-20200815001618-f69a88009b70 @@ -98,4 +100,5 @@ require ( gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/yaml.v2 v2.4.0 modernc.org/sqlite v1.11.2 + nhooyr.io/websocket v1.8.7 // indirect ) diff --git a/go.sum b/go.sum index eddca846c2..7ae559c873 100644 --- a/go.sum +++ b/go.sum @@ -191,6 +191,8 @@ github.com/getlantern/ops v0.0.0-20190325191751-d70cb0d6f85f/go.mod h1:D5ao98qkA github.com/getlantern/systray v1.1.0 h1:U0wCEqseLi2ok1fE6b88gJklzriavPJixZysZPkZd/Y= github.com/getlantern/systray v1.1.0/go.mod h1:AecygODWIsBquJCJFop8MEQcJbWFfw/1yWbVabNgpCM= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= github.com/gliderlabs/ssh v0.2.2 h1:6zsha5zo/TWhRhwqCD3+EarCAgZ2yN28ipRnGPnwkI0= github.com/gliderlabs/ssh v0.2.2/go.mod h1:U7qILu1NlMHj9FlMhZLlkCdDnU1DBEAqr0aevW3Awn0= github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= @@ -298,8 +300,12 @@ github.com/go-openapi/validate v0.19.15/go.mod h1:tbn/fdOwYHgrhPBzidZfJC2MIVvs9G github.com/go-openapi/validate v0.20.1/go.mod h1:b60iJT+xNNLfaQJUqLI7946tYiFEOuE9E4k54HpKcJ0= github.com/go-openapi/validate v0.20.2 h1:AhqDegYV3J3iQkMPJSXkvzymHKMTw0BST3RK3hTT4ts= github.com/go-openapi/validate v0.20.2/go.mod h1:e7OJoKNgd0twXZwIn0A43tHbvIcr/rZIVCbJBpTUoY0= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/colors v1.2.0 h1:0EdjTXKrr2g1L/LQTYtIqabeHpZuGZz1U4osS1T8+5M= github.com/go-playground/colors v1.2.0/go.mod h1:miw1R2JIE19cclPxsXqNdzLZsk4DP4iF+m88bRc7kfM= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= github.com/go-sql-driver/mysql v1.5.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-stack/stack v1.8.0 h1:5SgMzNM5HxrEjV0ww2lTmX6E2Izsfxas4+YHWRs3Lsk= github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= @@ -329,6 +335,9 @@ github.com/gobuffalo/packr v1.10.7/go.mod h1:rYwMLC6NXbAbkKb+9j3NTKbxSswkKLlelZY github.com/gobuffalo/packr/v2 v2.0.9/go.mod h1:emmyGweYTm6Kdper+iywB6YK5YzuKchGtJQZ0Odn4pQ= github.com/gobuffalo/packr/v2 v2.2.0/go.mod h1:CaAwI0GPIAv+5wKLtv8Afwl+Cm78K/I/VCm/3ptBN+0= github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= github.com/gogo/protobuf v1.0.0/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= @@ -362,6 +371,8 @@ github.com/golang/protobuf v1.4.3 h1:JjCZWpVbqXDqFVmTfYWEVTMIYrL/NPdPSCHPJ0T/raM github.com/golang/protobuf v1.4.3/go.mod h1:oDoupMAO8OvCJWAcko0GGGIgR6R6ocIYbsSw735rRwI= github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/golang/snappy v0.0.3 h1:fHPg5GQYlCeLIPB9BZqMVR5nR9A+IM5zcgeTdjMYmLA= +github.com/golang/snappy v0.0.3/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= @@ -374,6 +385,7 @@ github.com/google/go-cmp v0.5.1/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.3 h1:x95R7cp+rSeeqAMI2knLtQ0DKlaBhv2NrtrOvafPHRo= github.com/google/go-cmp v0.5.3/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/google/martian v2.1.0+incompatible/go.mod h1:9I4somxYTbIHy5NJKHRl3wXiIaQGbYVAs8BPL6v8lEs= github.com/google/martian/v3 v3.0.0/go.mod h1:y5Zk1BBys9G+gd6Jrk0W3cC1+ELVxBWuIGO+w/tUAp0= github.com/google/pprof v0.0.0-20181206194817-3ea8567a2e57/go.mod h1:zfwlbNMJ+OItoe0UupaVj+oy1omPYYDuagoSzA8v9mc= @@ -390,6 +402,8 @@ github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+ github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/googleapis/gax-go/v2 v2.0.4/go.mod h1:0Wqv26UfaUD9n4G6kQubkQ+KchISgw+vpHVxEJEs9eg= github.com/googleapis/gax-go/v2 v2.0.5 h1:sjZBwGj9Jlw33ImPtvFviGYvseOtDM7hkSKB7+Tv3SM= github.com/googleapis/gax-go/v2 v2.0.5/go.mod h1:DWXyrwAJ9X0FpwwEdw+IPEYBICEFu5mhpdKc/us6bOk= @@ -397,8 +411,10 @@ github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORR github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.1/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/graph-gophers/graphql-go v0.0.0-20201112095111-7a585a01e04c/go.mod h1:9CQHMSxwO4MprSdzoIEobiHpoLtHm77vfxsvsIN5Vuc= github.com/grpc-ecosystem/go-grpc-middleware v1.0.0/go.mod h1:FiyG127CGDf3tlThmgyCl78X/SZQqEOJBCDaAfeWzPs= github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0/go.mod h1:8NvIoxWQoOIhqOTXgfV/d3M/q6VIi02HzZEHgUlZvzk= github.com/grpc-ecosystem/grpc-gateway v1.9.0/go.mod h1:vNeuVxBJEsws4ogUvrchl83t/GYV9WGTSLVdBhOQFDY= @@ -431,6 +447,8 @@ github.com/hashicorp/logutils v1.0.0/go.mod h1:QIAnNjmIWmVIIkWDTG1z5v++HQmx9WQRO github.com/hashicorp/mdns v1.0.0/go.mod h1:tL+uN++7HEJ6SQLQ2/p+z2pH24WQKWjBPkE0mNTz8vQ= github.com/hashicorp/memberlist v0.1.3/go.mod h1:ajVTdAv/9Im8oMAAj5G31PhhMCZJV2pPBoIllUwCN7I= github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/JwenrHc= +github.com/hasura/go-graphql-client v0.2.0 h1:Ux6NQFkR6RCoNEMRsdJ1Hz0P43TRpLIzAXTXnNcEMxU= +github.com/hasura/go-graphql-client v0.2.0/go.mod h1:ovhXpnBL/+pYmmdYHndovgAc/wmhaXKvKwJETjtMomQ= github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 h1:S4qyfL2sEm5Budr4KVMyEniCy+PbS55651I/a+Kn/NQ= github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95/go.mod h1:QiyDdbZLaJ/mZP4Zwc9g2QsfaEA4o7XvvgZegSci5/E= github.com/hinshun/vt10x v0.0.0-20180616224451-1954e6464174/go.mod h1:DqJ97dSdRW1W22yXSB90986pcOyQ7r45iio1KN2ez1A= @@ -462,6 +480,7 @@ github.com/jonboulle/clockwork v0.1.0/go.mod h1:Ii8DK3G1RaLaWxj9trq07+26W01tbo22 github.com/josharian/intern v1.0.0 h1:vlS4z54oSdjm0bgjRigI+G1HpF+tI+9rE5LLzOg8HmY= github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= @@ -482,6 +501,9 @@ github.com/kisielk/errcheck v1.1.0/go.mod h1:EZBBE59ingxPouuu3KfxchcWSUPOHkagtvW github.com/kisielk/errcheck v1.2.0/go.mod h1:/BMXB+zMLi60iA8Vv6Ksmxu/1UDYcXs4uQLJ+jE2L00= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/klauspost/compress v1.9.5/go.mod h1:RyIbtBH6LamlWaDj8nUwkbUhJ87Yi3uG0guNDohfE1A= +github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +github.com/klauspost/compress v1.13.4 h1:0zhec2I8zGnjWcKyLl6i3gPqKANCCn5e9xmviEEeX6s= +github.com/klauspost/compress v1.13.4/go.mod h1:8dP1Hq4DHOhN9w426knH3Rhby4rFm6D8eO+e+Dq5Gzg= github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= github.com/konsorten/go-windows-terminal-sequences v1.0.2 h1:DB17ag19krx9CFsz4o3enTrPXyIXCl+2iCXH/aMAp9s= github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= @@ -507,6 +529,7 @@ github.com/leaanthony/synx v0.1.0 h1:R0lmg2w6VMb8XcotOwAe5DLyzwjLrskNkwU7LLWsyL8 github.com/leaanthony/synx v0.1.0/go.mod h1:Iz7eybeeG8bdq640iR+CwYb8p+9EOsgMWghkSRyZcqs= github.com/leaanthony/wincursor v0.1.0 h1:Dsyp68QcF5cCs65AMBmxoYNEm0n8K7mMchG6a8fYxf8= github.com/leaanthony/wincursor v0.1.0/go.mod h1:7TVwwrzSH/2Y9gLOGH+VhA+bZhoWXBRgbGNTMk+yimE= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= github.com/logrusorgru/aurora v0.0.0-20200102142835-e9ef32dff381/go.mod h1:7rIyQOR62GCctdiQpZ/zOJlFyk6y+94wXzv6RNZgaR4= github.com/lucasb-eyer/go-colorful v0.0.0-20180526135729-345fbb3dbcdb/go.mod h1:NXg0ArsFk0Y01623LgUqoqcouGDB+PwCCQlrwrG6xJ4= github.com/machinebox/graphql v0.2.2 h1:dWKpJligYKhYKO5A2gvNhkJdQMNZeChZYyBbrZkBZfo= @@ -569,7 +592,9 @@ github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RR github.com/mitchellh/mapstructure v1.4.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/mitchellh/mapstructure v1.4.1 h1:CpVNEelQCZBooIPDn+AR3NpivK/TIKU8bDxdASFVQag= github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= @@ -585,6 +610,7 @@ github.com/oklog/ulid v1.3.1 h1:EGfNDEx6MqHz8B3uNV6QAib1UR2Lm97sHi3ocA6ESJ4= github.com/oklog/ulid v1.3.1/go.mod h1:CirwcVhetQ6Lv90oh/F+FBtV6XMibvdAFo93nm5qn4U= github.com/opentracing/basictracer-go v1.0.0/go.mod h1:QfBfYuafItcjQuMwinw9GhYKwFXS9KnPs5lxoYwgW74= github.com/opentracing/opentracing-go v1.0.2/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= github.com/oxtoacart/bpool v0.0.0-20190530202638-03653db5a59c h1:rp5dCmg/yLR3mgFuSOe4oEnDDmGLROTvMragMUXpTQw= @@ -694,6 +720,8 @@ github.com/thoas/go-funk v0.8.0/go.mod h1:+IWnUfUmFO1+WVYQWQtIJHeRRdaIyyYglZN7xz github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= github.com/ulikunitz/xz v0.5.4 h1:zATC2OoZ8H1TZll3FpbX+ikwmadbO699PE06cIkm9oU= github.com/ulikunitz/xz v0.5.4/go.mod h1:2bypXElzHzzJZwzH67Y6wb67pO62Rzfn7BSiF4ABRW8= github.com/urfave/cli/v2 v2.1.1/go.mod h1:SE9GqnLQmjVa0iPEY0f1w3ygNIYcIJ0OKPMoW2caLfQ= @@ -844,6 +872,8 @@ golang.org/x/net v0.0.0-20210119194325-5f4716e94777/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20210421230115-4e50805a0758/go.mod h1:72T/g9IO56b78aLF+1Kcs5dz7/ng1VjMUvfKvpfy+jM= golang.org/x/net v0.0.0-20210614182718-04defd469f4e h1:XpT3nA5TvE525Ne3hInMh6+GETgn27Zfm9dxsThnX2Q= golang.org/x/net v0.0.0-20210614182718-04defd469f4e/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20210825183410-e898025ed96a h1:bRuuGXV8wwSdGTB+CtJf+FjgO1APK1CoO39T4BN/XBw= +golang.org/x/net v0.0.0-20210825183410-e898025ed96a/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= golang.org/x/oauth2 v0.0.0-20190226205417-e64efc72b421/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= golang.org/x/oauth2 v0.0.0-20190604053449-0f29369cfe45/go.mod h1:gOpvHmFTYa4IltrdGE7lF6nIHvwfUNPOp7c8zoXwtLw= @@ -1150,6 +1180,9 @@ modernc.org/token v1.0.0 h1:a0jaWiNMDhDUtqOj09wvjWWAqd3q7WpBulmL9H2egsk= modernc.org/token v1.0.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= modernc.org/z v1.0.1 h1:WyIDpEpAIx4Hel6q/Pcgj/VhaQV5XPJ2I6ryIYbjnpc= modernc.org/z v1.0.1/go.mod h1:8/SRk5C/HgiQWCgXdfpb+1RvhORdkz5sw72d3jjtyqA= +nhooyr.io/websocket v1.8.6/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= +nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g= +nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= rsc.io/binaryregexp v0.2.0/go.mod h1:qTv7/COck+e2FymRvadv62gMdZztPaShugOCi3I+8D8= rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/internal/gqlclient/gqlclient.go b/internal/gqlclient/gqlclient.go index 08fada7b0f..9fb4502844 100644 --- a/internal/gqlclient/gqlclient.go +++ b/internal/gqlclient/gqlclient.go @@ -2,14 +2,19 @@ package gqlclient import ( "context" + "encoding/json" "fmt" + "net/http" "os" "time" + "github.com/ActiveState/cli/internal/errs" "github.com/ActiveState/cli/internal/profile" "github.com/ActiveState/cli/internal/strutils" "github.com/machinebox/graphql" + hsgraphql "github.com/hasura/go-graphql-client" + "github.com/ActiveState/cli/internal/logging" "github.com/ActiveState/cli/internal/machineid" "github.com/ActiveState/cli/internal/retryhttp" @@ -31,26 +36,48 @@ type BearerTokenProvider interface { } type Client struct { - *graphqlClient - tokenProvider BearerTokenProvider - timeout time.Duration + baseUrl string + queryClient *graphqlClient + subscriptionClient *hsgraphql.SubscriptionClient + tokenProvider BearerTokenProvider + timeout time.Duration } -func NewWithOpts(url string, timeout time.Duration, opts ...graphql.ClientOption) *Client { +type ClientOption func(*Client) + +func NewWithOpts(baseUrl string, timeout time.Duration, opts ...ClientOption) *Client { if timeout == 0 { timeout = time.Second * 60 } client := &Client{ - graphqlClient: graphql.NewClient(url, opts...), - timeout: timeout, + baseUrl: baseUrl, + queryClient: graphql.NewClient(baseUrl), + timeout: timeout, } - client.graphqlClient.Log = func(s string) { logging.Debug("graphqlClient log message: %s", s) } + + for _, opt := range opts { + opt(client) + } + client.queryClient.Log = func(s string) { logging.Debug("graphqlClient log message: %s", s) } + return client } func New(url string, timeout time.Duration) *Client { - return NewWithOpts(url, timeout, graphql.WithHTTPClient(retryhttp.DefaultClient.StandardClient())) + return NewWithOpts(url, timeout, WithHTTPClient(retryhttp.DefaultClient.StandardClient())) +} + +func WithHTTPClient(httpclient *http.Client) ClientOption { + return func(c *Client) { + c.queryClient = graphql.NewClient(c.baseUrl, graphql.WithHTTPClient(httpclient)) + } +} + +func WithSubscriptions(subUrl string) ClientOption { + return func(c *Client) { + c.subscriptionClient = hsgraphql.NewSubscriptionClient(subUrl) + } } func (c *Client) SetTokenProvider(tokenProvider BearerTokenProvider) { @@ -58,15 +85,15 @@ func (c *Client) SetTokenProvider(tokenProvider BearerTokenProvider) { } func (c *Client) SetDebug(b bool) { - c.graphqlClient.Log = func(string) {} + c.queryClient.Log = func(string) {} if b { - c.graphqlClient.Log = func(s string) { + c.queryClient.Log = func(s string) { fmt.Fprintln(os.Stderr, s) } } } -func (c *Client) Run(request Request, response interface{}) error { +func (c *Client) RunQuery(request Request, response interface{}) error { ctx := context.Background() if c.timeout != 0 { var cancel context.CancelFunc @@ -95,5 +122,29 @@ func (c *Client) RunWithContext(ctx context.Context, request Request, response i graphRequest.Header.Set("X-Requestor", machineid.UniqID()) - return c.graphqlClient.Run(ctx, graphRequest, &response) + return c.queryClient.Run(ctx, graphRequest, &response) +} + +func (c *Client) RunSubscription(ctx context.Context, response interface{}) (chan interface{}, error) { + result := make(chan interface{}) + _, err := c.subscriptionClient.Subscribe(response, nil, func(message *json.RawMessage, err error) error { + if err != nil { + return nil + } + + err = json.Unmarshal(*message, response) + result <- response + return nil + }) + if err != nil { + return nil, errs.Wrap(err, "Could not subscribe") + } + + go c.subscriptionClient.Run() + + return result, nil +} + +func (c *Client) CloseSubscriptions() error { + return c.subscriptionClient.Close() } diff --git a/internal/graph/response.go b/internal/graph/response.go index c07b1450e6..f80ad5da4f 100644 --- a/internal/graph/response.go +++ b/internal/graph/response.go @@ -15,3 +15,7 @@ type UpdateResponse struct { type AvailableUpdateResponse struct { AvailableUpdate AvailableUpdate `json:"availableUpdate"` } + +type QuitResponse struct { + Quit bool +} diff --git a/pkg/platform/api/svc/svc.go b/pkg/platform/api/svc/svc.go index 7f93573b58..94c2537ee3 100644 --- a/pkg/platform/api/svc/svc.go +++ b/pkg/platform/api/svc/svc.go @@ -7,7 +7,6 @@ import ( "github.com/ActiveState/cli/internal/constants" "github.com/ActiveState/cli/internal/gqlclient" "github.com/ActiveState/cli/internal/locale" - "github.com/machinebox/graphql" ) type Client struct { @@ -27,9 +26,10 @@ func New(cfg configurable) (*Client, error) { } baseUrl := fmt.Sprintf("http://127.0.0.1:%d", port) + subUrl := fmt.Sprintf("ws://127.0.0.1:%d/subscriptions", port) return &Client{ // The custom client bypasses http-retry, which we don't need for doing local requests - Client: gqlclient.NewWithOpts(fmt.Sprintf("%s/query", baseUrl), 0, graphql.WithHTTPClient(&http.Client{})), + Client: gqlclient.NewWithOpts(fmt.Sprintf("%s/query", baseUrl), 0, gqlclient.WithHTTPClient(&http.Client{}), gqlclient.WithSubscriptions(subUrl)), baseUrl: baseUrl, }, nil } diff --git a/pkg/platform/model/checkpoints.go b/pkg/platform/model/checkpoints.go index 57b07b27c4..f11404be1b 100644 --- a/pkg/platform/model/checkpoints.go +++ b/pkg/platform/model/checkpoints.go @@ -88,7 +88,7 @@ func FetchCheckpointForCommit(commitID strfmt.UUID) ([]*gqlModel.Requirement, st gql := graphql.New() response := gqlModel.Checkpoint{} - err := gql.Run(request, &response) + err := gql.RunQuery(request, &response) if err != nil { return nil, strfmt.DateTime{}, errs.Wrap(err, "gql.Run failed") } diff --git a/pkg/platform/model/cve.go b/pkg/platform/model/cve.go index cf37ae383a..94efe817e5 100644 --- a/pkg/platform/model/cve.go +++ b/pkg/platform/model/cve.go @@ -29,7 +29,7 @@ func FetchProjectVulnerabilities(auth *authentication.Auth, org, project string) req := request.VulnerabilitiesByProject(org, project) var resp model.ProjectResponse med := mediator.New(auth) - err := med.Run(req, &resp) + err := med.RunQuery(req, &resp) if err != nil { return nil, errs.Wrap(err, "Failed to run mediator request.") } @@ -54,7 +54,7 @@ func FetchCommitVulnerabilities(auth *authentication.Auth, commitID string) (*mo req := request.VulnerabilitiesByCommit(commitID) var resp model.CommitResponse med := mediator.New(auth) - err := med.Run(req, &resp) + err := med.RunQuery(req, &resp) if err != nil { return nil, errs.Wrap(err, "Failed to run mediator request.") } diff --git a/pkg/platform/model/organizations.go b/pkg/platform/model/organizations.go index 224c2bd480..2e87628a7a 100644 --- a/pkg/platform/model/organizations.go +++ b/pkg/platform/model/organizations.go @@ -111,7 +111,7 @@ func FetchOrganizationsByIDs(ids []strfmt.UUID) ([]model.Organization, error) { gql := graphql.New() response := model.Organizations{} - err := gql.Run(request, &response) + err := gql.RunQuery(request, &response) if err != nil { return nil, errs.Wrap(err, "gql.Run failed") } diff --git a/pkg/platform/model/projects.go b/pkg/platform/model/projects.go index a384faf9ee..3633d30af9 100644 --- a/pkg/platform/model/projects.go +++ b/pkg/platform/model/projects.go @@ -33,7 +33,7 @@ func FetchProjectByName(orgName string, projectName string) (*mono_models.Projec gql := graphql.New() response := model.Projects{} - err := gql.Run(request, &response) + err := gql.RunQuery(request, &response) if err != nil { return nil, errs.Wrap(err, "GraphQL request failed") } diff --git a/pkg/platform/model/supportedlangs.go b/pkg/platform/model/supportedlangs.go index 3241f133dc..9b97a5d99b 100644 --- a/pkg/platform/model/supportedlangs.go +++ b/pkg/platform/model/supportedlangs.go @@ -13,7 +13,7 @@ func FetchSupportedLanguages(osPlatformName string) ([]model.SupportedLanguage, req := request.SupportedLanguages(kernelName) var resp model.SupportedLanguagesResponse med := mediator.New(nil) - err := med.Run(req, &resp) + err := med.RunQuery(req, &resp) if err != nil { return nil, errs.Wrap(err, "Failed to run mediator request.") } diff --git a/pkg/platform/model/svc.go b/pkg/platform/model/svc.go index b4cc2b1cd3..7a447658df 100644 --- a/pkg/platform/model/svc.go +++ b/pkg/platform/model/svc.go @@ -85,6 +85,11 @@ func (m *SvcModel) CheckUpdate(ctx context.Context) (*graph.AvailableUpdate, err return &u.AvailableUpdate, nil } +func (m *SvcModel) Quit(ctx context.Context) (chan interface{}, error) { + response := graph.QuitResponse{} + return m.client.RunSubscription(ctx, &response) +} + func (m *SvcModel) StopServer() error { htClient := retryhttp.DefaultClient.StandardClient() @@ -117,3 +122,7 @@ func (m *SvcModel) Ping() error { _, err := m.StateVersion(context.Background()) return err } + +func (m *SvcModel) CloseSubscriptions() error { + return m.client.CloseSubscriptions() +} diff --git a/vendor/github.com/golang/snappy/AUTHORS b/vendor/github.com/golang/snappy/AUTHORS index bcfa19520a..203e84ebab 100644 --- a/vendor/github.com/golang/snappy/AUTHORS +++ b/vendor/github.com/golang/snappy/AUTHORS @@ -8,8 +8,10 @@ # Please keep the list sorted. +Amazon.com, Inc Damian Gryski Google Inc. Jan Mercl <0xjnml@gmail.com> +Klaus Post Rodolfo Carvalho Sebastien Binet diff --git a/vendor/github.com/golang/snappy/CONTRIBUTORS b/vendor/github.com/golang/snappy/CONTRIBUTORS index 931ae31606..d9914732b5 100644 --- a/vendor/github.com/golang/snappy/CONTRIBUTORS +++ b/vendor/github.com/golang/snappy/CONTRIBUTORS @@ -28,7 +28,9 @@ Damian Gryski Jan Mercl <0xjnml@gmail.com> +Jonathan Swinney Kai Backman +Klaus Post Marc-Antoine Ruel Nigel Tao Rob Pike diff --git a/vendor/github.com/golang/snappy/decode.go b/vendor/github.com/golang/snappy/decode.go index 72efb0353d..f1e04b172c 100644 --- a/vendor/github.com/golang/snappy/decode.go +++ b/vendor/github.com/golang/snappy/decode.go @@ -52,6 +52,8 @@ const ( // Otherwise, a newly allocated slice will be returned. // // The dst and src must not overlap. It is valid to pass a nil dst. +// +// Decode handles the Snappy block format, not the Snappy stream format. func Decode(dst, src []byte) ([]byte, error) { dLen, s, err := decodedLen(src) if err != nil { @@ -83,6 +85,8 @@ func NewReader(r io.Reader) *Reader { } // Reader is an io.Reader that can read Snappy-compressed bytes. +// +// Reader handles the Snappy stream format, not the Snappy block format. type Reader struct { r io.Reader err error diff --git a/vendor/github.com/golang/snappy/decode_arm64.s b/vendor/github.com/golang/snappy/decode_arm64.s new file mode 100644 index 0000000000..7a3ead17ea --- /dev/null +++ b/vendor/github.com/golang/snappy/decode_arm64.s @@ -0,0 +1,494 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine +// +build gc +// +build !noasm + +#include "textflag.h" + +// The asm code generally follows the pure Go code in decode_other.go, except +// where marked with a "!!!". + +// func decode(dst, src []byte) int +// +// All local variables fit into registers. The non-zero stack size is only to +// spill registers and push args when issuing a CALL. The register allocation: +// - R2 scratch +// - R3 scratch +// - R4 length or x +// - R5 offset +// - R6 &src[s] +// - R7 &dst[d] +// + R8 dst_base +// + R9 dst_len +// + R10 dst_base + dst_len +// + R11 src_base +// + R12 src_len +// + R13 src_base + src_len +// - R14 used by doCopy +// - R15 used by doCopy +// +// The registers R8-R13 (marked with a "+") are set at the start of the +// function, and after a CALL returns, and are not otherwise modified. +// +// The d variable is implicitly R7 - R8, and len(dst)-d is R10 - R7. +// The s variable is implicitly R6 - R11, and len(src)-s is R13 - R6. +TEXT ·decode(SB), NOSPLIT, $56-56 + // Initialize R6, R7 and R8-R13. + MOVD dst_base+0(FP), R8 + MOVD dst_len+8(FP), R9 + MOVD R8, R7 + MOVD R8, R10 + ADD R9, R10, R10 + MOVD src_base+24(FP), R11 + MOVD src_len+32(FP), R12 + MOVD R11, R6 + MOVD R11, R13 + ADD R12, R13, R13 + +loop: + // for s < len(src) + CMP R13, R6 + BEQ end + + // R4 = uint32(src[s]) + // + // switch src[s] & 0x03 + MOVBU (R6), R4 + MOVW R4, R3 + ANDW $3, R3 + MOVW $1, R1 + CMPW R1, R3 + BGE tagCopy + + // ---------------------------------------- + // The code below handles literal tags. + + // case tagLiteral: + // x := uint32(src[s] >> 2) + // switch + MOVW $60, R1 + LSRW $2, R4, R4 + CMPW R4, R1 + BLS tagLit60Plus + + // case x < 60: + // s++ + ADD $1, R6, R6 + +doLit: + // This is the end of the inner "switch", when we have a literal tag. + // + // We assume that R4 == x and x fits in a uint32, where x is the variable + // used in the pure Go decode_other.go code. + + // length = int(x) + 1 + // + // Unlike the pure Go code, we don't need to check if length <= 0 because + // R4 can hold 64 bits, so the increment cannot overflow. + ADD $1, R4, R4 + + // Prepare to check if copying length bytes will run past the end of dst or + // src. + // + // R2 = len(dst) - d + // R3 = len(src) - s + MOVD R10, R2 + SUB R7, R2, R2 + MOVD R13, R3 + SUB R6, R3, R3 + + // !!! Try a faster technique for short (16 or fewer bytes) copies. + // + // if length > 16 || len(dst)-d < 16 || len(src)-s < 16 { + // goto callMemmove // Fall back on calling runtime·memmove. + // } + // + // The C++ snappy code calls this TryFastAppend. It also checks len(src)-s + // against 21 instead of 16, because it cannot assume that all of its input + // is contiguous in memory and so it needs to leave enough source bytes to + // read the next tag without refilling buffers, but Go's Decode assumes + // contiguousness (the src argument is a []byte). + CMP $16, R4 + BGT callMemmove + CMP $16, R2 + BLT callMemmove + CMP $16, R3 + BLT callMemmove + + // !!! Implement the copy from src to dst as a 16-byte load and store. + // (Decode's documentation says that dst and src must not overlap.) + // + // This always copies 16 bytes, instead of only length bytes, but that's + // OK. If the input is a valid Snappy encoding then subsequent iterations + // will fix up the overrun. Otherwise, Decode returns a nil []byte (and a + // non-nil error), so the overrun will be ignored. + // + // Note that on arm64, it is legal and cheap to issue unaligned 8-byte or + // 16-byte loads and stores. This technique probably wouldn't be as + // effective on architectures that are fussier about alignment. + LDP 0(R6), (R14, R15) + STP (R14, R15), 0(R7) + + // d += length + // s += length + ADD R4, R7, R7 + ADD R4, R6, R6 + B loop + +callMemmove: + // if length > len(dst)-d || length > len(src)-s { etc } + CMP R2, R4 + BGT errCorrupt + CMP R3, R4 + BGT errCorrupt + + // copy(dst[d:], src[s:s+length]) + // + // This means calling runtime·memmove(&dst[d], &src[s], length), so we push + // R7, R6 and R4 as arguments. Coincidentally, we also need to spill those + // three registers to the stack, to save local variables across the CALL. + MOVD R7, 8(RSP) + MOVD R6, 16(RSP) + MOVD R4, 24(RSP) + MOVD R7, 32(RSP) + MOVD R6, 40(RSP) + MOVD R4, 48(RSP) + CALL runtime·memmove(SB) + + // Restore local variables: unspill registers from the stack and + // re-calculate R8-R13. + MOVD 32(RSP), R7 + MOVD 40(RSP), R6 + MOVD 48(RSP), R4 + MOVD dst_base+0(FP), R8 + MOVD dst_len+8(FP), R9 + MOVD R8, R10 + ADD R9, R10, R10 + MOVD src_base+24(FP), R11 + MOVD src_len+32(FP), R12 + MOVD R11, R13 + ADD R12, R13, R13 + + // d += length + // s += length + ADD R4, R7, R7 + ADD R4, R6, R6 + B loop + +tagLit60Plus: + // !!! This fragment does the + // + // s += x - 58; if uint(s) > uint(len(src)) { etc } + // + // checks. In the asm version, we code it once instead of once per switch case. + ADD R4, R6, R6 + SUB $58, R6, R6 + MOVD R6, R3 + SUB R11, R3, R3 + CMP R12, R3 + BGT errCorrupt + + // case x == 60: + MOVW $61, R1 + CMPW R1, R4 + BEQ tagLit61 + BGT tagLit62Plus + + // x = uint32(src[s-1]) + MOVBU -1(R6), R4 + B doLit + +tagLit61: + // case x == 61: + // x = uint32(src[s-2]) | uint32(src[s-1])<<8 + MOVHU -2(R6), R4 + B doLit + +tagLit62Plus: + CMPW $62, R4 + BHI tagLit63 + + // case x == 62: + // x = uint32(src[s-3]) | uint32(src[s-2])<<8 | uint32(src[s-1])<<16 + MOVHU -3(R6), R4 + MOVBU -1(R6), R3 + ORR R3<<16, R4 + B doLit + +tagLit63: + // case x == 63: + // x = uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24 + MOVWU -4(R6), R4 + B doLit + + // The code above handles literal tags. + // ---------------------------------------- + // The code below handles copy tags. + +tagCopy4: + // case tagCopy4: + // s += 5 + ADD $5, R6, R6 + + // if uint(s) > uint(len(src)) { etc } + MOVD R6, R3 + SUB R11, R3, R3 + CMP R12, R3 + BGT errCorrupt + + // length = 1 + int(src[s-5])>>2 + MOVD $1, R1 + ADD R4>>2, R1, R4 + + // offset = int(uint32(src[s-4]) | uint32(src[s-3])<<8 | uint32(src[s-2])<<16 | uint32(src[s-1])<<24) + MOVWU -4(R6), R5 + B doCopy + +tagCopy2: + // case tagCopy2: + // s += 3 + ADD $3, R6, R6 + + // if uint(s) > uint(len(src)) { etc } + MOVD R6, R3 + SUB R11, R3, R3 + CMP R12, R3 + BGT errCorrupt + + // length = 1 + int(src[s-3])>>2 + MOVD $1, R1 + ADD R4>>2, R1, R4 + + // offset = int(uint32(src[s-2]) | uint32(src[s-1])<<8) + MOVHU -2(R6), R5 + B doCopy + +tagCopy: + // We have a copy tag. We assume that: + // - R3 == src[s] & 0x03 + // - R4 == src[s] + CMP $2, R3 + BEQ tagCopy2 + BGT tagCopy4 + + // case tagCopy1: + // s += 2 + ADD $2, R6, R6 + + // if uint(s) > uint(len(src)) { etc } + MOVD R6, R3 + SUB R11, R3, R3 + CMP R12, R3 + BGT errCorrupt + + // offset = int(uint32(src[s-2])&0xe0<<3 | uint32(src[s-1])) + MOVD R4, R5 + AND $0xe0, R5 + MOVBU -1(R6), R3 + ORR R5<<3, R3, R5 + + // length = 4 + int(src[s-2])>>2&0x7 + MOVD $7, R1 + AND R4>>2, R1, R4 + ADD $4, R4, R4 + +doCopy: + // This is the end of the outer "switch", when we have a copy tag. + // + // We assume that: + // - R4 == length && R4 > 0 + // - R5 == offset + + // if offset <= 0 { etc } + MOVD $0, R1 + CMP R1, R5 + BLE errCorrupt + + // if d < offset { etc } + MOVD R7, R3 + SUB R8, R3, R3 + CMP R5, R3 + BLT errCorrupt + + // if length > len(dst)-d { etc } + MOVD R10, R3 + SUB R7, R3, R3 + CMP R3, R4 + BGT errCorrupt + + // forwardCopy(dst[d:d+length], dst[d-offset:]); d += length + // + // Set: + // - R14 = len(dst)-d + // - R15 = &dst[d-offset] + MOVD R10, R14 + SUB R7, R14, R14 + MOVD R7, R15 + SUB R5, R15, R15 + + // !!! Try a faster technique for short (16 or fewer bytes) forward copies. + // + // First, try using two 8-byte load/stores, similar to the doLit technique + // above. Even if dst[d:d+length] and dst[d-offset:] can overlap, this is + // still OK if offset >= 8. Note that this has to be two 8-byte load/stores + // and not one 16-byte load/store, and the first store has to be before the + // second load, due to the overlap if offset is in the range [8, 16). + // + // if length > 16 || offset < 8 || len(dst)-d < 16 { + // goto slowForwardCopy + // } + // copy 16 bytes + // d += length + CMP $16, R4 + BGT slowForwardCopy + CMP $8, R5 + BLT slowForwardCopy + CMP $16, R14 + BLT slowForwardCopy + MOVD 0(R15), R2 + MOVD R2, 0(R7) + MOVD 8(R15), R3 + MOVD R3, 8(R7) + ADD R4, R7, R7 + B loop + +slowForwardCopy: + // !!! If the forward copy is longer than 16 bytes, or if offset < 8, we + // can still try 8-byte load stores, provided we can overrun up to 10 extra + // bytes. As above, the overrun will be fixed up by subsequent iterations + // of the outermost loop. + // + // The C++ snappy code calls this technique IncrementalCopyFastPath. Its + // commentary says: + // + // ---- + // + // The main part of this loop is a simple copy of eight bytes at a time + // until we've copied (at least) the requested amount of bytes. However, + // if d and d-offset are less than eight bytes apart (indicating a + // repeating pattern of length < 8), we first need to expand the pattern in + // order to get the correct results. For instance, if the buffer looks like + // this, with the eight-byte and patterns marked as + // intervals: + // + // abxxxxxxxxxxxx + // [------] d-offset + // [------] d + // + // a single eight-byte copy from to will repeat the pattern + // once, after which we can move two bytes without moving : + // + // ababxxxxxxxxxx + // [------] d-offset + // [------] d + // + // and repeat the exercise until the two no longer overlap. + // + // This allows us to do very well in the special case of one single byte + // repeated many times, without taking a big hit for more general cases. + // + // The worst case of extra writing past the end of the match occurs when + // offset == 1 and length == 1; the last copy will read from byte positions + // [0..7] and write to [4..11], whereas it was only supposed to write to + // position 1. Thus, ten excess bytes. + // + // ---- + // + // That "10 byte overrun" worst case is confirmed by Go's + // TestSlowForwardCopyOverrun, which also tests the fixUpSlowForwardCopy + // and finishSlowForwardCopy algorithm. + // + // if length > len(dst)-d-10 { + // goto verySlowForwardCopy + // } + SUB $10, R14, R14 + CMP R14, R4 + BGT verySlowForwardCopy + +makeOffsetAtLeast8: + // !!! As above, expand the pattern so that offset >= 8 and we can use + // 8-byte load/stores. + // + // for offset < 8 { + // copy 8 bytes from dst[d-offset:] to dst[d:] + // length -= offset + // d += offset + // offset += offset + // // The two previous lines together means that d-offset, and therefore + // // R15, is unchanged. + // } + CMP $8, R5 + BGE fixUpSlowForwardCopy + MOVD (R15), R3 + MOVD R3, (R7) + SUB R5, R4, R4 + ADD R5, R7, R7 + ADD R5, R5, R5 + B makeOffsetAtLeast8 + +fixUpSlowForwardCopy: + // !!! Add length (which might be negative now) to d (implied by R7 being + // &dst[d]) so that d ends up at the right place when we jump back to the + // top of the loop. Before we do that, though, we save R7 to R2 so that, if + // length is positive, copying the remaining length bytes will write to the + // right place. + MOVD R7, R2 + ADD R4, R7, R7 + +finishSlowForwardCopy: + // !!! Repeat 8-byte load/stores until length <= 0. Ending with a negative + // length means that we overrun, but as above, that will be fixed up by + // subsequent iterations of the outermost loop. + MOVD $0, R1 + CMP R1, R4 + BLE loop + MOVD (R15), R3 + MOVD R3, (R2) + ADD $8, R15, R15 + ADD $8, R2, R2 + SUB $8, R4, R4 + B finishSlowForwardCopy + +verySlowForwardCopy: + // verySlowForwardCopy is a simple implementation of forward copy. In C + // parlance, this is a do/while loop instead of a while loop, since we know + // that length > 0. In Go syntax: + // + // for { + // dst[d] = dst[d - offset] + // d++ + // length-- + // if length == 0 { + // break + // } + // } + MOVB (R15), R3 + MOVB R3, (R7) + ADD $1, R15, R15 + ADD $1, R7, R7 + SUB $1, R4, R4 + CBNZ R4, verySlowForwardCopy + B loop + + // The code above handles copy tags. + // ---------------------------------------- + +end: + // This is the end of the "for s < len(src)". + // + // if d != len(dst) { etc } + CMP R10, R7 + BNE errCorrupt + + // return 0 + MOVD $0, ret+48(FP) + RET + +errCorrupt: + // return decodeErrCodeCorrupt + MOVD $1, R2 + MOVD R2, ret+48(FP) + RET diff --git a/vendor/github.com/golang/snappy/decode_amd64.go b/vendor/github.com/golang/snappy/decode_asm.go similarity index 93% rename from vendor/github.com/golang/snappy/decode_amd64.go rename to vendor/github.com/golang/snappy/decode_asm.go index fcd192b849..7082b34919 100644 --- a/vendor/github.com/golang/snappy/decode_amd64.go +++ b/vendor/github.com/golang/snappy/decode_asm.go @@ -5,6 +5,7 @@ // +build !appengine // +build gc // +build !noasm +// +build amd64 arm64 package snappy diff --git a/vendor/github.com/golang/snappy/decode_other.go b/vendor/github.com/golang/snappy/decode_other.go index 8c9f2049bc..2f672be557 100644 --- a/vendor/github.com/golang/snappy/decode_other.go +++ b/vendor/github.com/golang/snappy/decode_other.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !amd64 appengine !gc noasm +// +build !amd64,!arm64 appengine !gc noasm package snappy @@ -85,14 +85,28 @@ func decode(dst, src []byte) int { if offset <= 0 || d < offset || length > len(dst)-d { return decodeErrCodeCorrupt } - // Copy from an earlier sub-slice of dst to a later sub-slice. Unlike - // the built-in copy function, this byte-by-byte copy always runs + // Copy from an earlier sub-slice of dst to a later sub-slice. + // If no overlap, use the built-in copy: + if offset >= length { + copy(dst[d:d+length], dst[d-offset:]) + d += length + continue + } + + // Unlike the built-in copy function, this byte-by-byte copy always runs // forwards, even if the slices overlap. Conceptually, this is: // // d += forwardCopy(dst[d:d+length], dst[d-offset:]) - for end := d + length; d != end; d++ { - dst[d] = dst[d-offset] + // + // We align the slices into a and b and show the compiler they are the same size. + // This allows the loop to run without bounds checks. + a := dst[d : d+length] + b := dst[d-offset:] + b = b[:len(a)] + for i := range a { + a[i] = b[i] } + d += length } if d != len(dst) { return decodeErrCodeCorrupt diff --git a/vendor/github.com/golang/snappy/encode.go b/vendor/github.com/golang/snappy/encode.go index 8d393e904b..7f23657076 100644 --- a/vendor/github.com/golang/snappy/encode.go +++ b/vendor/github.com/golang/snappy/encode.go @@ -15,6 +15,8 @@ import ( // Otherwise, a newly allocated slice will be returned. // // The dst and src must not overlap. It is valid to pass a nil dst. +// +// Encode handles the Snappy block format, not the Snappy stream format. func Encode(dst, src []byte) []byte { if n := MaxEncodedLen(len(src)); n < 0 { panic(ErrTooLarge) @@ -139,6 +141,8 @@ func NewBufferedWriter(w io.Writer) *Writer { } // Writer is an io.Writer that can write Snappy-compressed bytes. +// +// Writer handles the Snappy stream format, not the Snappy block format. type Writer struct { w io.Writer err error diff --git a/vendor/github.com/golang/snappy/encode_arm64.s b/vendor/github.com/golang/snappy/encode_arm64.s new file mode 100644 index 0000000000..bf83667d71 --- /dev/null +++ b/vendor/github.com/golang/snappy/encode_arm64.s @@ -0,0 +1,722 @@ +// Copyright 2020 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !appengine +// +build gc +// +build !noasm + +#include "textflag.h" + +// The asm code generally follows the pure Go code in encode_other.go, except +// where marked with a "!!!". + +// ---------------------------------------------------------------------------- + +// func emitLiteral(dst, lit []byte) int +// +// All local variables fit into registers. The register allocation: +// - R3 len(lit) +// - R4 n +// - R6 return value +// - R8 &dst[i] +// - R10 &lit[0] +// +// The 32 bytes of stack space is to call runtime·memmove. +// +// The unusual register allocation of local variables, such as R10 for the +// source pointer, matches the allocation used at the call site in encodeBlock, +// which makes it easier to manually inline this function. +TEXT ·emitLiteral(SB), NOSPLIT, $32-56 + MOVD dst_base+0(FP), R8 + MOVD lit_base+24(FP), R10 + MOVD lit_len+32(FP), R3 + MOVD R3, R6 + MOVW R3, R4 + SUBW $1, R4, R4 + + CMPW $60, R4 + BLT oneByte + CMPW $256, R4 + BLT twoBytes + +threeBytes: + MOVD $0xf4, R2 + MOVB R2, 0(R8) + MOVW R4, 1(R8) + ADD $3, R8, R8 + ADD $3, R6, R6 + B memmove + +twoBytes: + MOVD $0xf0, R2 + MOVB R2, 0(R8) + MOVB R4, 1(R8) + ADD $2, R8, R8 + ADD $2, R6, R6 + B memmove + +oneByte: + LSLW $2, R4, R4 + MOVB R4, 0(R8) + ADD $1, R8, R8 + ADD $1, R6, R6 + +memmove: + MOVD R6, ret+48(FP) + + // copy(dst[i:], lit) + // + // This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push + // R8, R10 and R3 as arguments. + MOVD R8, 8(RSP) + MOVD R10, 16(RSP) + MOVD R3, 24(RSP) + CALL runtime·memmove(SB) + RET + +// ---------------------------------------------------------------------------- + +// func emitCopy(dst []byte, offset, length int) int +// +// All local variables fit into registers. The register allocation: +// - R3 length +// - R7 &dst[0] +// - R8 &dst[i] +// - R11 offset +// +// The unusual register allocation of local variables, such as R11 for the +// offset, matches the allocation used at the call site in encodeBlock, which +// makes it easier to manually inline this function. +TEXT ·emitCopy(SB), NOSPLIT, $0-48 + MOVD dst_base+0(FP), R8 + MOVD R8, R7 + MOVD offset+24(FP), R11 + MOVD length+32(FP), R3 + +loop0: + // for length >= 68 { etc } + CMPW $68, R3 + BLT step1 + + // Emit a length 64 copy, encoded as 3 bytes. + MOVD $0xfe, R2 + MOVB R2, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + SUB $64, R3, R3 + B loop0 + +step1: + // if length > 64 { etc } + CMP $64, R3 + BLE step2 + + // Emit a length 60 copy, encoded as 3 bytes. + MOVD $0xee, R2 + MOVB R2, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + SUB $60, R3, R3 + +step2: + // if length >= 12 || offset >= 2048 { goto step3 } + CMP $12, R3 + BGE step3 + CMPW $2048, R11 + BGE step3 + + // Emit the remaining copy, encoded as 2 bytes. + MOVB R11, 1(R8) + LSRW $3, R11, R11 + AND $0xe0, R11, R11 + SUB $4, R3, R3 + LSLW $2, R3 + AND $0xff, R3, R3 + ORRW R3, R11, R11 + ORRW $1, R11, R11 + MOVB R11, 0(R8) + ADD $2, R8, R8 + + // Return the number of bytes written. + SUB R7, R8, R8 + MOVD R8, ret+40(FP) + RET + +step3: + // Emit the remaining copy, encoded as 3 bytes. + SUB $1, R3, R3 + AND $0xff, R3, R3 + LSLW $2, R3, R3 + ORRW $2, R3, R3 + MOVB R3, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + + // Return the number of bytes written. + SUB R7, R8, R8 + MOVD R8, ret+40(FP) + RET + +// ---------------------------------------------------------------------------- + +// func extendMatch(src []byte, i, j int) int +// +// All local variables fit into registers. The register allocation: +// - R6 &src[0] +// - R7 &src[j] +// - R13 &src[len(src) - 8] +// - R14 &src[len(src)] +// - R15 &src[i] +// +// The unusual register allocation of local variables, such as R15 for a source +// pointer, matches the allocation used at the call site in encodeBlock, which +// makes it easier to manually inline this function. +TEXT ·extendMatch(SB), NOSPLIT, $0-48 + MOVD src_base+0(FP), R6 + MOVD src_len+8(FP), R14 + MOVD i+24(FP), R15 + MOVD j+32(FP), R7 + ADD R6, R14, R14 + ADD R6, R15, R15 + ADD R6, R7, R7 + MOVD R14, R13 + SUB $8, R13, R13 + +cmp8: + // As long as we are 8 or more bytes before the end of src, we can load and + // compare 8 bytes at a time. If those 8 bytes are equal, repeat. + CMP R13, R7 + BHI cmp1 + MOVD (R15), R3 + MOVD (R7), R4 + CMP R4, R3 + BNE bsf + ADD $8, R15, R15 + ADD $8, R7, R7 + B cmp8 + +bsf: + // If those 8 bytes were not equal, XOR the two 8 byte values, and return + // the index of the first byte that differs. + // RBIT reverses the bit order, then CLZ counts the leading zeros, the + // combination of which finds the least significant bit which is set. + // The arm64 architecture is little-endian, and the shift by 3 converts + // a bit index to a byte index. + EOR R3, R4, R4 + RBIT R4, R4 + CLZ R4, R4 + ADD R4>>3, R7, R7 + + // Convert from &src[ret] to ret. + SUB R6, R7, R7 + MOVD R7, ret+40(FP) + RET + +cmp1: + // In src's tail, compare 1 byte at a time. + CMP R7, R14 + BLS extendMatchEnd + MOVB (R15), R3 + MOVB (R7), R4 + CMP R4, R3 + BNE extendMatchEnd + ADD $1, R15, R15 + ADD $1, R7, R7 + B cmp1 + +extendMatchEnd: + // Convert from &src[ret] to ret. + SUB R6, R7, R7 + MOVD R7, ret+40(FP) + RET + +// ---------------------------------------------------------------------------- + +// func encodeBlock(dst, src []byte) (d int) +// +// All local variables fit into registers, other than "var table". The register +// allocation: +// - R3 . . +// - R4 . . +// - R5 64 shift +// - R6 72 &src[0], tableSize +// - R7 80 &src[s] +// - R8 88 &dst[d] +// - R9 96 sLimit +// - R10 . &src[nextEmit] +// - R11 104 prevHash, currHash, nextHash, offset +// - R12 112 &src[base], skip +// - R13 . &src[nextS], &src[len(src) - 8] +// - R14 . len(src), bytesBetweenHashLookups, &src[len(src)], x +// - R15 120 candidate +// - R16 . hash constant, 0x1e35a7bd +// - R17 . &table +// - . 128 table +// +// The second column (64, 72, etc) is the stack offset to spill the registers +// when calling other functions. We could pack this slightly tighter, but it's +// simpler to have a dedicated spill map independent of the function called. +// +// "var table [maxTableSize]uint16" takes up 32768 bytes of stack space. An +// extra 64 bytes, to call other functions, and an extra 64 bytes, to spill +// local variables (registers) during calls gives 32768 + 64 + 64 = 32896. +TEXT ·encodeBlock(SB), 0, $32896-56 + MOVD dst_base+0(FP), R8 + MOVD src_base+24(FP), R7 + MOVD src_len+32(FP), R14 + + // shift, tableSize := uint32(32-8), 1<<8 + MOVD $24, R5 + MOVD $256, R6 + MOVW $0xa7bd, R16 + MOVKW $(0x1e35<<16), R16 + +calcShift: + // for ; tableSize < maxTableSize && tableSize < len(src); tableSize *= 2 { + // shift-- + // } + MOVD $16384, R2 + CMP R2, R6 + BGE varTable + CMP R14, R6 + BGE varTable + SUB $1, R5, R5 + LSL $1, R6, R6 + B calcShift + +varTable: + // var table [maxTableSize]uint16 + // + // In the asm code, unlike the Go code, we can zero-initialize only the + // first tableSize elements. Each uint16 element is 2 bytes and each + // iterations writes 64 bytes, so we can do only tableSize/32 writes + // instead of the 2048 writes that would zero-initialize all of table's + // 32768 bytes. This clear could overrun the first tableSize elements, but + // it won't overrun the allocated stack size. + ADD $128, RSP, R17 + MOVD R17, R4 + + // !!! R6 = &src[tableSize] + ADD R6<<1, R17, R6 + +memclr: + STP.P (ZR, ZR), 64(R4) + STP (ZR, ZR), -48(R4) + STP (ZR, ZR), -32(R4) + STP (ZR, ZR), -16(R4) + CMP R4, R6 + BHI memclr + + // !!! R6 = &src[0] + MOVD R7, R6 + + // sLimit := len(src) - inputMargin + MOVD R14, R9 + SUB $15, R9, R9 + + // !!! Pre-emptively spill R5, R6 and R9 to the stack. Their values don't + // change for the rest of the function. + MOVD R5, 64(RSP) + MOVD R6, 72(RSP) + MOVD R9, 96(RSP) + + // nextEmit := 0 + MOVD R6, R10 + + // s := 1 + ADD $1, R7, R7 + + // nextHash := hash(load32(src, s), shift) + MOVW 0(R7), R11 + MULW R16, R11, R11 + LSRW R5, R11, R11 + +outer: + // for { etc } + + // skip := 32 + MOVD $32, R12 + + // nextS := s + MOVD R7, R13 + + // candidate := 0 + MOVD $0, R15 + +inner0: + // for { etc } + + // s := nextS + MOVD R13, R7 + + // bytesBetweenHashLookups := skip >> 5 + MOVD R12, R14 + LSR $5, R14, R14 + + // nextS = s + bytesBetweenHashLookups + ADD R14, R13, R13 + + // skip += bytesBetweenHashLookups + ADD R14, R12, R12 + + // if nextS > sLimit { goto emitRemainder } + MOVD R13, R3 + SUB R6, R3, R3 + CMP R9, R3 + BHI emitRemainder + + // candidate = int(table[nextHash]) + MOVHU 0(R17)(R11<<1), R15 + + // table[nextHash] = uint16(s) + MOVD R7, R3 + SUB R6, R3, R3 + + MOVH R3, 0(R17)(R11<<1) + + // nextHash = hash(load32(src, nextS), shift) + MOVW 0(R13), R11 + MULW R16, R11 + LSRW R5, R11, R11 + + // if load32(src, s) != load32(src, candidate) { continue } break + MOVW 0(R7), R3 + MOVW (R6)(R15*1), R4 + CMPW R4, R3 + BNE inner0 + +fourByteMatch: + // As per the encode_other.go code: + // + // A 4-byte match has been found. We'll later see etc. + + // !!! Jump to a fast path for short (<= 16 byte) literals. See the comment + // on inputMargin in encode.go. + MOVD R7, R3 + SUB R10, R3, R3 + CMP $16, R3 + BLE emitLiteralFastPath + + // ---------------------------------------- + // Begin inline of the emitLiteral call. + // + // d += emitLiteral(dst[d:], src[nextEmit:s]) + + MOVW R3, R4 + SUBW $1, R4, R4 + + MOVW $60, R2 + CMPW R2, R4 + BLT inlineEmitLiteralOneByte + MOVW $256, R2 + CMPW R2, R4 + BLT inlineEmitLiteralTwoBytes + +inlineEmitLiteralThreeBytes: + MOVD $0xf4, R1 + MOVB R1, 0(R8) + MOVW R4, 1(R8) + ADD $3, R8, R8 + B inlineEmitLiteralMemmove + +inlineEmitLiteralTwoBytes: + MOVD $0xf0, R1 + MOVB R1, 0(R8) + MOVB R4, 1(R8) + ADD $2, R8, R8 + B inlineEmitLiteralMemmove + +inlineEmitLiteralOneByte: + LSLW $2, R4, R4 + MOVB R4, 0(R8) + ADD $1, R8, R8 + +inlineEmitLiteralMemmove: + // Spill local variables (registers) onto the stack; call; unspill. + // + // copy(dst[i:], lit) + // + // This means calling runtime·memmove(&dst[i], &lit[0], len(lit)), so we push + // R8, R10 and R3 as arguments. + MOVD R8, 8(RSP) + MOVD R10, 16(RSP) + MOVD R3, 24(RSP) + + // Finish the "d +=" part of "d += emitLiteral(etc)". + ADD R3, R8, R8 + MOVD R7, 80(RSP) + MOVD R8, 88(RSP) + MOVD R15, 120(RSP) + CALL runtime·memmove(SB) + MOVD 64(RSP), R5 + MOVD 72(RSP), R6 + MOVD 80(RSP), R7 + MOVD 88(RSP), R8 + MOVD 96(RSP), R9 + MOVD 120(RSP), R15 + ADD $128, RSP, R17 + MOVW $0xa7bd, R16 + MOVKW $(0x1e35<<16), R16 + B inner1 + +inlineEmitLiteralEnd: + // End inline of the emitLiteral call. + // ---------------------------------------- + +emitLiteralFastPath: + // !!! Emit the 1-byte encoding "uint8(len(lit)-1)<<2". + MOVB R3, R4 + SUBW $1, R4, R4 + AND $0xff, R4, R4 + LSLW $2, R4, R4 + MOVB R4, (R8) + ADD $1, R8, R8 + + // !!! Implement the copy from lit to dst as a 16-byte load and store. + // (Encode's documentation says that dst and src must not overlap.) + // + // This always copies 16 bytes, instead of only len(lit) bytes, but that's + // OK. Subsequent iterations will fix up the overrun. + // + // Note that on arm64, it is legal and cheap to issue unaligned 8-byte or + // 16-byte loads and stores. This technique probably wouldn't be as + // effective on architectures that are fussier about alignment. + LDP 0(R10), (R0, R1) + STP (R0, R1), 0(R8) + ADD R3, R8, R8 + +inner1: + // for { etc } + + // base := s + MOVD R7, R12 + + // !!! offset := base - candidate + MOVD R12, R11 + SUB R15, R11, R11 + SUB R6, R11, R11 + + // ---------------------------------------- + // Begin inline of the extendMatch call. + // + // s = extendMatch(src, candidate+4, s+4) + + // !!! R14 = &src[len(src)] + MOVD src_len+32(FP), R14 + ADD R6, R14, R14 + + // !!! R13 = &src[len(src) - 8] + MOVD R14, R13 + SUB $8, R13, R13 + + // !!! R15 = &src[candidate + 4] + ADD $4, R15, R15 + ADD R6, R15, R15 + + // !!! s += 4 + ADD $4, R7, R7 + +inlineExtendMatchCmp8: + // As long as we are 8 or more bytes before the end of src, we can load and + // compare 8 bytes at a time. If those 8 bytes are equal, repeat. + CMP R13, R7 + BHI inlineExtendMatchCmp1 + MOVD (R15), R3 + MOVD (R7), R4 + CMP R4, R3 + BNE inlineExtendMatchBSF + ADD $8, R15, R15 + ADD $8, R7, R7 + B inlineExtendMatchCmp8 + +inlineExtendMatchBSF: + // If those 8 bytes were not equal, XOR the two 8 byte values, and return + // the index of the first byte that differs. + // RBIT reverses the bit order, then CLZ counts the leading zeros, the + // combination of which finds the least significant bit which is set. + // The arm64 architecture is little-endian, and the shift by 3 converts + // a bit index to a byte index. + EOR R3, R4, R4 + RBIT R4, R4 + CLZ R4, R4 + ADD R4>>3, R7, R7 + B inlineExtendMatchEnd + +inlineExtendMatchCmp1: + // In src's tail, compare 1 byte at a time. + CMP R7, R14 + BLS inlineExtendMatchEnd + MOVB (R15), R3 + MOVB (R7), R4 + CMP R4, R3 + BNE inlineExtendMatchEnd + ADD $1, R15, R15 + ADD $1, R7, R7 + B inlineExtendMatchCmp1 + +inlineExtendMatchEnd: + // End inline of the extendMatch call. + // ---------------------------------------- + + // ---------------------------------------- + // Begin inline of the emitCopy call. + // + // d += emitCopy(dst[d:], base-candidate, s-base) + + // !!! length := s - base + MOVD R7, R3 + SUB R12, R3, R3 + +inlineEmitCopyLoop0: + // for length >= 68 { etc } + MOVW $68, R2 + CMPW R2, R3 + BLT inlineEmitCopyStep1 + + // Emit a length 64 copy, encoded as 3 bytes. + MOVD $0xfe, R1 + MOVB R1, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + SUBW $64, R3, R3 + B inlineEmitCopyLoop0 + +inlineEmitCopyStep1: + // if length > 64 { etc } + MOVW $64, R2 + CMPW R2, R3 + BLE inlineEmitCopyStep2 + + // Emit a length 60 copy, encoded as 3 bytes. + MOVD $0xee, R1 + MOVB R1, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + SUBW $60, R3, R3 + +inlineEmitCopyStep2: + // if length >= 12 || offset >= 2048 { goto inlineEmitCopyStep3 } + MOVW $12, R2 + CMPW R2, R3 + BGE inlineEmitCopyStep3 + MOVW $2048, R2 + CMPW R2, R11 + BGE inlineEmitCopyStep3 + + // Emit the remaining copy, encoded as 2 bytes. + MOVB R11, 1(R8) + LSRW $8, R11, R11 + LSLW $5, R11, R11 + SUBW $4, R3, R3 + AND $0xff, R3, R3 + LSLW $2, R3, R3 + ORRW R3, R11, R11 + ORRW $1, R11, R11 + MOVB R11, 0(R8) + ADD $2, R8, R8 + B inlineEmitCopyEnd + +inlineEmitCopyStep3: + // Emit the remaining copy, encoded as 3 bytes. + SUBW $1, R3, R3 + LSLW $2, R3, R3 + ORRW $2, R3, R3 + MOVB R3, 0(R8) + MOVW R11, 1(R8) + ADD $3, R8, R8 + +inlineEmitCopyEnd: + // End inline of the emitCopy call. + // ---------------------------------------- + + // nextEmit = s + MOVD R7, R10 + + // if s >= sLimit { goto emitRemainder } + MOVD R7, R3 + SUB R6, R3, R3 + CMP R3, R9 + BLS emitRemainder + + // As per the encode_other.go code: + // + // We could immediately etc. + + // x := load64(src, s-1) + MOVD -1(R7), R14 + + // prevHash := hash(uint32(x>>0), shift) + MOVW R14, R11 + MULW R16, R11, R11 + LSRW R5, R11, R11 + + // table[prevHash] = uint16(s-1) + MOVD R7, R3 + SUB R6, R3, R3 + SUB $1, R3, R3 + + MOVHU R3, 0(R17)(R11<<1) + + // currHash := hash(uint32(x>>8), shift) + LSR $8, R14, R14 + MOVW R14, R11 + MULW R16, R11, R11 + LSRW R5, R11, R11 + + // candidate = int(table[currHash]) + MOVHU 0(R17)(R11<<1), R15 + + // table[currHash] = uint16(s) + ADD $1, R3, R3 + MOVHU R3, 0(R17)(R11<<1) + + // if uint32(x>>8) == load32(src, candidate) { continue } + MOVW (R6)(R15*1), R4 + CMPW R4, R14 + BEQ inner1 + + // nextHash = hash(uint32(x>>16), shift) + LSR $8, R14, R14 + MOVW R14, R11 + MULW R16, R11, R11 + LSRW R5, R11, R11 + + // s++ + ADD $1, R7, R7 + + // break out of the inner1 for loop, i.e. continue the outer loop. + B outer + +emitRemainder: + // if nextEmit < len(src) { etc } + MOVD src_len+32(FP), R3 + ADD R6, R3, R3 + CMP R3, R10 + BEQ encodeBlockEnd + + // d += emitLiteral(dst[d:], src[nextEmit:]) + // + // Push args. + MOVD R8, 8(RSP) + MOVD $0, 16(RSP) // Unnecessary, as the callee ignores it, but conservative. + MOVD $0, 24(RSP) // Unnecessary, as the callee ignores it, but conservative. + MOVD R10, 32(RSP) + SUB R10, R3, R3 + MOVD R3, 40(RSP) + MOVD R3, 48(RSP) // Unnecessary, as the callee ignores it, but conservative. + + // Spill local variables (registers) onto the stack; call; unspill. + MOVD R8, 88(RSP) + CALL ·emitLiteral(SB) + MOVD 88(RSP), R8 + + // Finish the "d +=" part of "d += emitLiteral(etc)". + MOVD 56(RSP), R1 + ADD R1, R8, R8 + +encodeBlockEnd: + MOVD dst_base+0(FP), R3 + SUB R3, R8, R8 + MOVD R8, d+48(FP) + RET diff --git a/vendor/github.com/golang/snappy/encode_amd64.go b/vendor/github.com/golang/snappy/encode_asm.go similarity index 97% rename from vendor/github.com/golang/snappy/encode_amd64.go rename to vendor/github.com/golang/snappy/encode_asm.go index 150d91bc8b..107c1e7141 100644 --- a/vendor/github.com/golang/snappy/encode_amd64.go +++ b/vendor/github.com/golang/snappy/encode_asm.go @@ -5,6 +5,7 @@ // +build !appengine // +build gc // +build !noasm +// +build amd64 arm64 package snappy diff --git a/vendor/github.com/golang/snappy/encode_other.go b/vendor/github.com/golang/snappy/encode_other.go index dbcae905e6..296d7f0beb 100644 --- a/vendor/github.com/golang/snappy/encode_other.go +++ b/vendor/github.com/golang/snappy/encode_other.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// +build !amd64 appengine !gc noasm +// +build !amd64,!arm64 appengine !gc noasm package snappy diff --git a/vendor/github.com/google/uuid/hash.go b/vendor/github.com/google/uuid/hash.go index b174616315..b404f4bec2 100644 --- a/vendor/github.com/google/uuid/hash.go +++ b/vendor/github.com/google/uuid/hash.go @@ -26,8 +26,8 @@ var ( // NewMD5 and NewSHA1. func NewHash(h hash.Hash, space UUID, data []byte, version int) UUID { h.Reset() - h.Write(space[:]) - h.Write(data) + h.Write(space[:]) //nolint:errcheck + h.Write(data) //nolint:errcheck s := h.Sum(nil) var uuid UUID copy(uuid[:], s) diff --git a/vendor/github.com/google/uuid/null.go b/vendor/github.com/google/uuid/null.go new file mode 100644 index 0000000000..d7fcbf2865 --- /dev/null +++ b/vendor/github.com/google/uuid/null.go @@ -0,0 +1,118 @@ +// Copyright 2021 Google Inc. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package uuid + +import ( + "bytes" + "database/sql/driver" + "encoding/json" + "fmt" +) + +var jsonNull = []byte("null") + +// NullUUID represents a UUID that may be null. +// NullUUID implements the SQL driver.Scanner interface so +// it can be used as a scan destination: +// +// var u uuid.NullUUID +// err := db.QueryRow("SELECT name FROM foo WHERE id=?", id).Scan(&u) +// ... +// if u.Valid { +// // use u.UUID +// } else { +// // NULL value +// } +// +type NullUUID struct { + UUID UUID + Valid bool // Valid is true if UUID is not NULL +} + +// Scan implements the SQL driver.Scanner interface. +func (nu *NullUUID) Scan(value interface{}) error { + if value == nil { + nu.UUID, nu.Valid = Nil, false + return nil + } + + err := nu.UUID.Scan(value) + if err != nil { + nu.Valid = false + return err + } + + nu.Valid = true + return nil +} + +// Value implements the driver Valuer interface. +func (nu NullUUID) Value() (driver.Value, error) { + if !nu.Valid { + return nil, nil + } + // Delegate to UUID Value function + return nu.UUID.Value() +} + +// MarshalBinary implements encoding.BinaryMarshaler. +func (nu NullUUID) MarshalBinary() ([]byte, error) { + if nu.Valid { + return nu.UUID[:], nil + } + + return []byte(nil), nil +} + +// UnmarshalBinary implements encoding.BinaryUnmarshaler. +func (nu *NullUUID) UnmarshalBinary(data []byte) error { + if len(data) != 16 { + return fmt.Errorf("invalid UUID (got %d bytes)", len(data)) + } + copy(nu.UUID[:], data) + nu.Valid = true + return nil +} + +// MarshalText implements encoding.TextMarshaler. +func (nu NullUUID) MarshalText() ([]byte, error) { + if nu.Valid { + return nu.UUID.MarshalText() + } + + return jsonNull, nil +} + +// UnmarshalText implements encoding.TextUnmarshaler. +func (nu *NullUUID) UnmarshalText(data []byte) error { + id, err := ParseBytes(data) + if err != nil { + nu.Valid = false + return err + } + nu.UUID = id + nu.Valid = true + return nil +} + +// MarshalJSON implements json.Marshaler. +func (nu NullUUID) MarshalJSON() ([]byte, error) { + if nu.Valid { + return json.Marshal(nu.UUID) + } + + return jsonNull, nil +} + +// UnmarshalJSON implements json.Unmarshaler. +func (nu *NullUUID) UnmarshalJSON(data []byte) error { + if bytes.Equal(data, jsonNull) { + *nu = NullUUID{} + return nil // valid null UUID + } + err := json.Unmarshal(data, &nu.UUID) + nu.Valid = err == nil + return err +} diff --git a/vendor/github.com/google/uuid/sql.go b/vendor/github.com/google/uuid/sql.go index f326b54db3..2e02ec06c0 100644 --- a/vendor/github.com/google/uuid/sql.go +++ b/vendor/github.com/google/uuid/sql.go @@ -9,7 +9,7 @@ import ( "fmt" ) -// Scan implements sql.Scanner so UUIDs can be read from databases transparently +// Scan implements sql.Scanner so UUIDs can be read from databases transparently. // Currently, database types that map to string and []byte are supported. Please // consult database-specific driver documentation for matching types. func (uuid *UUID) Scan(src interface{}) error { diff --git a/vendor/github.com/google/uuid/uuid.go b/vendor/github.com/google/uuid/uuid.go index 524404cc52..a57207aeb6 100644 --- a/vendor/github.com/google/uuid/uuid.go +++ b/vendor/github.com/google/uuid/uuid.go @@ -12,6 +12,7 @@ import ( "fmt" "io" "strings" + "sync" ) // A UUID is a 128 bit (16 byte) Universal Unique IDentifier as defined in RFC @@ -33,7 +34,27 @@ const ( Future // Reserved for future definition. ) -var rander = rand.Reader // random function +const randPoolSize = 16 * 16 + +var ( + rander = rand.Reader // random function + poolEnabled = false + poolMu sync.Mutex + poolPos = randPoolSize // protected with poolMu + pool [randPoolSize]byte // protected with poolMu +) + +type invalidLengthError struct{ len int } + +func (err invalidLengthError) Error() string { + return fmt.Sprintf("invalid UUID length: %d", err.len) +} + +// IsInvalidLengthError is matcher function for custom error invalidLengthError +func IsInvalidLengthError(err error) bool { + _, ok := err.(invalidLengthError) + return ok +} // Parse decodes s into a UUID or returns an error. Both the standard UUID // forms of xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx and @@ -68,7 +89,7 @@ func Parse(s string) (UUID, error) { } return uuid, nil default: - return uuid, fmt.Errorf("invalid UUID length: %d", len(s)) + return uuid, invalidLengthError{len(s)} } // s is now at least 36 bytes long // it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx @@ -112,7 +133,7 @@ func ParseBytes(b []byte) (UUID, error) { } return uuid, nil default: - return uuid, fmt.Errorf("invalid UUID length: %d", len(b)) + return uuid, invalidLengthError{len(b)} } // s is now at least 36 bytes long // it must be of the form xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx @@ -243,3 +264,31 @@ func SetRand(r io.Reader) { } rander = r } + +// EnableRandPool enables internal randomness pool used for Random +// (Version 4) UUID generation. The pool contains random bytes read from +// the random number generator on demand in batches. Enabling the pool +// may improve the UUID generation throughput significantly. +// +// Since the pool is stored on the Go heap, this feature may be a bad fit +// for security sensitive applications. +// +// Both EnableRandPool and DisableRandPool are not thread-safe and should +// only be called when there is no possibility that New or any other +// UUID Version 4 generation function will be called concurrently. +func EnableRandPool() { + poolEnabled = true +} + +// DisableRandPool disables the randomness pool if it was previously +// enabled with EnableRandPool. +// +// Both EnableRandPool and DisableRandPool are not thread-safe and should +// only be called when there is no possibility that New or any other +// UUID Version 4 generation function will be called concurrently. +func DisableRandPool() { + poolEnabled = false + defer poolMu.Unlock() + poolMu.Lock() + poolPos = randPoolSize +} diff --git a/vendor/github.com/google/uuid/version4.go b/vendor/github.com/google/uuid/version4.go index c110465db5..7697802e4d 100644 --- a/vendor/github.com/google/uuid/version4.go +++ b/vendor/github.com/google/uuid/version4.go @@ -14,11 +14,21 @@ func New() UUID { return Must(NewRandom()) } +// NewString creates a new random UUID and returns it as a string or panics. +// NewString is equivalent to the expression +// +// uuid.New().String() +func NewString() string { + return Must(NewRandom()).String() +} + // NewRandom returns a Random (Version 4) UUID. // // The strength of the UUIDs is based on the strength of the crypto/rand // package. // +// Uses the randomness pool if it was enabled with EnableRandPool. +// // A note about uniqueness derived from the UUID Wikipedia entry: // // Randomly generated UUIDs have 122 random bits. One's annual risk of being @@ -27,7 +37,10 @@ func New() UUID { // equivalent to the odds of creating a few tens of trillions of UUIDs in a // year and having one duplicate. func NewRandom() (UUID, error) { - return NewRandomFromReader(rander) + if !poolEnabled { + return NewRandomFromReader(rander) + } + return newRandomFromPool() } // NewRandomFromReader returns a UUID based on bytes read from a given io.Reader. @@ -41,3 +54,23 @@ func NewRandomFromReader(r io.Reader) (UUID, error) { uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10 return uuid, nil } + +func newRandomFromPool() (UUID, error) { + var uuid UUID + poolMu.Lock() + if poolPos == randPoolSize { + _, err := io.ReadFull(rander, pool[:]) + if err != nil { + poolMu.Unlock() + return Nil, err + } + poolPos = 0 + } + copy(uuid[:], pool[poolPos:(poolPos+16)]) + poolPos += 16 + poolMu.Unlock() + + uuid[6] = (uuid[6] & 0x0f) | 0x40 // Version 4 + uuid[8] = (uuid[8] & 0x3f) | 0x80 // Variant is 10 + return uuid, nil +} diff --git a/vendor/github.com/hasura/go-graphql-client/.travis.yml b/vendor/github.com/hasura/go-graphql-client/.travis.yml new file mode 100644 index 0000000000..93b1fcdb31 --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/.travis.yml @@ -0,0 +1,16 @@ +sudo: false +language: go +go: + - 1.x + - master +matrix: + allow_failures: + - go: master + fast_finish: true +install: + - # Do nothing. This is needed to prevent default install action "go get -t -v ./..." from happening here (we want it to happen inside script step). +script: + - go get -t -v ./... + - diff -u <(echo -n) <(gofmt -d -s .) + - go tool vet . + - go test -v -race ./... diff --git a/vendor/github.com/hasura/go-graphql-client/LICENSE b/vendor/github.com/hasura/go-graphql-client/LICENSE new file mode 100644 index 0000000000..91cac9f03c --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/LICENSE @@ -0,0 +1,22 @@ +MIT License + +Copyright (c) 2017 Dmitri Shuralyov +Copyright (c) 2020 Hasura + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/github.com/hasura/go-graphql-client/README.md b/vendor/github.com/hasura/go-graphql-client/README.md new file mode 100644 index 0000000000..6bb2e18df4 --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/README.md @@ -0,0 +1,442 @@ +go-graphql-client +======= + +[![Build Status](https://travis-ci.org/hasura/go-graphql-client.svg?branch=master)](https://travis-ci.org/hasura/go-graphql-client.svg?branch=master) [![GoDoc](https://godoc.org/github.com/hasura/go-graphql-client?status.svg)](https://pkg.go.dev/github.com/hasura/go-graphql-client) + +**Preface:** This is a fork of `https://github.com/shurcooL/graphql` with extended features (subscription client, named operation) + +The subscription client follows Apollo client specification https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md, using websocket protocol with https://github.com/nhooyr/websocket, a minimal and idiomatic WebSocket library for Go. + +Package `graphql` provides a GraphQL client implementation. + +For more information, see package [`github.com/shurcooL/githubv4`](https://github.com/shurcooL/githubv4), which is a specialized version targeting GitHub GraphQL API v4. That package is driving the feature development. + +**Status:** In active early research and development. The API will change when opportunities for improvement are discovered; it is not yet frozen. + +Installation +------------ + +`go-graphql-client` requires Go version 1.13 or later. + +```bash +go get -u github.com/hasura/go-graphql-client +``` + +Usage +----- + +Construct a GraphQL client, specifying the GraphQL server URL. Then, you can use it to make GraphQL queries and mutations. + +```Go +client := graphql.NewClient("https://example.com/graphql", nil) +// Use client... +``` + +### Authentication + +Some GraphQL servers may require authentication. The `graphql` package does not directly handle authentication. Instead, when creating a new client, you're expected to pass an `http.Client` that performs authentication. The easiest and recommended way to do this is to use the [`golang.org/x/oauth2`](https://golang.org/x/oauth2) package. You'll need an OAuth token with the right scopes. Then: + +```Go +import "golang.org/x/oauth2" + +func main() { + src := oauth2.StaticTokenSource( + &oauth2.Token{AccessToken: os.Getenv("GRAPHQL_TOKEN")}, + ) + httpClient := oauth2.NewClient(context.Background(), src) + + client := graphql.NewClient("https://example.com/graphql", httpClient) + // Use client... +``` + +### Simple Query + +To make a GraphQL query, you need to define a corresponding Go type. + +For example, to make the following GraphQL query: + +```GraphQL +query { + me { + name + } +} +``` + +You can define this variable: + +```Go +var query struct { + Me struct { + Name graphql.String + } +} +``` + +Then call `client.Query`, passing a pointer to it: + +```Go +err := client.Query(context.Background(), &query, nil) +if err != nil { + // Handle error. +} +fmt.Println(query.Me.Name) + +// Output: Luke Skywalker +``` + +### Arguments and Variables + +Often, you'll want to specify arguments on some fields. You can use the `graphql` struct field tag for this. + +For example, to make the following GraphQL query: + +```GraphQL +{ + human(id: "1000") { + name + height(unit: METER) + } +} +``` + +You can define this variable: + +```Go +var q struct { + Human struct { + Name graphql.String + Height graphql.Float `graphql:"height(unit: METER)"` + } `graphql:"human(id: \"1000\")"` +} +``` + +Then call `client.Query`: + +```Go +err := client.Query(context.Background(), &q, nil) +if err != nil { + // Handle error. +} +fmt.Println(q.Human.Name) +fmt.Println(q.Human.Height) + +// Output: +// Luke Skywalker +// 1.72 +``` + +However, that'll only work if the arguments are constant and known in advance. Otherwise, you will need to make use of variables. Replace the constants in the struct field tag with variable names: + +```Go +var q struct { + Human struct { + Name graphql.String + Height graphql.Float `graphql:"height(unit: $unit)"` + } `graphql:"human(id: $id)"` +} +``` + +Then, define a `variables` map with their values: + +```Go +variables := map[string]interface{}{ + "id": graphql.ID(id), + "unit": starwars.LengthUnit("METER"), +} +``` + +Finally, call `client.Query` providing `variables`: + +```Go +err := client.Query(context.Background(), &q, variables) +if err != nil { + // Handle error. +} +``` + +### Inline Fragments + +Some GraphQL queries contain inline fragments. You can use the `graphql` struct field tag to express them. + +For example, to make the following GraphQL query: + +```GraphQL +{ + hero(episode: "JEDI") { + name + ... on Droid { + primaryFunction + } + ... on Human { + height + } + } +} +``` + +You can define this variable: + +```Go +var q struct { + Hero struct { + Name graphql.String + Droid struct { + PrimaryFunction graphql.String + } `graphql:"... on Droid"` + Human struct { + Height graphql.Float + } `graphql:"... on Human"` + } `graphql:"hero(episode: \"JEDI\")"` +} +``` + +Alternatively, you can define the struct types corresponding to inline fragments, and use them as embedded fields in your query: + +```Go +type ( + DroidFragment struct { + PrimaryFunction graphql.String + } + HumanFragment struct { + Height graphql.Float + } +) + +var q struct { + Hero struct { + Name graphql.String + DroidFragment `graphql:"... on Droid"` + HumanFragment `graphql:"... on Human"` + } `graphql:"hero(episode: \"JEDI\")"` +} +``` + +Then call `client.Query`: + +```Go +err := client.Query(context.Background(), &q, nil) +if err != nil { + // Handle error. +} +fmt.Println(q.Hero.Name) +fmt.Println(q.Hero.PrimaryFunction) +fmt.Println(q.Hero.Height) + +// Output: +// R2-D2 +// Astromech +// 0 +``` + +### Mutations + +Mutations often require information that you can only find out by performing a query first. Let's suppose you've already done that. + +For example, to make the following GraphQL mutation: + +```GraphQL +mutation($ep: Episode!, $review: ReviewInput!) { + createReview(episode: $ep, review: $review) { + stars + commentary + } +} +variables { + "ep": "JEDI", + "review": { + "stars": 5, + "commentary": "This is a great movie!" + } +} +``` + +You can define: + +```Go +var m struct { + CreateReview struct { + Stars graphql.Int + Commentary graphql.String + } `graphql:"createReview(episode: $ep, review: $review)"` +} +variables := map[string]interface{}{ + "ep": starwars.Episode("JEDI"), + "review": starwars.ReviewInput{ + Stars: graphql.Int(5), + Commentary: graphql.String("This is a great movie!"), + }, +} +``` + +Then call `client.Mutate`: + +```Go +err := client.Mutate(context.Background(), &m, variables) +if err != nil { + // Handle error. +} +fmt.Printf("Created a %v star review: %v\n", m.CreateReview.Stars, m.CreateReview.Commentary) + +// Output: +// Created a 5 star review: This is a great movie! +``` + +### Subcriptions + +Usage +----- + +Construct a Subscription client, specifying the GraphQL server URL. + +```Go +client := graphql.NewSubscriptionClient("wss://example.com/graphql") +defer client.Close() + +// Subscribe subscriptions +// ... +// finally run the client +client.Run() +``` + +#### Subscribe + +To make a GraphQL subscription, you need to define a corresponding Go type. + +For example, to make the following GraphQL query: + +```GraphQL +subscription { + me { + name + } +} +``` + +You can define this variable: + +```Go +var subscription struct { + Me struct { + Name graphql.String + } +} +``` + +Then call `client.Subscribe`, passing a pointer to it: + +```Go +subscriptionId, err := client.Subscribe(&query, nil, func(dataValue *json.RawMessage, errValue error) error { + if errValue != nil { + // handle error + // if returns error, it will failback to `onError` event + return nil + } + data := query{} + err := json.Unmarshal(dataValue, &data) + + fmt.Println(query.Me.Name) + + // Output: Luke Skywalker +}) + +if err != nil { + // Handle error. +} + +// you can unsubscribe the subscription while the client is running +client.Unsubscribe(subscriptionId) +``` + +#### Authentication + +The subscription client is authenticated with GraphQL server through connection params: + +```Go +client := graphql.NewSubscriptionClient("wss://example.com/graphql"). + WithConnectionParams(map[string]interface{}{ + "headers": map[string]string{ + "authentication": "...", + }, + }) + +``` + +#### Options + +```Go +client. + // write timeout of websocket client + WithTimeout(time.Minute). + // When the websocket server was stopped, the client will retry connecting every second until timeout + WithRetryTimeout(time.Minute). + // sets loging function to print out received messages. By default, nothing is printed + WithLog(log.Println). + // max size of response message + WithReadLimit(10*1024*1024). + // these operation event logs won't be printed + WithoutLogTypes(graphql.GQL_DATA, graphql.GQL_CONNECTION_KEEP_ALIVE) + +``` + +#### Events + +```Go +// OnConnected event is triggered when the websocket connected to GraphQL server sucessfully +client.OnConnected(fn func()) + +// OnDisconnected event is triggered when the websocket server was stil down after retry timeout +client.OnDisconnected(fn func()) + +// OnConnected event is triggered when there is any connection error. This is bottom exception handler level +// If this function is empty, or returns nil, the error is ignored +// If returns error, the websocket connection will be terminated +client.OnError(onError func(sc *SubscriptionClient, err error) error) +``` + +### With operation name + +Operation name is still on API decision plan https://github.com/shurcooL/graphql/issues/12. However, in my opinion separate methods are easier choice to avoid breaking changes + +```Go +func (c *Client) NamedQuery(ctx context.Context, name string, q interface{}, variables map[string]interface{}) error + +func (c *Client) NamedMutate(ctx context.Context, name string, q interface{}, variables map[string]interface{}) error + +func (sc *SubscriptionClient) NamedSubscribe(name string, v interface{}, variables map[string]interface{}, handler func(message *json.RawMessage, err error) error) (string, error) +``` + +### Raw bytes response + +In the case we developers want to decode JSON response ourself. Moreover, the default `UnmarshalGraphQL` function isn't ideal with complicated nested interfaces + +```Go +func (c *Client) QueryRaw(ctx context.Context, q interface{}, variables map[string]interface{}) (*json.RawMessage, error) + +func (c *Client) MutateRaw(ctx context.Context, q interface{}, variables map[string]interface{}) (*json.RawMessage, error) + +func (c *Client) NamedQueryRaw(ctx context.Context, name string, q interface{}, variables map[string]interface{}) (*json.RawMessage, error) + +func (c *Client) NamedMutateRaw(ctx context.Context, name string, q interface{}, variables map[string]interface{}) (*json.RawMessage, error) +``` + +Directories +----------- + +| Path | Synopsis | +|----------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------------------------------| +| [example/graphqldev](https://godoc.org/github.com/shurcooL/graphql/example/graphqldev) | graphqldev is a test program currently being used for developing graphql package. | +| [ident](https://godoc.org/github.com/shurcooL/graphql/ident) | Package ident provides functions for parsing and converting identifier names between various naming convention. | +| [internal/jsonutil](https://godoc.org/github.com/shurcooL/graphql/internal/jsonutil) | Package jsonutil provides a function for decoding JSON into a GraphQL query data structure. | + +References +---------- +- https://github.com/shurcooL/graphql +- https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md +- https://github.com/nhooyr/websocket + + +License +------- + +- [MIT License](LICENSE) diff --git a/vendor/github.com/hasura/go-graphql-client/doc.go b/vendor/github.com/hasura/go-graphql-client/doc.go new file mode 100644 index 0000000000..69ec4e0387 --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/doc.go @@ -0,0 +1,11 @@ +// Package graphql provides a GraphQL client implementation. +// +// For more information, see package github.com/shurcooL/githubv4, +// which is a specialized version targeting GitHub GraphQL API v4. +// That package is driving the feature development. +// +// Status: In active early research and development. The API will change when +// opportunities for improvement are discovered; it is not yet frozen. +// +// For now, see README for more details. +package graphql // import "github.com/shurcooL/graphql" diff --git a/vendor/github.com/hasura/go-graphql-client/go.mod b/vendor/github.com/hasura/go-graphql-client/go.mod new file mode 100644 index 0000000000..99ffe31533 --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/go.mod @@ -0,0 +1,11 @@ +module github.com/hasura/go-graphql-client + +go 1.14 + +require ( + github.com/google/uuid v1.1.2 + github.com/graph-gophers/graphql-go v0.0.0-20201112095111-7a585a01e04c + github.com/opentracing/opentracing-go v1.2.0 // indirect + golang.org/x/net v0.0.0-20201110031124-69a78807bb2b + nhooyr.io/websocket v1.8.6 +) diff --git a/vendor/github.com/hasura/go-graphql-client/go.sum b/vendor/github.com/hasura/go-graphql-client/go.sum new file mode 100644 index 0000000000..367132d5b7 --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/go.sum @@ -0,0 +1,90 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls= +github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= +github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/graph-gophers/graphql-go v0.0.0-20200819123640-3b5ddcd884ae h1:TQuRfD07N7uHp+CW7rCfR579o6PDnwJacRBJH74RMq0= +github.com/graph-gophers/graphql-go v0.0.0-20200819123640-3b5ddcd884ae/go.mod h1:9CQHMSxwO4MprSdzoIEobiHpoLtHm77vfxsvsIN5Vuc= +github.com/graph-gophers/graphql-go v0.0.0-20201112095111-7a585a01e04c h1:kD1FT9NpCUWNij9g+pSETBJ6ddaQb8VIHViZrQ27y58= +github.com/graph-gophers/graphql-go v0.0.0-20201112095111-7a585a01e04c/go.mod h1:9CQHMSxwO4MprSdzoIEobiHpoLtHm77vfxsvsIN5Vuc= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= +github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/opentracing/opentracing-go v1.1.0 h1:pWlfV3Bxv7k65HYwkikxat0+s3pV4bsqf19k25Ur8rU= +github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= +github.com/opentracing/opentracing-go v1.2.0 h1:uEJPy/1a5RIPAJ0Ov+OIO8OxWu77jEv+1B0VhjKrZUs= +github.com/opentracing/opentracing-go v1.2.0/go.mod h1:GxEUsuufX4nBwe+T+Wl9TAgYrxe9dPLANfrWvHYVTgc= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20200822124328-c89045814202 h1:VvcQYSHwXgi7W+TpUR6A9g6Up98WAHf3f/ulnJ62IyA= +golang.org/x/net v0.0.0-20200822124328-c89045814202/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME= +golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd h1:xhmwyvizuTgC2qz7ZlMluP20uW+C3Rm0FD/WLDX8884= +golang.org/x/sys v0.0.0-20200323222414-85ca7c5b95cd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +nhooyr.io/websocket v1.8.6 h1:s+C3xAMLwGmlI31Nyn/eAehUlZPwfYZu2JXM621Q5/k= +nhooyr.io/websocket v1.8.6/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0= diff --git a/vendor/github.com/hasura/go-graphql-client/graphql.go b/vendor/github.com/hasura/go-graphql-client/graphql.go new file mode 100644 index 0000000000..cc23340e84 --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/graphql.go @@ -0,0 +1,211 @@ +package graphql + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + + "github.com/hasura/go-graphql-client/internal/jsonutil" + "golang.org/x/net/context/ctxhttp" +) + +// Client is a GraphQL client. +type Client struct { + url string // GraphQL server URL. + httpClient *http.Client +} + +// NewClient creates a GraphQL client targeting the specified GraphQL server URL. +// If httpClient is nil, then http.DefaultClient is used. +func NewClient(url string, httpClient *http.Client) *Client { + if httpClient == nil { + httpClient = http.DefaultClient + } + return &Client{ + url: url, + httpClient: httpClient, + } +} + +// Query executes a single GraphQL query request, +// with a query derived from q, populating the response into it. +// q should be a pointer to struct that corresponds to the GraphQL schema. +func (c *Client) Query(ctx context.Context, q interface{}, variables map[string]interface{}) error { + return c.do(ctx, queryOperation, q, variables, "") +} + +// NamedQuery executes a single GraphQL query request, with operation name +func (c *Client) NamedQuery(ctx context.Context, name string, q interface{}, variables map[string]interface{}) error { + return c.do(ctx, queryOperation, q, variables, name) +} + +// Mutate executes a single GraphQL mutation request, +// with a mutation derived from m, populating the response into it. +// m should be a pointer to struct that corresponds to the GraphQL schema. +func (c *Client) Mutate(ctx context.Context, m interface{}, variables map[string]interface{}) error { + return c.do(ctx, mutationOperation, m, variables, "") +} + +// NamedMutate executes a single GraphQL mutation request, with operation name +func (c *Client) NamedMutate(ctx context.Context, name string, m interface{}, variables map[string]interface{}) error { + return c.do(ctx, mutationOperation, m, variables, name) +} + +// Query executes a single GraphQL query request, +// with a query derived from q, populating the response into it. +// q should be a pointer to struct that corresponds to the GraphQL schema. +// return raw bytes message. +func (c *Client) QueryRaw(ctx context.Context, q interface{}, variables map[string]interface{}) (*json.RawMessage, error) { + return c.doRaw(ctx, queryOperation, q, variables, "") +} + +// NamedQueryRaw executes a single GraphQL query request, with operation name +// return raw bytes message. +func (c *Client) NamedQueryRaw(ctx context.Context, name string, q interface{}, variables map[string]interface{}) (*json.RawMessage, error) { + return c.doRaw(ctx, queryOperation, q, variables, name) +} + +// MutateRaw executes a single GraphQL mutation request, +// with a mutation derived from m, populating the response into it. +// m should be a pointer to struct that corresponds to the GraphQL schema. +// return raw bytes message. +func (c *Client) MutateRaw(ctx context.Context, m interface{}, variables map[string]interface{}) (*json.RawMessage, error) { + return c.doRaw(ctx, mutationOperation, m, variables, "") +} + +// NamedMutateRaw executes a single GraphQL mutation request, with operation name +// return raw bytes message. +func (c *Client) NamedMutateRaw(ctx context.Context, name string, m interface{}, variables map[string]interface{}) (*json.RawMessage, error) { + return c.doRaw(ctx, mutationOperation, m, variables, name) +} + +// do executes a single GraphQL operation. +// return raw message and error +func (c *Client) doRaw(ctx context.Context, op operationType, v interface{}, variables map[string]interface{}, name string) (*json.RawMessage, error) { + var query string + switch op { + case queryOperation: + query = constructQuery(v, variables, name) + case mutationOperation: + query = constructMutation(v, variables, name) + } + in := struct { + Query string `json:"query"` + Variables map[string]interface{} `json:"variables,omitempty"` + }{ + Query: query, + Variables: variables, + } + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(in) + if err != nil { + return nil, err + } + resp, err := ctxhttp.Post(ctx, c.httpClient, c.url, "application/json", &buf) + if err != nil { + return nil, err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := ioutil.ReadAll(resp.Body) + return nil, fmt.Errorf("non-200 OK status code: %v body: %q", resp.Status, body) + } + var out struct { + Data *json.RawMessage + Errors errors + //Extensions interface{} // Unused. + } + err = json.NewDecoder(resp.Body).Decode(&out) + if err != nil { + // TODO: Consider including response body in returned error, if deemed helpful. + return nil, err + } + + if len(out.Errors) > 0 { + return out.Data, out.Errors + } + + return out.Data, nil +} + +// do executes a single GraphQL operation and unmarshal json. +func (c *Client) do(ctx context.Context, op operationType, v interface{}, variables map[string]interface{}, name string) error { + + var query string + switch op { + case queryOperation: + query = constructQuery(v, variables, name) + case mutationOperation: + query = constructMutation(v, variables, name) + } + in := struct { + Query string `json:"query"` + Variables map[string]interface{} `json:"variables,omitempty"` + }{ + Query: query, + Variables: variables, + } + var buf bytes.Buffer + err := json.NewEncoder(&buf).Encode(in) + if err != nil { + return err + } + resp, err := ctxhttp.Post(ctx, c.httpClient, c.url, "application/json", &buf) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + body, _ := ioutil.ReadAll(resp.Body) + return fmt.Errorf("non-200 OK status code: %v body: %q", resp.Status, body) + } + var out struct { + Data *json.RawMessage + Errors errors + //Extensions interface{} // Unused. + } + err = json.NewDecoder(resp.Body).Decode(&out) + if err != nil { + // TODO: Consider including response body in returned error, if deemed helpful. + return err + } + if out.Data != nil { + err := jsonutil.UnmarshalGraphQL(*out.Data, v) + if err != nil { + // TODO: Consider including response body in returned error, if deemed helpful. + return err + } + } + if len(out.Errors) > 0 { + return out.Errors + } + return nil +} + +// errors represents the "errors" array in a response from a GraphQL server. +// If returned via error interface, the slice is expected to contain at least 1 element. +// +// Specification: https://facebook.github.io/graphql/#sec-Errors. +type errors []struct { + Message string + Locations []struct { + Line int + Column int + } +} + +// Error implements error interface. +func (e errors) Error() string { + return e[0].Message +} + +type operationType uint8 + +const ( + queryOperation operationType = iota + mutationOperation + //subscriptionOperation // Unused. +) diff --git a/vendor/github.com/hasura/go-graphql-client/ident/ident.go b/vendor/github.com/hasura/go-graphql-client/ident/ident.go new file mode 100644 index 0000000000..29e498ed9e --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/ident/ident.go @@ -0,0 +1,240 @@ +// Package ident provides functions for parsing and converting identifier names +// between various naming convention. It has support for MixedCaps, lowerCamelCase, +// and SCREAMING_SNAKE_CASE naming conventions. +package ident + +import ( + "strings" + "unicode" + "unicode/utf8" +) + +// ParseMixedCaps parses a MixedCaps identifier name. +// +// E.g., "ClientMutationID" -> {"Client", "Mutation", "ID"}. +func ParseMixedCaps(name string) Name { + var words Name + + // Split name at any lower -> Upper or Upper -> Upper,lower transitions. + // Check each word for initialisms. + runes := []rune(name) + w, i := 0, 0 // Index of start of word, scan. + for i+1 <= len(runes) { + eow := false // Whether we hit the end of a word. + if i+1 == len(runes) { + eow = true + } else if unicode.IsLower(runes[i]) && unicode.IsUpper(runes[i+1]) { + // lower -> Upper. + eow = true + } else if i+2 < len(runes) && unicode.IsUpper(runes[i]) && unicode.IsUpper(runes[i+1]) && unicode.IsLower(runes[i+2]) { + // Upper -> Upper,lower. End of acronym, followed by a word. + eow = true + + if string(runes[i:i+3]) == "IDs" { // Special case, plural form of ID initialism. + eow = false + } + } + i++ + if !eow { + continue + } + + // [w, i) is a word. + word := string(runes[w:i]) + if initialism, ok := isInitialism(word); ok { + words = append(words, initialism) + } else if i1, i2, ok := isTwoInitialisms(word); ok { + words = append(words, i1, i2) + } else { + words = append(words, word) + } + w = i + } + return words +} + +// ParseLowerCamelCase parses a lowerCamelCase identifier name. +// +// E.g., "clientMutationId" -> {"client", "Mutation", "Id"}. +func ParseLowerCamelCase(name string) Name { + var words Name + + // Split name at any Upper letters. + runes := []rune(name) + w, i := 0, 0 // Index of start of word, scan. + for i+1 <= len(runes) { + eow := false // Whether we hit the end of a word. + if i+1 == len(runes) { + eow = true + } else if unicode.IsUpper(runes[i+1]) { + // Upper letter. + eow = true + } + i++ + if !eow { + continue + } + + // [w, i) is a word. + words = append(words, string(runes[w:i])) + w = i + } + return words +} + +// ParseScreamingSnakeCase parses a SCREAMING_SNAKE_CASE identifier name. +// +// E.g., "CLIENT_MUTATION_ID" -> {"CLIENT", "MUTATION", "ID"}. +func ParseScreamingSnakeCase(name string) Name { + var words Name + + // Split name at '_' characters. + runes := []rune(name) + w, i := 0, 0 // Index of start of word, scan. + for i+1 <= len(runes) { + eow := false // Whether we hit the end of a word. + if i+1 == len(runes) { + eow = true + } else if runes[i+1] == '_' { + // Underscore. + eow = true + } + i++ + if !eow { + continue + } + + // [w, i) is a word. + words = append(words, string(runes[w:i])) + if i < len(runes) && runes[i] == '_' { + // Skip underscore. + i++ + } + w = i + } + return words +} + +// Name is an identifier name, broken up into individual words. +type Name []string + +// ToMixedCaps expresses identifer name in MixedCaps naming convention. +// +// E.g., "ClientMutationID". +func (n Name) ToMixedCaps() string { + for i, word := range n { + if strings.EqualFold(word, "IDs") { // Special case, plural form of ID initialism. + n[i] = "IDs" + continue + } + if initialism, ok := isInitialism(word); ok { + n[i] = initialism + continue + } + if brand, ok := isBrand(word); ok { + n[i] = brand + continue + } + r, size := utf8.DecodeRuneInString(word) + n[i] = string(unicode.ToUpper(r)) + strings.ToLower(word[size:]) + } + return strings.Join(n, "") +} + +// ToLowerCamelCase expresses identifer name in lowerCamelCase naming convention. +// +// E.g., "clientMutationId". +func (n Name) ToLowerCamelCase() string { + for i, word := range n { + if i == 0 { + n[i] = strings.ToLower(word) + continue + } + r, size := utf8.DecodeRuneInString(word) + n[i] = string(unicode.ToUpper(r)) + strings.ToLower(word[size:]) + } + return strings.Join(n, "") +} + +// isInitialism reports whether word is an initialism. +func isInitialism(word string) (string, bool) { + initialism := strings.ToUpper(word) + _, ok := initialisms[initialism] + return initialism, ok +} + +// isTwoInitialisms reports whether word is two initialisms. +func isTwoInitialisms(word string) (string, string, bool) { + word = strings.ToUpper(word) + for i := 2; i <= len(word)-2; i++ { // Shortest initialism is 2 characters long. + _, ok1 := initialisms[word[:i]] + _, ok2 := initialisms[word[i:]] + if ok1 && ok2 { + return word[:i], word[i:], true + } + } + return "", "", false +} + +// initialisms is the set of initialisms in the MixedCaps naming convention. +// Only add entries that are highly unlikely to be non-initialisms. +// For instance, "ID" is fine (Freudian code is rare), but "AND" is not. +var initialisms = map[string]struct{}{ + // These are the common initialisms from golint. Keep them in sync + // with https://gotools.org/github.com/golang/lint#commonInitialisms. + "ACL": {}, + "API": {}, + "ASCII": {}, + "CPU": {}, + "CSS": {}, + "DNS": {}, + "EOF": {}, + "GUID": {}, + "HTML": {}, + "HTTP": {}, + "HTTPS": {}, + "ID": {}, + "IP": {}, + "JSON": {}, + "LHS": {}, + "QPS": {}, + "RAM": {}, + "RHS": {}, + "RPC": {}, + "SLA": {}, + "SMTP": {}, + "SQL": {}, + "SSH": {}, + "TCP": {}, + "TLS": {}, + "TTL": {}, + "UDP": {}, + "UI": {}, + "UID": {}, + "UUID": {}, + "URI": {}, + "URL": {}, + "UTF8": {}, + "VM": {}, + "XML": {}, + "XMPP": {}, + "XSRF": {}, + "XSS": {}, + + // Additional common initialisms. + "RSS": {}, +} + +// isBrand reports whether word is a brand. +func isBrand(word string) (string, bool) { + brand, ok := brands[strings.ToLower(word)] + return brand, ok +} + +// brands is the map of brands in the MixedCaps naming convention; +// see https://dmitri.shuralyov.com/idiomatic-go#for-brands-or-words-with-more-than-1-capital-letter-lowercase-all-letters. +// Key is the lower case version of the brand, value is the canonical brand spelling. +// Only add entries that are highly unlikely to be non-brands. +var brands = map[string]string{ + "github": "GitHub", +} diff --git a/vendor/github.com/hasura/go-graphql-client/internal/jsonutil/graphql.go b/vendor/github.com/hasura/go-graphql-client/internal/jsonutil/graphql.go new file mode 100644 index 0000000000..15bae246fd --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/internal/jsonutil/graphql.go @@ -0,0 +1,311 @@ +// Package jsonutil provides a function for decoding JSON +// into a GraphQL query data structure. +package jsonutil + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "io" + "reflect" + "strings" +) + +// UnmarshalGraphQL parses the JSON-encoded GraphQL response data and stores +// the result in the GraphQL query data structure pointed to by v. +// +// The implementation is created on top of the JSON tokenizer available +// in "encoding/json".Decoder. +func UnmarshalGraphQL(data []byte, v interface{}) error { + dec := json.NewDecoder(bytes.NewReader(data)) + dec.UseNumber() + err := (&decoder{tokenizer: dec}).Decode(v) + if err != nil { + return err + } + tok, err := dec.Token() + switch err { + case io.EOF: + // Expect to get io.EOF. There shouldn't be any more + // tokens left after we've decoded v successfully. + return nil + case nil: + return fmt.Errorf("invalid token '%v' after top-level value", tok) + default: + return err + } +} + +// decoder is a JSON decoder that performs custom unmarshaling behavior +// for GraphQL query data structures. It's implemented on top of a JSON tokenizer. +type decoder struct { + tokenizer interface { + Token() (json.Token, error) + } + + // Stack of what part of input JSON we're in the middle of - objects, arrays. + parseState []json.Delim + + // Stacks of values where to unmarshal. + // The top of each stack is the reflect.Value where to unmarshal next JSON value. + // + // The reason there's more than one stack is because we might be unmarshaling + // a single JSON value into multiple GraphQL fragments or embedded structs, so + // we keep track of them all. + vs [][]reflect.Value +} + +// Decode decodes a single JSON value from d.tokenizer into v. +func (d *decoder) Decode(v interface{}) error { + rv := reflect.ValueOf(v) + if rv.Kind() != reflect.Ptr { + return fmt.Errorf("cannot decode into non-pointer %T", v) + } + d.vs = [][]reflect.Value{{rv.Elem()}} + return d.decode() +} + +// decode decodes a single JSON value from d.tokenizer into d.vs. +func (d *decoder) decode() error { + // The loop invariant is that the top of each d.vs stack + // is where we try to unmarshal the next JSON value we see. + for len(d.vs) > 0 { + tok, err := d.tokenizer.Token() + if err == io.EOF { + return errors.New("unexpected end of JSON input") + } else if err != nil { + return err + } + + switch { + + // Are we inside an object and seeing next key (rather than end of object)? + case d.state() == '{' && tok != json.Delim('}'): + key, ok := tok.(string) + if !ok { + return errors.New("unexpected non-key in JSON input") + } + someFieldExist := false + for i := range d.vs { + v := d.vs[i][len(d.vs[i])-1] + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + var f reflect.Value + if v.Kind() == reflect.Struct { + f = fieldByGraphQLName(v, key) + if f.IsValid() { + someFieldExist = true + } + } + d.vs[i] = append(d.vs[i], f) + } + if !someFieldExist { + return fmt.Errorf("struct field for %q doesn't exist in any of %v places to unmarshal", key, len(d.vs)) + } + + // We've just consumed the current token, which was the key. + // Read the next token, which should be the value, and let the rest of code process it. + tok, err = d.tokenizer.Token() + if err == io.EOF { + return errors.New("unexpected end of JSON input") + } else if err != nil { + return err + } + + // Are we inside an array and seeing next value (rather than end of array)? + case d.state() == '[' && tok != json.Delim(']'): + someSliceExist := false + for i := range d.vs { + v := d.vs[i][len(d.vs[i])-1] + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + var f reflect.Value + if v.Kind() == reflect.Slice { + v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem()))) // v = append(v, T). + f = v.Index(v.Len() - 1) + someSliceExist = true + } + d.vs[i] = append(d.vs[i], f) + } + if !someSliceExist { + return fmt.Errorf("slice doesn't exist in any of %v places to unmarshal", len(d.vs)) + } + } + + switch tok := tok.(type) { + case string, json.Number, bool, nil: + // Value. + + for i := range d.vs { + v := d.vs[i][len(d.vs[i])-1] + if !v.IsValid() { + continue + } + err := unmarshalValue(tok, v) + if err != nil { + return err + } + } + d.popAllVs() + + case json.Delim: + switch tok { + case '{': + // Start of object. + + d.pushState(tok) + + frontier := make([]reflect.Value, len(d.vs)) // Places to look for GraphQL fragments/embedded structs. + for i := range d.vs { + v := d.vs[i][len(d.vs[i])-1] + frontier[i] = v + // TODO: Do this recursively or not? Add a test case if needed. + if v.Kind() == reflect.Ptr && v.IsNil() { + v.Set(reflect.New(v.Type().Elem())) // v = new(T). + } + } + // Find GraphQL fragments/embedded structs recursively, adding to frontier + // as new ones are discovered and exploring them further. + for len(frontier) > 0 { + v := frontier[0] + frontier = frontier[1:] + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Struct { + continue + } + for i := 0; i < v.NumField(); i++ { + if isGraphQLFragment(v.Type().Field(i)) || v.Type().Field(i).Anonymous { + // Add GraphQL fragment or embedded struct. + d.vs = append(d.vs, []reflect.Value{v.Field(i)}) + frontier = append(frontier, v.Field(i)) + } + } + } + case '[': + // Start of array. + + d.pushState(tok) + + for i := range d.vs { + v := d.vs[i][len(d.vs[i])-1] + // TODO: Confirm this is needed, write a test case. + //if v.Kind() == reflect.Ptr && v.IsNil() { + // v.Set(reflect.New(v.Type().Elem())) // v = new(T). + //} + + // Reset slice to empty (in case it had non-zero initial value). + if v.Kind() == reflect.Ptr { + v = v.Elem() + } + if v.Kind() != reflect.Slice { + continue + } + v.Set(reflect.MakeSlice(v.Type(), 0, 0)) // v = make(T, 0, 0). + } + case '}', ']': + // End of object or array. + d.popAllVs() + d.popState() + default: + return errors.New("unexpected delimiter in JSON input") + } + default: + return errors.New("unexpected token in JSON input") + } + } + return nil +} + +// pushState pushes a new parse state s onto the stack. +func (d *decoder) pushState(s json.Delim) { + d.parseState = append(d.parseState, s) +} + +// popState pops a parse state (already obtained) off the stack. +// The stack must be non-empty. +func (d *decoder) popState() { + d.parseState = d.parseState[:len(d.parseState)-1] +} + +// state reports the parse state on top of stack, or 0 if empty. +func (d *decoder) state() json.Delim { + if len(d.parseState) == 0 { + return 0 + } + return d.parseState[len(d.parseState)-1] +} + +// popAllVs pops from all d.vs stacks, keeping only non-empty ones. +func (d *decoder) popAllVs() { + var nonEmpty [][]reflect.Value + for i := range d.vs { + d.vs[i] = d.vs[i][:len(d.vs[i])-1] + if len(d.vs[i]) > 0 { + nonEmpty = append(nonEmpty, d.vs[i]) + } + } + d.vs = nonEmpty +} + +// fieldByGraphQLName returns an exported struct field of struct v +// that matches GraphQL name, or invalid reflect.Value if none found. +func fieldByGraphQLName(v reflect.Value, name string) reflect.Value { + for i := 0; i < v.NumField(); i++ { + if v.Type().Field(i).PkgPath != "" { + // Skip unexported field. + continue + } + if hasGraphQLName(v.Type().Field(i), name) { + return v.Field(i) + } + } + return reflect.Value{} +} + +// hasGraphQLName reports whether struct field f has GraphQL name. +func hasGraphQLName(f reflect.StructField, name string) bool { + value, ok := f.Tag.Lookup("graphql") + if !ok { + // TODO: caseconv package is relatively slow. Optimize it, then consider using it here. + //return caseconv.MixedCapsToLowerCamelCase(f.Name) == name + return strings.EqualFold(f.Name, name) + } + value = strings.TrimSpace(value) // TODO: Parse better. + if strings.HasPrefix(value, "...") { + // GraphQL fragment. It doesn't have a name. + return false + } + if i := strings.Index(value, "("); i != -1 { + value = value[:i] + } + if i := strings.Index(value, ":"); i != -1 { + value = value[:i] + } + return strings.TrimSpace(value) == name +} + +// isGraphQLFragment reports whether struct field f is a GraphQL fragment. +func isGraphQLFragment(f reflect.StructField) bool { + value, ok := f.Tag.Lookup("graphql") + if !ok { + return false + } + value = strings.TrimSpace(value) // TODO: Parse better. + return strings.HasPrefix(value, "...") +} + +// unmarshalValue unmarshals JSON value into v. +// v must be addressable and not obtained by the use of unexported +// struct fields, otherwise unmarshalValue will panic. +func unmarshalValue(value json.Token, v reflect.Value) error { + b, err := json.Marshal(value) // TODO: Short-circuit (if profiling says it's worth it). + if err != nil { + return err + } + return json.Unmarshal(b, v.Addr().Interface()) +} diff --git a/vendor/github.com/hasura/go-graphql-client/query.go b/vendor/github.com/hasura/go-graphql-client/query.go new file mode 100644 index 0000000000..28b6b80e3c --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/query.go @@ -0,0 +1,149 @@ +package graphql + +import ( + "bytes" + "encoding/json" + "io" + "reflect" + "sort" + + "github.com/hasura/go-graphql-client/ident" +) + +func constructQuery(v interface{}, variables map[string]interface{}, name string) string { + query := query(v) + if len(variables) > 0 { + return "query " + name + "(" + queryArguments(variables) + ")" + query + } + + if name != "" { + return "query " + name + query + } + return query +} + +func constructMutation(v interface{}, variables map[string]interface{}, name string) string { + query := query(v) + if len(variables) > 0 { + return "mutation " + name + "(" + queryArguments(variables) + ")" + query + } + if name != "" { + return "mutation " + name + query + } + return "mutation" + query +} + +func constructSubscription(v interface{}, variables map[string]interface{}, name string) string { + query := query(v) + if len(variables) > 0 { + return "subscription " + name + "(" + queryArguments(variables) + ")" + query + } + if name != "" { + return "subscription " + name + query + } + return "subscription" + query +} + +// queryArguments constructs a minified arguments string for variables. +// +// E.g., map[string]interface{}{"a": Int(123), "b": NewBoolean(true)} -> "$a:Int!$b:Boolean". +func queryArguments(variables map[string]interface{}) string { + // Sort keys in order to produce deterministic output for testing purposes. + // TODO: If tests can be made to work with non-deterministic output, then no need to sort. + keys := make([]string, 0, len(variables)) + for k := range variables { + keys = append(keys, k) + } + sort.Strings(keys) + + var buf bytes.Buffer + for _, k := range keys { + io.WriteString(&buf, "$") + io.WriteString(&buf, k) + io.WriteString(&buf, ":") + writeArgumentType(&buf, reflect.TypeOf(variables[k]), true) + // Don't insert a comma here. + // Commas in GraphQL are insignificant, and we want minified output. + // See https://facebook.github.io/graphql/October2016/#sec-Insignificant-Commas. + } + return buf.String() +} + +// writeArgumentType writes a minified GraphQL type for t to w. +// value indicates whether t is a value (required) type or pointer (optional) type. +// If value is true, then "!" is written at the end of t. +func writeArgumentType(w io.Writer, t reflect.Type, value bool) { + if t.Kind() == reflect.Ptr { + // Pointer is an optional type, so no "!" at the end of the pointer's underlying type. + writeArgumentType(w, t.Elem(), false) + return + } + + switch t.Kind() { + case reflect.Slice, reflect.Array: + // List. E.g., "[Int]". + io.WriteString(w, "[") + writeArgumentType(w, t.Elem(), true) + io.WriteString(w, "]") + default: + // Named type. E.g., "Int". + name := t.Name() + if name == "string" { // HACK: Workaround for https://github.com/shurcooL/githubv4/issues/12. + name = "ID" + } + io.WriteString(w, name) + } + + if value { + // Value is a required type, so add "!" to the end. + io.WriteString(w, "!") + } +} + +// query uses writeQuery to recursively construct +// a minified query string from the provided struct v. +// +// E.g., struct{Foo Int, BarBaz *Boolean} -> "{foo,barBaz}". +func query(v interface{}) string { + var buf bytes.Buffer + writeQuery(&buf, reflect.TypeOf(v), false) + return buf.String() +} + +// writeQuery writes a minified query for t to w. +// If inline is true, the struct fields of t are inlined into parent struct. +func writeQuery(w io.Writer, t reflect.Type, inline bool) { + switch t.Kind() { + case reflect.Ptr, reflect.Slice: + writeQuery(w, t.Elem(), false) + case reflect.Struct: + // If the type implements json.Unmarshaler, it's a scalar. Don't expand it. + if reflect.PtrTo(t).Implements(jsonUnmarshaler) { + return + } + if !inline { + io.WriteString(w, "{") + } + for i := 0; i < t.NumField(); i++ { + if i != 0 { + io.WriteString(w, ",") + } + f := t.Field(i) + value, ok := f.Tag.Lookup("graphql") + inlineField := f.Anonymous && !ok + if !inlineField { + if ok { + io.WriteString(w, value) + } else { + io.WriteString(w, ident.ParseMixedCaps(f.Name).ToLowerCamelCase()) + } + } + writeQuery(w, f.Type, inlineField) + } + if !inline { + io.WriteString(w, "}") + } + } +} + +var jsonUnmarshaler = reflect.TypeOf((*json.Unmarshaler)(nil)).Elem() diff --git a/vendor/github.com/hasura/go-graphql-client/scalar.go b/vendor/github.com/hasura/go-graphql-client/scalar.go new file mode 100644 index 0000000000..0f7ceea5e0 --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/scalar.go @@ -0,0 +1,51 @@ +package graphql + +// Note: These custom types are meant to be used in queries for now. +// But the plan is to switch to using native Go types (string, int, bool, time.Time, etc.). +// See https://github.com/shurcooL/githubv4/issues/9 for details. +// +// These custom types currently provide documentation, and their use +// is required for sending outbound queries. However, native Go types +// can be used for unmarshaling. Once https://github.com/shurcooL/githubv4/issues/9 +// is resolved, native Go types can completely replace these. + +type ( + // Boolean represents true or false values. + Boolean bool + + // Float represents signed double-precision fractional values as + // specified by IEEE 754. + Float float64 + + // ID represents a unique identifier that is Base64 obfuscated. It + // is often used to refetch an object or as key for a cache. The ID + // type appears in a JSON response as a String; however, it is not + // intended to be human-readable. When expected as an input type, + // any string (such as "VXNlci0xMA==") or integer (such as 4) input + // value will be accepted as an ID. + ID interface{} + + // Int represents non-fractional signed whole numeric values. + // Int can represent values between -(2^31) and 2^31 - 1. + Int int32 + + // String represents textual data as UTF-8 character sequences. + // This type is most often used by GraphQL to represent free-form + // human-readable text. + String string +) + +// NewBoolean is a helper to make a new *Boolean. +func NewBoolean(v Boolean) *Boolean { return &v } + +// NewFloat is a helper to make a new *Float. +func NewFloat(v Float) *Float { return &v } + +// NewID is a helper to make a new *ID. +func NewID(v ID) *ID { return &v } + +// NewInt is a helper to make a new *Int. +func NewInt(v Int) *Int { return &v } + +// NewString is a helper to make a new *String. +func NewString(v String) *String { return &v } diff --git a/vendor/github.com/hasura/go-graphql-client/subscription.go b/vendor/github.com/hasura/go-graphql-client/subscription.go new file mode 100644 index 0000000000..139a349495 --- /dev/null +++ b/vendor/github.com/hasura/go-graphql-client/subscription.go @@ -0,0 +1,589 @@ +package graphql + +import ( + "context" + "encoding/json" + "fmt" + "io" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "nhooyr.io/websocket" + "nhooyr.io/websocket/wsjson" +) + +// Subscription transport follow Apollo's subscriptions-transport-ws protocol specification +// https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md + +// OperationMessageType +type OperationMessageType string + +const ( + // Client sends this message after plain websocket connection to start the communication with the server + GQL_CONNECTION_INIT OperationMessageType = "connection_init" + // The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server rejected the connection. + GQL_CONNECTION_ERROR OperationMessageType = "conn_err" + // Client sends this message to execute GraphQL operation + GQL_START OperationMessageType = "start" + // Client sends this message in order to stop a running GraphQL operation execution (for example: unsubscribe) + GQL_STOP OperationMessageType = "stop" + // Server sends this message upon a failing operation, before the GraphQL execution, usually due to GraphQL validation errors (resolver errors are part of GQL_DATA message, and will be added as errors array) + GQL_ERROR OperationMessageType = "error" + // The server sends this message to transfter the GraphQL execution result from the server to the client, this message is a response for GQL_START message. + GQL_DATA OperationMessageType = "data" + // Server sends this message to indicate that a GraphQL operation is done, and no more data will arrive for the specific operation. + GQL_COMPLETE OperationMessageType = "complete" + // Server message that should be sent right after each GQL_CONNECTION_ACK processed and then periodically to keep the client connection alive. + // The client starts to consider the keep alive message only upon the first received keep alive message from the server. + GQL_CONNECTION_KEEP_ALIVE OperationMessageType = "ka" + // The server may responses with this message to the GQL_CONNECTION_INIT from client, indicates the server accepted the connection. May optionally include a payload. + GQL_CONNECTION_ACK OperationMessageType = "connection_ack" + // Client sends this message to terminate the connection. + GQL_CONNECTION_TERMINATE OperationMessageType = "connection_terminate" + // Unknown operation type, for logging only + GQL_UNKNOWN OperationMessageType = "unknown" + // Internal status, for logging only + GQL_INTERNAL OperationMessageType = "internal" +) + +type OperationMessage struct { + ID string `json:"id,omitempty"` + Type OperationMessageType `json:"type"` + Payload json.RawMessage `json:"payload,omitempty"` +} + +func (om OperationMessage) String() string { + bs, _ := json.Marshal(om) + + return string(bs) +} + +// WebsocketHandler abstracts WebSocket connecton functions +// ReadJSON and WriteJSON data of a frame from the WebSocket connection. +// Close the WebSocket connection. +type WebsocketConn interface { + ReadJSON(v interface{}) error + WriteJSON(v interface{}) error + Close() error + // SetReadLimit sets the maximum size in bytes for a message read from the peer. If a + // message exceeds the limit, the connection sends a close message to the peer + // and returns ErrReadLimit to the application. + SetReadLimit(limit int64) +} + +type handlerFunc func(data *json.RawMessage, err error) error +type subscription struct { + query string + variables map[string]interface{} + handler func(data *json.RawMessage, err error) + started Boolean +} + +// SubscriptionClient is a GraphQL subscription client. +type SubscriptionClient struct { + url string + conn WebsocketConn + connectionParams map[string]interface{} + context context.Context + subscriptions map[string]*subscription + cancel context.CancelFunc + subscribersMu sync.Mutex + timeout time.Duration + isRunning Boolean + readLimit int64 // max size of response message. Default 10 MB + log func(args ...interface{}) + createConn func(sc *SubscriptionClient) (WebsocketConn, error) + retryTimeout time.Duration + onConnected func() + onDisconnected func() + onError func(sc *SubscriptionClient, err error) error + errorChan chan error + disabledLogTypes []OperationMessageType +} + +func NewSubscriptionClient(url string) *SubscriptionClient { + return &SubscriptionClient{ + url: url, + timeout: time.Minute, + readLimit: 10 * 1024 * 1024, // set default limit 10MB + subscriptions: make(map[string]*subscription), + createConn: newWebsocketConn, + retryTimeout: time.Minute, + errorChan: make(chan error), + } +} + +// GetURL returns GraphQL server's URL +func (sc *SubscriptionClient) GetURL() string { + return sc.url +} + +// GetContext returns current context of subscription client +func (sc *SubscriptionClient) GetContext() context.Context { + return sc.context +} + +// GetContext returns write timeout of websocket client +func (sc *SubscriptionClient) GetTimeout() time.Duration { + return sc.timeout +} + +// WithWebSocket replaces customized websocket client constructor +// In default, subscription client uses https://github.com/nhooyr/websocket +func (sc *SubscriptionClient) WithWebSocket(fn func(sc *SubscriptionClient) (WebsocketConn, error)) *SubscriptionClient { + sc.createConn = fn + return sc +} + +// WithConnectionParams updates connection params for sending to server through GQL_CONNECTION_INIT event +// It's usually used for authentication handshake +func (sc *SubscriptionClient) WithConnectionParams(params map[string]interface{}) *SubscriptionClient { + sc.connectionParams = params + return sc +} + +// WithTimeout updates write timeout of websocket client +func (sc *SubscriptionClient) WithTimeout(timeout time.Duration) *SubscriptionClient { + sc.timeout = timeout + return sc +} + +// WithRetryTimeout updates reconnecting timeout. When the websocket server was stopped, the client will retry connecting every second until timeout +func (sc *SubscriptionClient) WithRetryTimeout(timeout time.Duration) *SubscriptionClient { + sc.retryTimeout = timeout + return sc +} + +// WithLog sets loging function to print out received messages. By default, nothing is printed +func (sc *SubscriptionClient) WithLog(logger func(args ...interface{})) *SubscriptionClient { + sc.log = logger + return sc +} + +// WithoutLogTypes these operation types won't be printed +func (sc *SubscriptionClient) WithoutLogTypes(types ...OperationMessageType) *SubscriptionClient { + sc.disabledLogTypes = types + return sc +} + +// WithReadLimit set max size of response message +func (sc *SubscriptionClient) WithReadLimit(limit int64) *SubscriptionClient { + sc.readLimit = limit + return sc +} + +// OnConnected event is triggered when there is any connection error. This is bottom exception handler level +// If this function is empty, or returns nil, the error is ignored +// If returns error, the websocket connection will be terminated +func (sc *SubscriptionClient) OnError(onError func(sc *SubscriptionClient, err error) error) *SubscriptionClient { + sc.onError = onError + return sc +} + +// OnConnected event is triggered when the websocket connected to GraphQL server sucessfully +func (sc *SubscriptionClient) OnConnected(fn func()) *SubscriptionClient { + sc.onConnected = fn + return sc +} + +// OnDisconnected event is triggered when the websocket server was stil down after retry timeout +func (sc *SubscriptionClient) OnDisconnected(fn func()) *SubscriptionClient { + sc.onDisconnected = fn + return sc +} + +func (sc *SubscriptionClient) setIsRunning(value Boolean) { + sc.subscribersMu.Lock() + sc.isRunning = value + sc.subscribersMu.Unlock() +} + +func (sc *SubscriptionClient) init() error { + + now := time.Now() + ctx, cancel := context.WithCancel(context.Background()) + sc.context = ctx + sc.cancel = cancel + + for { + var err error + var conn WebsocketConn + // allow custom websocket client + if sc.conn == nil { + conn, err = newWebsocketConn(sc) + if err == nil { + sc.conn = conn + } + } + + if err == nil { + sc.conn.SetReadLimit(sc.readLimit) + // send connection init event to the server + err = sc.sendConnectionInit() + } + + if err == nil { + return nil + } + + if now.Add(sc.retryTimeout).Before(time.Now()) { + if sc.onDisconnected != nil { + sc.onDisconnected() + } + return err + } + sc.printLog(err.Error()+". retry in second....", GQL_INTERNAL) + time.Sleep(time.Second) + } +} + +func (sc *SubscriptionClient) printLog(message interface{}, opType OperationMessageType) { + if sc.log == nil { + return + } + for _, ty := range sc.disabledLogTypes { + if ty == opType { + return + } + } + + sc.log(message) +} + +func (sc *SubscriptionClient) sendConnectionInit() (err error) { + var bParams []byte = nil + if sc.connectionParams != nil { + + bParams, err = json.Marshal(sc.connectionParams) + if err != nil { + return + } + } + + // send connection_init event to the server + msg := OperationMessage{ + Type: GQL_CONNECTION_INIT, + Payload: bParams, + } + + sc.printLog(msg, GQL_CONNECTION_INIT) + return sc.conn.WriteJSON(msg) +} + +// Subscribe sends start message to server and open a channel to receive data. +// The handler callback function will receive raw message data or error. If the call return error, onError event will be triggered +// The function returns subscription ID and error. You can use subscription ID to unsubscribe the subscription +func (sc *SubscriptionClient) Subscribe(v interface{}, variables map[string]interface{}, handler func(message *json.RawMessage, err error) error) (string, error) { + return sc.do(v, variables, handler, "") +} + +// NamedSubscribe sends start message to server and open a channel to receive data, with operation name +func (sc *SubscriptionClient) NamedSubscribe(name string, v interface{}, variables map[string]interface{}, handler func(message *json.RawMessage, err error) error) (string, error) { + return sc.do(v, variables, handler, name) +} + +func (sc *SubscriptionClient) do(v interface{}, variables map[string]interface{}, handler func(message *json.RawMessage, err error) error, name string) (string, error) { + id := uuid.New().String() + query := constructSubscription(v, variables, name) + + sub := subscription{ + query: query, + variables: variables, + handler: sc.wrapHandler(handler), + } + + // if the websocket client is running, start subscription immediately + if sc.isRunning { + if err := sc.startSubscription(id, &sub); err != nil { + return "", err + } + } + + sc.subscribersMu.Lock() + sc.subscriptions[id] = &sub + sc.subscribersMu.Unlock() + + return id, nil +} + +// Subscribe sends start message to server and open a channel to receive data +func (sc *SubscriptionClient) startSubscription(id string, sub *subscription) error { + if sub == nil || sub.started { + return nil + } + + in := struct { + Query string `json:"query"` + Variables map[string]interface{} `json:"variables,omitempty"` + }{ + Query: sub.query, + Variables: sub.variables, + } + + payload, err := json.Marshal(in) + if err != nil { + return err + } + + // send stop message to the server + msg := OperationMessage{ + ID: id, + Type: GQL_START, + Payload: payload, + } + + sc.printLog(msg, GQL_START) + if err := sc.conn.WriteJSON(msg); err != nil { + return err + } + + sub.started = true + return nil +} + +func (sc *SubscriptionClient) wrapHandler(fn handlerFunc) func(data *json.RawMessage, err error) { + return func(data *json.RawMessage, err error) { + if errValue := fn(data, err); errValue != nil { + sc.errorChan <- errValue + } + } +} + +// Run start websocket client and subscriptions. If this function is run with goroutine, it can be stopped after closed +func (sc *SubscriptionClient) Run() error { + if err := sc.init(); err != nil { + return fmt.Errorf("retry timeout. exiting...") + } + + // lazily start subscriptions + for k, v := range sc.subscriptions { + if err := sc.startSubscription(k, v); err != nil { + sc.Unsubscribe(k) + return err + } + } + sc.setIsRunning(true) + + for sc.isRunning { + select { + case <-sc.context.Done(): + return nil + case e := <-sc.errorChan: + if sc.onError != nil { + if err := sc.onError(sc, e); err != nil { + return err + } + } + default: + + var message OperationMessage + if err := sc.conn.ReadJSON(&message); err != nil { + // manual EOF check + if err == io.EOF || strings.Contains(err.Error(), "EOF") { + return sc.Reset() + } + closeStatus := websocket.CloseStatus(err) + if closeStatus == websocket.StatusNormalClosure { + // close event from websocket client, exiting... + return nil + } + if closeStatus != -1 { + sc.printLog(fmt.Sprintf("%s. Retry connecting...", err), GQL_INTERNAL) + return sc.Reset() + } + + if sc.onError != nil { + if err = sc.onError(sc, err); err != nil { + return err + } + } + continue + } + + switch message.Type { + case GQL_ERROR: + sc.printLog(message, GQL_ERROR) + fallthrough + case GQL_DATA: + sc.printLog(message, GQL_DATA) + id, err := uuid.Parse(message.ID) + if err != nil { + continue + } + sub, ok := sc.subscriptions[id.String()] + if !ok { + continue + } + var out struct { + Data *json.RawMessage + Errors errors + //Extensions interface{} // Unused. + } + + err = json.Unmarshal(message.Payload, &out) + if err != nil { + go sub.handler(nil, err) + continue + } + if len(out.Errors) > 0 { + go sub.handler(nil, out.Errors) + continue + } + + go sub.handler(out.Data, nil) + case GQL_CONNECTION_ERROR: + sc.printLog(message, GQL_CONNECTION_ERROR) + case GQL_COMPLETE: + sc.printLog(message, GQL_COMPLETE) + sc.Unsubscribe(message.ID) + case GQL_CONNECTION_KEEP_ALIVE: + sc.printLog(message, GQL_CONNECTION_KEEP_ALIVE) + case GQL_CONNECTION_ACK: + sc.printLog(message, GQL_CONNECTION_ACK) + if sc.onConnected != nil { + sc.onConnected() + } + default: + sc.printLog(message, GQL_UNKNOWN) + } + } + } + + // if the running status is false, stop retrying + if !sc.isRunning { + return nil + } + + return sc.Reset() +} + +// Unsubscribe sends stop message to server and close subscription channel +// The input parameter is subscription ID that is returned from Subscribe function +func (sc *SubscriptionClient) Unsubscribe(id string) error { + _, ok := sc.subscriptions[id] + if !ok { + return fmt.Errorf("subscription id %s doesn't not exist", id) + } + + err := sc.stopSubscription(id) + + sc.subscribersMu.Lock() + delete(sc.subscriptions, id) + sc.subscribersMu.Unlock() + return err +} + +func (sc *SubscriptionClient) stopSubscription(id string) error { + if sc.conn != nil { + // send stop message to the server + msg := OperationMessage{ + ID: id, + Type: GQL_STOP, + } + + sc.printLog(msg, GQL_STOP) + if err := sc.conn.WriteJSON(msg); err != nil { + return err + } + + } + + return nil +} + +func (sc *SubscriptionClient) terminate() error { + if sc.conn != nil { + // send terminate message to the server + msg := OperationMessage{ + Type: GQL_CONNECTION_TERMINATE, + } + + sc.printLog(msg, GQL_CONNECTION_TERMINATE) + return sc.conn.WriteJSON(msg) + } + + return nil +} + +// Reset restart websocket connection and subscriptions +func (sc *SubscriptionClient) Reset() error { + if !sc.isRunning { + return nil + } + + for id, sub := range sc.subscriptions { + _ = sc.stopSubscription(id) + sub.started = false + } + + if sc.conn != nil { + _ = sc.terminate() + _ = sc.conn.Close() + sc.conn = nil + } + sc.cancel() + + return sc.Run() +} + +// Close closes all subscription channel and websocket as well +func (sc *SubscriptionClient) Close() (err error) { + sc.setIsRunning(false) + for id := range sc.subscriptions { + if err = sc.Unsubscribe(id); err != nil { + sc.cancel() + return err + } + } + if sc.conn != nil { + _ = sc.terminate() + err = sc.conn.Close() + sc.conn = nil + } + sc.cancel() + + return +} + +// default websocket handler implementation using https://github.com/nhooyr/websocket +type websocketHandler struct { + ctx context.Context + timeout time.Duration + *websocket.Conn +} + +func (wh *websocketHandler) WriteJSON(v interface{}) error { + ctx, cancel := context.WithTimeout(wh.ctx, wh.timeout) + defer cancel() + + return wsjson.Write(ctx, wh.Conn, v) +} + +func (wh *websocketHandler) ReadJSON(v interface{}) error { + ctx, cancel := context.WithTimeout(wh.ctx, wh.timeout) + defer cancel() + return wsjson.Read(ctx, wh.Conn, v) +} + +func (wh *websocketHandler) Close() error { + return wh.Conn.Close(websocket.StatusNormalClosure, "close websocket") +} + +func newWebsocketConn(sc *SubscriptionClient) (WebsocketConn, error) { + + options := &websocket.DialOptions{ + Subprotocols: []string{"graphql-ws"}, + } + c, _, err := websocket.Dial(sc.GetContext(), sc.GetURL(), options) + if err != nil { + return nil, err + } + + return &websocketHandler{ + ctx: sc.GetContext(), + Conn: c, + timeout: sc.GetTimeout(), + }, nil +} diff --git a/vendor/github.com/klauspost/compress/LICENSE b/vendor/github.com/klauspost/compress/LICENSE new file mode 100644 index 0000000000..1eb75ef68e --- /dev/null +++ b/vendor/github.com/klauspost/compress/LICENSE @@ -0,0 +1,28 @@ +Copyright (c) 2012 The Go Authors. All rights reserved. +Copyright (c) 2019 Klaus Post. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + * Redistributions in binary form must reproduce the above +copyright notice, this list of conditions and the following disclaimer +in the documentation and/or other materials provided with the +distribution. + * Neither the name of Google Inc. nor the names of its +contributors may be used to endorse or promote products derived from +this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/github.com/klauspost/compress/flate/deflate.go b/vendor/github.com/klauspost/compress/flate/deflate.go new file mode 100644 index 0000000000..5283ac5a53 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/deflate.go @@ -0,0 +1,820 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Copyright (c) 2015 Klaus Post +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "fmt" + "io" + "math" +) + +const ( + NoCompression = 0 + BestSpeed = 1 + BestCompression = 9 + DefaultCompression = -1 + + // HuffmanOnly disables Lempel-Ziv match searching and only performs Huffman + // entropy encoding. This mode is useful in compressing data that has + // already been compressed with an LZ style algorithm (e.g. Snappy or LZ4) + // that lacks an entropy encoder. Compression gains are achieved when + // certain bytes in the input stream occur more frequently than others. + // + // Note that HuffmanOnly produces a compressed output that is + // RFC 1951 compliant. That is, any valid DEFLATE decompressor will + // continue to be able to decompress this output. + HuffmanOnly = -2 + ConstantCompression = HuffmanOnly // compatibility alias. + + logWindowSize = 15 + windowSize = 1 << logWindowSize + windowMask = windowSize - 1 + logMaxOffsetSize = 15 // Standard DEFLATE + minMatchLength = 4 // The smallest match that the compressor looks for + maxMatchLength = 258 // The longest match for the compressor + minOffsetSize = 1 // The shortest offset that makes any sense + + // The maximum number of tokens we put into a single flat block, just too + // stop things from getting too large. + maxFlateBlockTokens = 1 << 14 + maxStoreBlockSize = 65535 + hashBits = 17 // After 17 performance degrades + hashSize = 1 << hashBits + hashMask = (1 << hashBits) - 1 + hashShift = (hashBits + minMatchLength - 1) / minMatchLength + maxHashOffset = 1 << 24 + + skipNever = math.MaxInt32 + + debugDeflate = false +) + +type compressionLevel struct { + good, lazy, nice, chain, fastSkipHashing, level int +} + +// Compression levels have been rebalanced from zlib deflate defaults +// to give a bigger spread in speed and compression. +// See https://blog.klauspost.com/rebalancing-deflate-compression-levels/ +var levels = []compressionLevel{ + {}, // 0 + // Level 1-6 uses specialized algorithm - values not used + {0, 0, 0, 0, 0, 1}, + {0, 0, 0, 0, 0, 2}, + {0, 0, 0, 0, 0, 3}, + {0, 0, 0, 0, 0, 4}, + {0, 0, 0, 0, 0, 5}, + {0, 0, 0, 0, 0, 6}, + // Levels 7-9 use increasingly more lazy matching + // and increasingly stringent conditions for "good enough". + {8, 8, 24, 16, skipNever, 7}, + {10, 16, 24, 64, skipNever, 8}, + {32, 258, 258, 4096, skipNever, 9}, +} + +// advancedState contains state for the advanced levels, with bigger hash tables, etc. +type advancedState struct { + // deflate state + length int + offset int + maxInsertIndex int + + // Input hash chains + // hashHead[hashValue] contains the largest inputIndex with the specified hash value + // If hashHead[hashValue] is within the current window, then + // hashPrev[hashHead[hashValue] & windowMask] contains the previous index + // with the same hash value. + chainHead int + hashHead [hashSize]uint32 + hashPrev [windowSize]uint32 + hashOffset int + + // input window: unprocessed data is window[index:windowEnd] + index int + hashMatch [maxMatchLength + minMatchLength]uint32 + + hash uint32 + ii uint16 // position of last match, intended to overflow to reset. +} + +type compressor struct { + compressionLevel + + w *huffmanBitWriter + + // compression algorithm + fill func(*compressor, []byte) int // copy data to window + step func(*compressor) // process window + + window []byte + windowEnd int + blockStart int // window index where current tokens start + err error + + // queued output tokens + tokens tokens + fast fastEnc + state *advancedState + + sync bool // requesting flush + byteAvailable bool // if true, still need to process window[index-1]. +} + +func (d *compressor) fillDeflate(b []byte) int { + s := d.state + if s.index >= 2*windowSize-(minMatchLength+maxMatchLength) { + // shift the window by windowSize + copy(d.window[:], d.window[windowSize:2*windowSize]) + s.index -= windowSize + d.windowEnd -= windowSize + if d.blockStart >= windowSize { + d.blockStart -= windowSize + } else { + d.blockStart = math.MaxInt32 + } + s.hashOffset += windowSize + if s.hashOffset > maxHashOffset { + delta := s.hashOffset - 1 + s.hashOffset -= delta + s.chainHead -= delta + // Iterate over slices instead of arrays to avoid copying + // the entire table onto the stack (Issue #18625). + for i, v := range s.hashPrev[:] { + if int(v) > delta { + s.hashPrev[i] = uint32(int(v) - delta) + } else { + s.hashPrev[i] = 0 + } + } + for i, v := range s.hashHead[:] { + if int(v) > delta { + s.hashHead[i] = uint32(int(v) - delta) + } else { + s.hashHead[i] = 0 + } + } + } + } + n := copy(d.window[d.windowEnd:], b) + d.windowEnd += n + return n +} + +func (d *compressor) writeBlock(tok *tokens, index int, eof bool) error { + if index > 0 || eof { + var window []byte + if d.blockStart <= index { + window = d.window[d.blockStart:index] + } + d.blockStart = index + d.w.writeBlock(tok, eof, window) + return d.w.err + } + return nil +} + +// writeBlockSkip writes the current block and uses the number of tokens +// to determine if the block should be stored on no matches, or +// only huffman encoded. +func (d *compressor) writeBlockSkip(tok *tokens, index int, eof bool) error { + if index > 0 || eof { + if d.blockStart <= index { + window := d.window[d.blockStart:index] + // If we removed less than a 64th of all literals + // we huffman compress the block. + if int(tok.n) > len(window)-int(tok.n>>6) { + d.w.writeBlockHuff(eof, window, d.sync) + } else { + // Write a dynamic huffman block. + d.w.writeBlockDynamic(tok, eof, window, d.sync) + } + } else { + d.w.writeBlock(tok, eof, nil) + } + d.blockStart = index + return d.w.err + } + return nil +} + +// fillWindow will fill the current window with the supplied +// dictionary and calculate all hashes. +// This is much faster than doing a full encode. +// Should only be used after a start/reset. +func (d *compressor) fillWindow(b []byte) { + // Do not fill window if we are in store-only or huffman mode. + if d.level <= 0 { + return + } + if d.fast != nil { + // encode the last data, but discard the result + if len(b) > maxMatchOffset { + b = b[len(b)-maxMatchOffset:] + } + d.fast.Encode(&d.tokens, b) + d.tokens.Reset() + return + } + s := d.state + // If we are given too much, cut it. + if len(b) > windowSize { + b = b[len(b)-windowSize:] + } + // Add all to window. + n := copy(d.window[d.windowEnd:], b) + + // Calculate 256 hashes at the time (more L1 cache hits) + loops := (n + 256 - minMatchLength) / 256 + for j := 0; j < loops; j++ { + startindex := j * 256 + end := startindex + 256 + minMatchLength - 1 + if end > n { + end = n + } + tocheck := d.window[startindex:end] + dstSize := len(tocheck) - minMatchLength + 1 + + if dstSize <= 0 { + continue + } + + dst := s.hashMatch[:dstSize] + bulkHash4(tocheck, dst) + var newH uint32 + for i, val := range dst { + di := i + startindex + newH = val & hashMask + // Get previous value with the same hash. + // Our chain should point to the previous value. + s.hashPrev[di&windowMask] = s.hashHead[newH] + // Set the head of the hash chain to us. + s.hashHead[newH] = uint32(di + s.hashOffset) + } + s.hash = newH + } + // Update window information. + d.windowEnd += n + s.index = n +} + +// Try to find a match starting at index whose length is greater than prevSize. +// We only look at chainCount possibilities before giving up. +// pos = s.index, prevHead = s.chainHead-s.hashOffset, prevLength=minMatchLength-1, lookahead +func (d *compressor) findMatch(pos int, prevHead int, prevLength int, lookahead int) (length, offset int, ok bool) { + minMatchLook := maxMatchLength + if lookahead < minMatchLook { + minMatchLook = lookahead + } + + win := d.window[0 : pos+minMatchLook] + + // We quit when we get a match that's at least nice long + nice := len(win) - pos + if d.nice < nice { + nice = d.nice + } + + // If we've got a match that's good enough, only look in 1/4 the chain. + tries := d.chain + length = prevLength + if length >= d.good { + tries >>= 2 + } + + wEnd := win[pos+length] + wPos := win[pos:] + minIndex := pos - windowSize + + for i := prevHead; tries > 0; tries-- { + if wEnd == win[i+length] { + n := matchLen(win[i:i+minMatchLook], wPos) + + if n > length && (n > minMatchLength || pos-i <= 4096) { + length = n + offset = pos - i + ok = true + if n >= nice { + // The match is good enough that we don't try to find a better one. + break + } + wEnd = win[pos+n] + } + } + if i == minIndex { + // hashPrev[i & windowMask] has already been overwritten, so stop now. + break + } + i = int(d.state.hashPrev[i&windowMask]) - d.state.hashOffset + if i < minIndex || i < 0 { + break + } + } + return +} + +func (d *compressor) writeStoredBlock(buf []byte) error { + if d.w.writeStoredHeader(len(buf), false); d.w.err != nil { + return d.w.err + } + d.w.writeBytes(buf) + return d.w.err +} + +// hash4 returns a hash representation of the first 4 bytes +// of the supplied slice. +// The caller must ensure that len(b) >= 4. +func hash4(b []byte) uint32 { + b = b[:4] + return hash4u(uint32(b[3])|uint32(b[2])<<8|uint32(b[1])<<16|uint32(b[0])<<24, hashBits) +} + +// bulkHash4 will compute hashes using the same +// algorithm as hash4 +func bulkHash4(b []byte, dst []uint32) { + if len(b) < 4 { + return + } + hb := uint32(b[3]) | uint32(b[2])<<8 | uint32(b[1])<<16 | uint32(b[0])<<24 + dst[0] = hash4u(hb, hashBits) + end := len(b) - 4 + 1 + for i := 1; i < end; i++ { + hb = (hb << 8) | uint32(b[i+3]) + dst[i] = hash4u(hb, hashBits) + } +} + +func (d *compressor) initDeflate() { + d.window = make([]byte, 2*windowSize) + d.byteAvailable = false + d.err = nil + if d.state == nil { + return + } + s := d.state + s.index = 0 + s.hashOffset = 1 + s.length = minMatchLength - 1 + s.offset = 0 + s.hash = 0 + s.chainHead = -1 +} + +// deflateLazy is the same as deflate, but with d.fastSkipHashing == skipNever, +// meaning it always has lazy matching on. +func (d *compressor) deflateLazy() { + s := d.state + // Sanity enables additional runtime tests. + // It's intended to be used during development + // to supplement the currently ad-hoc unit tests. + const sanity = debugDeflate + + if d.windowEnd-s.index < minMatchLength+maxMatchLength && !d.sync { + return + } + + s.maxInsertIndex = d.windowEnd - (minMatchLength - 1) + if s.index < s.maxInsertIndex { + s.hash = hash4(d.window[s.index : s.index+minMatchLength]) + } + + for { + if sanity && s.index > d.windowEnd { + panic("index > windowEnd") + } + lookahead := d.windowEnd - s.index + if lookahead < minMatchLength+maxMatchLength { + if !d.sync { + return + } + if sanity && s.index > d.windowEnd { + panic("index > windowEnd") + } + if lookahead == 0 { + // Flush current output block if any. + if d.byteAvailable { + // There is still one pending token that needs to be flushed + d.tokens.AddLiteral(d.window[s.index-1]) + d.byteAvailable = false + } + if d.tokens.n > 0 { + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + return + } + } + if s.index < s.maxInsertIndex { + // Update the hash + s.hash = hash4(d.window[s.index : s.index+minMatchLength]) + ch := s.hashHead[s.hash&hashMask] + s.chainHead = int(ch) + s.hashPrev[s.index&windowMask] = ch + s.hashHead[s.hash&hashMask] = uint32(s.index + s.hashOffset) + } + prevLength := s.length + prevOffset := s.offset + s.length = minMatchLength - 1 + s.offset = 0 + minIndex := s.index - windowSize + if minIndex < 0 { + minIndex = 0 + } + + if s.chainHead-s.hashOffset >= minIndex && lookahead > prevLength && prevLength < d.lazy { + if newLength, newOffset, ok := d.findMatch(s.index, s.chainHead-s.hashOffset, minMatchLength-1, lookahead); ok { + s.length = newLength + s.offset = newOffset + } + } + if prevLength >= minMatchLength && s.length <= prevLength { + // There was a match at the previous step, and the current match is + // not better. Output the previous match. + d.tokens.AddMatch(uint32(prevLength-3), uint32(prevOffset-minOffsetSize)) + + // Insert in the hash table all strings up to the end of the match. + // index and index-1 are already inserted. If there is not enough + // lookahead, the last two strings are not inserted into the hash + // table. + newIndex := s.index + prevLength - 1 + // Calculate missing hashes + end := newIndex + if end > s.maxInsertIndex { + end = s.maxInsertIndex + } + end += minMatchLength - 1 + startindex := s.index + 1 + if startindex > s.maxInsertIndex { + startindex = s.maxInsertIndex + } + tocheck := d.window[startindex:end] + dstSize := len(tocheck) - minMatchLength + 1 + if dstSize > 0 { + dst := s.hashMatch[:dstSize] + bulkHash4(tocheck, dst) + var newH uint32 + for i, val := range dst { + di := i + startindex + newH = val & hashMask + // Get previous value with the same hash. + // Our chain should point to the previous value. + s.hashPrev[di&windowMask] = s.hashHead[newH] + // Set the head of the hash chain to us. + s.hashHead[newH] = uint32(di + s.hashOffset) + } + s.hash = newH + } + + s.index = newIndex + d.byteAvailable = false + s.length = minMatchLength - 1 + if d.tokens.n == maxFlateBlockTokens { + // The block includes the current character + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + } else { + // Reset, if we got a match this run. + if s.length >= minMatchLength { + s.ii = 0 + } + // We have a byte waiting. Emit it. + if d.byteAvailable { + s.ii++ + d.tokens.AddLiteral(d.window[s.index-1]) + if d.tokens.n == maxFlateBlockTokens { + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + s.index++ + + // If we have a long run of no matches, skip additional bytes + // Resets when s.ii overflows after 64KB. + if s.ii > 31 { + n := int(s.ii >> 5) + for j := 0; j < n; j++ { + if s.index >= d.windowEnd-1 { + break + } + + d.tokens.AddLiteral(d.window[s.index-1]) + if d.tokens.n == maxFlateBlockTokens { + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + s.index++ + } + // Flush last byte + d.tokens.AddLiteral(d.window[s.index-1]) + d.byteAvailable = false + // s.length = minMatchLength - 1 // not needed, since s.ii is reset above, so it should never be > minMatchLength + if d.tokens.n == maxFlateBlockTokens { + if d.err = d.writeBlock(&d.tokens, s.index, false); d.err != nil { + return + } + d.tokens.Reset() + } + } + } else { + s.index++ + d.byteAvailable = true + } + } + } +} + +func (d *compressor) store() { + if d.windowEnd > 0 && (d.windowEnd == maxStoreBlockSize || d.sync) { + d.err = d.writeStoredBlock(d.window[:d.windowEnd]) + d.windowEnd = 0 + } +} + +// fillWindow will fill the buffer with data for huffman-only compression. +// The number of bytes copied is returned. +func (d *compressor) fillBlock(b []byte) int { + n := copy(d.window[d.windowEnd:], b) + d.windowEnd += n + return n +} + +// storeHuff will compress and store the currently added data, +// if enough has been accumulated or we at the end of the stream. +// Any error that occurred will be in d.err +func (d *compressor) storeHuff() { + if d.windowEnd < len(d.window) && !d.sync || d.windowEnd == 0 { + return + } + d.w.writeBlockHuff(false, d.window[:d.windowEnd], d.sync) + d.err = d.w.err + d.windowEnd = 0 +} + +// storeFast will compress and store the currently added data, +// if enough has been accumulated or we at the end of the stream. +// Any error that occurred will be in d.err +func (d *compressor) storeFast() { + // We only compress if we have maxStoreBlockSize. + if d.windowEnd < len(d.window) { + if !d.sync { + return + } + // Handle extremely small sizes. + if d.windowEnd < 128 { + if d.windowEnd == 0 { + return + } + if d.windowEnd <= 32 { + d.err = d.writeStoredBlock(d.window[:d.windowEnd]) + } else { + d.w.writeBlockHuff(false, d.window[:d.windowEnd], true) + d.err = d.w.err + } + d.tokens.Reset() + d.windowEnd = 0 + d.fast.Reset() + return + } + } + + d.fast.Encode(&d.tokens, d.window[:d.windowEnd]) + // If we made zero matches, store the block as is. + if d.tokens.n == 0 { + d.err = d.writeStoredBlock(d.window[:d.windowEnd]) + // If we removed less than 1/16th, huffman compress the block. + } else if int(d.tokens.n) > d.windowEnd-(d.windowEnd>>4) { + d.w.writeBlockHuff(false, d.window[:d.windowEnd], d.sync) + d.err = d.w.err + } else { + d.w.writeBlockDynamic(&d.tokens, false, d.window[:d.windowEnd], d.sync) + d.err = d.w.err + } + d.tokens.Reset() + d.windowEnd = 0 +} + +// write will add input byte to the stream. +// Unless an error occurs all bytes will be consumed. +func (d *compressor) write(b []byte) (n int, err error) { + if d.err != nil { + return 0, d.err + } + n = len(b) + for len(b) > 0 { + d.step(d) + b = b[d.fill(d, b):] + if d.err != nil { + return 0, d.err + } + } + return n, d.err +} + +func (d *compressor) syncFlush() error { + d.sync = true + if d.err != nil { + return d.err + } + d.step(d) + if d.err == nil { + d.w.writeStoredHeader(0, false) + d.w.flush() + d.err = d.w.err + } + d.sync = false + return d.err +} + +func (d *compressor) init(w io.Writer, level int) (err error) { + d.w = newHuffmanBitWriter(w) + + switch { + case level == NoCompression: + d.window = make([]byte, maxStoreBlockSize) + d.fill = (*compressor).fillBlock + d.step = (*compressor).store + case level == ConstantCompression: + d.w.logNewTablePenalty = 10 + d.window = make([]byte, 32<<10) + d.fill = (*compressor).fillBlock + d.step = (*compressor).storeHuff + case level == DefaultCompression: + level = 5 + fallthrough + case level >= 1 && level <= 6: + d.w.logNewTablePenalty = 8 + d.fast = newFastEnc(level) + d.window = make([]byte, maxStoreBlockSize) + d.fill = (*compressor).fillBlock + d.step = (*compressor).storeFast + case 7 <= level && level <= 9: + d.w.logNewTablePenalty = 10 + d.state = &advancedState{} + d.compressionLevel = levels[level] + d.initDeflate() + d.fill = (*compressor).fillDeflate + d.step = (*compressor).deflateLazy + default: + return fmt.Errorf("flate: invalid compression level %d: want value in range [-2, 9]", level) + } + d.level = level + return nil +} + +// reset the state of the compressor. +func (d *compressor) reset(w io.Writer) { + d.w.reset(w) + d.sync = false + d.err = nil + // We only need to reset a few things for Snappy. + if d.fast != nil { + d.fast.Reset() + d.windowEnd = 0 + d.tokens.Reset() + return + } + switch d.compressionLevel.chain { + case 0: + // level was NoCompression or ConstantCompresssion. + d.windowEnd = 0 + default: + s := d.state + s.chainHead = -1 + for i := range s.hashHead { + s.hashHead[i] = 0 + } + for i := range s.hashPrev { + s.hashPrev[i] = 0 + } + s.hashOffset = 1 + s.index, d.windowEnd = 0, 0 + d.blockStart, d.byteAvailable = 0, false + d.tokens.Reset() + s.length = minMatchLength - 1 + s.offset = 0 + s.hash = 0 + s.ii = 0 + s.maxInsertIndex = 0 + } +} + +func (d *compressor) close() error { + if d.err != nil { + return d.err + } + d.sync = true + d.step(d) + if d.err != nil { + return d.err + } + if d.w.writeStoredHeader(0, true); d.w.err != nil { + return d.w.err + } + d.w.flush() + d.w.reset(nil) + return d.w.err +} + +// NewWriter returns a new Writer compressing data at the given level. +// Following zlib, levels range from 1 (BestSpeed) to 9 (BestCompression); +// higher levels typically run slower but compress more. +// Level 0 (NoCompression) does not attempt any compression; it only adds the +// necessary DEFLATE framing. +// Level -1 (DefaultCompression) uses the default compression level. +// Level -2 (ConstantCompression) will use Huffman compression only, giving +// a very fast compression for all types of input, but sacrificing considerable +// compression efficiency. +// +// If level is in the range [-2, 9] then the error returned will be nil. +// Otherwise the error returned will be non-nil. +func NewWriter(w io.Writer, level int) (*Writer, error) { + var dw Writer + if err := dw.d.init(w, level); err != nil { + return nil, err + } + return &dw, nil +} + +// NewWriterDict is like NewWriter but initializes the new +// Writer with a preset dictionary. The returned Writer behaves +// as if the dictionary had been written to it without producing +// any compressed output. The compressed data written to w +// can only be decompressed by a Reader initialized with the +// same dictionary. +func NewWriterDict(w io.Writer, level int, dict []byte) (*Writer, error) { + zw, err := NewWriter(w, level) + if err != nil { + return nil, err + } + zw.d.fillWindow(dict) + zw.dict = append(zw.dict, dict...) // duplicate dictionary for Reset method. + return zw, err +} + +// A Writer takes data written to it and writes the compressed +// form of that data to an underlying writer (see NewWriter). +type Writer struct { + d compressor + dict []byte +} + +// Write writes data to w, which will eventually write the +// compressed form of data to its underlying writer. +func (w *Writer) Write(data []byte) (n int, err error) { + return w.d.write(data) +} + +// Flush flushes any pending data to the underlying writer. +// It is useful mainly in compressed network protocols, to ensure that +// a remote reader has enough data to reconstruct a packet. +// Flush does not return until the data has been written. +// Calling Flush when there is no pending data still causes the Writer +// to emit a sync marker of at least 4 bytes. +// If the underlying writer returns an error, Flush returns that error. +// +// In the terminology of the zlib library, Flush is equivalent to Z_SYNC_FLUSH. +func (w *Writer) Flush() error { + // For more about flushing: + // http://www.bolet.org/~pornin/deflate-flush.html + return w.d.syncFlush() +} + +// Close flushes and closes the writer. +func (w *Writer) Close() error { + return w.d.close() +} + +// Reset discards the writer's state and makes it equivalent to +// the result of NewWriter or NewWriterDict called with dst +// and w's level and dictionary. +func (w *Writer) Reset(dst io.Writer) { + if len(w.dict) > 0 { + // w was created with NewWriterDict + w.d.reset(dst) + if dst != nil { + w.d.fillWindow(w.dict) + } + } else { + // w was created with NewWriter + w.d.reset(dst) + } +} + +// ResetDict discards the writer's state and makes it equivalent to +// the result of NewWriter or NewWriterDict called with dst +// and w's level, but sets a specific dictionary. +func (w *Writer) ResetDict(dst io.Writer, dict []byte) { + w.dict = dict + w.d.reset(dst) + w.d.fillWindow(w.dict) +} diff --git a/vendor/github.com/klauspost/compress/flate/dict_decoder.go b/vendor/github.com/klauspost/compress/flate/dict_decoder.go new file mode 100644 index 0000000000..71c75a065e --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/dict_decoder.go @@ -0,0 +1,184 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +// dictDecoder implements the LZ77 sliding dictionary as used in decompression. +// LZ77 decompresses data through sequences of two forms of commands: +// +// * Literal insertions: Runs of one or more symbols are inserted into the data +// stream as is. This is accomplished through the writeByte method for a +// single symbol, or combinations of writeSlice/writeMark for multiple symbols. +// Any valid stream must start with a literal insertion if no preset dictionary +// is used. +// +// * Backward copies: Runs of one or more symbols are copied from previously +// emitted data. Backward copies come as the tuple (dist, length) where dist +// determines how far back in the stream to copy from and length determines how +// many bytes to copy. Note that it is valid for the length to be greater than +// the distance. Since LZ77 uses forward copies, that situation is used to +// perform a form of run-length encoding on repeated runs of symbols. +// The writeCopy and tryWriteCopy are used to implement this command. +// +// For performance reasons, this implementation performs little to no sanity +// checks about the arguments. As such, the invariants documented for each +// method call must be respected. +type dictDecoder struct { + hist []byte // Sliding window history + + // Invariant: 0 <= rdPos <= wrPos <= len(hist) + wrPos int // Current output position in buffer + rdPos int // Have emitted hist[:rdPos] already + full bool // Has a full window length been written yet? +} + +// init initializes dictDecoder to have a sliding window dictionary of the given +// size. If a preset dict is provided, it will initialize the dictionary with +// the contents of dict. +func (dd *dictDecoder) init(size int, dict []byte) { + *dd = dictDecoder{hist: dd.hist} + + if cap(dd.hist) < size { + dd.hist = make([]byte, size) + } + dd.hist = dd.hist[:size] + + if len(dict) > len(dd.hist) { + dict = dict[len(dict)-len(dd.hist):] + } + dd.wrPos = copy(dd.hist, dict) + if dd.wrPos == len(dd.hist) { + dd.wrPos = 0 + dd.full = true + } + dd.rdPos = dd.wrPos +} + +// histSize reports the total amount of historical data in the dictionary. +func (dd *dictDecoder) histSize() int { + if dd.full { + return len(dd.hist) + } + return dd.wrPos +} + +// availRead reports the number of bytes that can be flushed by readFlush. +func (dd *dictDecoder) availRead() int { + return dd.wrPos - dd.rdPos +} + +// availWrite reports the available amount of output buffer space. +func (dd *dictDecoder) availWrite() int { + return len(dd.hist) - dd.wrPos +} + +// writeSlice returns a slice of the available buffer to write data to. +// +// This invariant will be kept: len(s) <= availWrite() +func (dd *dictDecoder) writeSlice() []byte { + return dd.hist[dd.wrPos:] +} + +// writeMark advances the writer pointer by cnt. +// +// This invariant must be kept: 0 <= cnt <= availWrite() +func (dd *dictDecoder) writeMark(cnt int) { + dd.wrPos += cnt +} + +// writeByte writes a single byte to the dictionary. +// +// This invariant must be kept: 0 < availWrite() +func (dd *dictDecoder) writeByte(c byte) { + dd.hist[dd.wrPos] = c + dd.wrPos++ +} + +// writeCopy copies a string at a given (dist, length) to the output. +// This returns the number of bytes copied and may be less than the requested +// length if the available space in the output buffer is too small. +// +// This invariant must be kept: 0 < dist <= histSize() +func (dd *dictDecoder) writeCopy(dist, length int) int { + dstBase := dd.wrPos + dstPos := dstBase + srcPos := dstPos - dist + endPos := dstPos + length + if endPos > len(dd.hist) { + endPos = len(dd.hist) + } + + // Copy non-overlapping section after destination position. + // + // This section is non-overlapping in that the copy length for this section + // is always less than or equal to the backwards distance. This can occur + // if a distance refers to data that wraps-around in the buffer. + // Thus, a backwards copy is performed here; that is, the exact bytes in + // the source prior to the copy is placed in the destination. + if srcPos < 0 { + srcPos += len(dd.hist) + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:]) + srcPos = 0 + } + + // Copy possibly overlapping section before destination position. + // + // This section can overlap if the copy length for this section is larger + // than the backwards distance. This is allowed by LZ77 so that repeated + // strings can be succinctly represented using (dist, length) pairs. + // Thus, a forwards copy is performed here; that is, the bytes copied is + // possibly dependent on the resulting bytes in the destination as the copy + // progresses along. This is functionally equivalent to the following: + // + // for i := 0; i < endPos-dstPos; i++ { + // dd.hist[dstPos+i] = dd.hist[srcPos+i] + // } + // dstPos = endPos + // + for dstPos < endPos { + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos]) + } + + dd.wrPos = dstPos + return dstPos - dstBase +} + +// tryWriteCopy tries to copy a string at a given (distance, length) to the +// output. This specialized version is optimized for short distances. +// +// This method is designed to be inlined for performance reasons. +// +// This invariant must be kept: 0 < dist <= histSize() +func (dd *dictDecoder) tryWriteCopy(dist, length int) int { + dstPos := dd.wrPos + endPos := dstPos + length + if dstPos < dist || endPos > len(dd.hist) { + return 0 + } + dstBase := dstPos + srcPos := dstPos - dist + + // Copy possibly overlapping section before destination position. +loop: + dstPos += copy(dd.hist[dstPos:endPos], dd.hist[srcPos:dstPos]) + if dstPos < endPos { + goto loop // Avoid for-loop so that this function can be inlined + } + + dd.wrPos = dstPos + return dstPos - dstBase +} + +// readFlush returns a slice of the historical buffer that is ready to be +// emitted to the user. The data returned by readFlush must be fully consumed +// before calling any other dictDecoder methods. +func (dd *dictDecoder) readFlush() []byte { + toRead := dd.hist[dd.rdPos:dd.wrPos] + dd.rdPos = dd.wrPos + if dd.wrPos == len(dd.hist) { + dd.wrPos, dd.rdPos = 0, 0 + dd.full = true + } + return toRead +} diff --git a/vendor/github.com/klauspost/compress/flate/fast_encoder.go b/vendor/github.com/klauspost/compress/flate/fast_encoder.go new file mode 100644 index 0000000000..347ac2c902 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/fast_encoder.go @@ -0,0 +1,244 @@ +// Copyright 2011 The Snappy-Go Authors. All rights reserved. +// Modified for deflate by Klaus Post (c) 2015. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "encoding/binary" + "fmt" + "math/bits" +) + +type fastEnc interface { + Encode(dst *tokens, src []byte) + Reset() +} + +func newFastEnc(level int) fastEnc { + switch level { + case 1: + return &fastEncL1{fastGen: fastGen{cur: maxStoreBlockSize}} + case 2: + return &fastEncL2{fastGen: fastGen{cur: maxStoreBlockSize}} + case 3: + return &fastEncL3{fastGen: fastGen{cur: maxStoreBlockSize}} + case 4: + return &fastEncL4{fastGen: fastGen{cur: maxStoreBlockSize}} + case 5: + return &fastEncL5{fastGen: fastGen{cur: maxStoreBlockSize}} + case 6: + return &fastEncL6{fastGen: fastGen{cur: maxStoreBlockSize}} + default: + panic("invalid level specified") + } +} + +const ( + tableBits = 15 // Bits used in the table + tableSize = 1 << tableBits // Size of the table + tableShift = 32 - tableBits // Right-shift to get the tableBits most significant bits of a uint32. + baseMatchOffset = 1 // The smallest match offset + baseMatchLength = 3 // The smallest match length per the RFC section 3.2.5 + maxMatchOffset = 1 << 15 // The largest match offset + + bTableBits = 17 // Bits used in the big tables + bTableSize = 1 << bTableBits // Size of the table + allocHistory = maxStoreBlockSize * 5 // Size to preallocate for history. + bufferReset = (1 << 31) - allocHistory - maxStoreBlockSize - 1 // Reset the buffer offset when reaching this. +) + +const ( + prime3bytes = 506832829 + prime4bytes = 2654435761 + prime5bytes = 889523592379 + prime6bytes = 227718039650203 + prime7bytes = 58295818150454627 + prime8bytes = 0xcf1bbcdcb7a56463 +) + +func load32(b []byte, i int) uint32 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:4] + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func load64(b []byte, i int) uint64 { + return binary.LittleEndian.Uint64(b[i:]) +} + +func load3232(b []byte, i int32) uint32 { + return binary.LittleEndian.Uint32(b[i:]) +} + +func load6432(b []byte, i int32) uint64 { + return binary.LittleEndian.Uint64(b[i:]) +} + +func hash(u uint32) uint32 { + return (u * 0x1e35a7bd) >> tableShift +} + +type tableEntry struct { + offset int32 +} + +// fastGen maintains the table for matches, +// and the previous byte block for level 2. +// This is the generic implementation. +type fastGen struct { + hist []byte + cur int32 +} + +func (e *fastGen) addBlock(src []byte) int32 { + // check if we have space already + if len(e.hist)+len(src) > cap(e.hist) { + if cap(e.hist) == 0 { + e.hist = make([]byte, 0, allocHistory) + } else { + if cap(e.hist) < maxMatchOffset*2 { + panic("unexpected buffer size") + } + // Move down + offset := int32(len(e.hist)) - maxMatchOffset + copy(e.hist[0:maxMatchOffset], e.hist[offset:]) + e.cur += offset + e.hist = e.hist[:maxMatchOffset] + } + } + s := int32(len(e.hist)) + e.hist = append(e.hist, src...) + return s +} + +// hash4 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4u(u uint32, h uint8) uint32 { + return (u * prime4bytes) >> ((32 - h) & reg8SizeMask32) +} + +type tableEntryPrev struct { + Cur tableEntry + Prev tableEntry +} + +// hash4x64 returns the hash of the lowest 4 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <32. +func hash4x64(u uint64, h uint8) uint32 { + return (uint32(u) * prime4bytes) >> ((32 - h) & reg8SizeMask32) +} + +// hash7 returns the hash of the lowest 7 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash7(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 56)) * prime7bytes) >> ((64 - h) & reg8SizeMask64)) +} + +// hash8 returns the hash of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash8(u uint64, h uint8) uint32 { + return uint32((u * prime8bytes) >> ((64 - h) & reg8SizeMask64)) +} + +// hash6 returns the hash of the lowest 6 bytes of u to fit in a hash table with h bits. +// Preferably h should be a constant and should always be <64. +func hash6(u uint64, h uint8) uint32 { + return uint32(((u << (64 - 48)) * prime6bytes) >> ((64 - h) & reg8SizeMask64)) +} + +// matchlen will return the match length between offsets and t in src. +// The maximum length returned is maxMatchLength - 4. +// It is assumed that s > t, that t >=0 and s < len(src). +func (e *fastGen) matchlen(s, t int32, src []byte) int32 { + if debugDecode { + if t >= s { + panic(fmt.Sprint("t >=s:", t, s)) + } + if int(s) >= len(src) { + panic(fmt.Sprint("s >= len(src):", s, len(src))) + } + if t < 0 { + panic(fmt.Sprint("t < 0:", t)) + } + if s-t > maxMatchOffset { + panic(fmt.Sprint(s, "-", t, "(", s-t, ") > maxMatchLength (", maxMatchOffset, ")")) + } + } + s1 := int(s) + maxMatchLength - 4 + if s1 > len(src) { + s1 = len(src) + } + + // Extend the match to be as long as possible. + return int32(matchLen(src[s:s1], src[t:])) +} + +// matchlenLong will return the match length between offsets and t in src. +// It is assumed that s > t, that t >=0 and s < len(src). +func (e *fastGen) matchlenLong(s, t int32, src []byte) int32 { + if debugDecode { + if t >= s { + panic(fmt.Sprint("t >=s:", t, s)) + } + if int(s) >= len(src) { + panic(fmt.Sprint("s >= len(src):", s, len(src))) + } + if t < 0 { + panic(fmt.Sprint("t < 0:", t)) + } + if s-t > maxMatchOffset { + panic(fmt.Sprint(s, "-", t, "(", s-t, ") > maxMatchLength (", maxMatchOffset, ")")) + } + } + // Extend the match to be as long as possible. + return int32(matchLen(src[s:], src[t:])) +} + +// Reset the encoding table. +func (e *fastGen) Reset() { + if cap(e.hist) < allocHistory { + e.hist = make([]byte, 0, allocHistory) + } + // We offset current position so everything will be out of reach. + // If we are above the buffer reset it will be cleared anyway since len(hist) == 0. + if e.cur <= bufferReset { + e.cur += maxMatchOffset + int32(len(e.hist)) + } + e.hist = e.hist[:0] +} + +// matchLen returns the maximum length. +// 'a' must be the shortest of the two. +func matchLen(a, b []byte) int { + b = b[:len(a)] + var checked int + if len(a) >= 4 { + // Try 4 bytes first + if diff := binary.LittleEndian.Uint32(a) ^ binary.LittleEndian.Uint32(b); diff != 0 { + return bits.TrailingZeros32(diff) >> 3 + } + // Switch to 8 byte matching. + checked = 4 + a = a[4:] + b = b[4:] + for len(a) >= 8 { + b = b[:len(a)] + if diff := binary.LittleEndian.Uint64(a) ^ binary.LittleEndian.Uint64(b); diff != 0 { + return checked + (bits.TrailingZeros64(diff) >> 3) + } + checked += 8 + a = a[8:] + b = b[8:] + } + } + b = b[:len(a)] + for i := range a { + if a[i] != b[i] { + return i + checked + } + } + return len(a) + checked +} diff --git a/vendor/github.com/klauspost/compress/flate/huffman_bit_writer.go b/vendor/github.com/klauspost/compress/flate/huffman_bit_writer.go new file mode 100644 index 0000000000..3ad5e98072 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/huffman_bit_writer.go @@ -0,0 +1,1034 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "encoding/binary" + "fmt" + "io" +) + +const ( + // The largest offset code. + offsetCodeCount = 30 + + // The special code used to mark the end of a block. + endBlockMarker = 256 + + // The first length code. + lengthCodesStart = 257 + + // The number of codegen codes. + codegenCodeCount = 19 + badCode = 255 + + // bufferFlushSize indicates the buffer size + // after which bytes are flushed to the writer. + // Should preferably be a multiple of 6, since + // we accumulate 6 bytes between writes to the buffer. + bufferFlushSize = 246 + + // bufferSize is the actual output byte buffer size. + // It must have additional headroom for a flush + // which can contain up to 8 bytes. + bufferSize = bufferFlushSize + 8 +) + +// The number of extra bits needed by length code X - LENGTH_CODES_START. +var lengthExtraBits = [32]int8{ + /* 257 */ 0, 0, 0, + /* 260 */ 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, + /* 270 */ 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, + /* 280 */ 4, 5, 5, 5, 5, 0, +} + +// The length indicated by length code X - LENGTH_CODES_START. +var lengthBase = [32]uint8{ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, + 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, + 64, 80, 96, 112, 128, 160, 192, 224, 255, +} + +// offset code word extra bits. +var offsetExtraBits = [64]int8{ + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, + 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, + 9, 9, 10, 10, 11, 11, 12, 12, 13, 13, + /* extended window */ + 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 20, 20, +} + +var offsetCombined = [32]uint32{} + +func init() { + var offsetBase = [64]uint32{ + /* normal deflate */ + 0x000000, 0x000001, 0x000002, 0x000003, 0x000004, + 0x000006, 0x000008, 0x00000c, 0x000010, 0x000018, + 0x000020, 0x000030, 0x000040, 0x000060, 0x000080, + 0x0000c0, 0x000100, 0x000180, 0x000200, 0x000300, + 0x000400, 0x000600, 0x000800, 0x000c00, 0x001000, + 0x001800, 0x002000, 0x003000, 0x004000, 0x006000, + + /* extended window */ + 0x008000, 0x00c000, 0x010000, 0x018000, 0x020000, + 0x030000, 0x040000, 0x060000, 0x080000, 0x0c0000, + 0x100000, 0x180000, 0x200000, 0x300000, + } + + for i := range offsetCombined[:] { + // Don't use extended window values... + if offsetBase[i] > 0x006000 { + continue + } + offsetCombined[i] = uint32(offsetExtraBits[i])<<16 | (offsetBase[i]) + } +} + +// The odd order in which the codegen code sizes are written. +var codegenOrder = []uint32{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} + +type huffmanBitWriter struct { + // writer is the underlying writer. + // Do not use it directly; use the write method, which ensures + // that Write errors are sticky. + writer io.Writer + + // Data waiting to be written is bytes[0:nbytes] + // and then the low nbits of bits. + bits uint64 + nbits uint16 + nbytes uint8 + lastHuffMan bool + literalEncoding *huffmanEncoder + tmpLitEncoding *huffmanEncoder + offsetEncoding *huffmanEncoder + codegenEncoding *huffmanEncoder + err error + lastHeader int + // Set between 0 (reused block can be up to 2x the size) + logNewTablePenalty uint + bytes [256 + 8]byte + literalFreq [lengthCodesStart + 32]uint16 + offsetFreq [32]uint16 + codegenFreq [codegenCodeCount]uint16 + + // codegen must have an extra space for the final symbol. + codegen [literalCount + offsetCodeCount + 1]uint8 +} + +// Huffman reuse. +// +// The huffmanBitWriter supports reusing huffman tables and thereby combining block sections. +// +// This is controlled by several variables: +// +// If lastHeader is non-zero the Huffman table can be reused. +// This also indicates that a Huffman table has been generated that can output all +// possible symbols. +// It also indicates that an EOB has not yet been emitted, so if a new tabel is generated +// an EOB with the previous table must be written. +// +// If lastHuffMan is set, a table for outputting literals has been generated and offsets are invalid. +// +// An incoming block estimates the output size of a new table using a 'fresh' by calculating the +// optimal size and adding a penalty in 'logNewTablePenalty'. +// A Huffman table is not optimal, which is why we add a penalty, and generating a new table +// is slower both for compression and decompression. + +func newHuffmanBitWriter(w io.Writer) *huffmanBitWriter { + return &huffmanBitWriter{ + writer: w, + literalEncoding: newHuffmanEncoder(literalCount), + tmpLitEncoding: newHuffmanEncoder(literalCount), + codegenEncoding: newHuffmanEncoder(codegenCodeCount), + offsetEncoding: newHuffmanEncoder(offsetCodeCount), + } +} + +func (w *huffmanBitWriter) reset(writer io.Writer) { + w.writer = writer + w.bits, w.nbits, w.nbytes, w.err = 0, 0, 0, nil + w.lastHeader = 0 + w.lastHuffMan = false +} + +func (w *huffmanBitWriter) canReuse(t *tokens) (offsets, lits bool) { + offsets, lits = true, true + a := t.offHist[:offsetCodeCount] + b := w.offsetFreq[:len(a)] + for i := range a { + if b[i] == 0 && a[i] != 0 { + offsets = false + break + } + } + + a = t.extraHist[:literalCount-256] + b = w.literalFreq[256:literalCount] + b = b[:len(a)] + for i := range a { + if b[i] == 0 && a[i] != 0 { + lits = false + break + } + } + if lits { + a = t.litHist[:] + b = w.literalFreq[:len(a)] + for i := range a { + if b[i] == 0 && a[i] != 0 { + lits = false + break + } + } + } + return +} + +func (w *huffmanBitWriter) flush() { + if w.err != nil { + w.nbits = 0 + return + } + if w.lastHeader > 0 { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + n := w.nbytes + for w.nbits != 0 { + w.bytes[n] = byte(w.bits) + w.bits >>= 8 + if w.nbits > 8 { // Avoid underflow + w.nbits -= 8 + } else { + w.nbits = 0 + } + n++ + } + w.bits = 0 + w.write(w.bytes[:n]) + w.nbytes = 0 +} + +func (w *huffmanBitWriter) write(b []byte) { + if w.err != nil { + return + } + _, w.err = w.writer.Write(b) +} + +func (w *huffmanBitWriter) writeBits(b int32, nb uint16) { + w.bits |= uint64(b) << w.nbits + w.nbits += nb + if w.nbits >= 48 { + w.writeOutBits() + } +} + +func (w *huffmanBitWriter) writeBytes(bytes []byte) { + if w.err != nil { + return + } + n := w.nbytes + if w.nbits&7 != 0 { + w.err = InternalError("writeBytes with unfinished bits") + return + } + for w.nbits != 0 { + w.bytes[n] = byte(w.bits) + w.bits >>= 8 + w.nbits -= 8 + n++ + } + if n != 0 { + w.write(w.bytes[:n]) + } + w.nbytes = 0 + w.write(bytes) +} + +// RFC 1951 3.2.7 specifies a special run-length encoding for specifying +// the literal and offset lengths arrays (which are concatenated into a single +// array). This method generates that run-length encoding. +// +// The result is written into the codegen array, and the frequencies +// of each code is written into the codegenFreq array. +// Codes 0-15 are single byte codes. Codes 16-18 are followed by additional +// information. Code badCode is an end marker +// +// numLiterals The number of literals in literalEncoding +// numOffsets The number of offsets in offsetEncoding +// litenc, offenc The literal and offset encoder to use +func (w *huffmanBitWriter) generateCodegen(numLiterals int, numOffsets int, litEnc, offEnc *huffmanEncoder) { + for i := range w.codegenFreq { + w.codegenFreq[i] = 0 + } + // Note that we are using codegen both as a temporary variable for holding + // a copy of the frequencies, and as the place where we put the result. + // This is fine because the output is always shorter than the input used + // so far. + codegen := w.codegen[:] // cache + // Copy the concatenated code sizes to codegen. Put a marker at the end. + cgnl := codegen[:numLiterals] + for i := range cgnl { + cgnl[i] = uint8(litEnc.codes[i].len) + } + + cgnl = codegen[numLiterals : numLiterals+numOffsets] + for i := range cgnl { + cgnl[i] = uint8(offEnc.codes[i].len) + } + codegen[numLiterals+numOffsets] = badCode + + size := codegen[0] + count := 1 + outIndex := 0 + for inIndex := 1; size != badCode; inIndex++ { + // INVARIANT: We have seen "count" copies of size that have not yet + // had output generated for them. + nextSize := codegen[inIndex] + if nextSize == size { + count++ + continue + } + // We need to generate codegen indicating "count" of size. + if size != 0 { + codegen[outIndex] = size + outIndex++ + w.codegenFreq[size]++ + count-- + for count >= 3 { + n := 6 + if n > count { + n = count + } + codegen[outIndex] = 16 + outIndex++ + codegen[outIndex] = uint8(n - 3) + outIndex++ + w.codegenFreq[16]++ + count -= n + } + } else { + for count >= 11 { + n := 138 + if n > count { + n = count + } + codegen[outIndex] = 18 + outIndex++ + codegen[outIndex] = uint8(n - 11) + outIndex++ + w.codegenFreq[18]++ + count -= n + } + if count >= 3 { + // count >= 3 && count <= 10 + codegen[outIndex] = 17 + outIndex++ + codegen[outIndex] = uint8(count - 3) + outIndex++ + w.codegenFreq[17]++ + count = 0 + } + } + count-- + for ; count >= 0; count-- { + codegen[outIndex] = size + outIndex++ + w.codegenFreq[size]++ + } + // Set up invariant for next time through the loop. + size = nextSize + count = 1 + } + // Marker indicating the end of the codegen. + codegen[outIndex] = badCode +} + +func (w *huffmanBitWriter) codegens() int { + numCodegens := len(w.codegenFreq) + for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 { + numCodegens-- + } + return numCodegens +} + +func (w *huffmanBitWriter) headerSize() (size, numCodegens int) { + numCodegens = len(w.codegenFreq) + for numCodegens > 4 && w.codegenFreq[codegenOrder[numCodegens-1]] == 0 { + numCodegens-- + } + return 3 + 5 + 5 + 4 + (3 * numCodegens) + + w.codegenEncoding.bitLength(w.codegenFreq[:]) + + int(w.codegenFreq[16])*2 + + int(w.codegenFreq[17])*3 + + int(w.codegenFreq[18])*7, numCodegens +} + +// dynamicSize returns the size of dynamically encoded data in bits. +func (w *huffmanBitWriter) dynamicReuseSize(litEnc, offEnc *huffmanEncoder) (size int) { + size = litEnc.bitLength(w.literalFreq[:]) + + offEnc.bitLength(w.offsetFreq[:]) + return size +} + +// dynamicSize returns the size of dynamically encoded data in bits. +func (w *huffmanBitWriter) dynamicSize(litEnc, offEnc *huffmanEncoder, extraBits int) (size, numCodegens int) { + header, numCodegens := w.headerSize() + size = header + + litEnc.bitLength(w.literalFreq[:]) + + offEnc.bitLength(w.offsetFreq[:]) + + extraBits + return size, numCodegens +} + +// extraBitSize will return the number of bits that will be written +// as "extra" bits on matches. +func (w *huffmanBitWriter) extraBitSize() int { + total := 0 + for i, n := range w.literalFreq[257:literalCount] { + total += int(n) * int(lengthExtraBits[i&31]) + } + for i, n := range w.offsetFreq[:offsetCodeCount] { + total += int(n) * int(offsetExtraBits[i&31]) + } + return total +} + +// fixedSize returns the size of dynamically encoded data in bits. +func (w *huffmanBitWriter) fixedSize(extraBits int) int { + return 3 + + fixedLiteralEncoding.bitLength(w.literalFreq[:]) + + fixedOffsetEncoding.bitLength(w.offsetFreq[:]) + + extraBits +} + +// storedSize calculates the stored size, including header. +// The function returns the size in bits and whether the block +// fits inside a single block. +func (w *huffmanBitWriter) storedSize(in []byte) (int, bool) { + if in == nil { + return 0, false + } + if len(in) <= maxStoreBlockSize { + return (len(in) + 5) * 8, true + } + return 0, false +} + +func (w *huffmanBitWriter) writeCode(c hcode) { + // The function does not get inlined if we "& 63" the shift. + w.bits |= uint64(c.code) << w.nbits + w.nbits += c.len + if w.nbits >= 48 { + w.writeOutBits() + } +} + +// writeOutBits will write bits to the buffer. +func (w *huffmanBitWriter) writeOutBits() { + bits := w.bits + w.bits >>= 48 + w.nbits -= 48 + n := w.nbytes + + // We over-write, but faster... + binary.LittleEndian.PutUint64(w.bytes[n:], bits) + n += 6 + + if n >= bufferFlushSize { + if w.err != nil { + n = 0 + return + } + w.write(w.bytes[:n]) + n = 0 + } + + w.nbytes = n +} + +// Write the header of a dynamic Huffman block to the output stream. +// +// numLiterals The number of literals specified in codegen +// numOffsets The number of offsets specified in codegen +// numCodegens The number of codegens used in codegen +func (w *huffmanBitWriter) writeDynamicHeader(numLiterals int, numOffsets int, numCodegens int, isEof bool) { + if w.err != nil { + return + } + var firstBits int32 = 4 + if isEof { + firstBits = 5 + } + w.writeBits(firstBits, 3) + w.writeBits(int32(numLiterals-257), 5) + w.writeBits(int32(numOffsets-1), 5) + w.writeBits(int32(numCodegens-4), 4) + + for i := 0; i < numCodegens; i++ { + value := uint(w.codegenEncoding.codes[codegenOrder[i]].len) + w.writeBits(int32(value), 3) + } + + i := 0 + for { + var codeWord = uint32(w.codegen[i]) + i++ + if codeWord == badCode { + break + } + w.writeCode(w.codegenEncoding.codes[codeWord]) + + switch codeWord { + case 16: + w.writeBits(int32(w.codegen[i]), 2) + i++ + case 17: + w.writeBits(int32(w.codegen[i]), 3) + i++ + case 18: + w.writeBits(int32(w.codegen[i]), 7) + i++ + } + } +} + +// writeStoredHeader will write a stored header. +// If the stored block is only used for EOF, +// it is replaced with a fixed huffman block. +func (w *huffmanBitWriter) writeStoredHeader(length int, isEof bool) { + if w.err != nil { + return + } + if w.lastHeader > 0 { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + + // To write EOF, use a fixed encoding block. 10 bits instead of 5 bytes. + if length == 0 && isEof { + w.writeFixedHeader(isEof) + // EOB: 7 bits, value: 0 + w.writeBits(0, 7) + w.flush() + return + } + + var flag int32 + if isEof { + flag = 1 + } + w.writeBits(flag, 3) + w.flush() + w.writeBits(int32(length), 16) + w.writeBits(int32(^uint16(length)), 16) +} + +func (w *huffmanBitWriter) writeFixedHeader(isEof bool) { + if w.err != nil { + return + } + if w.lastHeader > 0 { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + + // Indicate that we are a fixed Huffman block + var value int32 = 2 + if isEof { + value = 3 + } + w.writeBits(value, 3) +} + +// writeBlock will write a block of tokens with the smallest encoding. +// The original input can be supplied, and if the huffman encoded data +// is larger than the original bytes, the data will be written as a +// stored block. +// If the input is nil, the tokens will always be Huffman encoded. +func (w *huffmanBitWriter) writeBlock(tokens *tokens, eof bool, input []byte) { + if w.err != nil { + return + } + + tokens.AddEOB() + if w.lastHeader > 0 { + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } + numLiterals, numOffsets := w.indexTokens(tokens, false) + w.generate(tokens) + var extraBits int + storedSize, storable := w.storedSize(input) + if storable { + extraBits = w.extraBitSize() + } + + // Figure out smallest code. + // Fixed Huffman baseline. + var literalEncoding = fixedLiteralEncoding + var offsetEncoding = fixedOffsetEncoding + var size = w.fixedSize(extraBits) + + // Dynamic Huffman? + var numCodegens int + + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + dynamicSize, numCodegens := w.dynamicSize(w.literalEncoding, w.offsetEncoding, extraBits) + + if dynamicSize < size { + size = dynamicSize + literalEncoding = w.literalEncoding + offsetEncoding = w.offsetEncoding + } + + // Stored bytes? + if storable && storedSize < size { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + return + } + + // Huffman. + if literalEncoding == fixedLiteralEncoding { + w.writeFixedHeader(eof) + } else { + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + } + + // Write the tokens. + w.writeTokens(tokens.Slice(), literalEncoding.codes, offsetEncoding.codes) +} + +// writeBlockDynamic encodes a block using a dynamic Huffman table. +// This should be used if the symbols used have a disproportionate +// histogram distribution. +// If input is supplied and the compression savings are below 1/16th of the +// input size the block is stored. +func (w *huffmanBitWriter) writeBlockDynamic(tokens *tokens, eof bool, input []byte, sync bool) { + if w.err != nil { + return + } + + sync = sync || eof + if sync { + tokens.AddEOB() + } + + // We cannot reuse pure huffman table, and must mark as EOF. + if (w.lastHuffMan || eof) && w.lastHeader > 0 { + // We will not try to reuse. + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + w.lastHuffMan = false + } + if !sync { + tokens.Fill() + } + numLiterals, numOffsets := w.indexTokens(tokens, !sync) + + var size int + // Check if we should reuse. + if w.lastHeader > 0 { + // Estimate size for using a new table. + // Use the previous header size as the best estimate. + newSize := w.lastHeader + tokens.EstimatedBits() + newSize += newSize >> w.logNewTablePenalty + + // The estimated size is calculated as an optimal table. + // We add a penalty to make it more realistic and re-use a bit more. + reuseSize := w.dynamicReuseSize(w.literalEncoding, w.offsetEncoding) + w.extraBitSize() + + // Check if a new table is better. + if newSize < reuseSize { + // Write the EOB we owe. + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + size = newSize + w.lastHeader = 0 + } else { + size = reuseSize + } + // Check if we get a reasonable size decrease. + if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + w.lastHeader = 0 + return + } + } + + // We want a new block/table + if w.lastHeader == 0 { + w.generate(tokens) + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, w.offsetEncoding) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + var numCodegens int + size, numCodegens = w.dynamicSize(w.literalEncoding, w.offsetEncoding, w.extraBitSize()) + // Store bytes, if we don't get a reasonable improvement. + if ssize, storable := w.storedSize(input); storable && ssize < (size+size>>4) { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + w.lastHeader = 0 + return + } + + // Write Huffman table. + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + w.lastHeader, _ = w.headerSize() + w.lastHuffMan = false + } + + if sync { + w.lastHeader = 0 + } + // Write the tokens. + w.writeTokens(tokens.Slice(), w.literalEncoding.codes, w.offsetEncoding.codes) +} + +// indexTokens indexes a slice of tokens, and updates +// literalFreq and offsetFreq, and generates literalEncoding +// and offsetEncoding. +// The number of literal and offset tokens is returned. +func (w *huffmanBitWriter) indexTokens(t *tokens, filled bool) (numLiterals, numOffsets int) { + copy(w.literalFreq[:], t.litHist[:]) + copy(w.literalFreq[256:], t.extraHist[:]) + copy(w.offsetFreq[:], t.offHist[:offsetCodeCount]) + + if t.n == 0 { + return + } + if filled { + return maxNumLit, maxNumDist + } + // get the number of literals + numLiterals = len(w.literalFreq) + for w.literalFreq[numLiterals-1] == 0 { + numLiterals-- + } + // get the number of offsets + numOffsets = len(w.offsetFreq) + for numOffsets > 0 && w.offsetFreq[numOffsets-1] == 0 { + numOffsets-- + } + if numOffsets == 0 { + // We haven't found a single match. If we want to go with the dynamic encoding, + // we should count at least one offset to be sure that the offset huffman tree could be encoded. + w.offsetFreq[0] = 1 + numOffsets = 1 + } + return +} + +func (w *huffmanBitWriter) generate(t *tokens) { + w.literalEncoding.generate(w.literalFreq[:literalCount], 15) + w.offsetEncoding.generate(w.offsetFreq[:offsetCodeCount], 15) +} + +// writeTokens writes a slice of tokens to the output. +// codes for literal and offset encoding must be supplied. +func (w *huffmanBitWriter) writeTokens(tokens []token, leCodes, oeCodes []hcode) { + if w.err != nil { + return + } + if len(tokens) == 0 { + return + } + + // Only last token should be endBlockMarker. + var deferEOB bool + if tokens[len(tokens)-1] == endBlockMarker { + tokens = tokens[:len(tokens)-1] + deferEOB = true + } + + // Create slices up to the next power of two to avoid bounds checks. + lits := leCodes[:256] + offs := oeCodes[:32] + lengths := leCodes[lengthCodesStart:] + lengths = lengths[:32] + + // Go 1.16 LOVES having these on stack. + bits, nbits, nbytes := w.bits, w.nbits, w.nbytes + + for _, t := range tokens { + if t < matchType { + //w.writeCode(lits[t.literal()]) + c := lits[t.literal()] + bits |= uint64(c.code) << nbits + nbits += c.len + if nbits >= 48 { + binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits) + //*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits + bits >>= 48 + nbits -= 48 + nbytes += 6 + if nbytes >= bufferFlushSize { + if w.err != nil { + nbytes = 0 + return + } + _, w.err = w.writer.Write(w.bytes[:nbytes]) + nbytes = 0 + } + } + continue + } + + // Write the length + length := t.length() + lengthCode := lengthCode(length) + if false { + w.writeCode(lengths[lengthCode&31]) + } else { + // inlined + c := lengths[lengthCode&31] + bits |= uint64(c.code) << nbits + nbits += c.len + if nbits >= 48 { + binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits) + //*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits + bits >>= 48 + nbits -= 48 + nbytes += 6 + if nbytes >= bufferFlushSize { + if w.err != nil { + nbytes = 0 + return + } + _, w.err = w.writer.Write(w.bytes[:nbytes]) + nbytes = 0 + } + } + } + + extraLengthBits := uint16(lengthExtraBits[lengthCode&31]) + if extraLengthBits > 0 { + //w.writeBits(extraLength, extraLengthBits) + extraLength := int32(length - lengthBase[lengthCode&31]) + bits |= uint64(extraLength) << nbits + nbits += extraLengthBits + if nbits >= 48 { + binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits) + //*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits + bits >>= 48 + nbits -= 48 + nbytes += 6 + if nbytes >= bufferFlushSize { + if w.err != nil { + nbytes = 0 + return + } + _, w.err = w.writer.Write(w.bytes[:nbytes]) + nbytes = 0 + } + } + } + // Write the offset + offset := t.offset() + offsetCode := offset >> 16 + offset &= matchOffsetOnlyMask + if false { + w.writeCode(offs[offsetCode&31]) + } else { + // inlined + c := offs[offsetCode] + bits |= uint64(c.code) << nbits + nbits += c.len + if nbits >= 48 { + binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits) + //*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits + bits >>= 48 + nbits -= 48 + nbytes += 6 + if nbytes >= bufferFlushSize { + if w.err != nil { + nbytes = 0 + return + } + _, w.err = w.writer.Write(w.bytes[:nbytes]) + nbytes = 0 + } + } + } + offsetComb := offsetCombined[offsetCode] + if offsetComb > 1<<16 { + //w.writeBits(extraOffset, extraOffsetBits) + bits |= uint64(offset&matchOffsetOnlyMask-(offsetComb&0xffff)) << nbits + nbits += uint16(offsetComb >> 16) + if nbits >= 48 { + binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits) + //*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits + bits >>= 48 + nbits -= 48 + nbytes += 6 + if nbytes >= bufferFlushSize { + if w.err != nil { + nbytes = 0 + return + } + _, w.err = w.writer.Write(w.bytes[:nbytes]) + nbytes = 0 + } + } + } + } + // Restore... + w.bits, w.nbits, w.nbytes = bits, nbits, nbytes + + if deferEOB { + w.writeCode(leCodes[endBlockMarker]) + } +} + +// huffOffset is a static offset encoder used for huffman only encoding. +// It can be reused since we will not be encoding offset values. +var huffOffset *huffmanEncoder + +func init() { + w := newHuffmanBitWriter(nil) + w.offsetFreq[0] = 1 + huffOffset = newHuffmanEncoder(offsetCodeCount) + huffOffset.generate(w.offsetFreq[:offsetCodeCount], 15) +} + +// writeBlockHuff encodes a block of bytes as either +// Huffman encoded literals or uncompressed bytes if the +// results only gains very little from compression. +func (w *huffmanBitWriter) writeBlockHuff(eof bool, input []byte, sync bool) { + if w.err != nil { + return + } + + // Clear histogram + for i := range w.literalFreq[:] { + w.literalFreq[i] = 0 + } + if !w.lastHuffMan { + for i := range w.offsetFreq[:] { + w.offsetFreq[i] = 0 + } + } + + // Fill is rarely better... + const fill = false + const numLiterals = endBlockMarker + 1 + const numOffsets = 1 + + // Add everything as literals + // We have to estimate the header size. + // Assume header is around 70 bytes: + // https://stackoverflow.com/a/25454430 + const guessHeaderSizeBits = 70 * 8 + histogram(input, w.literalFreq[:numLiterals], fill) + w.literalFreq[endBlockMarker] = 1 + w.tmpLitEncoding.generate(w.literalFreq[:numLiterals], 15) + if fill { + // Clear fill... + for i := range w.literalFreq[:numLiterals] { + w.literalFreq[i] = 0 + } + histogram(input, w.literalFreq[:numLiterals], false) + } + estBits := w.tmpLitEncoding.canReuseBits(w.literalFreq[:numLiterals]) + estBits += w.lastHeader + if w.lastHeader == 0 { + estBits += guessHeaderSizeBits + } + estBits += estBits >> w.logNewTablePenalty + + // Store bytes, if we don't get a reasonable improvement. + ssize, storable := w.storedSize(input) + if storable && ssize <= estBits { + w.writeStoredHeader(len(input), eof) + w.writeBytes(input) + return + } + + if w.lastHeader > 0 { + reuseSize := w.literalEncoding.canReuseBits(w.literalFreq[:256]) + + if estBits < reuseSize { + if debugDeflate { + //fmt.Println("not reusing, reuse:", reuseSize/8, "> new:", estBits/8, "- header est:", w.lastHeader/8) + } + // We owe an EOB + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + } else if debugDeflate { + fmt.Println("reusing, reuse:", reuseSize/8, "> new:", estBits/8, "- header est:", w.lastHeader/8) + } + } + + count := 0 + if w.lastHeader == 0 { + // Use the temp encoding, so swap. + w.literalEncoding, w.tmpLitEncoding = w.tmpLitEncoding, w.literalEncoding + // Generate codegen and codegenFrequencies, which indicates how to encode + // the literalEncoding and the offsetEncoding. + w.generateCodegen(numLiterals, numOffsets, w.literalEncoding, huffOffset) + w.codegenEncoding.generate(w.codegenFreq[:], 7) + numCodegens := w.codegens() + + // Huffman. + w.writeDynamicHeader(numLiterals, numOffsets, numCodegens, eof) + w.lastHuffMan = true + w.lastHeader, _ = w.headerSize() + if debugDeflate { + count += w.lastHeader + fmt.Println("header:", count/8) + } + } + + encoding := w.literalEncoding.codes[:256] + // Go 1.16 LOVES having these on stack. At least 1.5x the speed. + bits, nbits, nbytes := w.bits, w.nbits, w.nbytes + for _, t := range input { + // Bitwriting inlined, ~30% speedup + c := encoding[t] + bits |= uint64(c.code) << nbits + nbits += c.len + if debugDeflate { + count += int(c.len) + } + if nbits >= 48 { + binary.LittleEndian.PutUint64(w.bytes[nbytes:], bits) + //*(*uint64)(unsafe.Pointer(&w.bytes[nbytes])) = bits + bits >>= 48 + nbits -= 48 + nbytes += 6 + if nbytes >= bufferFlushSize { + if w.err != nil { + nbytes = 0 + return + } + _, w.err = w.writer.Write(w.bytes[:nbytes]) + nbytes = 0 + } + } + } + // Restore... + w.bits, w.nbits, w.nbytes = bits, nbits, nbytes + + if debugDeflate { + fmt.Println("wrote", count/8, "bytes") + } + if eof || sync { + w.writeCode(w.literalEncoding.codes[endBlockMarker]) + w.lastHeader = 0 + w.lastHuffMan = false + } +} diff --git a/vendor/github.com/klauspost/compress/flate/huffman_code.go b/vendor/github.com/klauspost/compress/flate/huffman_code.go new file mode 100644 index 0000000000..67b2b38728 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/huffman_code.go @@ -0,0 +1,374 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "math" + "math/bits" +) + +const ( + maxBitsLimit = 16 + // number of valid literals + literalCount = 286 +) + +// hcode is a huffman code with a bit code and bit length. +type hcode struct { + code, len uint16 +} + +type huffmanEncoder struct { + codes []hcode + bitCount [17]int32 + + // Allocate a reusable buffer with the longest possible frequency table. + // Possible lengths are codegenCodeCount, offsetCodeCount and literalCount. + // The largest of these is literalCount, so we allocate for that case. + freqcache [literalCount + 1]literalNode +} + +type literalNode struct { + literal uint16 + freq uint16 +} + +// A levelInfo describes the state of the constructed tree for a given depth. +type levelInfo struct { + // Our level. for better printing + level int32 + + // The frequency of the last node at this level + lastFreq int32 + + // The frequency of the next character to add to this level + nextCharFreq int32 + + // The frequency of the next pair (from level below) to add to this level. + // Only valid if the "needed" value of the next lower level is 0. + nextPairFreq int32 + + // The number of chains remaining to generate for this level before moving + // up to the next level + needed int32 +} + +// set sets the code and length of an hcode. +func (h *hcode) set(code uint16, length uint16) { + h.len = length + h.code = code +} + +func reverseBits(number uint16, bitLength byte) uint16 { + return bits.Reverse16(number << ((16 - bitLength) & 15)) +} + +func maxNode() literalNode { return literalNode{math.MaxUint16, math.MaxUint16} } + +func newHuffmanEncoder(size int) *huffmanEncoder { + // Make capacity to next power of two. + c := uint(bits.Len32(uint32(size - 1))) + return &huffmanEncoder{codes: make([]hcode, size, 1<= 3 +// The cases of 0, 1, and 2 literals are handled by special case code. +// +// list An array of the literals with non-zero frequencies +// and their associated frequencies. The array is in order of increasing +// frequency, and has as its last element a special element with frequency +// MaxInt32 +// maxBits The maximum number of bits that should be used to encode any literal. +// Must be less than 16. +// return An integer array in which array[i] indicates the number of literals +// that should be encoded in i bits. +func (h *huffmanEncoder) bitCounts(list []literalNode, maxBits int32) []int32 { + if maxBits >= maxBitsLimit { + panic("flate: maxBits too large") + } + n := int32(len(list)) + list = list[0 : n+1] + list[n] = maxNode() + + // The tree can't have greater depth than n - 1, no matter what. This + // saves a little bit of work in some small cases + if maxBits > n-1 { + maxBits = n - 1 + } + + // Create information about each of the levels. + // A bogus "Level 0" whose sole purpose is so that + // level1.prev.needed==0. This makes level1.nextPairFreq + // be a legitimate value that never gets chosen. + var levels [maxBitsLimit]levelInfo + // leafCounts[i] counts the number of literals at the left + // of ancestors of the rightmost node at level i. + // leafCounts[i][j] is the number of literals at the left + // of the level j ancestor. + var leafCounts [maxBitsLimit][maxBitsLimit]int32 + + for level := int32(1); level <= maxBits; level++ { + // For every level, the first two items are the first two characters. + // We initialize the levels as if we had already figured this out. + levels[level] = levelInfo{ + level: level, + lastFreq: int32(list[1].freq), + nextCharFreq: int32(list[2].freq), + nextPairFreq: int32(list[0].freq) + int32(list[1].freq), + } + leafCounts[level][level] = 2 + if level == 1 { + levels[level].nextPairFreq = math.MaxInt32 + } + } + + // We need a total of 2*n - 2 items at top level and have already generated 2. + levels[maxBits].needed = 2*n - 4 + + level := maxBits + for { + l := &levels[level] + if l.nextPairFreq == math.MaxInt32 && l.nextCharFreq == math.MaxInt32 { + // We've run out of both leafs and pairs. + // End all calculations for this level. + // To make sure we never come back to this level or any lower level, + // set nextPairFreq impossibly large. + l.needed = 0 + levels[level+1].nextPairFreq = math.MaxInt32 + level++ + continue + } + + prevFreq := l.lastFreq + if l.nextCharFreq < l.nextPairFreq { + // The next item on this row is a leaf node. + n := leafCounts[level][level] + 1 + l.lastFreq = l.nextCharFreq + // Lower leafCounts are the same of the previous node. + leafCounts[level][level] = n + e := list[n] + if e.literal < math.MaxUint16 { + l.nextCharFreq = int32(e.freq) + } else { + l.nextCharFreq = math.MaxInt32 + } + } else { + // The next item on this row is a pair from the previous row. + // nextPairFreq isn't valid until we generate two + // more values in the level below + l.lastFreq = l.nextPairFreq + // Take leaf counts from the lower level, except counts[level] remains the same. + copy(leafCounts[level][:level], leafCounts[level-1][:level]) + levels[l.level-1].needed = 2 + } + + if l.needed--; l.needed == 0 { + // We've done everything we need to do for this level. + // Continue calculating one level up. Fill in nextPairFreq + // of that level with the sum of the two nodes we've just calculated on + // this level. + if l.level == maxBits { + // All done! + break + } + levels[l.level+1].nextPairFreq = prevFreq + l.lastFreq + level++ + } else { + // If we stole from below, move down temporarily to replenish it. + for levels[level-1].needed > 0 { + level-- + } + } + } + + // Somethings is wrong if at the end, the top level is null or hasn't used + // all of the leaves. + if leafCounts[maxBits][maxBits] != n { + panic("leafCounts[maxBits][maxBits] != n") + } + + bitCount := h.bitCount[:maxBits+1] + bits := 1 + counts := &leafCounts[maxBits] + for level := maxBits; level > 0; level-- { + // chain.leafCount gives the number of literals requiring at least "bits" + // bits to encode. + bitCount[bits] = counts[level] - counts[level-1] + bits++ + } + return bitCount +} + +// Look at the leaves and assign them a bit count and an encoding as specified +// in RFC 1951 3.2.2 +func (h *huffmanEncoder) assignEncodingAndSize(bitCount []int32, list []literalNode) { + code := uint16(0) + for n, bits := range bitCount { + code <<= 1 + if n == 0 || bits == 0 { + continue + } + // The literals list[len(list)-bits] .. list[len(list)-bits] + // are encoded using "bits" bits, and get the values + // code, code + 1, .... The code values are + // assigned in literal order (not frequency order). + chunk := list[len(list)-int(bits):] + + sortByLiteral(chunk) + for _, node := range chunk { + h.codes[node.literal] = hcode{code: reverseBits(code, uint8(n)), len: uint16(n)} + code++ + } + list = list[0 : len(list)-int(bits)] + } +} + +// Update this Huffman Code object to be the minimum code for the specified frequency count. +// +// freq An array of frequencies, in which frequency[i] gives the frequency of literal i. +// maxBits The maximum number of bits to use for any literal. +func (h *huffmanEncoder) generate(freq []uint16, maxBits int32) { + list := h.freqcache[:len(freq)+1] + // Number of non-zero literals + count := 0 + // Set list to be the set of all non-zero literals and their frequencies + for i, f := range freq { + if f != 0 { + list[count] = literalNode{uint16(i), f} + count++ + } else { + list[count] = literalNode{} + h.codes[i].len = 0 + } + } + list[len(freq)] = literalNode{} + + list = list[:count] + if count <= 2 { + // Handle the small cases here, because they are awkward for the general case code. With + // two or fewer literals, everything has bit length 1. + for i, node := range list { + // "list" is in order of increasing literal value. + h.codes[node.literal].set(uint16(i), 1) + } + return + } + sortByFreq(list) + + // Get the number of literals for each bit count + bitCount := h.bitCounts(list, maxBits) + // And do the assignment + h.assignEncodingAndSize(bitCount, list) +} + +// atLeastOne clamps the result between 1 and 15. +func atLeastOne(v float32) float32 { + if v < 1 { + return 1 + } + if v > 15 { + return 15 + } + return v +} + +// Unassigned values are assigned '1' in the histogram. +func fillHist(b []uint16) { + for i, v := range b { + if v == 0 { + b[i] = 1 + } + } +} + +func histogram(b []byte, h []uint16, fill bool) { + h = h[:256] + for _, t := range b { + h[t]++ + } + if fill { + fillHist(h) + } +} diff --git a/vendor/github.com/klauspost/compress/flate/huffman_sortByFreq.go b/vendor/github.com/klauspost/compress/flate/huffman_sortByFreq.go new file mode 100644 index 0000000000..2077802990 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/huffman_sortByFreq.go @@ -0,0 +1,178 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +// Sort sorts data. +// It makes one call to data.Len to determine n, and O(n*log(n)) calls to +// data.Less and data.Swap. The sort is not guaranteed to be stable. +func sortByFreq(data []literalNode) { + n := len(data) + quickSortByFreq(data, 0, n, maxDepth(n)) +} + +func quickSortByFreq(data []literalNode, a, b, maxDepth int) { + for b-a > 12 { // Use ShellSort for slices <= 12 elements + if maxDepth == 0 { + heapSort(data, a, b) + return + } + maxDepth-- + mlo, mhi := doPivotByFreq(data, a, b) + // Avoiding recursion on the larger subproblem guarantees + // a stack depth of at most lg(b-a). + if mlo-a < b-mhi { + quickSortByFreq(data, a, mlo, maxDepth) + a = mhi // i.e., quickSortByFreq(data, mhi, b) + } else { + quickSortByFreq(data, mhi, b, maxDepth) + b = mlo // i.e., quickSortByFreq(data, a, mlo) + } + } + if b-a > 1 { + // Do ShellSort pass with gap 6 + // It could be written in this simplified form cause b-a <= 12 + for i := a + 6; i < b; i++ { + if data[i].freq == data[i-6].freq && data[i].literal < data[i-6].literal || data[i].freq < data[i-6].freq { + data[i], data[i-6] = data[i-6], data[i] + } + } + insertionSortByFreq(data, a, b) + } +} + +// siftDownByFreq implements the heap property on data[lo, hi). +// first is an offset into the array where the root of the heap lies. +func siftDownByFreq(data []literalNode, lo, hi, first int) { + root := lo + for { + child := 2*root + 1 + if child >= hi { + break + } + if child+1 < hi && (data[first+child].freq == data[first+child+1].freq && data[first+child].literal < data[first+child+1].literal || data[first+child].freq < data[first+child+1].freq) { + child++ + } + if data[first+root].freq == data[first+child].freq && data[first+root].literal > data[first+child].literal || data[first+root].freq > data[first+child].freq { + return + } + data[first+root], data[first+child] = data[first+child], data[first+root] + root = child + } +} +func doPivotByFreq(data []literalNode, lo, hi int) (midlo, midhi int) { + m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow. + if hi-lo > 40 { + // Tukey's ``Ninther,'' median of three medians of three. + s := (hi - lo) / 8 + medianOfThreeSortByFreq(data, lo, lo+s, lo+2*s) + medianOfThreeSortByFreq(data, m, m-s, m+s) + medianOfThreeSortByFreq(data, hi-1, hi-1-s, hi-1-2*s) + } + medianOfThreeSortByFreq(data, lo, m, hi-1) + + // Invariants are: + // data[lo] = pivot (set up by ChoosePivot) + // data[lo < i < a] < pivot + // data[a <= i < b] <= pivot + // data[b <= i < c] unexamined + // data[c <= i < hi-1] > pivot + // data[hi-1] >= pivot + pivot := lo + a, c := lo+1, hi-1 + + for ; a < c && (data[a].freq == data[pivot].freq && data[a].literal < data[pivot].literal || data[a].freq < data[pivot].freq); a++ { + } + b := a + for { + for ; b < c && (data[pivot].freq == data[b].freq && data[pivot].literal > data[b].literal || data[pivot].freq > data[b].freq); b++ { // data[b] <= pivot + } + for ; b < c && (data[pivot].freq == data[c-1].freq && data[pivot].literal < data[c-1].literal || data[pivot].freq < data[c-1].freq); c-- { // data[c-1] > pivot + } + if b >= c { + break + } + // data[b] > pivot; data[c-1] <= pivot + data[b], data[c-1] = data[c-1], data[b] + b++ + c-- + } + // If hi-c<3 then there are duplicates (by property of median of nine). + // Let's be a bit more conservative, and set border to 5. + protect := hi-c < 5 + if !protect && hi-c < (hi-lo)/4 { + // Lets test some points for equality to pivot + dups := 0 + if data[pivot].freq == data[hi-1].freq && data[pivot].literal > data[hi-1].literal || data[pivot].freq > data[hi-1].freq { // data[hi-1] = pivot + data[c], data[hi-1] = data[hi-1], data[c] + c++ + dups++ + } + if data[b-1].freq == data[pivot].freq && data[b-1].literal > data[pivot].literal || data[b-1].freq > data[pivot].freq { // data[b-1] = pivot + b-- + dups++ + } + // m-lo = (hi-lo)/2 > 6 + // b-lo > (hi-lo)*3/4-1 > 8 + // ==> m < b ==> data[m] <= pivot + if data[m].freq == data[pivot].freq && data[m].literal > data[pivot].literal || data[m].freq > data[pivot].freq { // data[m] = pivot + data[m], data[b-1] = data[b-1], data[m] + b-- + dups++ + } + // if at least 2 points are equal to pivot, assume skewed distribution + protect = dups > 1 + } + if protect { + // Protect against a lot of duplicates + // Add invariant: + // data[a <= i < b] unexamined + // data[b <= i < c] = pivot + for { + for ; a < b && (data[b-1].freq == data[pivot].freq && data[b-1].literal > data[pivot].literal || data[b-1].freq > data[pivot].freq); b-- { // data[b] == pivot + } + for ; a < b && (data[a].freq == data[pivot].freq && data[a].literal < data[pivot].literal || data[a].freq < data[pivot].freq); a++ { // data[a] < pivot + } + if a >= b { + break + } + // data[a] == pivot; data[b-1] < pivot + data[a], data[b-1] = data[b-1], data[a] + a++ + b-- + } + } + // Swap pivot into middle + data[pivot], data[b-1] = data[b-1], data[pivot] + return b - 1, c +} + +// Insertion sort +func insertionSortByFreq(data []literalNode, a, b int) { + for i := a + 1; i < b; i++ { + for j := i; j > a && (data[j].freq == data[j-1].freq && data[j].literal < data[j-1].literal || data[j].freq < data[j-1].freq); j-- { + data[j], data[j-1] = data[j-1], data[j] + } + } +} + +// quickSortByFreq, loosely following Bentley and McIlroy, +// ``Engineering a Sort Function,'' SP&E November 1993. + +// medianOfThreeSortByFreq moves the median of the three values data[m0], data[m1], data[m2] into data[m1]. +func medianOfThreeSortByFreq(data []literalNode, m1, m0, m2 int) { + // sort 3 elements + if data[m1].freq == data[m0].freq && data[m1].literal < data[m0].literal || data[m1].freq < data[m0].freq { + data[m1], data[m0] = data[m0], data[m1] + } + // data[m0] <= data[m1] + if data[m2].freq == data[m1].freq && data[m2].literal < data[m1].literal || data[m2].freq < data[m1].freq { + data[m2], data[m1] = data[m1], data[m2] + // data[m0] <= data[m2] && data[m1] < data[m2] + if data[m1].freq == data[m0].freq && data[m1].literal < data[m0].literal || data[m1].freq < data[m0].freq { + data[m1], data[m0] = data[m0], data[m1] + } + } + // now data[m0] <= data[m1] <= data[m2] +} diff --git a/vendor/github.com/klauspost/compress/flate/huffman_sortByLiteral.go b/vendor/github.com/klauspost/compress/flate/huffman_sortByLiteral.go new file mode 100644 index 0000000000..93f1aea109 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/huffman_sortByLiteral.go @@ -0,0 +1,201 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +// Sort sorts data. +// It makes one call to data.Len to determine n, and O(n*log(n)) calls to +// data.Less and data.Swap. The sort is not guaranteed to be stable. +func sortByLiteral(data []literalNode) { + n := len(data) + quickSort(data, 0, n, maxDepth(n)) +} + +func quickSort(data []literalNode, a, b, maxDepth int) { + for b-a > 12 { // Use ShellSort for slices <= 12 elements + if maxDepth == 0 { + heapSort(data, a, b) + return + } + maxDepth-- + mlo, mhi := doPivot(data, a, b) + // Avoiding recursion on the larger subproblem guarantees + // a stack depth of at most lg(b-a). + if mlo-a < b-mhi { + quickSort(data, a, mlo, maxDepth) + a = mhi // i.e., quickSort(data, mhi, b) + } else { + quickSort(data, mhi, b, maxDepth) + b = mlo // i.e., quickSort(data, a, mlo) + } + } + if b-a > 1 { + // Do ShellSort pass with gap 6 + // It could be written in this simplified form cause b-a <= 12 + for i := a + 6; i < b; i++ { + if data[i].literal < data[i-6].literal { + data[i], data[i-6] = data[i-6], data[i] + } + } + insertionSort(data, a, b) + } +} +func heapSort(data []literalNode, a, b int) { + first := a + lo := 0 + hi := b - a + + // Build heap with greatest element at top. + for i := (hi - 1) / 2; i >= 0; i-- { + siftDown(data, i, hi, first) + } + + // Pop elements, largest first, into end of data. + for i := hi - 1; i >= 0; i-- { + data[first], data[first+i] = data[first+i], data[first] + siftDown(data, lo, i, first) + } +} + +// siftDown implements the heap property on data[lo, hi). +// first is an offset into the array where the root of the heap lies. +func siftDown(data []literalNode, lo, hi, first int) { + root := lo + for { + child := 2*root + 1 + if child >= hi { + break + } + if child+1 < hi && data[first+child].literal < data[first+child+1].literal { + child++ + } + if data[first+root].literal > data[first+child].literal { + return + } + data[first+root], data[first+child] = data[first+child], data[first+root] + root = child + } +} +func doPivot(data []literalNode, lo, hi int) (midlo, midhi int) { + m := int(uint(lo+hi) >> 1) // Written like this to avoid integer overflow. + if hi-lo > 40 { + // Tukey's ``Ninther,'' median of three medians of three. + s := (hi - lo) / 8 + medianOfThree(data, lo, lo+s, lo+2*s) + medianOfThree(data, m, m-s, m+s) + medianOfThree(data, hi-1, hi-1-s, hi-1-2*s) + } + medianOfThree(data, lo, m, hi-1) + + // Invariants are: + // data[lo] = pivot (set up by ChoosePivot) + // data[lo < i < a] < pivot + // data[a <= i < b] <= pivot + // data[b <= i < c] unexamined + // data[c <= i < hi-1] > pivot + // data[hi-1] >= pivot + pivot := lo + a, c := lo+1, hi-1 + + for ; a < c && data[a].literal < data[pivot].literal; a++ { + } + b := a + for { + for ; b < c && data[pivot].literal > data[b].literal; b++ { // data[b] <= pivot + } + for ; b < c && data[pivot].literal < data[c-1].literal; c-- { // data[c-1] > pivot + } + if b >= c { + break + } + // data[b] > pivot; data[c-1] <= pivot + data[b], data[c-1] = data[c-1], data[b] + b++ + c-- + } + // If hi-c<3 then there are duplicates (by property of median of nine). + // Let's be a bit more conservative, and set border to 5. + protect := hi-c < 5 + if !protect && hi-c < (hi-lo)/4 { + // Lets test some points for equality to pivot + dups := 0 + if data[pivot].literal > data[hi-1].literal { // data[hi-1] = pivot + data[c], data[hi-1] = data[hi-1], data[c] + c++ + dups++ + } + if data[b-1].literal > data[pivot].literal { // data[b-1] = pivot + b-- + dups++ + } + // m-lo = (hi-lo)/2 > 6 + // b-lo > (hi-lo)*3/4-1 > 8 + // ==> m < b ==> data[m] <= pivot + if data[m].literal > data[pivot].literal { // data[m] = pivot + data[m], data[b-1] = data[b-1], data[m] + b-- + dups++ + } + // if at least 2 points are equal to pivot, assume skewed distribution + protect = dups > 1 + } + if protect { + // Protect against a lot of duplicates + // Add invariant: + // data[a <= i < b] unexamined + // data[b <= i < c] = pivot + for { + for ; a < b && data[b-1].literal > data[pivot].literal; b-- { // data[b] == pivot + } + for ; a < b && data[a].literal < data[pivot].literal; a++ { // data[a] < pivot + } + if a >= b { + break + } + // data[a] == pivot; data[b-1] < pivot + data[a], data[b-1] = data[b-1], data[a] + a++ + b-- + } + } + // Swap pivot into middle + data[pivot], data[b-1] = data[b-1], data[pivot] + return b - 1, c +} + +// Insertion sort +func insertionSort(data []literalNode, a, b int) { + for i := a + 1; i < b; i++ { + for j := i; j > a && data[j].literal < data[j-1].literal; j-- { + data[j], data[j-1] = data[j-1], data[j] + } + } +} + +// maxDepth returns a threshold at which quicksort should switch +// to heapsort. It returns 2*ceil(lg(n+1)). +func maxDepth(n int) int { + var depth int + for i := n; i > 0; i >>= 1 { + depth++ + } + return depth * 2 +} + +// medianOfThree moves the median of the three values data[m0], data[m1], data[m2] into data[m1]. +func medianOfThree(data []literalNode, m1, m0, m2 int) { + // sort 3 elements + if data[m1].literal < data[m0].literal { + data[m1], data[m0] = data[m0], data[m1] + } + // data[m0] <= data[m1] + if data[m2].literal < data[m1].literal { + data[m2], data[m1] = data[m1], data[m2] + // data[m0] <= data[m2] && data[m1] < data[m2] + if data[m1].literal < data[m0].literal { + data[m1], data[m0] = data[m0], data[m1] + } + } + // now data[m0] <= data[m1] <= data[m2] +} diff --git a/vendor/github.com/klauspost/compress/flate/inflate.go b/vendor/github.com/klauspost/compress/flate/inflate.go new file mode 100644 index 0000000000..16bc51408e --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/inflate.go @@ -0,0 +1,1010 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package flate implements the DEFLATE compressed data format, described in +// RFC 1951. The gzip and zlib packages implement access to DEFLATE-based file +// formats. +package flate + +import ( + "bufio" + "fmt" + "io" + "math/bits" + "strconv" + "sync" +) + +const ( + maxCodeLen = 16 // max length of Huffman code + maxCodeLenMask = 15 // mask for max length of Huffman code + // The next three numbers come from the RFC section 3.2.7, with the + // additional proviso in section 3.2.5 which implies that distance codes + // 30 and 31 should never occur in compressed data. + maxNumLit = 286 + maxNumDist = 30 + numCodes = 19 // number of codes in Huffman meta-code + + debugDecode = false +) + +// Value of length - 3 and extra bits. +type lengthExtra struct { + length, extra uint8 +} + +var decCodeToLen = [32]lengthExtra{{length: 0x0, extra: 0x0}, {length: 0x1, extra: 0x0}, {length: 0x2, extra: 0x0}, {length: 0x3, extra: 0x0}, {length: 0x4, extra: 0x0}, {length: 0x5, extra: 0x0}, {length: 0x6, extra: 0x0}, {length: 0x7, extra: 0x0}, {length: 0x8, extra: 0x1}, {length: 0xa, extra: 0x1}, {length: 0xc, extra: 0x1}, {length: 0xe, extra: 0x1}, {length: 0x10, extra: 0x2}, {length: 0x14, extra: 0x2}, {length: 0x18, extra: 0x2}, {length: 0x1c, extra: 0x2}, {length: 0x20, extra: 0x3}, {length: 0x28, extra: 0x3}, {length: 0x30, extra: 0x3}, {length: 0x38, extra: 0x3}, {length: 0x40, extra: 0x4}, {length: 0x50, extra: 0x4}, {length: 0x60, extra: 0x4}, {length: 0x70, extra: 0x4}, {length: 0x80, extra: 0x5}, {length: 0xa0, extra: 0x5}, {length: 0xc0, extra: 0x5}, {length: 0xe0, extra: 0x5}, {length: 0xff, extra: 0x0}, {length: 0x0, extra: 0x0}, {length: 0x0, extra: 0x0}, {length: 0x0, extra: 0x0}} + +// Initialize the fixedHuffmanDecoder only once upon first use. +var fixedOnce sync.Once +var fixedHuffmanDecoder huffmanDecoder + +// A CorruptInputError reports the presence of corrupt input at a given offset. +type CorruptInputError int64 + +func (e CorruptInputError) Error() string { + return "flate: corrupt input before offset " + strconv.FormatInt(int64(e), 10) +} + +// An InternalError reports an error in the flate code itself. +type InternalError string + +func (e InternalError) Error() string { return "flate: internal error: " + string(e) } + +// A ReadError reports an error encountered while reading input. +// +// Deprecated: No longer returned. +type ReadError struct { + Offset int64 // byte offset where error occurred + Err error // error returned by underlying Read +} + +func (e *ReadError) Error() string { + return "flate: read error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error() +} + +// A WriteError reports an error encountered while writing output. +// +// Deprecated: No longer returned. +type WriteError struct { + Offset int64 // byte offset where error occurred + Err error // error returned by underlying Write +} + +func (e *WriteError) Error() string { + return "flate: write error at offset " + strconv.FormatInt(e.Offset, 10) + ": " + e.Err.Error() +} + +// Resetter resets a ReadCloser returned by NewReader or NewReaderDict to +// to switch to a new underlying Reader. This permits reusing a ReadCloser +// instead of allocating a new one. +type Resetter interface { + // Reset discards any buffered data and resets the Resetter as if it was + // newly initialized with the given reader. + Reset(r io.Reader, dict []byte) error +} + +// The data structure for decoding Huffman tables is based on that of +// zlib. There is a lookup table of a fixed bit width (huffmanChunkBits), +// For codes smaller than the table width, there are multiple entries +// (each combination of trailing bits has the same value). For codes +// larger than the table width, the table contains a link to an overflow +// table. The width of each entry in the link table is the maximum code +// size minus the chunk width. +// +// Note that you can do a lookup in the table even without all bits +// filled. Since the extra bits are zero, and the DEFLATE Huffman codes +// have the property that shorter codes come before longer ones, the +// bit length estimate in the result is a lower bound on the actual +// number of bits. +// +// See the following: +// http://www.gzip.org/algorithm.txt + +// chunk & 15 is number of bits +// chunk >> 4 is value, including table link + +const ( + huffmanChunkBits = 9 + huffmanNumChunks = 1 << huffmanChunkBits + huffmanCountMask = 15 + huffmanValueShift = 4 +) + +type huffmanDecoder struct { + maxRead int // the maximum number of bits we can read and not overread + chunks *[huffmanNumChunks]uint16 // chunks as described above + links [][]uint16 // overflow links + linkMask uint32 // mask the width of the link table +} + +// Initialize Huffman decoding tables from array of code lengths. +// Following this function, h is guaranteed to be initialized into a complete +// tree (i.e., neither over-subscribed nor under-subscribed). The exception is a +// degenerate case where the tree has only a single symbol with length 1. Empty +// trees are permitted. +func (h *huffmanDecoder) init(lengths []int) bool { + // Sanity enables additional runtime tests during Huffman + // table construction. It's intended to be used during + // development to supplement the currently ad-hoc unit tests. + const sanity = false + + if h.chunks == nil { + h.chunks = &[huffmanNumChunks]uint16{} + } + if h.maxRead != 0 { + *h = huffmanDecoder{chunks: h.chunks, links: h.links} + } + + // Count number of codes of each length, + // compute maxRead and max length. + var count [maxCodeLen]int + var min, max int + for _, n := range lengths { + if n == 0 { + continue + } + if min == 0 || n < min { + min = n + } + if n > max { + max = n + } + count[n&maxCodeLenMask]++ + } + + // Empty tree. The decompressor.huffSym function will fail later if the tree + // is used. Technically, an empty tree is only valid for the HDIST tree and + // not the HCLEN and HLIT tree. However, a stream with an empty HCLEN tree + // is guaranteed to fail since it will attempt to use the tree to decode the + // codes for the HLIT and HDIST trees. Similarly, an empty HLIT tree is + // guaranteed to fail later since the compressed data section must be + // composed of at least one symbol (the end-of-block marker). + if max == 0 { + return true + } + + code := 0 + var nextcode [maxCodeLen]int + for i := min; i <= max; i++ { + code <<= 1 + nextcode[i&maxCodeLenMask] = code + code += count[i&maxCodeLenMask] + } + + // Check that the coding is complete (i.e., that we've + // assigned all 2-to-the-max possible bit sequences). + // Exception: To be compatible with zlib, we also need to + // accept degenerate single-code codings. See also + // TestDegenerateHuffmanCoding. + if code != 1< huffmanChunkBits { + numLinks := 1 << (uint(max) - huffmanChunkBits) + h.linkMask = uint32(numLinks - 1) + + // create link tables + link := nextcode[huffmanChunkBits+1] >> 1 + if cap(h.links) < huffmanNumChunks-link { + h.links = make([][]uint16, huffmanNumChunks-link) + } else { + h.links = h.links[:huffmanNumChunks-link] + } + for j := uint(link); j < huffmanNumChunks; j++ { + reverse := int(bits.Reverse16(uint16(j))) + reverse >>= uint(16 - huffmanChunkBits) + off := j - uint(link) + if sanity && h.chunks[reverse] != 0 { + panic("impossible: overwriting existing chunk") + } + h.chunks[reverse] = uint16(off<>= uint(16 - n) + if n <= huffmanChunkBits { + for off := reverse; off < len(h.chunks); off += 1 << uint(n) { + // We should never need to overwrite + // an existing chunk. Also, 0 is + // never a valid chunk, because the + // lower 4 "count" bits should be + // between 1 and 15. + if sanity && h.chunks[off] != 0 { + panic("impossible: overwriting existing chunk") + } + h.chunks[off] = chunk + } + } else { + j := reverse & (huffmanNumChunks - 1) + if sanity && h.chunks[j]&huffmanCountMask != huffmanChunkBits+1 { + // Longer codes should have been + // associated with a link table above. + panic("impossible: not an indirect chunk") + } + value := h.chunks[j] >> huffmanValueShift + linktab := h.links[value] + reverse >>= huffmanChunkBits + for off := reverse; off < len(linktab); off += 1 << uint(n-huffmanChunkBits) { + if sanity && linktab[off] != 0 { + panic("impossible: overwriting existing chunk") + } + linktab[off] = chunk + } + } + } + + if sanity { + // Above we've sanity checked that we never overwrote + // an existing entry. Here we additionally check that + // we filled the tables completely. + for i, chunk := range h.chunks { + if chunk == 0 { + // As an exception, in the degenerate + // single-code case, we allow odd + // chunks to be missing. + if code == 1 && i%2 == 1 { + continue + } + panic("impossible: missing chunk") + } + } + for _, linktab := range h.links { + for _, chunk := range linktab { + if chunk == 0 { + panic("impossible: missing chunk") + } + } + } + } + + return true +} + +// The actual read interface needed by NewReader. +// If the passed in io.Reader does not also have ReadByte, +// the NewReader will introduce its own buffering. +type Reader interface { + io.Reader + io.ByteReader +} + +// Decompress state. +type decompressor struct { + // Input source. + r Reader + roffset int64 + + // Huffman decoders for literal/length, distance. + h1, h2 huffmanDecoder + + // Length arrays used to define Huffman codes. + bits *[maxNumLit + maxNumDist]int + codebits *[numCodes]int + + // Output history, buffer. + dict dictDecoder + + // Next step in the decompression, + // and decompression state. + step func(*decompressor) + stepState int + err error + toRead []byte + hl, hd *huffmanDecoder + copyLen int + copyDist int + + // Temporary buffer (avoids repeated allocation). + buf [4]byte + + // Input bits, in top of b. + b uint32 + + nb uint + final bool +} + +func (f *decompressor) nextBlock() { + for f.nb < 1+2 { + if f.err = f.moreBits(); f.err != nil { + return + } + } + f.final = f.b&1 == 1 + f.b >>= 1 + typ := f.b & 3 + f.b >>= 2 + f.nb -= 1 + 2 + switch typ { + case 0: + f.dataBlock() + case 1: + // compressed, fixed Huffman tables + f.hl = &fixedHuffmanDecoder + f.hd = nil + f.huffmanBlockDecoder()() + case 2: + // compressed, dynamic Huffman tables + if f.err = f.readHuffman(); f.err != nil { + break + } + f.hl = &f.h1 + f.hd = &f.h2 + f.huffmanBlockDecoder()() + default: + // 3 is reserved. + if debugDecode { + fmt.Println("reserved data block encountered") + } + f.err = CorruptInputError(f.roffset) + } +} + +func (f *decompressor) Read(b []byte) (int, error) { + for { + if len(f.toRead) > 0 { + n := copy(b, f.toRead) + f.toRead = f.toRead[n:] + if len(f.toRead) == 0 { + return n, f.err + } + return n, nil + } + if f.err != nil { + return 0, f.err + } + f.step(f) + if f.err != nil && len(f.toRead) == 0 { + f.toRead = f.dict.readFlush() // Flush what's left in case of error + } + } +} + +// Support the io.WriteTo interface for io.Copy and friends. +func (f *decompressor) WriteTo(w io.Writer) (int64, error) { + total := int64(0) + flushed := false + for { + if len(f.toRead) > 0 { + n, err := w.Write(f.toRead) + total += int64(n) + if err != nil { + f.err = err + return total, err + } + if n != len(f.toRead) { + return total, io.ErrShortWrite + } + f.toRead = f.toRead[:0] + } + if f.err != nil && flushed { + if f.err == io.EOF { + return total, nil + } + return total, f.err + } + if f.err == nil { + f.step(f) + } + if len(f.toRead) == 0 && f.err != nil && !flushed { + f.toRead = f.dict.readFlush() // Flush what's left in case of error + flushed = true + } + } +} + +func (f *decompressor) Close() error { + if f.err == io.EOF { + return nil + } + return f.err +} + +// RFC 1951 section 3.2.7. +// Compression with dynamic Huffman codes + +var codeOrder = [...]int{16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15} + +func (f *decompressor) readHuffman() error { + // HLIT[5], HDIST[5], HCLEN[4]. + for f.nb < 5+5+4 { + if err := f.moreBits(); err != nil { + return err + } + } + nlit := int(f.b&0x1F) + 257 + if nlit > maxNumLit { + if debugDecode { + fmt.Println("nlit > maxNumLit", nlit) + } + return CorruptInputError(f.roffset) + } + f.b >>= 5 + ndist := int(f.b&0x1F) + 1 + if ndist > maxNumDist { + if debugDecode { + fmt.Println("ndist > maxNumDist", ndist) + } + return CorruptInputError(f.roffset) + } + f.b >>= 5 + nclen := int(f.b&0xF) + 4 + // numCodes is 19, so nclen is always valid. + f.b >>= 4 + f.nb -= 5 + 5 + 4 + + // (HCLEN+4)*3 bits: code lengths in the magic codeOrder order. + for i := 0; i < nclen; i++ { + for f.nb < 3 { + if err := f.moreBits(); err != nil { + return err + } + } + f.codebits[codeOrder[i]] = int(f.b & 0x7) + f.b >>= 3 + f.nb -= 3 + } + for i := nclen; i < len(codeOrder); i++ { + f.codebits[codeOrder[i]] = 0 + } + if !f.h1.init(f.codebits[0:]) { + if debugDecode { + fmt.Println("init codebits failed") + } + return CorruptInputError(f.roffset) + } + + // HLIT + 257 code lengths, HDIST + 1 code lengths, + // using the code length Huffman code. + for i, n := 0, nlit+ndist; i < n; { + x, err := f.huffSym(&f.h1) + if err != nil { + return err + } + if x < 16 { + // Actual length. + f.bits[i] = x + i++ + continue + } + // Repeat previous length or zero. + var rep int + var nb uint + var b int + switch x { + default: + return InternalError("unexpected length code") + case 16: + rep = 3 + nb = 2 + if i == 0 { + if debugDecode { + fmt.Println("i==0") + } + return CorruptInputError(f.roffset) + } + b = f.bits[i-1] + case 17: + rep = 3 + nb = 3 + b = 0 + case 18: + rep = 11 + nb = 7 + b = 0 + } + for f.nb < nb { + if err := f.moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits:", err) + } + return err + } + } + rep += int(f.b & uint32(1<<(nb®SizeMaskUint32)-1)) + f.b >>= nb & regSizeMaskUint32 + f.nb -= nb + if i+rep > n { + if debugDecode { + fmt.Println("i+rep > n", i, rep, n) + } + return CorruptInputError(f.roffset) + } + for j := 0; j < rep; j++ { + f.bits[i] = b + i++ + } + } + + if !f.h1.init(f.bits[0:nlit]) || !f.h2.init(f.bits[nlit:nlit+ndist]) { + if debugDecode { + fmt.Println("init2 failed") + } + return CorruptInputError(f.roffset) + } + + // As an optimization, we can initialize the maxRead bits to read at a time + // for the HLIT tree to the length of the EOB marker since we know that + // every block must terminate with one. This preserves the property that + // we never read any extra bytes after the end of the DEFLATE stream. + if f.h1.maxRead < f.bits[endBlockMarker] { + f.h1.maxRead = f.bits[endBlockMarker] + } + if !f.final { + // If not the final block, the smallest block possible is + // a predefined table, BTYPE=01, with a single EOB marker. + // This will take up 3 + 7 bits. + f.h1.maxRead += 10 + } + + return nil +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBlockGeneric() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := f.r.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var n uint // number of bits extra + var length int + var err error + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBlockGeneric + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + n = 0 + case v < 269: + length = v*2 - (265*2 - 11) + n = 1 + case v < 273: + length = v*4 - (269*4 - 19) + n = 2 + case v < 277: + length = v*8 - (273*8 - 35) + n = 3 + case v < 281: + length = v*16 - (277*16 - 67) + n = 4 + case v < 285: + length = v*32 - (281*32 - 131) + n = 5 + case v < maxNumLit: + length = 258 + n = 0 + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + if n > 0 { + for f.nb < n { + if err = f.moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + } + length += int(f.b & uint32(1<<(n®SizeMaskUint32)-1)) + f.b >>= n & regSizeMaskUint32 + f.nb -= n + } + + var dist uint32 + if f.hd == nil { + for f.nb < 5 { + if err = f.moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + } + dist = uint32(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + sym, err := f.huffSym(f.hd) + if err != nil { + if debugDecode { + fmt.Println("huffsym:", err) + } + f.err = err + return + } + dist = uint32(sym) + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << (nb & regSizeMaskUint32) + for f.nb < nb { + if err = f.moreBits(); err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb & regSizeMaskUint32 + f.nb -= nb + dist = 1<<((nb+1)®SizeMaskUint32) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > uint32(f.dict.histSize()) { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, int(dist) + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBlockGeneric // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Copy a single uncompressed data block from input to output. +func (f *decompressor) dataBlock() { + // Uncompressed. + // Discard current half-byte. + left := (f.nb) & 7 + f.nb -= left + f.b >>= left + + offBytes := f.nb >> 3 + // Unfilled values will be overwritten. + f.buf[0] = uint8(f.b) + f.buf[1] = uint8(f.b >> 8) + f.buf[2] = uint8(f.b >> 16) + f.buf[3] = uint8(f.b >> 24) + + f.roffset += int64(offBytes) + f.nb, f.b = 0, 0 + + // Length then ones-complement of length. + nr, err := io.ReadFull(f.r, f.buf[offBytes:4]) + f.roffset += int64(nr) + if err != nil { + f.err = noEOF(err) + return + } + n := uint16(f.buf[0]) | uint16(f.buf[1])<<8 + nn := uint16(f.buf[2]) | uint16(f.buf[3])<<8 + if nn != ^n { + if debugDecode { + ncomp := ^n + fmt.Println("uint16(nn) != uint16(^n)", nn, ncomp) + } + f.err = CorruptInputError(f.roffset) + return + } + + if n == 0 { + f.toRead = f.dict.readFlush() + f.finishBlock() + return + } + + f.copyLen = int(n) + f.copyData() +} + +// copyData copies f.copyLen bytes from the underlying reader into f.hist. +// It pauses for reads when f.hist is full. +func (f *decompressor) copyData() { + buf := f.dict.writeSlice() + if len(buf) > f.copyLen { + buf = buf[:f.copyLen] + } + + cnt, err := io.ReadFull(f.r, buf) + f.roffset += int64(cnt) + f.copyLen -= cnt + f.dict.writeMark(cnt) + if err != nil { + f.err = noEOF(err) + return + } + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).copyData + return + } + f.finishBlock() +} + +func (f *decompressor) finishBlock() { + if f.final { + if f.dict.availRead() > 0 { + f.toRead = f.dict.readFlush() + } + f.err = io.EOF + } + f.step = (*decompressor).nextBlock +} + +// noEOF returns err, unless err == io.EOF, in which case it returns io.ErrUnexpectedEOF. +func noEOF(e error) error { + if e == io.EOF { + return io.ErrUnexpectedEOF + } + return e +} + +func (f *decompressor) moreBits() error { + c, err := f.r.ReadByte() + if err != nil { + return noEOF(err) + } + f.roffset++ + f.b |= uint32(c) << (f.nb & regSizeMaskUint32) + f.nb += 8 + return nil +} + +// Read the next Huffman-encoded symbol from f according to h. +func (f *decompressor) huffSym(h *huffmanDecoder) (int, error) { + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(h.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := f.r.ReadByte() + if err != nil { + f.b = b + f.nb = nb + return 0, noEOF(err) + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := h.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = h.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&h.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return 0, f.err + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + return int(chunk >> huffmanValueShift), nil + } + } +} + +func makeReader(r io.Reader) Reader { + if rr, ok := r.(Reader); ok { + return rr + } + return bufio.NewReader(r) +} + +func fixedHuffmanDecoderInit() { + fixedOnce.Do(func() { + // These come from the RFC section 3.2.6. + var bits [288]int + for i := 0; i < 144; i++ { + bits[i] = 8 + } + for i := 144; i < 256; i++ { + bits[i] = 9 + } + for i := 256; i < 280; i++ { + bits[i] = 7 + } + for i := 280; i < 288; i++ { + bits[i] = 8 + } + fixedHuffmanDecoder.init(bits[:]) + }) +} + +func (f *decompressor) Reset(r io.Reader, dict []byte) error { + *f = decompressor{ + r: makeReader(r), + bits: f.bits, + codebits: f.codebits, + h1: f.h1, + h2: f.h2, + dict: f.dict, + step: (*decompressor).nextBlock, + } + f.dict.init(maxMatchOffset, dict) + return nil +} + +// NewReader returns a new ReadCloser that can be used +// to read the uncompressed version of r. +// If r does not also implement io.ByteReader, +// the decompressor may read more data than necessary from r. +// It is the caller's responsibility to call Close on the ReadCloser +// when finished reading. +// +// The ReadCloser returned by NewReader also implements Resetter. +func NewReader(r io.Reader) io.ReadCloser { + fixedHuffmanDecoderInit() + + var f decompressor + f.r = makeReader(r) + f.bits = new([maxNumLit + maxNumDist]int) + f.codebits = new([numCodes]int) + f.step = (*decompressor).nextBlock + f.dict.init(maxMatchOffset, nil) + return &f +} + +// NewReaderDict is like NewReader but initializes the reader +// with a preset dictionary. The returned Reader behaves as if +// the uncompressed data stream started with the given dictionary, +// which has already been read. NewReaderDict is typically used +// to read data compressed by NewWriterDict. +// +// The ReadCloser returned by NewReader also implements Resetter. +func NewReaderDict(r io.Reader, dict []byte) io.ReadCloser { + fixedHuffmanDecoderInit() + + var f decompressor + f.r = makeReader(r) + f.bits = new([maxNumLit + maxNumDist]int) + f.codebits = new([numCodes]int) + f.step = (*decompressor).nextBlock + f.dict.init(maxMatchOffset, dict) + return &f +} diff --git a/vendor/github.com/klauspost/compress/flate/inflate_gen.go b/vendor/github.com/klauspost/compress/flate/inflate_gen.go new file mode 100644 index 0000000000..cc6db27925 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/inflate_gen.go @@ -0,0 +1,1002 @@ +// Code generated by go generate gen_inflate.go. DO NOT EDIT. + +package flate + +import ( + "bufio" + "bytes" + "fmt" + "math/bits" + "strings" +) + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBytesBuffer() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.(*bytes.Buffer) + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var length int + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBytesBuffer + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + case v < maxNumLit: + val := decCodeToLen[(v - 257)] + length = int(val.length) + 3 + n := uint(val.extra) + for f.nb < n { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + } + length += int(f.b & uint32(1<<(n®SizeMaskUint32)-1)) + f.b >>= n & regSizeMaskUint32 + f.nb -= n + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + + var dist uint32 + if f.hd == nil { + for f.nb < 5 { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + } + dist = uint32(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hd.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hd.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hd.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hd.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + dist = uint32(chunk >> huffmanValueShift) + break + } + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << (nb & regSizeMaskUint32) + for f.nb < nb { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb & regSizeMaskUint32 + f.nb -= nb + dist = 1<<((nb+1)®SizeMaskUint32) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > uint32(f.dict.histSize()) { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, int(dist) + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBytesBuffer // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBytesReader() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.(*bytes.Reader) + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var length int + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBytesReader + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + case v < maxNumLit: + val := decCodeToLen[(v - 257)] + length = int(val.length) + 3 + n := uint(val.extra) + for f.nb < n { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + } + length += int(f.b & uint32(1<<(n®SizeMaskUint32)-1)) + f.b >>= n & regSizeMaskUint32 + f.nb -= n + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + + var dist uint32 + if f.hd == nil { + for f.nb < 5 { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + } + dist = uint32(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hd.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hd.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hd.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hd.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + dist = uint32(chunk >> huffmanValueShift) + break + } + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << (nb & regSizeMaskUint32) + for f.nb < nb { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb & regSizeMaskUint32 + f.nb -= nb + dist = 1<<((nb+1)®SizeMaskUint32) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > uint32(f.dict.histSize()) { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, int(dist) + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBytesReader // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanBufioReader() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.(*bufio.Reader) + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var length int + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBufioReader + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + case v < maxNumLit: + val := decCodeToLen[(v - 257)] + length = int(val.length) + 3 + n := uint(val.extra) + for f.nb < n { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + } + length += int(f.b & uint32(1<<(n®SizeMaskUint32)-1)) + f.b >>= n & regSizeMaskUint32 + f.nb -= n + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + + var dist uint32 + if f.hd == nil { + for f.nb < 5 { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + } + dist = uint32(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hd.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hd.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hd.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hd.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + dist = uint32(chunk >> huffmanValueShift) + break + } + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << (nb & regSizeMaskUint32) + for f.nb < nb { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb & regSizeMaskUint32 + f.nb -= nb + dist = 1<<((nb+1)®SizeMaskUint32) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > uint32(f.dict.histSize()) { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, int(dist) + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanBufioReader // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +// Decode a single Huffman block from f. +// hl and hd are the Huffman states for the lit/length values +// and the distance values, respectively. If hd == nil, using the +// fixed distance encoding associated with fixed Huffman blocks. +func (f *decompressor) huffmanStringsReader() { + const ( + stateInit = iota // Zero value must be stateInit + stateDict + ) + fr := f.r.(*strings.Reader) + + switch f.stepState { + case stateInit: + goto readLiteral + case stateDict: + goto copyHistory + } + +readLiteral: + // Read literal and/or (length, distance) according to RFC section 3.2.3. + { + var v int + { + // Inlined v, err := f.huffSym(f.hl) + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hl.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hl.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hl.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hl.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + v = int(chunk >> huffmanValueShift) + break + } + } + } + + var length int + switch { + case v < 256: + f.dict.writeByte(byte(v)) + if f.dict.availWrite() == 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanStringsReader + f.stepState = stateInit + return + } + goto readLiteral + case v == 256: + f.finishBlock() + return + // otherwise, reference to older data + case v < 265: + length = v - (257 - 3) + case v < maxNumLit: + val := decCodeToLen[(v - 257)] + length = int(val.length) + 3 + n := uint(val.extra) + for f.nb < n { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits n>0:", err) + } + f.err = err + return + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + } + length += int(f.b & uint32(1<<(n®SizeMaskUint32)-1)) + f.b >>= n & regSizeMaskUint32 + f.nb -= n + default: + if debugDecode { + fmt.Println(v, ">= maxNumLit") + } + f.err = CorruptInputError(f.roffset) + return + } + + var dist uint32 + if f.hd == nil { + for f.nb < 5 { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits f.nb<5:", err) + } + f.err = err + return + } + f.roffset++ + f.b |= uint32(c) << f.nb + f.nb += 8 + } + dist = uint32(bits.Reverse8(uint8(f.b & 0x1F << 3))) + f.b >>= 5 + f.nb -= 5 + } else { + // Since a huffmanDecoder can be empty or be composed of a degenerate tree + // with single element, huffSym must error on these two edge cases. In both + // cases, the chunks slice will be 0 for the invalid sequence, leading it + // satisfy the n == 0 check below. + n := uint(f.hd.maxRead) + // Optimization. Compiler isn't smart enough to keep f.b,f.nb in registers, + // but is smart enough to keep local variables in registers, so use nb and b, + // inline call to moreBits and reassign b,nb back to f on return. + nb, b := f.nb, f.b + for { + for nb < n { + c, err := fr.ReadByte() + if err != nil { + f.b = b + f.nb = nb + f.err = noEOF(err) + return + } + f.roffset++ + b |= uint32(c) << (nb & regSizeMaskUint32) + nb += 8 + } + chunk := f.hd.chunks[b&(huffmanNumChunks-1)] + n = uint(chunk & huffmanCountMask) + if n > huffmanChunkBits { + chunk = f.hd.links[chunk>>huffmanValueShift][(b>>huffmanChunkBits)&f.hd.linkMask] + n = uint(chunk & huffmanCountMask) + } + if n <= nb { + if n == 0 { + f.b = b + f.nb = nb + if debugDecode { + fmt.Println("huffsym: n==0") + } + f.err = CorruptInputError(f.roffset) + return + } + f.b = b >> (n & regSizeMaskUint32) + f.nb = nb - n + dist = uint32(chunk >> huffmanValueShift) + break + } + } + } + + switch { + case dist < 4: + dist++ + case dist < maxNumDist: + nb := uint(dist-2) >> 1 + // have 1 bit in bottom of dist, need nb more. + extra := (dist & 1) << (nb & regSizeMaskUint32) + for f.nb < nb { + c, err := fr.ReadByte() + if err != nil { + if debugDecode { + fmt.Println("morebits f.nb>= nb & regSizeMaskUint32 + f.nb -= nb + dist = 1<<((nb+1)®SizeMaskUint32) + 1 + extra + default: + if debugDecode { + fmt.Println("dist too big:", dist, maxNumDist) + } + f.err = CorruptInputError(f.roffset) + return + } + + // No check on length; encoding can be prescient. + if dist > uint32(f.dict.histSize()) { + if debugDecode { + fmt.Println("dist > f.dict.histSize():", dist, f.dict.histSize()) + } + f.err = CorruptInputError(f.roffset) + return + } + + f.copyLen, f.copyDist = length, int(dist) + goto copyHistory + } + +copyHistory: + // Perform a backwards copy according to RFC section 3.2.3. + { + cnt := f.dict.tryWriteCopy(f.copyDist, f.copyLen) + if cnt == 0 { + cnt = f.dict.writeCopy(f.copyDist, f.copyLen) + } + f.copyLen -= cnt + + if f.dict.availWrite() == 0 || f.copyLen > 0 { + f.toRead = f.dict.readFlush() + f.step = (*decompressor).huffmanStringsReader // We need to continue this work + f.stepState = stateDict + return + } + goto readLiteral + } +} + +func (f *decompressor) huffmanBlockDecoder() func() { + switch f.r.(type) { + case *bytes.Buffer: + return f.huffmanBytesBuffer + case *bytes.Reader: + return f.huffmanBytesReader + case *bufio.Reader: + return f.huffmanBufioReader + case *strings.Reader: + return f.huffmanStringsReader + default: + return f.huffmanBlockGeneric + } +} diff --git a/vendor/github.com/klauspost/compress/flate/level1.go b/vendor/github.com/klauspost/compress/flate/level1.go new file mode 100644 index 0000000000..1e5eea3968 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/level1.go @@ -0,0 +1,179 @@ +package flate + +import "fmt" + +// fastGen maintains the table for matches, +// and the previous byte block for level 2. +// This is the generic implementation. +type fastEncL1 struct { + fastGen + table [tableSize]tableEntry +} + +// EncodeL1 uses a similar algorithm to level 1 +func (e *fastEncL1) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load3232(src, s) + + for { + const skipLog = 5 + const doEvery = 2 + + nextS := s + var candidate tableEntry + for { + nextHash := hash(cv) + candidate = e.table[nextHash] + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + + now := load6432(src, nextS) + e.table[nextHash] = tableEntry{offset: s + e.cur} + nextHash = hash(uint32(now)) + + offset := s - (candidate.offset - e.cur) + if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { + e.table[nextHash] = tableEntry{offset: nextS + e.cur} + break + } + + // Do one right away... + cv = uint32(now) + s = nextS + nextS++ + candidate = e.table[nextHash] + now >>= 8 + e.table[nextHash] = tableEntry{offset: s + e.cur} + + offset = s - (candidate.offset - e.cur) + if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { + e.table[nextHash] = tableEntry{offset: nextS + e.cur} + break + } + cv = uint32(now) + s = nextS + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + t := candidate.offset - e.cur + l := e.matchlenLong(s+4, t+4, src) + 4 + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + + // Save the match found + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + if s >= sLimit { + // Index first pair after match end. + if int(s+l+4) < len(src) { + cv := load3232(src, s) + e.table[hash(cv)] = tableEntry{offset: s + e.cur} + } + goto emitRemainder + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-2 and at s. If + // another emitCopy is not our next move, also calculate nextHash + // at s+1. At least on GOARCH=amd64, these three hash calculations + // are faster as one load64 call (with some shifts) instead of + // three load32 calls. + x := load6432(src, s-2) + o := e.cur + s - 2 + prevHash := hash(uint32(x)) + e.table[prevHash] = tableEntry{offset: o} + x >>= 16 + currHash := hash(uint32(x)) + candidate = e.table[currHash] + e.table[currHash] = tableEntry{offset: o + 2} + + offset := s - (candidate.offset - e.cur) + if offset > maxMatchOffset || uint32(x) != load3232(src, candidate.offset-e.cur) { + cv = uint32(x >> 8) + s++ + break + } + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/vendor/github.com/klauspost/compress/flate/level2.go b/vendor/github.com/klauspost/compress/flate/level2.go new file mode 100644 index 0000000000..234c4389ab --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/level2.go @@ -0,0 +1,205 @@ +package flate + +import "fmt" + +// fastGen maintains the table for matches, +// and the previous byte block for level 2. +// This is the generic implementation. +type fastEncL2 struct { + fastGen + table [bTableSize]tableEntry +} + +// EncodeL2 uses a similar algorithm to level 1, but is capable +// of matching across blocks giving better compression at a small slowdown. +func (e *fastEncL2) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load3232(src, s) + for { + // When should we start skipping if we haven't found matches in a long while. + const skipLog = 5 + const doEvery = 2 + + nextS := s + var candidate tableEntry + for { + nextHash := hash4u(cv, bTableBits) + s = nextS + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + candidate = e.table[nextHash] + now := load6432(src, nextS) + e.table[nextHash] = tableEntry{offset: s + e.cur} + nextHash = hash4u(uint32(now), bTableBits) + + offset := s - (candidate.offset - e.cur) + if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { + e.table[nextHash] = tableEntry{offset: nextS + e.cur} + break + } + + // Do one right away... + cv = uint32(now) + s = nextS + nextS++ + candidate = e.table[nextHash] + now >>= 8 + e.table[nextHash] = tableEntry{offset: s + e.cur} + + offset = s - (candidate.offset - e.cur) + if offset < maxMatchOffset && cv == load3232(src, candidate.offset-e.cur) { + break + } + cv = uint32(now) + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + // Call emitCopy, and then see if another emitCopy could be our next + // move. Repeat until we find no match for the input immediately after + // what was consumed by the last emitCopy call. + // + // If we exit this loop normally then we need to call emitLiteral next, + // though we don't yet know how big the literal will be. We handle that + // by proceeding to the next iteration of the main loop. We also can + // exit this loop via goto if we get close to exhausting the input. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + t := candidate.offset - e.cur + l := e.matchlenLong(s+4, t+4, src) + 4 + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + // Index first pair after match end. + if int(s+l+4) < len(src) { + cv := load3232(src, s) + e.table[hash4u(cv, bTableBits)] = tableEntry{offset: s + e.cur} + } + goto emitRemainder + } + + // Store every second hash in-between, but offset by 1. + for i := s - l + 2; i < s-5; i += 7 { + x := load6432(src, i) + nextHash := hash4u(uint32(x), bTableBits) + e.table[nextHash] = tableEntry{offset: e.cur + i} + // Skip one + x >>= 16 + nextHash = hash4u(uint32(x), bTableBits) + e.table[nextHash] = tableEntry{offset: e.cur + i + 2} + // Skip one + x >>= 16 + nextHash = hash4u(uint32(x), bTableBits) + e.table[nextHash] = tableEntry{offset: e.cur + i + 4} + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-2 to s. If + // another emitCopy is not our next move, also calculate nextHash + // at s+1. At least on GOARCH=amd64, these three hash calculations + // are faster as one load64 call (with some shifts) instead of + // three load32 calls. + x := load6432(src, s-2) + o := e.cur + s - 2 + prevHash := hash4u(uint32(x), bTableBits) + prevHash2 := hash4u(uint32(x>>8), bTableBits) + e.table[prevHash] = tableEntry{offset: o} + e.table[prevHash2] = tableEntry{offset: o + 1} + currHash := hash4u(uint32(x>>16), bTableBits) + candidate = e.table[currHash] + e.table[currHash] = tableEntry{offset: o + 2} + + offset := s - (candidate.offset - e.cur) + if offset > maxMatchOffset || uint32(x>>16) != load3232(src, candidate.offset-e.cur) { + cv = uint32(x >> 24) + s++ + break + } + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/vendor/github.com/klauspost/compress/flate/level3.go b/vendor/github.com/klauspost/compress/flate/level3.go new file mode 100644 index 0000000000..c22b4244a5 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/level3.go @@ -0,0 +1,229 @@ +package flate + +import "fmt" + +// fastEncL3 +type fastEncL3 struct { + fastGen + table [tableSize]tableEntryPrev +} + +// Encode uses a similar algorithm to level 2, will check up to two candidates. +func (e *fastEncL3) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 8 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntryPrev{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i] + if v.Cur.offset <= minOff { + v.Cur.offset = 0 + } else { + v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset + } + if v.Prev.offset <= minOff { + v.Prev.offset = 0 + } else { + v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset + } + e.table[i] = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // Skip if too small. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load3232(src, s) + for { + const skipLog = 6 + nextS := s + var candidate tableEntry + for { + nextHash := hash(cv) + s = nextS + nextS = s + 1 + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + candidates := e.table[nextHash] + now := load3232(src, nextS) + + // Safe offset distance until s + 4... + minOffset := e.cur + s - (maxMatchOffset - 4) + e.table[nextHash] = tableEntryPrev{Prev: candidates.Cur, Cur: tableEntry{offset: s + e.cur}} + + // Check both candidates + candidate = candidates.Cur + if candidate.offset < minOffset { + cv = now + // Previous will also be invalid, we have nothing. + continue + } + + if cv == load3232(src, candidate.offset-e.cur) { + if candidates.Prev.offset < minOffset || cv != load3232(src, candidates.Prev.offset-e.cur) { + break + } + // Both match and are valid, pick longest. + offset := s - (candidate.offset - e.cur) + o2 := s - (candidates.Prev.offset - e.cur) + l1, l2 := matchLen(src[s+4:], src[s-offset+4:]), matchLen(src[s+4:], src[s-o2+4:]) + if l2 > l1 { + candidate = candidates.Prev + } + break + } else { + // We only check if value mismatches. + // Offset will always be invalid in other cases. + candidate = candidates.Prev + if candidate.offset > minOffset && cv == load3232(src, candidate.offset-e.cur) { + break + } + } + cv = now + } + + // Call emitCopy, and then see if another emitCopy could be our next + // move. Repeat until we find no match for the input immediately after + // what was consumed by the last emitCopy call. + // + // If we exit this loop normally then we need to call emitLiteral next, + // though we don't yet know how big the literal will be. We handle that + // by proceeding to the next iteration of the main loop. We also can + // exit this loop via goto if we get close to exhausting the input. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + // + t := candidate.offset - e.cur + l := e.matchlenLong(s+4, t+4, src) + 4 + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + t += l + // Index first pair after match end. + if int(t+4) < len(src) && t > 0 { + cv := load3232(src, t) + nextHash := hash(cv) + e.table[nextHash] = tableEntryPrev{ + Prev: e.table[nextHash].Cur, + Cur: tableEntry{offset: e.cur + t}, + } + } + goto emitRemainder + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-3 to s. + x := load6432(src, s-3) + prevHash := hash(uint32(x)) + e.table[prevHash] = tableEntryPrev{ + Prev: e.table[prevHash].Cur, + Cur: tableEntry{offset: e.cur + s - 3}, + } + x >>= 8 + prevHash = hash(uint32(x)) + + e.table[prevHash] = tableEntryPrev{ + Prev: e.table[prevHash].Cur, + Cur: tableEntry{offset: e.cur + s - 2}, + } + x >>= 8 + prevHash = hash(uint32(x)) + + e.table[prevHash] = tableEntryPrev{ + Prev: e.table[prevHash].Cur, + Cur: tableEntry{offset: e.cur + s - 1}, + } + x >>= 8 + currHash := hash(uint32(x)) + candidates := e.table[currHash] + cv = uint32(x) + e.table[currHash] = tableEntryPrev{ + Prev: candidates.Cur, + Cur: tableEntry{offset: s + e.cur}, + } + + // Check both candidates + candidate = candidates.Cur + minOffset := e.cur + s - (maxMatchOffset - 4) + + if candidate.offset > minOffset && cv != load3232(src, candidate.offset-e.cur) { + // We only check if value mismatches. + // Offset will always be invalid in other cases. + candidate = candidates.Prev + if candidate.offset > minOffset && cv == load3232(src, candidate.offset-e.cur) { + offset := s - (candidate.offset - e.cur) + if offset <= maxMatchOffset { + continue + } + } + } + cv = uint32(x >> 8) + s++ + break + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/vendor/github.com/klauspost/compress/flate/level4.go b/vendor/github.com/klauspost/compress/flate/level4.go new file mode 100644 index 0000000000..e62f0c02b1 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/level4.go @@ -0,0 +1,212 @@ +package flate + +import "fmt" + +type fastEncL4 struct { + fastGen + table [tableSize]tableEntry + bTable [tableSize]tableEntry +} + +func (e *fastEncL4) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + for i := range e.bTable[:] { + e.bTable[i] = tableEntry{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + for i := range e.bTable[:] { + v := e.bTable[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.bTable[i].offset = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load6432(src, s) + for { + const skipLog = 6 + const doEvery = 1 + + nextS := s + var t int32 + for { + nextHashS := hash4x64(cv, tableBits) + nextHashL := hash7(cv, tableBits) + + s = nextS + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + // Fetch a short+long candidate + sCandidate := e.table[nextHashS] + lCandidate := e.bTable[nextHashL] + next := load6432(src, nextS) + entry := tableEntry{offset: s + e.cur} + e.table[nextHashS] = entry + e.bTable[nextHashL] = entry + + t = lCandidate.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.offset-e.cur) { + // We got a long match. Use that. + break + } + + t = sCandidate.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) { + // Found a 4 match... + lCandidate = e.bTable[hash7(next, tableBits)] + + // If the next long is a candidate, check if we should use that instead... + lOff := nextS - (lCandidate.offset - e.cur) + if lOff < maxMatchOffset && load3232(src, lCandidate.offset-e.cur) == uint32(next) { + l1, l2 := matchLen(src[s+4:], src[t+4:]), matchLen(src[nextS+4:], src[nextS-lOff+4:]) + if l2 > l1 { + s = nextS + t = lCandidate.offset - e.cur + } + } + break + } + cv = next + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + // Extend the 4-byte match as long as possible. + l := e.matchlenLong(s+4, t+4, src) + 4 + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + if debugDeflate { + if t >= s { + panic("s-t") + } + if (s - t) > maxMatchOffset { + panic(fmt.Sprintln("mmo", t)) + } + if l < baseMatchLength { + panic("bml") + } + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + // Index first pair after match end. + if int(s+8) < len(src) { + cv := load6432(src, s) + e.table[hash4x64(cv, tableBits)] = tableEntry{offset: s + e.cur} + e.bTable[hash7(cv, tableBits)] = tableEntry{offset: s + e.cur} + } + goto emitRemainder + } + + // Store every 3rd hash in-between + if true { + i := nextS + if i < s-1 { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + t2 := tableEntry{offset: t.offset + 1} + e.bTable[hash7(cv, tableBits)] = t + e.bTable[hash7(cv>>8, tableBits)] = t2 + e.table[hash4u(uint32(cv>>8), tableBits)] = t2 + + i += 3 + for ; i < s-1; i += 3 { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + t2 := tableEntry{offset: t.offset + 1} + e.bTable[hash7(cv, tableBits)] = t + e.bTable[hash7(cv>>8, tableBits)] = t2 + e.table[hash4u(uint32(cv>>8), tableBits)] = t2 + } + } + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-1 and at s. + x := load6432(src, s-1) + o := e.cur + s - 1 + prevHashS := hash4x64(x, tableBits) + prevHashL := hash7(x, tableBits) + e.table[prevHashS] = tableEntry{offset: o} + e.bTable[prevHashL] = tableEntry{offset: o} + cv = x >> 8 + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/vendor/github.com/klauspost/compress/flate/level5.go b/vendor/github.com/klauspost/compress/flate/level5.go new file mode 100644 index 0000000000..293a3a320b --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/level5.go @@ -0,0 +1,294 @@ +package flate + +import "fmt" + +type fastEncL5 struct { + fastGen + table [tableSize]tableEntry + bTable [tableSize]tableEntryPrev +} + +func (e *fastEncL5) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + for i := range e.bTable[:] { + e.bTable[i] = tableEntryPrev{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + for i := range e.bTable[:] { + v := e.bTable[i] + if v.Cur.offset <= minOff { + v.Cur.offset = 0 + v.Prev.offset = 0 + } else { + v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset + if v.Prev.offset <= minOff { + v.Prev.offset = 0 + } else { + v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset + } + } + e.bTable[i] = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load6432(src, s) + for { + const skipLog = 6 + const doEvery = 1 + + nextS := s + var l int32 + var t int32 + for { + nextHashS := hash4x64(cv, tableBits) + nextHashL := hash7(cv, tableBits) + + s = nextS + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + // Fetch a short+long candidate + sCandidate := e.table[nextHashS] + lCandidate := e.bTable[nextHashL] + next := load6432(src, nextS) + entry := tableEntry{offset: s + e.cur} + e.table[nextHashS] = entry + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = entry, eLong.Cur + + nextHashS = hash4x64(next, tableBits) + nextHashL = hash7(next, tableBits) + + t = lCandidate.Cur.offset - e.cur + if s-t < maxMatchOffset { + if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) { + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + + t2 := lCandidate.Prev.offset - e.cur + if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) { + l = e.matchlen(s+4, t+4, src) + 4 + ml1 := e.matchlen(s+4, t2+4, src) + 4 + if ml1 > l { + t = t2 + l = ml1 + break + } + } + break + } + t = lCandidate.Prev.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) { + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + break + } + } + + t = sCandidate.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) { + // Found a 4 match... + l = e.matchlen(s+4, t+4, src) + 4 + lCandidate = e.bTable[nextHashL] + // Store the next match + + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + + // If the next long is a candidate, use that... + t2 := lCandidate.Cur.offset - e.cur + if nextS-t2 < maxMatchOffset { + if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) { + ml := e.matchlen(nextS+4, t2+4, src) + 4 + if ml > l { + t = t2 + s = nextS + l = ml + break + } + } + // If the previous long is a candidate, use that... + t2 = lCandidate.Prev.offset - e.cur + if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) { + ml := e.matchlen(nextS+4, t2+4, src) + 4 + if ml > l { + t = t2 + s = nextS + l = ml + break + } + } + } + break + } + cv = next + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + if l == 0 { + // Extend the 4-byte match as long as possible. + l = e.matchlenLong(s+4, t+4, src) + 4 + } else if l == maxMatchLength { + l += e.matchlenLong(s+l, t+l, src) + } + + // Try to locate a better match by checking the end of best match... + if sAt := s + l; l < 30 && sAt < sLimit { + eLong := e.bTable[hash7(load6432(src, sAt), tableBits)].Cur.offset + // Test current + t2 := eLong - e.cur - l + off := s - t2 + if t2 >= 0 && off < maxMatchOffset && off > 0 { + if l2 := e.matchlenLong(s, t2, src); l2 > l { + t = t2 + l = l2 + } + } + } + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + if debugDeflate { + if t >= s { + panic(fmt.Sprintln("s-t", s, t)) + } + if (s - t) > maxMatchOffset { + panic(fmt.Sprintln("mmo", s-t)) + } + if l < baseMatchLength { + panic("bml") + } + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + goto emitRemainder + } + + // Store every 3rd hash in-between. + if true { + const hashEvery = 3 + i := s - l + 1 + if i < s-1 { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + e.table[hash4x64(cv, tableBits)] = t + eLong := &e.bTable[hash7(cv, tableBits)] + eLong.Cur, eLong.Prev = t, eLong.Cur + + // Do an long at i+1 + cv >>= 8 + t = tableEntry{offset: t.offset + 1} + eLong = &e.bTable[hash7(cv, tableBits)] + eLong.Cur, eLong.Prev = t, eLong.Cur + + // We only have enough bits for a short entry at i+2 + cv >>= 8 + t = tableEntry{offset: t.offset + 1} + e.table[hash4x64(cv, tableBits)] = t + + // Skip one - otherwise we risk hitting 's' + i += 4 + for ; i < s-1; i += hashEvery { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + t2 := tableEntry{offset: t.offset + 1} + eLong := &e.bTable[hash7(cv, tableBits)] + eLong.Cur, eLong.Prev = t, eLong.Cur + e.table[hash4u(uint32(cv>>8), tableBits)] = t2 + } + } + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-1 and at s. + x := load6432(src, s-1) + o := e.cur + s - 1 + prevHashS := hash4x64(x, tableBits) + prevHashL := hash7(x, tableBits) + e.table[prevHashS] = tableEntry{offset: o} + eLong := &e.bTable[prevHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: o}, eLong.Cur + cv = x >> 8 + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/vendor/github.com/klauspost/compress/flate/level6.go b/vendor/github.com/klauspost/compress/flate/level6.go new file mode 100644 index 0000000000..a709977ec4 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/level6.go @@ -0,0 +1,307 @@ +package flate + +import "fmt" + +type fastEncL6 struct { + fastGen + table [tableSize]tableEntry + bTable [tableSize]tableEntryPrev +} + +func (e *fastEncL6) Encode(dst *tokens, src []byte) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + if debugDeflate && e.cur < 0 { + panic(fmt.Sprint("e.cur < 0: ", e.cur)) + } + + // Protect against e.cur wraparound. + for e.cur >= bufferReset { + if len(e.hist) == 0 { + for i := range e.table[:] { + e.table[i] = tableEntry{} + } + for i := range e.bTable[:] { + e.bTable[i] = tableEntryPrev{} + } + e.cur = maxMatchOffset + break + } + // Shift down everything in the table that isn't already too far away. + minOff := e.cur + int32(len(e.hist)) - maxMatchOffset + for i := range e.table[:] { + v := e.table[i].offset + if v <= minOff { + v = 0 + } else { + v = v - e.cur + maxMatchOffset + } + e.table[i].offset = v + } + for i := range e.bTable[:] { + v := e.bTable[i] + if v.Cur.offset <= minOff { + v.Cur.offset = 0 + v.Prev.offset = 0 + } else { + v.Cur.offset = v.Cur.offset - e.cur + maxMatchOffset + if v.Prev.offset <= minOff { + v.Prev.offset = 0 + } else { + v.Prev.offset = v.Prev.offset - e.cur + maxMatchOffset + } + } + e.bTable[i] = v + } + e.cur = maxMatchOffset + } + + s := e.addBlock(src) + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = uint16(len(src)) + return + } + + // Override src + src = e.hist + nextEmit := s + + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int32(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load6432(src, s) + // Repeat MUST be > 1 and within range + repeat := int32(1) + for { + const skipLog = 7 + const doEvery = 1 + + nextS := s + var l int32 + var t int32 + for { + nextHashS := hash4x64(cv, tableBits) + nextHashL := hash7(cv, tableBits) + s = nextS + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit { + goto emitRemainder + } + // Fetch a short+long candidate + sCandidate := e.table[nextHashS] + lCandidate := e.bTable[nextHashL] + next := load6432(src, nextS) + entry := tableEntry{offset: s + e.cur} + e.table[nextHashS] = entry + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = entry, eLong.Cur + + // Calculate hashes of 'next' + nextHashS = hash4x64(next, tableBits) + nextHashL = hash7(next, tableBits) + + t = lCandidate.Cur.offset - e.cur + if s-t < maxMatchOffset { + if uint32(cv) == load3232(src, lCandidate.Cur.offset-e.cur) { + // Long candidate matches at least 4 bytes. + + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + + // Check the previous long candidate as well. + t2 := lCandidate.Prev.offset - e.cur + if s-t2 < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) { + l = e.matchlen(s+4, t+4, src) + 4 + ml1 := e.matchlen(s+4, t2+4, src) + 4 + if ml1 > l { + t = t2 + l = ml1 + break + } + } + break + } + // Current value did not match, but check if previous long value does. + t = lCandidate.Prev.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, lCandidate.Prev.offset-e.cur) { + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + break + } + } + + t = sCandidate.offset - e.cur + if s-t < maxMatchOffset && uint32(cv) == load3232(src, sCandidate.offset-e.cur) { + // Found a 4 match... + l = e.matchlen(s+4, t+4, src) + 4 + + // Look up next long candidate (at nextS) + lCandidate = e.bTable[nextHashL] + + // Store the next match + e.table[nextHashS] = tableEntry{offset: nextS + e.cur} + eLong := &e.bTable[nextHashL] + eLong.Cur, eLong.Prev = tableEntry{offset: nextS + e.cur}, eLong.Cur + + // Check repeat at s + repOff + const repOff = 1 + t2 := s - repeat + repOff + if load3232(src, t2) == uint32(cv>>(8*repOff)) { + ml := e.matchlen(s+4+repOff, t2+4, src) + 4 + if ml > l { + t = t2 + l = ml + s += repOff + // Not worth checking more. + break + } + } + + // If the next long is a candidate, use that... + t2 = lCandidate.Cur.offset - e.cur + if nextS-t2 < maxMatchOffset { + if load3232(src, lCandidate.Cur.offset-e.cur) == uint32(next) { + ml := e.matchlen(nextS+4, t2+4, src) + 4 + if ml > l { + t = t2 + s = nextS + l = ml + // This is ok, but check previous as well. + } + } + // If the previous long is a candidate, use that... + t2 = lCandidate.Prev.offset - e.cur + if nextS-t2 < maxMatchOffset && load3232(src, lCandidate.Prev.offset-e.cur) == uint32(next) { + ml := e.matchlen(nextS+4, t2+4, src) + 4 + if ml > l { + t = t2 + s = nextS + l = ml + break + } + } + } + break + } + cv = next + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + + // Extend the 4-byte match as long as possible. + if l == 0 { + l = e.matchlenLong(s+4, t+4, src) + 4 + } else if l == maxMatchLength { + l += e.matchlenLong(s+l, t+l, src) + } + + // Try to locate a better match by checking the end-of-match... + if sAt := s + l; sAt < sLimit { + eLong := &e.bTable[hash7(load6432(src, sAt), tableBits)] + // Test current + t2 := eLong.Cur.offset - e.cur - l + off := s - t2 + if off < maxMatchOffset { + if off > 0 && t2 >= 0 { + if l2 := e.matchlenLong(s, t2, src); l2 > l { + t = t2 + l = l2 + } + } + // Test next: + t2 = eLong.Prev.offset - e.cur - l + off := s - t2 + if off > 0 && off < maxMatchOffset && t2 >= 0 { + if l2 := e.matchlenLong(s, t2, src); l2 > l { + t = t2 + l = l2 + } + } + } + } + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + if false { + if t >= s { + panic(fmt.Sprintln("s-t", s, t)) + } + if (s - t) > maxMatchOffset { + panic(fmt.Sprintln("mmo", s-t)) + } + if l < baseMatchLength { + panic("bml") + } + } + + dst.AddMatchLong(l, uint32(s-t-baseMatchOffset)) + repeat = s - t + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + + if s >= sLimit { + // Index after match end. + for i := nextS + 1; i < int32(len(src))-8; i += 2 { + cv := load6432(src, i) + e.table[hash4x64(cv, tableBits)] = tableEntry{offset: i + e.cur} + eLong := &e.bTable[hash7(cv, tableBits)] + eLong.Cur, eLong.Prev = tableEntry{offset: i + e.cur}, eLong.Cur + } + goto emitRemainder + } + + // Store every long hash in-between and every second short. + if true { + for i := nextS + 1; i < s-1; i += 2 { + cv := load6432(src, i) + t := tableEntry{offset: i + e.cur} + t2 := tableEntry{offset: t.offset + 1} + eLong := &e.bTable[hash7(cv, tableBits)] + eLong2 := &e.bTable[hash7(cv>>8, tableBits)] + e.table[hash4x64(cv, tableBits)] = t + eLong.Cur, eLong.Prev = t, eLong.Cur + eLong2.Cur, eLong2.Prev = t2, eLong2.Cur + } + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-1 and at s. + cv = load6432(src, s) + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/vendor/github.com/klauspost/compress/flate/regmask_amd64.go b/vendor/github.com/klauspost/compress/flate/regmask_amd64.go new file mode 100644 index 0000000000..6ed28061b2 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/regmask_amd64.go @@ -0,0 +1,37 @@ +package flate + +const ( + // Masks for shifts with register sizes of the shift value. + // This can be used to work around the x86 design of shifting by mod register size. + // It can be used when a variable shift is always smaller than the register size. + + // reg8SizeMaskX - shift value is 8 bits, shifted is X + reg8SizeMask8 = 7 + reg8SizeMask16 = 15 + reg8SizeMask32 = 31 + reg8SizeMask64 = 63 + + // reg16SizeMaskX - shift value is 16 bits, shifted is X + reg16SizeMask8 = reg8SizeMask8 + reg16SizeMask16 = reg8SizeMask16 + reg16SizeMask32 = reg8SizeMask32 + reg16SizeMask64 = reg8SizeMask64 + + // reg32SizeMaskX - shift value is 32 bits, shifted is X + reg32SizeMask8 = reg8SizeMask8 + reg32SizeMask16 = reg8SizeMask16 + reg32SizeMask32 = reg8SizeMask32 + reg32SizeMask64 = reg8SizeMask64 + + // reg64SizeMaskX - shift value is 64 bits, shifted is X + reg64SizeMask8 = reg8SizeMask8 + reg64SizeMask16 = reg8SizeMask16 + reg64SizeMask32 = reg8SizeMask32 + reg64SizeMask64 = reg8SizeMask64 + + // regSizeMaskUintX - shift value is uint, shifted is X + regSizeMaskUint8 = reg8SizeMask8 + regSizeMaskUint16 = reg8SizeMask16 + regSizeMaskUint32 = reg8SizeMask32 + regSizeMaskUint64 = reg8SizeMask64 +) diff --git a/vendor/github.com/klauspost/compress/flate/regmask_other.go b/vendor/github.com/klauspost/compress/flate/regmask_other.go new file mode 100644 index 0000000000..f477a5d6e5 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/regmask_other.go @@ -0,0 +1,39 @@ +//+build !amd64 + +package flate + +const ( + // Masks for shifts with register sizes of the shift value. + // This can be used to work around the x86 design of shifting by mod register size. + // It can be used when a variable shift is always smaller than the register size. + + // reg8SizeMaskX - shift value is 8 bits, shifted is X + reg8SizeMask8 = 0xff + reg8SizeMask16 = 0xff + reg8SizeMask32 = 0xff + reg8SizeMask64 = 0xff + + // reg16SizeMaskX - shift value is 16 bits, shifted is X + reg16SizeMask8 = 0xffff + reg16SizeMask16 = 0xffff + reg16SizeMask32 = 0xffff + reg16SizeMask64 = 0xffff + + // reg32SizeMaskX - shift value is 32 bits, shifted is X + reg32SizeMask8 = 0xffffffff + reg32SizeMask16 = 0xffffffff + reg32SizeMask32 = 0xffffffff + reg32SizeMask64 = 0xffffffff + + // reg64SizeMaskX - shift value is 64 bits, shifted is X + reg64SizeMask8 = 0xffffffffffffffff + reg64SizeMask16 = 0xffffffffffffffff + reg64SizeMask32 = 0xffffffffffffffff + reg64SizeMask64 = 0xffffffffffffffff + + // regSizeMaskUintX - shift value is uint, shifted is X + regSizeMaskUint8 = ^uint(0) + regSizeMaskUint16 = ^uint(0) + regSizeMaskUint32 = ^uint(0) + regSizeMaskUint64 = ^uint(0) +) diff --git a/vendor/github.com/klauspost/compress/flate/stateless.go b/vendor/github.com/klauspost/compress/flate/stateless.go new file mode 100644 index 0000000000..53e8991246 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/stateless.go @@ -0,0 +1,297 @@ +package flate + +import ( + "io" + "math" + "sync" +) + +const ( + maxStatelessBlock = math.MaxInt16 + // dictionary will be taken from maxStatelessBlock, so limit it. + maxStatelessDict = 8 << 10 + + slTableBits = 13 + slTableSize = 1 << slTableBits + slTableShift = 32 - slTableBits +) + +type statelessWriter struct { + dst io.Writer + closed bool +} + +func (s *statelessWriter) Close() error { + if s.closed { + return nil + } + s.closed = true + // Emit EOF block + return StatelessDeflate(s.dst, nil, true, nil) +} + +func (s *statelessWriter) Write(p []byte) (n int, err error) { + err = StatelessDeflate(s.dst, p, false, nil) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (s *statelessWriter) Reset(w io.Writer) { + s.dst = w + s.closed = false +} + +// NewStatelessWriter will do compression but without maintaining any state +// between Write calls. +// There will be no memory kept between Write calls, +// but compression and speed will be suboptimal. +// Because of this, the size of actual Write calls will affect output size. +func NewStatelessWriter(dst io.Writer) io.WriteCloser { + return &statelessWriter{dst: dst} +} + +// bitWriterPool contains bit writers that can be reused. +var bitWriterPool = sync.Pool{ + New: func() interface{} { + return newHuffmanBitWriter(nil) + }, +} + +// StatelessDeflate allows to compress directly to a Writer without retaining state. +// When returning everything will be flushed. +// Up to 8KB of an optional dictionary can be given which is presumed to presumed to precede the block. +// Longer dictionaries will be truncated and will still produce valid output. +// Sending nil dictionary is perfectly fine. +func StatelessDeflate(out io.Writer, in []byte, eof bool, dict []byte) error { + var dst tokens + bw := bitWriterPool.Get().(*huffmanBitWriter) + bw.reset(out) + defer func() { + // don't keep a reference to our output + bw.reset(nil) + bitWriterPool.Put(bw) + }() + if eof && len(in) == 0 { + // Just write an EOF block. + // Could be faster... + bw.writeStoredHeader(0, true) + bw.flush() + return bw.err + } + + // Truncate dict + if len(dict) > maxStatelessDict { + dict = dict[len(dict)-maxStatelessDict:] + } + + for len(in) > 0 { + todo := in + if len(todo) > maxStatelessBlock-len(dict) { + todo = todo[:maxStatelessBlock-len(dict)] + } + in = in[len(todo):] + uncompressed := todo + if len(dict) > 0 { + // combine dict and source + bufLen := len(todo) + len(dict) + combined := make([]byte, bufLen) + copy(combined, dict) + copy(combined[len(dict):], todo) + todo = combined + } + // Compress + statelessEnc(&dst, todo, int16(len(dict))) + isEof := eof && len(in) == 0 + + if dst.n == 0 { + bw.writeStoredHeader(len(uncompressed), isEof) + if bw.err != nil { + return bw.err + } + bw.writeBytes(uncompressed) + } else if int(dst.n) > len(uncompressed)-len(uncompressed)>>4 { + // If we removed less than 1/16th, huffman compress the block. + bw.writeBlockHuff(isEof, uncompressed, len(in) == 0) + } else { + bw.writeBlockDynamic(&dst, isEof, uncompressed, len(in) == 0) + } + if len(in) > 0 { + // Retain a dict if we have more + dict = todo[len(todo)-maxStatelessDict:] + dst.Reset() + } + if bw.err != nil { + return bw.err + } + } + if !eof { + // Align, only a stored block can do that. + bw.writeStoredHeader(0, false) + } + bw.flush() + return bw.err +} + +func hashSL(u uint32) uint32 { + return (u * 0x1e35a7bd) >> slTableShift +} + +func load3216(b []byte, i int16) uint32 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:4] + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +func load6416(b []byte, i int16) uint64 { + // Help the compiler eliminate bounds checks on the read so it can be done in a single read. + b = b[i:] + b = b[:8] + return uint64(b[0]) | uint64(b[1])<<8 | uint64(b[2])<<16 | uint64(b[3])<<24 | + uint64(b[4])<<32 | uint64(b[5])<<40 | uint64(b[6])<<48 | uint64(b[7])<<56 +} + +func statelessEnc(dst *tokens, src []byte, startAt int16) { + const ( + inputMargin = 12 - 1 + minNonLiteralBlockSize = 1 + 1 + inputMargin + ) + + type tableEntry struct { + offset int16 + } + + var table [slTableSize]tableEntry + + // This check isn't in the Snappy implementation, but there, the caller + // instead of the callee handles this case. + if len(src)-int(startAt) < minNonLiteralBlockSize { + // We do not fill the token table. + // This will be picked up by caller. + dst.n = 0 + return + } + // Index until startAt + if startAt > 0 { + cv := load3232(src, 0) + for i := int16(0); i < startAt; i++ { + table[hashSL(cv)] = tableEntry{offset: i} + cv = (cv >> 8) | (uint32(src[i+4]) << 24) + } + } + + s := startAt + 1 + nextEmit := startAt + // sLimit is when to stop looking for offset/length copies. The inputMargin + // lets us use a fast path for emitLiteral in the main loop, while we are + // looking for copies. + sLimit := int16(len(src) - inputMargin) + + // nextEmit is where in src the next emitLiteral should start from. + cv := load3216(src, s) + + for { + const skipLog = 5 + const doEvery = 2 + + nextS := s + var candidate tableEntry + for { + nextHash := hashSL(cv) + candidate = table[nextHash] + nextS = s + doEvery + (s-nextEmit)>>skipLog + if nextS > sLimit || nextS <= 0 { + goto emitRemainder + } + + now := load6416(src, nextS) + table[nextHash] = tableEntry{offset: s} + nextHash = hashSL(uint32(now)) + + if cv == load3216(src, candidate.offset) { + table[nextHash] = tableEntry{offset: nextS} + break + } + + // Do one right away... + cv = uint32(now) + s = nextS + nextS++ + candidate = table[nextHash] + now >>= 8 + table[nextHash] = tableEntry{offset: s} + + if cv == load3216(src, candidate.offset) { + table[nextHash] = tableEntry{offset: nextS} + break + } + cv = uint32(now) + s = nextS + } + + // A 4-byte match has been found. We'll later see if more than 4 bytes + // match. But, prior to the match, src[nextEmit:s] are unmatched. Emit + // them as literal bytes. + for { + // Invariant: we have a 4-byte match at s, and no need to emit any + // literal bytes prior to s. + + // Extend the 4-byte match as long as possible. + t := candidate.offset + l := int16(matchLen(src[s+4:], src[t+4:]) + 4) + + // Extend backwards + for t > 0 && s > nextEmit && src[t-1] == src[s-1] { + s-- + t-- + l++ + } + if nextEmit < s { + emitLiteral(dst, src[nextEmit:s]) + } + + // Save the match found + dst.AddMatchLong(int32(l), uint32(s-t-baseMatchOffset)) + s += l + nextEmit = s + if nextS >= s { + s = nextS + 1 + } + if s >= sLimit { + goto emitRemainder + } + + // We could immediately start working at s now, but to improve + // compression we first update the hash table at s-2 and at s. If + // another emitCopy is not our next move, also calculate nextHash + // at s+1. At least on GOARCH=amd64, these three hash calculations + // are faster as one load64 call (with some shifts) instead of + // three load32 calls. + x := load6416(src, s-2) + o := s - 2 + prevHash := hashSL(uint32(x)) + table[prevHash] = tableEntry{offset: o} + x >>= 16 + currHash := hashSL(uint32(x)) + candidate = table[currHash] + table[currHash] = tableEntry{offset: o + 2} + + if uint32(x) != load3216(src, candidate.offset) { + cv = uint32(x >> 8) + s++ + break + } + } + } + +emitRemainder: + if int(nextEmit) < len(src) { + // If nothing was added, don't encode literals. + if dst.n == 0 { + return + } + emitLiteral(dst, src[nextEmit:]) + } +} diff --git a/vendor/github.com/klauspost/compress/flate/token.go b/vendor/github.com/klauspost/compress/flate/token.go new file mode 100644 index 0000000000..eb862d7a92 --- /dev/null +++ b/vendor/github.com/klauspost/compress/flate/token.go @@ -0,0 +1,380 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package flate + +import ( + "bytes" + "encoding/binary" + "fmt" + "io" + "math" +) + +const ( + // From top + // 2 bits: type 0 = literal 1=EOF 2=Match 3=Unused + // 8 bits: xlength = length - MIN_MATCH_LENGTH + // 5 bits offsetcode + // 16 bits xoffset = offset - MIN_OFFSET_SIZE, or literal + lengthShift = 22 + offsetMask = 1<maxnumlit + offHist [32]uint16 // offset codes + litHist [256]uint16 // codes 0->255 + n uint16 // Must be able to contain maxStoreBlockSize + tokens [maxStoreBlockSize + 1]token +} + +func (t *tokens) Reset() { + if t.n == 0 { + return + } + t.n = 0 + t.nLits = 0 + for i := range t.litHist[:] { + t.litHist[i] = 0 + } + for i := range t.extraHist[:] { + t.extraHist[i] = 0 + } + for i := range t.offHist[:] { + t.offHist[i] = 0 + } +} + +func (t *tokens) Fill() { + if t.n == 0 { + return + } + for i, v := range t.litHist[:] { + if v == 0 { + t.litHist[i] = 1 + t.nLits++ + } + } + for i, v := range t.extraHist[:literalCount-256] { + if v == 0 { + t.nLits++ + t.extraHist[i] = 1 + } + } + for i, v := range t.offHist[:offsetCodeCount] { + if v == 0 { + t.offHist[i] = 1 + } + } +} + +func indexTokens(in []token) tokens { + var t tokens + t.indexTokens(in) + return t +} + +func (t *tokens) indexTokens(in []token) { + t.Reset() + for _, tok := range in { + if tok < matchType { + t.AddLiteral(tok.literal()) + continue + } + t.AddMatch(uint32(tok.length()), tok.offset()&matchOffsetOnlyMask) + } +} + +// emitLiteral writes a literal chunk and returns the number of bytes written. +func emitLiteral(dst *tokens, lit []byte) { + ol := int(dst.n) + for i, v := range lit { + dst.tokens[(i+ol)&maxStoreBlockSize] = token(v) + dst.litHist[v]++ + } + dst.n += uint16(len(lit)) + dst.nLits += len(lit) +} + +func (t *tokens) AddLiteral(lit byte) { + t.tokens[t.n] = token(lit) + t.litHist[lit]++ + t.n++ + t.nLits++ +} + +// from https://stackoverflow.com/a/28730362 +func mFastLog2(val float32) float32 { + ux := int32(math.Float32bits(val)) + log2 := (float32)(((ux >> 23) & 255) - 128) + ux &= -0x7f800001 + ux += 127 << 23 + uval := math.Float32frombits(uint32(ux)) + log2 += ((-0.34484843)*uval+2.02466578)*uval - 0.67487759 + return log2 +} + +// EstimatedBits will return an minimum size estimated by an *optimal* +// compression of the block. +// The size of the block +func (t *tokens) EstimatedBits() int { + shannon := float32(0) + bits := int(0) + nMatches := 0 + if t.nLits > 0 { + invTotal := 1.0 / float32(t.nLits) + for _, v := range t.litHist[:] { + if v > 0 { + n := float32(v) + shannon += atLeastOne(-mFastLog2(n*invTotal)) * n + } + } + // Just add 15 for EOB + shannon += 15 + for i, v := range t.extraHist[1 : literalCount-256] { + if v > 0 { + n := float32(v) + shannon += atLeastOne(-mFastLog2(n*invTotal)) * n + bits += int(lengthExtraBits[i&31]) * int(v) + nMatches += int(v) + } + } + } + if nMatches > 0 { + invTotal := 1.0 / float32(nMatches) + for i, v := range t.offHist[:offsetCodeCount] { + if v > 0 { + n := float32(v) + shannon += atLeastOne(-mFastLog2(n*invTotal)) * n + bits += int(offsetExtraBits[i&31]) * int(v) + } + } + } + return int(shannon) + bits +} + +// AddMatch adds a match to the tokens. +// This function is very sensitive to inlining and right on the border. +func (t *tokens) AddMatch(xlength uint32, xoffset uint32) { + if debugDeflate { + if xlength >= maxMatchLength+baseMatchLength { + panic(fmt.Errorf("invalid length: %v", xlength)) + } + if xoffset >= maxMatchOffset+baseMatchOffset { + panic(fmt.Errorf("invalid offset: %v", xoffset)) + } + } + oCode := offsetCode(xoffset) + xoffset |= oCode << 16 + t.nLits++ + + t.extraHist[lengthCodes1[uint8(xlength)]]++ + t.offHist[oCode]++ + t.tokens[t.n] = token(matchType | xlength<= maxMatchOffset+baseMatchOffset { + panic(fmt.Errorf("invalid offset: %v", xoffset)) + } + } + oc := offsetCode(xoffset) + xoffset |= oc << 16 + for xlength > 0 { + xl := xlength + if xl > 258 { + // We need to have at least baseMatchLength left over for next loop. + xl = 258 - baseMatchLength + } + xlength -= xl + xl -= baseMatchLength + t.nLits++ + t.extraHist[lengthCodes1[uint8(xl)]]++ + t.offHist[oc]++ + t.tokens[t.n] = token(matchType | uint32(xl)<> lengthShift) } + +// The code is never more than 8 bits, but is returned as uint32 for convenience. +func lengthCode(len uint8) uint32 { return uint32(lengthCodes[len]) } + +// Returns the offset code corresponding to a specific offset +func offsetCode(off uint32) uint32 { + if false { + if off < uint32(len(offsetCodes)) { + return offsetCodes[off&255] + } else if off>>7 < uint32(len(offsetCodes)) { + return offsetCodes[(off>>7)&255] + 14 + } else { + return offsetCodes[(off>>14)&255] + 28 + } + } + if off < uint32(len(offsetCodes)) { + return offsetCodes[uint8(off)] + } + return offsetCodes14[uint8(off>>7)] +} diff --git a/vendor/golang.org/x/net/http2/README b/vendor/golang.org/x/net/http2/README deleted file mode 100644 index 360d5aa379..0000000000 --- a/vendor/golang.org/x/net/http2/README +++ /dev/null @@ -1,20 +0,0 @@ -This is a work-in-progress HTTP/2 implementation for Go. - -It will eventually live in the Go standard library and won't require -any changes to your code to use. It will just be automatic. - -Status: - -* The server support is pretty good. A few things are missing - but are being worked on. -* The client work has just started but shares a lot of code - is coming along much quicker. - -Docs are at https://godoc.org/golang.org/x/net/http2 - -Demo test server at https://http2.golang.org/ - -Help & bug reports welcome! - -Contributing: https://golang.org/doc/contribute.html -Bugs: https://golang.org/issue/new?title=x/net/http2:+ diff --git a/vendor/golang.org/x/net/http2/h2c/h2c.go b/vendor/golang.org/x/net/http2/h2c/h2c.go index 16319b8fda..c0970d8467 100644 --- a/vendor/golang.org/x/net/http2/h2c/h2c.go +++ b/vendor/golang.org/x/net/http2/h2c/h2c.go @@ -338,7 +338,7 @@ func (w *settingsAckSwallowWriter) Write(p []byte) (int, error) { } fSize := fh.Length + 9 if uint32(len(w.buf)) < fSize { - // Have not collected whole frame. Stop processing buf, and withold on + // Have not collected whole frame. Stop processing buf, and withhold on // forward bytes to w.Writer until we get the full frame. break } diff --git a/vendor/golang.org/x/net/http2/server.go b/vendor/golang.org/x/net/http2/server.go index 0ccbe9b4c2..19e449cfae 100644 --- a/vendor/golang.org/x/net/http2/server.go +++ b/vendor/golang.org/x/net/http2/server.go @@ -816,7 +816,7 @@ func (sc *serverConn) serve() { }) sc.unackedSettings++ - // Each connection starts with intialWindowSize inflow tokens. + // Each connection starts with initialWindowSize inflow tokens. // If a higher value is configured, we add more tokens. if diff := sc.srv.initialConnRecvWindowSize() - initialWindowSize; diff > 0 { sc.sendWindowUpdate(nil, int(diff)) @@ -856,6 +856,15 @@ func (sc *serverConn) serve() { case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: + // Process any written frames before reading new frames from the client since a + // written frame could have triggered a new stream to be started. + if sc.writingFrameAsync { + select { + case wroteRes := <-sc.wroteFrameCh: + sc.wroteFrame(wroteRes) + default: + } + } if !sc.processFrameFromReader(res) { return } diff --git a/vendor/golang.org/x/net/http2/transport.go b/vendor/golang.org/x/net/http2/transport.go index b97adff7d0..b261beb1d0 100644 --- a/vendor/golang.org/x/net/http2/transport.go +++ b/vendor/golang.org/x/net/http2/transport.go @@ -385,8 +385,13 @@ func (cs *clientStream) abortRequestBodyWrite(err error) { } cc := cs.cc cc.mu.Lock() - cs.stopReqBody = err - cc.cond.Broadcast() + if cs.stopReqBody == nil { + cs.stopReqBody = err + if cs.req.Body != nil { + cs.req.Body.Close() + } + cc.cond.Broadcast() + } cc.mu.Unlock() } @@ -1110,40 +1115,28 @@ func (cc *ClientConn) roundTrip(req *http.Request) (res *http.Response, gotErrAf return res, false, nil } + handleError := func(err error) (*http.Response, bool, error) { + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) + } else { + bodyWriter.cancel() + cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) + <-bodyWriter.resc + } + cc.forgetStreamID(cs.ID) + return nil, cs.getStartedWrite(), err + } + for { select { case re := <-readLoopResCh: return handleReadLoopResponse(re) case <-respHeaderTimer: - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), errTimeout + return handleError(errTimeout) case <-ctx.Done(): - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), ctx.Err() + return handleError(ctx.Err()) case <-req.Cancel: - if !hasBody || bodyWritten { - cc.writeStreamReset(cs.ID, ErrCodeCancel, nil) - } else { - bodyWriter.cancel() - cs.abortRequestBodyWrite(errStopReqBodyWriteAndCancel) - <-bodyWriter.resc - } - cc.forgetStreamID(cs.ID) - return nil, cs.getStartedWrite(), errRequestCanceled + return handleError(errRequestCanceled) case <-cs.peerReset: // processResetStream already removed the // stream from the streams map; no need for @@ -1290,7 +1283,13 @@ func (cs *clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) ( // Request.Body is closed by the Transport, // and in multiple cases: server replies <=299 and >299 // while still writing request body - cerr := bodyCloser.Close() + var cerr error + cc.mu.Lock() + if cs.stopReqBody == nil { + cs.stopReqBody = errStopReqBodyWrite + cerr = bodyCloser.Close() + } + cc.mu.Unlock() if err == nil { err = cerr } diff --git a/vendor/modules.txt b/vendor/modules.txt index cd3879283c..e15fa3a21b 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -235,7 +235,7 @@ github.com/golang/protobuf/ptypes/any github.com/golang/protobuf/ptypes/duration github.com/golang/protobuf/ptypes/empty github.com/golang/protobuf/ptypes/timestamp -# github.com/golang/snappy v0.0.1 +# github.com/golang/snappy v0.0.3 github.com/golang/snappy # github.com/google/go-cmp v0.5.3 github.com/google/go-cmp/cmp @@ -245,7 +245,7 @@ github.com/google/go-cmp/cmp/internal/function github.com/google/go-cmp/cmp/internal/value # github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99 github.com/google/pprof/profile -# github.com/google/uuid v1.1.2 +# github.com/google/uuid v1.3.0 ## explicit github.com/google/uuid # github.com/googleapis/gax-go/v2 v2.0.5 @@ -266,6 +266,11 @@ github.com/hashicorp/go-version ## explicit github.com/hashicorp/golang-lru github.com/hashicorp/golang-lru/simplelru +# github.com/hasura/go-graphql-client v0.2.0 +## explicit +github.com/hasura/go-graphql-client +github.com/hasura/go-graphql-client/ident +github.com/hasura/go-graphql-client/internal/jsonutil # github.com/hectane/go-acl v0.0.0-20190604041725-da78bae5fc95 ## explicit github.com/hectane/go-acl @@ -311,6 +316,9 @@ github.com/kballard/go-shellquote github.com/kennygrant/sanitize # github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd github.com/kevinburke/ssh_config +# github.com/klauspost/compress v1.13.4 +## explicit +github.com/klauspost/compress/flate # github.com/konsorten/go-windows-terminal-sequences v1.0.2 github.com/konsorten/go-windows-terminal-sequences # github.com/kr/pty v1.1.8 @@ -562,7 +570,7 @@ golang.org/x/image/draw golang.org/x/image/math/f64 # golang.org/x/mod v0.3.0 golang.org/x/mod/semver -# golang.org/x/net v0.0.0-20210614182718-04defd469f4e +# golang.org/x/net v0.0.0-20210825183410-e898025ed96a ## explicit golang.org/x/net/context golang.org/x/net/context/ctxhttp @@ -847,4 +855,12 @@ modernc.org/sqlite/lib modernc.org/strutil # modernc.org/token v1.0.0 modernc.org/token +# nhooyr.io/websocket v1.8.7 +## explicit +nhooyr.io/websocket +nhooyr.io/websocket/internal/bpool +nhooyr.io/websocket/internal/errd +nhooyr.io/websocket/internal/wsjs +nhooyr.io/websocket/internal/xsync +nhooyr.io/websocket/wsjson # github.com/ActiveState/cli => ./ diff --git a/vendor/nhooyr.io/websocket/.gitignore b/vendor/nhooyr.io/websocket/.gitignore new file mode 100644 index 0000000000..6961e5c894 --- /dev/null +++ b/vendor/nhooyr.io/websocket/.gitignore @@ -0,0 +1 @@ +websocket.test diff --git a/vendor/nhooyr.io/websocket/LICENSE.txt b/vendor/nhooyr.io/websocket/LICENSE.txt new file mode 100644 index 0000000000..b5b5fef31f --- /dev/null +++ b/vendor/nhooyr.io/websocket/LICENSE.txt @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2018 Anmol Sethi + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/vendor/nhooyr.io/websocket/README.md b/vendor/nhooyr.io/websocket/README.md new file mode 100644 index 0000000000..df20c581a5 --- /dev/null +++ b/vendor/nhooyr.io/websocket/README.md @@ -0,0 +1,132 @@ +# websocket + +[![godoc](https://godoc.org/nhooyr.io/websocket?status.svg)](https://pkg.go.dev/nhooyr.io/websocket) +[![coverage](https://img.shields.io/badge/coverage-88%25-success)](https://nhooyrio-websocket-coverage.netlify.app) + +websocket is a minimal and idiomatic WebSocket library for Go. + +## Install + +```bash +go get nhooyr.io/websocket +``` + +## Highlights + +- Minimal and idiomatic API +- First class [context.Context](https://blog.golang.org/context) support +- Fully passes the WebSocket [autobahn-testsuite](https://github.com/crossbario/autobahn-testsuite) +- [Single dependency](https://pkg.go.dev/nhooyr.io/websocket?tab=imports) +- JSON and protobuf helpers in the [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages +- Zero alloc reads and writes +- Concurrent writes +- [Close handshake](https://pkg.go.dev/nhooyr.io/websocket#Conn.Close) +- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper +- [Ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API +- [RFC 7692](https://tools.ietf.org/html/rfc7692) permessage-deflate compression +- Compile to [Wasm](https://pkg.go.dev/nhooyr.io/websocket#hdr-Wasm) + +## Roadmap + +- [ ] HTTP/2 [#4](https://github.com/nhooyr/websocket/issues/4) + +## Examples + +For a production quality example that demonstrates the complete API, see the +[echo example](./examples/echo). + +For a full stack example, see the [chat example](./examples/chat). + +### Server + +```go +http.HandlerFunc(func (w http.ResponseWriter, r *http.Request) { + c, err := websocket.Accept(w, r, nil) + if err != nil { + // ... + } + defer c.Close(websocket.StatusInternalError, "the sky is falling") + + ctx, cancel := context.WithTimeout(r.Context(), time.Second*10) + defer cancel() + + var v interface{} + err = wsjson.Read(ctx, c, &v) + if err != nil { + // ... + } + + log.Printf("received: %v", v) + + c.Close(websocket.StatusNormalClosure, "") +}) +``` + +### Client + +```go +ctx, cancel := context.WithTimeout(context.Background(), time.Minute) +defer cancel() + +c, _, err := websocket.Dial(ctx, "ws://localhost:8080", nil) +if err != nil { + // ... +} +defer c.Close(websocket.StatusInternalError, "the sky is falling") + +err = wsjson.Write(ctx, c, "hi") +if err != nil { + // ... +} + +c.Close(websocket.StatusNormalClosure, "") +``` + +## Comparison + +### gorilla/websocket + +Advantages of [gorilla/websocket](https://github.com/gorilla/websocket): + +- Mature and widely used +- [Prepared writes](https://pkg.go.dev/github.com/gorilla/websocket#PreparedMessage) +- Configurable [buffer sizes](https://pkg.go.dev/github.com/gorilla/websocket#hdr-Buffers) + +Advantages of nhooyr.io/websocket: + +- Minimal and idiomatic API + - Compare godoc of [nhooyr.io/websocket](https://pkg.go.dev/nhooyr.io/websocket) with [gorilla/websocket](https://pkg.go.dev/github.com/gorilla/websocket) side by side. +- [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) wrapper +- Zero alloc reads and writes ([gorilla/websocket#535](https://github.com/gorilla/websocket/issues/535)) +- Full [context.Context](https://blog.golang.org/context) support +- Dial uses [net/http.Client](https://golang.org/pkg/net/http/#Client) + - Will enable easy HTTP/2 support in the future + - Gorilla writes directly to a net.Conn and so duplicates features of net/http.Client. +- Concurrent writes +- Close handshake ([gorilla/websocket#448](https://github.com/gorilla/websocket/issues/448)) +- Idiomatic [ping pong](https://pkg.go.dev/nhooyr.io/websocket#Conn.Ping) API + - Gorilla requires registering a pong callback before sending a Ping +- Can target Wasm ([gorilla/websocket#432](https://github.com/gorilla/websocket/issues/432)) +- Transparent message buffer reuse with [wsjson](https://pkg.go.dev/nhooyr.io/websocket/wsjson) and [wspb](https://pkg.go.dev/nhooyr.io/websocket/wspb) subpackages +- [1.75x](https://github.com/nhooyr/websocket/releases/tag/v1.7.4) faster WebSocket masking implementation in pure Go + - Gorilla's implementation is slower and uses [unsafe](https://golang.org/pkg/unsafe/). +- Full [permessage-deflate](https://tools.ietf.org/html/rfc7692) compression extension support + - Gorilla only supports no context takeover mode + - We use [klauspost/compress](https://github.com/klauspost/compress) for much lower memory usage ([gorilla/websocket#203](https://github.com/gorilla/websocket/issues/203)) +- [CloseRead](https://pkg.go.dev/nhooyr.io/websocket#Conn.CloseRead) helper ([gorilla/websocket#492](https://github.com/gorilla/websocket/issues/492)) +- Actively maintained ([gorilla/websocket#370](https://github.com/gorilla/websocket/issues/370)) + +#### golang.org/x/net/websocket + +[golang.org/x/net/websocket](https://pkg.go.dev/golang.org/x/net/websocket) is deprecated. +See [golang/go/issues/18152](https://github.com/golang/go/issues/18152). + +The [net.Conn](https://pkg.go.dev/nhooyr.io/websocket#NetConn) can help in transitioning +to nhooyr.io/websocket. + +#### gobwas/ws + +[gobwas/ws](https://github.com/gobwas/ws) has an extremely flexible API that allows it to be used +in an event driven style for performance. See the author's [blog post](https://medium.freecodecamp.org/million-websockets-and-go-cc58418460bb). + +However when writing idiomatic Go, nhooyr.io/websocket will be faster and easier to use. diff --git a/vendor/nhooyr.io/websocket/accept.go b/vendor/nhooyr.io/websocket/accept.go new file mode 100644 index 0000000000..18536bdb2c --- /dev/null +++ b/vendor/nhooyr.io/websocket/accept.go @@ -0,0 +1,370 @@ +// +build !js + +package websocket + +import ( + "bytes" + "crypto/sha1" + "encoding/base64" + "errors" + "fmt" + "io" + "log" + "net/http" + "net/textproto" + "net/url" + "path/filepath" + "strings" + + "nhooyr.io/websocket/internal/errd" +) + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. + // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to + // reject it, close the connection when c.Subprotocol() == "". + Subprotocols []string + + // InsecureSkipVerify is used to disable Accept's origin verification behaviour. + // + // You probably want to use OriginPatterns instead. + InsecureSkipVerify bool + + // OriginPatterns lists the host patterns for authorized origins. + // The request host is always authorized. + // Use this to enable cross origin WebSockets. + // + // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. + // In such a case, example.com is the origin and chat.example.com is the request host. + // One would set this field to []string{"example.com"} to authorize example.com to connect. + // + // Each pattern is matched case insensitively against the request origin host + // with filepath.Match. + // See https://golang.org/pkg/path/filepath/#Match + // + // Please ensure you understand the ramifications of enabling this. + // If used incorrectly your WebSocket server will be open to CSRF attacks. + // + // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead + // to bring attention to the danger of such a setting. + OriginPatterns []string + + // CompressionMode controls the compression mode. + // Defaults to CompressionNoContextTakeover. + // + // See docs on CompressionMode for details. + CompressionMode CompressionMode + + // CompressionThreshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes + // for CompressionContextTakeover. + CompressionThreshold int +} + +// Accept accepts a WebSocket handshake from a client and upgrades the +// the connection to a WebSocket. +// +// Accept will not allow cross origin requests by default. +// See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. +// +// Accept will write a response to w on all errors. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return accept(w, r, opts) +} + +func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { + defer errd.Wrap(&err, "failed to accept WebSocket connection") + + if opts == nil { + opts = &AcceptOptions{} + } + opts = &*opts + + errCode, err := verifyClientRequest(w, r) + if err != nil { + http.Error(w, err.Error(), errCode) + return nil, err + } + + if !opts.InsecureSkipVerify { + err = authenticateOrigin(r, opts.OriginPatterns) + if err != nil { + if errors.Is(err, filepath.ErrBadPattern) { + log.Printf("websocket: %v", err) + err = errors.New(http.StatusText(http.StatusForbidden)) + } + http.Error(w, err.Error(), http.StatusForbidden) + return nil, err + } + } + + hj, ok := w.(http.Hijacker) + if !ok { + err = errors.New("http.ResponseWriter does not implement http.Hijacker") + http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) + return nil, err + } + + w.Header().Set("Upgrade", "websocket") + w.Header().Set("Connection", "Upgrade") + + key := r.Header.Get("Sec-WebSocket-Key") + w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) + + subproto := selectSubprotocol(r, opts.Subprotocols) + if subproto != "" { + w.Header().Set("Sec-WebSocket-Protocol", subproto) + } + + copts, err := acceptCompression(r, w, opts.CompressionMode) + if err != nil { + return nil, err + } + + w.WriteHeader(http.StatusSwitchingProtocols) + // See https://github.com/nhooyr/websocket/issues/166 + if ginWriter, ok := w.(interface { + WriteHeaderNow() + }); ok { + ginWriter.WriteHeaderNow() + } + + netConn, brw, err := hj.Hijack() + if err != nil { + err = fmt.Errorf("failed to hijack connection: %w", err) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + return nil, err + } + + // https://github.com/golang/go/issues/32314 + b, _ := brw.Reader.Peek(brw.Reader.Buffered()) + brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) + + return newConn(connConfig{ + subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), + rwc: netConn, + client: false, + copts: copts, + flateThreshold: opts.CompressionThreshold, + + br: brw.Reader, + bw: brw.Writer, + }), nil +} + +func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { + if !r.ProtoAtLeast(1, 1) { + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) + } + + if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) + } + + if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { + w.Header().Set("Connection", "Upgrade") + w.Header().Set("Upgrade", "websocket") + return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) + } + + if r.Method != "GET" { + return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) + } + + if r.Header.Get("Sec-WebSocket-Version") != "13" { + w.Header().Set("Sec-WebSocket-Version", "13") + return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) + } + + if r.Header.Get("Sec-WebSocket-Key") == "" { + return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") + } + + return 0, nil +} + +func authenticateOrigin(r *http.Request, originHosts []string) error { + origin := r.Header.Get("Origin") + if origin == "" { + return nil + } + + u, err := url.Parse(origin) + if err != nil { + return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) + } + + if strings.EqualFold(r.Host, u.Host) { + return nil + } + + for _, hostPattern := range originHosts { + matched, err := match(hostPattern, u.Host) + if err != nil { + return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) + } + if matched { + return nil + } + } + return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) +} + +func match(pattern, s string) (bool, error) { + return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) +} + +func selectSubprotocol(r *http.Request, subprotocols []string) string { + cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") + for _, sp := range subprotocols { + for _, cp := range cps { + if strings.EqualFold(sp, cp) { + return cp + } + } + } + return "" +} + +func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { + if mode == CompressionDisabled { + return nil, nil + } + + for _, ext := range websocketExtensions(r.Header) { + switch ext.name { + case "permessage-deflate": + return acceptDeflate(w, ext, mode) + // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 + // case "x-webkit-deflate-frame": + // return acceptWebkitDeflate(w, ext, mode) + } + } + return nil, nil +} + +func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + } + + if strings.HasPrefix(p, "client_max_window_bits") { + // We cannot adjust the read sliding window so cannot make use of this. + continue + } + + err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + + copts.setHeader(w.Header()) + + return copts, nil +} + +func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { + copts := mode.opts() + // The peer must explicitly request it. + copts.serverNoContextTakeover = false + + for _, p := range ext.params { + if p == "no_context_takeover" { + copts.serverNoContextTakeover = true + continue + } + + // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead + // of ignoring it as the draft spec is unclear. It says the server can ignore it + // but the server has no way of signalling to the client it was ignored as the parameters + // are set one way. + // Thus us ignoring it would make the client think we understood it which would cause issues. + // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 + // + // Either way, we're only implementing this for webkit which never sends the max_window_bits + // parameter so we don't need to worry about it. + err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) + http.Error(w, err.Error(), http.StatusBadRequest) + return nil, err + } + + s := "x-webkit-deflate-frame" + if copts.clientNoContextTakeover { + s += "; no_context_takeover" + } + w.Header().Set("Sec-WebSocket-Extensions", s) + + return copts, nil +} + +func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { + for _, t := range headerTokens(h, key) { + if strings.EqualFold(t, token) { + return true + } + } + return false +} + +type websocketExtension struct { + name string + params []string +} + +func websocketExtensions(h http.Header) []websocketExtension { + var exts []websocketExtension + extStrs := headerTokens(h, "Sec-WebSocket-Extensions") + for _, extStr := range extStrs { + if extStr == "" { + continue + } + + vals := strings.Split(extStr, ";") + for i := range vals { + vals[i] = strings.TrimSpace(vals[i]) + } + + e := websocketExtension{ + name: vals[0], + params: vals[1:], + } + + exts = append(exts, e) + } + return exts +} + +func headerTokens(h http.Header, key string) []string { + key = textproto.CanonicalMIMEHeaderKey(key) + var tokens []string + for _, v := range h[key] { + v = strings.TrimSpace(v) + for _, t := range strings.Split(v, ",") { + t = strings.TrimSpace(t) + tokens = append(tokens, t) + } + } + return tokens +} + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +func secWebSocketAccept(secWebSocketKey string) string { + h := sha1.New() + h.Write([]byte(secWebSocketKey)) + h.Write(keyGUID) + + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +} diff --git a/vendor/nhooyr.io/websocket/accept_js.go b/vendor/nhooyr.io/websocket/accept_js.go new file mode 100644 index 0000000000..daad4b79fe --- /dev/null +++ b/vendor/nhooyr.io/websocket/accept_js.go @@ -0,0 +1,20 @@ +package websocket + +import ( + "errors" + "net/http" +) + +// AcceptOptions represents Accept's options. +type AcceptOptions struct { + Subprotocols []string + InsecureSkipVerify bool + OriginPatterns []string + CompressionMode CompressionMode + CompressionThreshold int +} + +// Accept is stubbed out for Wasm. +func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { + return nil, errors.New("unimplemented") +} diff --git a/vendor/nhooyr.io/websocket/close.go b/vendor/nhooyr.io/websocket/close.go new file mode 100644 index 0000000000..7cbc19e9de --- /dev/null +++ b/vendor/nhooyr.io/websocket/close.go @@ -0,0 +1,76 @@ +package websocket + +import ( + "errors" + "fmt" +) + +// StatusCode represents a WebSocket status code. +// https://tools.ietf.org/html/rfc6455#section-7.4 +type StatusCode int + +// https://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// +// These are only the status codes defined by the protocol. +// +// You can define custom codes in the 3000-4999 range. +// The 3000-3999 range is reserved for use by libraries, frameworks and applications. +// The 4000-4999 range is reserved for private use. +const ( + StatusNormalClosure StatusCode = 1000 + StatusGoingAway StatusCode = 1001 + StatusProtocolError StatusCode = 1002 + StatusUnsupportedData StatusCode = 1003 + + // 1004 is reserved and so unexported. + statusReserved StatusCode = 1004 + + // StatusNoStatusRcvd cannot be sent in a close message. + // It is reserved for when a close message is received without + // a status code. + StatusNoStatusRcvd StatusCode = 1005 + + // StatusAbnormalClosure is exported for use only with Wasm. + // In non Wasm Go, the returned error will indicate whether the + // connection was closed abnormally. + StatusAbnormalClosure StatusCode = 1006 + + StatusInvalidFramePayloadData StatusCode = 1007 + StatusPolicyViolation StatusCode = 1008 + StatusMessageTooBig StatusCode = 1009 + StatusMandatoryExtension StatusCode = 1010 + StatusInternalError StatusCode = 1011 + StatusServiceRestart StatusCode = 1012 + StatusTryAgainLater StatusCode = 1013 + StatusBadGateway StatusCode = 1014 + + // StatusTLSHandshake is only exported for use with Wasm. + // In non Wasm Go, the returned error will indicate whether there was + // a TLS handshake failure. + StatusTLSHandshake StatusCode = 1015 +) + +// CloseError is returned when the connection is closed with a status and reason. +// +// Use Go 1.13's errors.As to check for this error. +// Also see the CloseStatus helper. +type CloseError struct { + Code StatusCode + Reason string +} + +func (ce CloseError) Error() string { + return fmt.Sprintf("status = %v and reason = %q", ce.Code, ce.Reason) +} + +// CloseStatus is a convenience wrapper around Go 1.13's errors.As to grab +// the status code from a CloseError. +// +// -1 will be returned if the passed error is nil or not a CloseError. +func CloseStatus(err error) StatusCode { + var ce CloseError + if errors.As(err, &ce) { + return ce.Code + } + return -1 +} diff --git a/vendor/nhooyr.io/websocket/close_notjs.go b/vendor/nhooyr.io/websocket/close_notjs.go new file mode 100644 index 0000000000..4251311d2e --- /dev/null +++ b/vendor/nhooyr.io/websocket/close_notjs.go @@ -0,0 +1,211 @@ +// +build !js + +package websocket + +import ( + "context" + "encoding/binary" + "errors" + "fmt" + "log" + "time" + + "nhooyr.io/websocket/internal/errd" +) + +// Close performs the WebSocket close handshake with the given status code and reason. +// +// It will write a WebSocket close frame with a timeout of 5s and then wait 5s for +// the peer to send a close frame. +// All data messages received from the peer during the close handshake will be discarded. +// +// The connection can only be closed once. Additional calls to Close +// are no-ops. +// +// The maximum length of reason must be 125 bytes. Avoid +// sending a dynamic reason. +// +// Close will unblock all goroutines interacting with the connection once +// complete. +func (c *Conn) Close(code StatusCode, reason string) error { + return c.closeHandshake(code, reason) +} + +func (c *Conn) closeHandshake(code StatusCode, reason string) (err error) { + defer errd.Wrap(&err, "failed to close WebSocket") + + writeErr := c.writeClose(code, reason) + closeHandshakeErr := c.waitCloseHandshake() + + if writeErr != nil { + return writeErr + } + + if CloseStatus(closeHandshakeErr) == -1 { + return closeHandshakeErr + } + + return nil +} + +var errAlreadyWroteClose = errors.New("already wrote close") + +func (c *Conn) writeClose(code StatusCode, reason string) error { + c.closeMu.Lock() + wroteClose := c.wroteClose + c.wroteClose = true + c.closeMu.Unlock() + if wroteClose { + return errAlreadyWroteClose + } + + ce := CloseError{ + Code: code, + Reason: reason, + } + + var p []byte + var marshalErr error + if ce.Code != StatusNoStatusRcvd { + p, marshalErr = ce.bytes() + if marshalErr != nil { + log.Printf("websocket: %v", marshalErr) + } + } + + writeErr := c.writeControl(context.Background(), opClose, p) + if CloseStatus(writeErr) != -1 { + // Not a real error if it's due to a close frame being received. + writeErr = nil + } + + // We do this after in case there was an error writing the close frame. + c.setCloseErr(fmt.Errorf("sent close frame: %w", ce)) + + if marshalErr != nil { + return marshalErr + } + return writeErr +} + +func (c *Conn) waitCloseHandshake() error { + defer c.close(nil) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*5) + defer cancel() + + err := c.readMu.lock(ctx) + if err != nil { + return err + } + defer c.readMu.unlock() + + if c.readCloseFrameErr != nil { + return c.readCloseFrameErr + } + + for { + h, err := c.readLoop(ctx) + if err != nil { + return err + } + + for i := int64(0); i < h.payloadLength; i++ { + _, err := c.br.ReadByte() + if err != nil { + return err + } + } + } +} + +func parseClosePayload(p []byte) (CloseError, error) { + if len(p) == 0 { + return CloseError{ + Code: StatusNoStatusRcvd, + }, nil + } + + if len(p) < 2 { + return CloseError{}, fmt.Errorf("close payload %q too small, cannot even contain the 2 byte status code", p) + } + + ce := CloseError{ + Code: StatusCode(binary.BigEndian.Uint16(p)), + Reason: string(p[2:]), + } + + if !validWireCloseCode(ce.Code) { + return CloseError{}, fmt.Errorf("invalid status code %v", ce.Code) + } + + return ce, nil +} + +// See http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number +// and https://tools.ietf.org/html/rfc6455#section-7.4.1 +func validWireCloseCode(code StatusCode) bool { + switch code { + case statusReserved, StatusNoStatusRcvd, StatusAbnormalClosure, StatusTLSHandshake: + return false + } + + if code >= StatusNormalClosure && code <= StatusBadGateway { + return true + } + if code >= 3000 && code <= 4999 { + return true + } + + return false +} + +func (ce CloseError) bytes() ([]byte, error) { + p, err := ce.bytesErr() + if err != nil { + err = fmt.Errorf("failed to marshal close frame: %w", err) + ce = CloseError{ + Code: StatusInternalError, + } + p, _ = ce.bytesErr() + } + return p, err +} + +const maxCloseReason = maxControlPayload - 2 + +func (ce CloseError) bytesErr() ([]byte, error) { + if len(ce.Reason) > maxCloseReason { + return nil, fmt.Errorf("reason string max is %v but got %q with length %v", maxCloseReason, ce.Reason, len(ce.Reason)) + } + + if !validWireCloseCode(ce.Code) { + return nil, fmt.Errorf("status code %v cannot be set", ce.Code) + } + + buf := make([]byte, 2+len(ce.Reason)) + binary.BigEndian.PutUint16(buf, uint16(ce.Code)) + copy(buf[2:], ce.Reason) + return buf, nil +} + +func (c *Conn) setCloseErr(err error) { + c.closeMu.Lock() + c.setCloseErrLocked(err) + c.closeMu.Unlock() +} + +func (c *Conn) setCloseErrLocked(err error) { + if c.closeErr == nil { + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + } +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/vendor/nhooyr.io/websocket/compress.go b/vendor/nhooyr.io/websocket/compress.go new file mode 100644 index 0000000000..80b46d1c1d --- /dev/null +++ b/vendor/nhooyr.io/websocket/compress.go @@ -0,0 +1,39 @@ +package websocket + +// CompressionMode represents the modes available to the deflate extension. +// See https://tools.ietf.org/html/rfc7692 +// +// A compatibility layer is implemented for the older deflate-frame extension used +// by safari. See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06 +// It will work the same in every way except that we cannot signal to the peer we +// want to use no context takeover on our side, we can only signal that they should. +// It is however currently disabled due to Safari bugs. See https://github.com/nhooyr/websocket/issues/218 +type CompressionMode int + +const ( + // CompressionNoContextTakeover grabs a new flate.Reader and flate.Writer as needed + // for every message. This applies to both server and client side. + // + // This means less efficient compression as the sliding window from previous messages + // will not be used but the memory overhead will be lower if the connections + // are long lived and seldom used. + // + // The message will only be compressed if greater than 512 bytes. + CompressionNoContextTakeover CompressionMode = iota + + // CompressionContextTakeover uses a flate.Reader and flate.Writer per connection. + // This enables reusing the sliding window from previous messages. + // As most WebSocket protocols are repetitive, this can be very efficient. + // It carries an overhead of 8 kB for every connection compared to CompressionNoContextTakeover. + // + // If the peer negotiates NoContextTakeover on the client or server side, it will be + // used instead as this is required by the RFC. + CompressionContextTakeover + + // CompressionDisabled disables the deflate extension. + // + // Use this if you are using a predominantly binary protocol with very + // little duplication in between messages or CPU and memory are more + // important than bandwidth. + CompressionDisabled +) diff --git a/vendor/nhooyr.io/websocket/compress_notjs.go b/vendor/nhooyr.io/websocket/compress_notjs.go new file mode 100644 index 0000000000..809a272c3d --- /dev/null +++ b/vendor/nhooyr.io/websocket/compress_notjs.go @@ -0,0 +1,181 @@ +// +build !js + +package websocket + +import ( + "io" + "net/http" + "sync" + + "github.com/klauspost/compress/flate" +) + +func (m CompressionMode) opts() *compressionOptions { + return &compressionOptions{ + clientNoContextTakeover: m == CompressionNoContextTakeover, + serverNoContextTakeover: m == CompressionNoContextTakeover, + } +} + +type compressionOptions struct { + clientNoContextTakeover bool + serverNoContextTakeover bool +} + +func (copts *compressionOptions) setHeader(h http.Header) { + s := "permessage-deflate" + if copts.clientNoContextTakeover { + s += "; client_no_context_takeover" + } + if copts.serverNoContextTakeover { + s += "; server_no_context_takeover" + } + h.Set("Sec-WebSocket-Extensions", s) +} + +// These bytes are required to get flate.Reader to return. +// They are removed when sending to avoid the overhead as +// WebSocket framing tell's when the message has ended but then +// we need to add them back otherwise flate.Reader keeps +// trying to return more bytes. +const deflateMessageTail = "\x00\x00\xff\xff" + +type trimLastFourBytesWriter struct { + w io.Writer + tail []byte +} + +func (tw *trimLastFourBytesWriter) reset() { + if tw != nil && tw.tail != nil { + tw.tail = tw.tail[:0] + } +} + +func (tw *trimLastFourBytesWriter) Write(p []byte) (int, error) { + if tw.tail == nil { + tw.tail = make([]byte, 0, 4) + } + + extra := len(tw.tail) + len(p) - 4 + + if extra <= 0 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Now we need to write as many extra bytes as we can from the previous tail. + if extra > len(tw.tail) { + extra = len(tw.tail) + } + if extra > 0 { + _, err := tw.w.Write(tw.tail[:extra]) + if err != nil { + return 0, err + } + + // Shift remaining bytes in tail over. + n := copy(tw.tail, tw.tail[extra:]) + tw.tail = tw.tail[:n] + } + + // If p is less than or equal to 4 bytes, + // all of it is is part of the tail. + if len(p) <= 4 { + tw.tail = append(tw.tail, p...) + return len(p), nil + } + + // Otherwise, only the last 4 bytes are. + tw.tail = append(tw.tail, p[len(p)-4:]...) + + p = p[:len(p)-4] + n, err := tw.w.Write(p) + return n + 4, err +} + +var flateReaderPool sync.Pool + +func getFlateReader(r io.Reader, dict []byte) io.Reader { + fr, ok := flateReaderPool.Get().(io.Reader) + if !ok { + return flate.NewReaderDict(r, dict) + } + fr.(flate.Resetter).Reset(r, dict) + return fr +} + +func putFlateReader(fr io.Reader) { + flateReaderPool.Put(fr) +} + +type slidingWindow struct { + buf []byte +} + +var swPoolMu sync.RWMutex +var swPool = map[int]*sync.Pool{} + +func slidingWindowPool(n int) *sync.Pool { + swPoolMu.RLock() + p, ok := swPool[n] + swPoolMu.RUnlock() + if ok { + return p + } + + p = &sync.Pool{} + + swPoolMu.Lock() + swPool[n] = p + swPoolMu.Unlock() + + return p +} + +func (sw *slidingWindow) init(n int) { + if sw.buf != nil { + return + } + + if n == 0 { + n = 32768 + } + + p := slidingWindowPool(n) + buf, ok := p.Get().([]byte) + if ok { + sw.buf = buf[:0] + } else { + sw.buf = make([]byte, 0, n) + } +} + +func (sw *slidingWindow) close() { + if sw.buf == nil { + return + } + + swPoolMu.Lock() + swPool[cap(sw.buf)].Put(sw.buf) + swPoolMu.Unlock() + sw.buf = nil +} + +func (sw *slidingWindow) write(p []byte) { + if len(p) >= cap(sw.buf) { + sw.buf = sw.buf[:cap(sw.buf)] + p = p[len(p)-cap(sw.buf):] + copy(sw.buf, p) + return + } + + left := cap(sw.buf) - len(sw.buf) + if left < len(p) { + // We need to shift spaceNeeded bytes from the end to make room for p at the end. + spaceNeeded := len(p) - left + copy(sw.buf, sw.buf[spaceNeeded:]) + sw.buf = sw.buf[:len(sw.buf)-spaceNeeded] + } + + sw.buf = append(sw.buf, p...) +} diff --git a/vendor/nhooyr.io/websocket/conn.go b/vendor/nhooyr.io/websocket/conn.go new file mode 100644 index 0000000000..a41808be3f --- /dev/null +++ b/vendor/nhooyr.io/websocket/conn.go @@ -0,0 +1,13 @@ +package websocket + +// MessageType represents the type of a WebSocket message. +// See https://tools.ietf.org/html/rfc6455#section-5.6 +type MessageType int + +// MessageType constants. +const ( + // MessageText is for UTF-8 encoded text messages like JSON. + MessageText MessageType = iota + 1 + // MessageBinary is for binary messages like protobufs. + MessageBinary +) diff --git a/vendor/nhooyr.io/websocket/conn_notjs.go b/vendor/nhooyr.io/websocket/conn_notjs.go new file mode 100644 index 0000000000..0c85ab7711 --- /dev/null +++ b/vendor/nhooyr.io/websocket/conn_notjs.go @@ -0,0 +1,265 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "runtime" + "strconv" + "sync" + "sync/atomic" +) + +// Conn represents a WebSocket connection. +// All methods may be called concurrently except for Reader and Read. +// +// You must always read from the connection. Otherwise control +// frames will not be handled. See Reader and CloseRead. +// +// Be sure to call Close on the connection when you +// are finished with it to release associated resources. +// +// On any error from any method, the connection is closed +// with an appropriate reason. +type Conn struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + br *bufio.Reader + bw *bufio.Writer + + readTimeout chan context.Context + writeTimeout chan context.Context + + // Read state. + readMu *mu + readHeaderBuf [8]byte + readControlBuf [maxControlPayload]byte + msgReader *msgReader + readCloseFrameErr error + + // Write state. + msgWriterState *msgWriterState + writeFrameMu *mu + writeBuf []byte + writeHeaderBuf [8]byte + writeHeader header + + closed chan struct{} + closeMu sync.Mutex + closeErr error + wroteClose bool + + pingCounter int32 + activePingsMu sync.Mutex + activePings map[string]chan<- struct{} +} + +type connConfig struct { + subprotocol string + rwc io.ReadWriteCloser + client bool + copts *compressionOptions + flateThreshold int + + br *bufio.Reader + bw *bufio.Writer +} + +func newConn(cfg connConfig) *Conn { + c := &Conn{ + subprotocol: cfg.subprotocol, + rwc: cfg.rwc, + client: cfg.client, + copts: cfg.copts, + flateThreshold: cfg.flateThreshold, + + br: cfg.br, + bw: cfg.bw, + + readTimeout: make(chan context.Context), + writeTimeout: make(chan context.Context), + + closed: make(chan struct{}), + activePings: make(map[string]chan<- struct{}), + } + + c.readMu = newMu(c) + c.writeFrameMu = newMu(c) + + c.msgReader = newMsgReader(c) + + c.msgWriterState = newMsgWriterState(c) + if c.client { + c.writeBuf = extractBufioWriterBuf(c.bw, c.rwc) + } + + if c.flate() && c.flateThreshold == 0 { + c.flateThreshold = 128 + if !c.msgWriterState.flateContextTakeover() { + c.flateThreshold = 512 + } + } + + runtime.SetFinalizer(c, func(c *Conn) { + c.close(errors.New("connection garbage collected")) + }) + + go c.timeoutLoop() + + return c +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +func (c *Conn) close(err error) { + c.closeMu.Lock() + defer c.closeMu.Unlock() + + if c.isClosed() { + return + } + c.setCloseErrLocked(err) + close(c.closed) + runtime.SetFinalizer(c, nil) + + // Have to close after c.closed is closed to ensure any goroutine that wakes up + // from the connection being closed also sees that c.closed is closed and returns + // closeErr. + c.rwc.Close() + + go func() { + c.msgWriterState.close() + + c.msgReader.close() + }() +} + +func (c *Conn) timeoutLoop() { + readCtx := context.Background() + writeCtx := context.Background() + + for { + select { + case <-c.closed: + return + + case writeCtx = <-c.writeTimeout: + case readCtx = <-c.readTimeout: + + case <-readCtx.Done(): + c.setCloseErr(fmt.Errorf("read timed out: %w", readCtx.Err())) + go c.writeError(StatusPolicyViolation, errors.New("timed out")) + case <-writeCtx.Done(): + c.close(fmt.Errorf("write timed out: %w", writeCtx.Err())) + return + } + } +} + +func (c *Conn) flate() bool { + return c.copts != nil +} + +// Ping sends a ping to the peer and waits for a pong. +// Use this to measure latency or ensure the peer is responsive. +// Ping must be called concurrently with Reader as it does +// not read from the connection but instead waits for a Reader call +// to read the pong. +// +// TCP Keepalives should suffice for most use cases. +func (c *Conn) Ping(ctx context.Context) error { + p := atomic.AddInt32(&c.pingCounter, 1) + + err := c.ping(ctx, strconv.Itoa(int(p))) + if err != nil { + return fmt.Errorf("failed to ping: %w", err) + } + return nil +} + +func (c *Conn) ping(ctx context.Context, p string) error { + pong := make(chan struct{}, 1) + + c.activePingsMu.Lock() + c.activePings[p] = pong + c.activePingsMu.Unlock() + + defer func() { + c.activePingsMu.Lock() + delete(c.activePings, p) + c.activePingsMu.Unlock() + }() + + err := c.writeControl(ctx, opPing, []byte(p)) + if err != nil { + return err + } + + select { + case <-c.closed: + return c.closeErr + case <-ctx.Done(): + err := fmt.Errorf("failed to wait for pong: %w", ctx.Err()) + c.close(err) + return err + case <-pong: + return nil + } +} + +type mu struct { + c *Conn + ch chan struct{} +} + +func newMu(c *Conn) *mu { + return &mu{ + c: c, + ch: make(chan struct{}, 1), + } +} + +func (m *mu) forceLock() { + m.ch <- struct{}{} +} + +func (m *mu) lock(ctx context.Context) error { + select { + case <-m.c.closed: + return m.c.closeErr + case <-ctx.Done(): + err := fmt.Errorf("failed to acquire lock: %w", ctx.Err()) + m.c.close(err) + return err + case m.ch <- struct{}{}: + // To make sure the connection is certainly alive. + // As it's possible the send on m.ch was selected + // over the receive on closed. + select { + case <-m.c.closed: + // Make sure to release. + m.unlock() + return m.c.closeErr + default: + } + return nil + } +} + +func (m *mu) unlock() { + select { + case <-m.ch: + default: + } +} diff --git a/vendor/nhooyr.io/websocket/dial.go b/vendor/nhooyr.io/websocket/dial.go new file mode 100644 index 0000000000..7a7787ff71 --- /dev/null +++ b/vendor/nhooyr.io/websocket/dial.go @@ -0,0 +1,292 @@ +// +build !js + +package websocket + +import ( + "bufio" + "bytes" + "context" + "crypto/rand" + "encoding/base64" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/url" + "strings" + "sync" + "time" + + "nhooyr.io/websocket/internal/errd" +) + +// DialOptions represents Dial's options. +type DialOptions struct { + // HTTPClient is used for the connection. + // Its Transport must return writable bodies for WebSocket handshakes. + // http.Transport does beginning with Go 1.12. + HTTPClient *http.Client + + // HTTPHeader specifies the HTTP headers included in the handshake request. + HTTPHeader http.Header + + // Subprotocols lists the WebSocket subprotocols to negotiate with the server. + Subprotocols []string + + // CompressionMode controls the compression mode. + // Defaults to CompressionNoContextTakeover. + // + // See docs on CompressionMode for details. + CompressionMode CompressionMode + + // CompressionThreshold controls the minimum size of a message before compression is applied. + // + // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes + // for CompressionContextTakeover. + CompressionThreshold int +} + +// Dial performs a WebSocket handshake on url. +// +// The response is the WebSocket handshake response from the server. +// You never need to close resp.Body yourself. +// +// If an error occurs, the returned response may be non nil. +// However, you can only read the first 1024 bytes of the body. +// +// This function requires at least Go 1.12 as it uses a new feature +// in net/http to perform WebSocket handshakes. +// See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 +// +// URLs with http/https schemes will work and are interpreted as ws/wss. +func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { + return dial(ctx, u, opts, nil) +} + +func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { + defer errd.Wrap(&err, "failed to WebSocket dial") + + if opts == nil { + opts = &DialOptions{} + } + + opts = &*opts + if opts.HTTPClient == nil { + opts.HTTPClient = http.DefaultClient + } else if opts.HTTPClient.Timeout > 0 { + var cancel context.CancelFunc + + ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) + defer cancel() + + newClient := *opts.HTTPClient + newClient.Timeout = 0 + opts.HTTPClient = &newClient + } + + if opts.HTTPHeader == nil { + opts.HTTPHeader = http.Header{} + } + + secWebSocketKey, err := secWebSocketKey(rand) + if err != nil { + return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) + } + + var copts *compressionOptions + if opts.CompressionMode != CompressionDisabled { + copts = opts.CompressionMode.opts() + } + + resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) + if err != nil { + return nil, resp, err + } + respBody := resp.Body + resp.Body = nil + defer func() { + if err != nil { + // We read a bit of the body for easier debugging. + r := io.LimitReader(respBody, 1024) + + timer := time.AfterFunc(time.Second*3, func() { + respBody.Close() + }) + defer timer.Stop() + + b, _ := ioutil.ReadAll(r) + respBody.Close() + resp.Body = ioutil.NopCloser(bytes.NewReader(b)) + } + }() + + copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) + if err != nil { + return nil, resp, err + } + + rwc, ok := respBody.(io.ReadWriteCloser) + if !ok { + return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) + } + + return newConn(connConfig{ + subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), + rwc: rwc, + client: true, + copts: copts, + flateThreshold: opts.CompressionThreshold, + br: getBufioReader(rwc), + bw: getBufioWriter(rwc), + }), resp, nil +} + +func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { + u, err := url.Parse(urls) + if err != nil { + return nil, fmt.Errorf("failed to parse url: %w", err) + } + + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + case "http", "https": + default: + return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) + } + + req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) + req.Header = opts.HTTPHeader.Clone() + req.Header.Set("Connection", "Upgrade") + req.Header.Set("Upgrade", "websocket") + req.Header.Set("Sec-WebSocket-Version", "13") + req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) + if len(opts.Subprotocols) > 0 { + req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) + } + if copts != nil { + copts.setHeader(req.Header) + } + + resp, err := opts.HTTPClient.Do(req) + if err != nil { + return nil, fmt.Errorf("failed to send handshake request: %w", err) + } + return resp, nil +} + +func secWebSocketKey(rr io.Reader) (string, error) { + if rr == nil { + rr = rand.Reader + } + b := make([]byte, 16) + _, err := io.ReadFull(rr, b) + if err != nil { + return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) + } + return base64.StdEncoding.EncodeToString(b), nil +} + +func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { + if resp.StatusCode != http.StatusSwitchingProtocols { + return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) + } + + if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { + return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) + } + + if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { + return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) + } + + if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { + return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", + resp.Header.Get("Sec-WebSocket-Accept"), + secWebSocketKey, + ) + } + + err := verifySubprotocol(opts.Subprotocols, resp) + if err != nil { + return nil, err + } + + return verifyServerExtensions(copts, resp.Header) +} + +func verifySubprotocol(subprotos []string, resp *http.Response) error { + proto := resp.Header.Get("Sec-WebSocket-Protocol") + if proto == "" { + return nil + } + + for _, sp2 := range subprotos { + if strings.EqualFold(sp2, proto) { + return nil + } + } + + return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) +} + +func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { + exts := websocketExtensions(h) + if len(exts) == 0 { + return nil, nil + } + + ext := exts[0] + if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { + return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) + } + + copts = &*copts + + for _, p := range ext.params { + switch p { + case "client_no_context_takeover": + copts.clientNoContextTakeover = true + continue + case "server_no_context_takeover": + copts.serverNoContextTakeover = true + continue + } + + return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) + } + + return copts, nil +} + +var bufioReaderPool sync.Pool + +func getBufioReader(r io.Reader) *bufio.Reader { + br, ok := bufioReaderPool.Get().(*bufio.Reader) + if !ok { + return bufio.NewReader(r) + } + br.Reset(r) + return br +} + +func putBufioReader(br *bufio.Reader) { + bufioReaderPool.Put(br) +} + +var bufioWriterPool sync.Pool + +func getBufioWriter(w io.Writer) *bufio.Writer { + bw, ok := bufioWriterPool.Get().(*bufio.Writer) + if !ok { + return bufio.NewWriter(w) + } + bw.Reset(w) + return bw +} + +func putBufioWriter(bw *bufio.Writer) { + bufioWriterPool.Put(bw) +} diff --git a/vendor/nhooyr.io/websocket/doc.go b/vendor/nhooyr.io/websocket/doc.go new file mode 100644 index 0000000000..efa920e3b6 --- /dev/null +++ b/vendor/nhooyr.io/websocket/doc.go @@ -0,0 +1,32 @@ +// +build !js + +// Package websocket implements the RFC 6455 WebSocket protocol. +// +// https://tools.ietf.org/html/rfc6455 +// +// Use Dial to dial a WebSocket server. +// +// Use Accept to accept a WebSocket client. +// +// Conn represents the resulting WebSocket connection. +// +// The examples are the best way to understand how to correctly use the library. +// +// The wsjson and wspb subpackages contain helpers for JSON and protobuf messages. +// +// More documentation at https://nhooyr.io/websocket. +// +// Wasm +// +// The client side supports compiling to Wasm. +// It wraps the WebSocket browser API. +// +// See https://developer.mozilla.org/en-US/docs/Web/API/WebSocket +// +// Some important caveats to be aware of: +// +// - Accept always errors out +// - Conn.Ping is no-op +// - HTTPClient, HTTPHeader and CompressionMode in DialOptions are no-op +// - *http.Response from Dial is &http.Response{} with a 101 status code on success +package websocket // import "nhooyr.io/websocket" diff --git a/vendor/nhooyr.io/websocket/frame.go b/vendor/nhooyr.io/websocket/frame.go new file mode 100644 index 0000000000..2a036f944a --- /dev/null +++ b/vendor/nhooyr.io/websocket/frame.go @@ -0,0 +1,294 @@ +package websocket + +import ( + "bufio" + "encoding/binary" + "fmt" + "io" + "math" + "math/bits" + + "nhooyr.io/websocket/internal/errd" +) + +// opcode represents a WebSocket opcode. +type opcode int + +// https://tools.ietf.org/html/rfc6455#section-11.8. +const ( + opContinuation opcode = iota + opText + opBinary + // 3 - 7 are reserved for further non-control frames. + _ + _ + _ + _ + _ + opClose + opPing + opPong + // 11-16 are reserved for further control frames. +) + +// header represents a WebSocket frame header. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +type header struct { + fin bool + rsv1 bool + rsv2 bool + rsv3 bool + opcode opcode + + payloadLength int64 + + masked bool + maskKey uint32 +} + +// readFrameHeader reads a header from the reader. +// See https://tools.ietf.org/html/rfc6455#section-5.2. +func readFrameHeader(r *bufio.Reader, readBuf []byte) (h header, err error) { + defer errd.Wrap(&err, "failed to read frame header") + + b, err := r.ReadByte() + if err != nil { + return header{}, err + } + + h.fin = b&(1<<7) != 0 + h.rsv1 = b&(1<<6) != 0 + h.rsv2 = b&(1<<5) != 0 + h.rsv3 = b&(1<<4) != 0 + + h.opcode = opcode(b & 0xf) + + b, err = r.ReadByte() + if err != nil { + return header{}, err + } + + h.masked = b&(1<<7) != 0 + + payloadLength := b &^ (1 << 7) + switch { + case payloadLength < 126: + h.payloadLength = int64(payloadLength) + case payloadLength == 126: + _, err = io.ReadFull(r, readBuf[:2]) + h.payloadLength = int64(binary.BigEndian.Uint16(readBuf)) + case payloadLength == 127: + _, err = io.ReadFull(r, readBuf) + h.payloadLength = int64(binary.BigEndian.Uint64(readBuf)) + } + if err != nil { + return header{}, err + } + + if h.payloadLength < 0 { + return header{}, fmt.Errorf("received negative payload length: %v", h.payloadLength) + } + + if h.masked { + _, err = io.ReadFull(r, readBuf[:4]) + if err != nil { + return header{}, err + } + h.maskKey = binary.LittleEndian.Uint32(readBuf) + } + + return h, nil +} + +// maxControlPayload is the maximum length of a control frame payload. +// See https://tools.ietf.org/html/rfc6455#section-5.5. +const maxControlPayload = 125 + +// writeFrameHeader writes the bytes of the header to w. +// See https://tools.ietf.org/html/rfc6455#section-5.2 +func writeFrameHeader(h header, w *bufio.Writer, buf []byte) (err error) { + defer errd.Wrap(&err, "failed to write frame header") + + var b byte + if h.fin { + b |= 1 << 7 + } + if h.rsv1 { + b |= 1 << 6 + } + if h.rsv2 { + b |= 1 << 5 + } + if h.rsv3 { + b |= 1 << 4 + } + + b |= byte(h.opcode) + + err = w.WriteByte(b) + if err != nil { + return err + } + + lengthByte := byte(0) + if h.masked { + lengthByte |= 1 << 7 + } + + switch { + case h.payloadLength > math.MaxUint16: + lengthByte |= 127 + case h.payloadLength > 125: + lengthByte |= 126 + case h.payloadLength >= 0: + lengthByte |= byte(h.payloadLength) + } + err = w.WriteByte(lengthByte) + if err != nil { + return err + } + + switch { + case h.payloadLength > math.MaxUint16: + binary.BigEndian.PutUint64(buf, uint64(h.payloadLength)) + _, err = w.Write(buf) + case h.payloadLength > 125: + binary.BigEndian.PutUint16(buf, uint16(h.payloadLength)) + _, err = w.Write(buf[:2]) + } + if err != nil { + return err + } + + if h.masked { + binary.LittleEndian.PutUint32(buf, h.maskKey) + _, err = w.Write(buf[:4]) + if err != nil { + return err + } + } + + return nil +} + +// mask applies the WebSocket masking algorithm to p +// with the given key. +// See https://tools.ietf.org/html/rfc6455#section-5.3 +// +// The returned value is the correctly rotated key to +// to continue to mask/unmask the message. +// +// It is optimized for LittleEndian and expects the key +// to be in little endian. +// +// See https://github.com/golang/go/issues/31586 +func mask(key uint32, b []byte) uint32 { + if len(b) >= 8 { + key64 := uint64(key)<<32 | uint64(key) + + // At some point in the future we can clean these unrolled loops up. + // See https://github.com/golang/go/issues/31586#issuecomment-487436401 + + // Then we xor until b is less than 128 bytes. + for len(b) >= 128 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + v = binary.LittleEndian.Uint64(b[64:72]) + binary.LittleEndian.PutUint64(b[64:72], v^key64) + v = binary.LittleEndian.Uint64(b[72:80]) + binary.LittleEndian.PutUint64(b[72:80], v^key64) + v = binary.LittleEndian.Uint64(b[80:88]) + binary.LittleEndian.PutUint64(b[80:88], v^key64) + v = binary.LittleEndian.Uint64(b[88:96]) + binary.LittleEndian.PutUint64(b[88:96], v^key64) + v = binary.LittleEndian.Uint64(b[96:104]) + binary.LittleEndian.PutUint64(b[96:104], v^key64) + v = binary.LittleEndian.Uint64(b[104:112]) + binary.LittleEndian.PutUint64(b[104:112], v^key64) + v = binary.LittleEndian.Uint64(b[112:120]) + binary.LittleEndian.PutUint64(b[112:120], v^key64) + v = binary.LittleEndian.Uint64(b[120:128]) + binary.LittleEndian.PutUint64(b[120:128], v^key64) + b = b[128:] + } + + // Then we xor until b is less than 64 bytes. + for len(b) >= 64 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + v = binary.LittleEndian.Uint64(b[32:40]) + binary.LittleEndian.PutUint64(b[32:40], v^key64) + v = binary.LittleEndian.Uint64(b[40:48]) + binary.LittleEndian.PutUint64(b[40:48], v^key64) + v = binary.LittleEndian.Uint64(b[48:56]) + binary.LittleEndian.PutUint64(b[48:56], v^key64) + v = binary.LittleEndian.Uint64(b[56:64]) + binary.LittleEndian.PutUint64(b[56:64], v^key64) + b = b[64:] + } + + // Then we xor until b is less than 32 bytes. + for len(b) >= 32 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + v = binary.LittleEndian.Uint64(b[16:24]) + binary.LittleEndian.PutUint64(b[16:24], v^key64) + v = binary.LittleEndian.Uint64(b[24:32]) + binary.LittleEndian.PutUint64(b[24:32], v^key64) + b = b[32:] + } + + // Then we xor until b is less than 16 bytes. + for len(b) >= 16 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + v = binary.LittleEndian.Uint64(b[8:16]) + binary.LittleEndian.PutUint64(b[8:16], v^key64) + b = b[16:] + } + + // Then we xor until b is less than 8 bytes. + for len(b) >= 8 { + v := binary.LittleEndian.Uint64(b) + binary.LittleEndian.PutUint64(b, v^key64) + b = b[8:] + } + } + + // Then we xor until b is less than 4 bytes. + for len(b) >= 4 { + v := binary.LittleEndian.Uint32(b) + binary.LittleEndian.PutUint32(b, v^key) + b = b[4:] + } + + // xor remaining bytes. + for i := range b { + b[i] ^= byte(key) + key = bits.RotateLeft32(key, -8) + } + + return key +} diff --git a/vendor/nhooyr.io/websocket/go.mod b/vendor/nhooyr.io/websocket/go.mod new file mode 100644 index 0000000000..c5f1a20f59 --- /dev/null +++ b/vendor/nhooyr.io/websocket/go.mod @@ -0,0 +1,15 @@ +module nhooyr.io/websocket + +go 1.13 + +require ( + github.com/gin-gonic/gin v1.6.3 + github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee // indirect + github.com/gobwas/pool v0.2.0 // indirect + github.com/gobwas/ws v1.0.2 + github.com/golang/protobuf v1.3.5 + github.com/google/go-cmp v0.4.0 + github.com/gorilla/websocket v1.4.1 + github.com/klauspost/compress v1.10.3 + golang.org/x/time v0.0.0-20191024005414-555d28b269f0 +) diff --git a/vendor/nhooyr.io/websocket/go.sum b/vendor/nhooyr.io/websocket/go.sum new file mode 100644 index 0000000000..155c301326 --- /dev/null +++ b/vendor/nhooyr.io/websocket/go.sum @@ -0,0 +1,64 @@ +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.6.3 h1:ahKqKTFpO5KTPHxWZjEdPScmYaGtLo8Y4DMHoEsnp14= +github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= +github.com/go-playground/assert/v2 v2.0.1 h1:MsBgLAaY856+nPRTKrp3/OZK38U/wa0CcBYNjji3q3A= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.13.0 h1:HyWk6mgj5qFqCT5fjGBuRArbVDfE4hi8+e8ceBS/t7Q= +github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8= +github.com/go-playground/universal-translator v0.17.0 h1:icxd5fm+REJzpZx7ZfpaD876Lmtgy7VtROAbHHXk8no= +github.com/go-playground/universal-translator v0.17.0/go.mod h1:UkSxE5sNxxRwHyU+Scu5vgOQjsIJAF8j9muTVoKLVtA= +github.com/go-playground/validator/v10 v10.2.0 h1:KgJ0snyC2R9VXYN2rneOtQcw5aHQB1Vv0sFl1UcHBOY= +github.com/go-playground/validator/v10 v10.2.0/go.mod h1:uOYAAleCW8F/7oMFd6aG0GOhaH6EGOAJShg8Id5JGkI= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee h1:s+21KNqlpePfkah2I+gwHF8xmJWRjooY+5248k6m4A0= +github.com/gobwas/httphead v0.0.0-20180130184737-2c6c146eadee/go.mod h1:L0fX3K22YWvt/FAX9NnzrNzcI4wNYi9Yku4O0LKYflo= +github.com/gobwas/pool v0.2.0 h1:QEmUOlnSjWtnpRGHF3SauEiOsy82Cup83Vf2LcMlnc8= +github.com/gobwas/pool v0.2.0/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.0.2 h1:CoAavW/wd/kulfZmSIBt6p24n4j7tHgNVCjsfHVNUbo= +github.com/gobwas/ws v1.0.2/go.mod h1:szmBTxLgaFppYjEmNtny/v3w89xOydFnnZMcgRRu/EM= +github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= +github.com/golang/protobuf v1.3.5 h1:F768QJ1E9tib+q5Sc8MkdJi1RxLTbRcTf8LJV56aRls= +github.com/golang/protobuf v1.3.5/go.mod h1:6O5/vntMXwX2lRkT1hjjk0nAC1IDOTvTlVgjlRvqsdk= +github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= +github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM= +github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= +github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= +github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= +github.com/klauspost/compress v1.10.3 h1:OP96hzwJVBIHYU52pVTI6CczrxPvrGfgqF9N5eTO0Q8= +github.com/klauspost/compress v1.10.3/go.mod h1:aoV0uJVorq1K+umq18yTdKaF57EivdYsUV+/s2qKfXs= +github.com/leodido/go-urn v1.2.0 h1:hpXL4XnriNwQ/ABnpepYM/1vCLWNDfUNts8dX3xTG6Y= +github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg= +github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo= +github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw= +github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs= +github.com/ugorji/go/codec v1.1.7/go.mod h1:Ax+UKWsSmolVDwsd+7N3ZtXu+yMGCf907BLYF3GoBXY= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42 h1:vEOn+mP2zCOVzKckCZy6YsCtDblrpj/w7B9nxGNELpg= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= +golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= +gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= diff --git a/vendor/nhooyr.io/websocket/internal/bpool/bpool.go b/vendor/nhooyr.io/websocket/internal/bpool/bpool.go new file mode 100644 index 0000000000..aa826fba2b --- /dev/null +++ b/vendor/nhooyr.io/websocket/internal/bpool/bpool.go @@ -0,0 +1,24 @@ +package bpool + +import ( + "bytes" + "sync" +) + +var bpool sync.Pool + +// Get returns a buffer from the pool or creates a new one if +// the pool is empty. +func Get() *bytes.Buffer { + b := bpool.Get() + if b == nil { + return &bytes.Buffer{} + } + return b.(*bytes.Buffer) +} + +// Put returns a buffer into the pool. +func Put(b *bytes.Buffer) { + b.Reset() + bpool.Put(b) +} diff --git a/vendor/nhooyr.io/websocket/internal/errd/wrap.go b/vendor/nhooyr.io/websocket/internal/errd/wrap.go new file mode 100644 index 0000000000..6e779131af --- /dev/null +++ b/vendor/nhooyr.io/websocket/internal/errd/wrap.go @@ -0,0 +1,14 @@ +package errd + +import ( + "fmt" +) + +// Wrap wraps err with fmt.Errorf if err is non nil. +// Intended for use with defer and a named error return. +// Inspired by https://github.com/golang/go/issues/32676. +func Wrap(err *error, f string, v ...interface{}) { + if *err != nil { + *err = fmt.Errorf(f+": %w", append(v, *err)...) + } +} diff --git a/vendor/nhooyr.io/websocket/internal/wsjs/wsjs_js.go b/vendor/nhooyr.io/websocket/internal/wsjs/wsjs_js.go new file mode 100644 index 0000000000..26ffb45625 --- /dev/null +++ b/vendor/nhooyr.io/websocket/internal/wsjs/wsjs_js.go @@ -0,0 +1,170 @@ +// +build js + +// Package wsjs implements typed access to the browser javascript WebSocket API. +// +// https://developer.mozilla.org/en-US/docs/Web/API/WebSocket +package wsjs + +import ( + "syscall/js" +) + +func handleJSError(err *error, onErr func()) { + r := recover() + + if jsErr, ok := r.(js.Error); ok { + *err = jsErr + + if onErr != nil { + onErr() + } + return + } + + if r != nil { + panic(r) + } +} + +// New is a wrapper around the javascript WebSocket constructor. +func New(url string, protocols []string) (c WebSocket, err error) { + defer handleJSError(&err, func() { + c = WebSocket{} + }) + + jsProtocols := make([]interface{}, len(protocols)) + for i, p := range protocols { + jsProtocols[i] = p + } + + c = WebSocket{ + v: js.Global().Get("WebSocket").New(url, jsProtocols), + } + + c.setBinaryType("arraybuffer") + + return c, nil +} + +// WebSocket is a wrapper around a javascript WebSocket object. +type WebSocket struct { + v js.Value +} + +func (c WebSocket) setBinaryType(typ string) { + c.v.Set("binaryType", string(typ)) +} + +func (c WebSocket) addEventListener(eventType string, fn func(e js.Value)) func() { + f := js.FuncOf(func(this js.Value, args []js.Value) interface{} { + fn(args[0]) + return nil + }) + c.v.Call("addEventListener", eventType, f) + + return func() { + c.v.Call("removeEventListener", eventType, f) + f.Release() + } +} + +// CloseEvent is the type passed to a WebSocket close handler. +type CloseEvent struct { + Code uint16 + Reason string + WasClean bool +} + +// OnClose registers a function to be called when the WebSocket is closed. +func (c WebSocket) OnClose(fn func(CloseEvent)) (remove func()) { + return c.addEventListener("close", func(e js.Value) { + ce := CloseEvent{ + Code: uint16(e.Get("code").Int()), + Reason: e.Get("reason").String(), + WasClean: e.Get("wasClean").Bool(), + } + fn(ce) + }) +} + +// OnError registers a function to be called when there is an error +// with the WebSocket. +func (c WebSocket) OnError(fn func(e js.Value)) (remove func()) { + return c.addEventListener("error", fn) +} + +// MessageEvent is the type passed to a message handler. +type MessageEvent struct { + // string or []byte. + Data interface{} + + // There are more fields to the interface but we don't use them. + // See https://developer.mozilla.org/en-US/docs/Web/API/MessageEvent +} + +// OnMessage registers a function to be called when the WebSocket receives a message. +func (c WebSocket) OnMessage(fn func(m MessageEvent)) (remove func()) { + return c.addEventListener("message", func(e js.Value) { + var data interface{} + + arrayBuffer := e.Get("data") + if arrayBuffer.Type() == js.TypeString { + data = arrayBuffer.String() + } else { + data = extractArrayBuffer(arrayBuffer) + } + + me := MessageEvent{ + Data: data, + } + fn(me) + + return + }) +} + +// Subprotocol returns the WebSocket subprotocol in use. +func (c WebSocket) Subprotocol() string { + return c.v.Get("protocol").String() +} + +// OnOpen registers a function to be called when the WebSocket is opened. +func (c WebSocket) OnOpen(fn func(e js.Value)) (remove func()) { + return c.addEventListener("open", fn) +} + +// Close closes the WebSocket with the given code and reason. +func (c WebSocket) Close(code int, reason string) (err error) { + defer handleJSError(&err, nil) + c.v.Call("close", code, reason) + return err +} + +// SendText sends the given string as a text message +// on the WebSocket. +func (c WebSocket) SendText(v string) (err error) { + defer handleJSError(&err, nil) + c.v.Call("send", v) + return err +} + +// SendBytes sends the given message as a binary message +// on the WebSocket. +func (c WebSocket) SendBytes(v []byte) (err error) { + defer handleJSError(&err, nil) + c.v.Call("send", uint8Array(v)) + return err +} + +func extractArrayBuffer(arrayBuffer js.Value) []byte { + uint8Array := js.Global().Get("Uint8Array").New(arrayBuffer) + dst := make([]byte, uint8Array.Length()) + js.CopyBytesToGo(dst, uint8Array) + return dst +} + +func uint8Array(src []byte) js.Value { + uint8Array := js.Global().Get("Uint8Array").New(len(src)) + js.CopyBytesToJS(uint8Array, src) + return uint8Array +} diff --git a/vendor/nhooyr.io/websocket/internal/xsync/go.go b/vendor/nhooyr.io/websocket/internal/xsync/go.go new file mode 100644 index 0000000000..7a61f27fa2 --- /dev/null +++ b/vendor/nhooyr.io/websocket/internal/xsync/go.go @@ -0,0 +1,25 @@ +package xsync + +import ( + "fmt" +) + +// Go allows running a function in another goroutine +// and waiting for its error. +func Go(fn func() error) <-chan error { + errs := make(chan error, 1) + go func() { + defer func() { + r := recover() + if r != nil { + select { + case errs <- fmt.Errorf("panic in go fn: %v", r): + default: + } + } + }() + errs <- fn() + }() + + return errs +} diff --git a/vendor/nhooyr.io/websocket/internal/xsync/int64.go b/vendor/nhooyr.io/websocket/internal/xsync/int64.go new file mode 100644 index 0000000000..a0c4020415 --- /dev/null +++ b/vendor/nhooyr.io/websocket/internal/xsync/int64.go @@ -0,0 +1,23 @@ +package xsync + +import ( + "sync/atomic" +) + +// Int64 represents an atomic int64. +type Int64 struct { + // We do not use atomic.Load/StoreInt64 since it does not + // work on 32 bit computers but we need 64 bit integers. + i atomic.Value +} + +// Load loads the int64. +func (v *Int64) Load() int64 { + i, _ := v.i.Load().(int64) + return i +} + +// Store stores the int64. +func (v *Int64) Store(i int64) { + v.i.Store(i) +} diff --git a/vendor/nhooyr.io/websocket/netconn.go b/vendor/nhooyr.io/websocket/netconn.go new file mode 100644 index 0000000000..64aadf0b99 --- /dev/null +++ b/vendor/nhooyr.io/websocket/netconn.go @@ -0,0 +1,166 @@ +package websocket + +import ( + "context" + "fmt" + "io" + "math" + "net" + "sync" + "time" +) + +// NetConn converts a *websocket.Conn into a net.Conn. +// +// It's for tunneling arbitrary protocols over WebSockets. +// Few users of the library will need this but it's tricky to implement +// correctly and so provided in the library. +// See https://github.com/nhooyr/websocket/issues/100. +// +// Every Write to the net.Conn will correspond to a message write of +// the given type on *websocket.Conn. +// +// The passed ctx bounds the lifetime of the net.Conn. If cancelled, +// all reads and writes on the net.Conn will be cancelled. +// +// If a message is read that is not of the correct type, the connection +// will be closed with StatusUnsupportedData and an error will be returned. +// +// Close will close the *websocket.Conn with StatusNormalClosure. +// +// When a deadline is hit, the connection will be closed. This is +// different from most net.Conn implementations where only the +// reading/writing goroutines are interrupted but the connection is kept alive. +// +// The Addr methods will return a mock net.Addr that returns "websocket" for Network +// and "websocket/unknown-addr" for String. +// +// A received StatusNormalClosure or StatusGoingAway close frame will be translated to +// io.EOF when reading. +func NetConn(ctx context.Context, c *Conn, msgType MessageType) net.Conn { + nc := &netConn{ + c: c, + msgType: msgType, + } + + var cancel context.CancelFunc + nc.writeContext, cancel = context.WithCancel(ctx) + nc.writeTimer = time.AfterFunc(math.MaxInt64, cancel) + if !nc.writeTimer.Stop() { + <-nc.writeTimer.C + } + + nc.readContext, cancel = context.WithCancel(ctx) + nc.readTimer = time.AfterFunc(math.MaxInt64, cancel) + if !nc.readTimer.Stop() { + <-nc.readTimer.C + } + + return nc +} + +type netConn struct { + c *Conn + msgType MessageType + + writeTimer *time.Timer + writeContext context.Context + + readTimer *time.Timer + readContext context.Context + + readMu sync.Mutex + eofed bool + reader io.Reader +} + +var _ net.Conn = &netConn{} + +func (c *netConn) Close() error { + return c.c.Close(StatusNormalClosure, "") +} + +func (c *netConn) Write(p []byte) (int, error) { + err := c.c.Write(c.writeContext, c.msgType, p) + if err != nil { + return 0, err + } + return len(p), nil +} + +func (c *netConn) Read(p []byte) (int, error) { + c.readMu.Lock() + defer c.readMu.Unlock() + + if c.eofed { + return 0, io.EOF + } + + if c.reader == nil { + typ, r, err := c.c.Reader(c.readContext) + if err != nil { + switch CloseStatus(err) { + case StatusNormalClosure, StatusGoingAway: + c.eofed = true + return 0, io.EOF + } + return 0, err + } + if typ != c.msgType { + err := fmt.Errorf("unexpected frame type read (expected %v): %v", c.msgType, typ) + c.c.Close(StatusUnsupportedData, err.Error()) + return 0, err + } + c.reader = r + } + + n, err := c.reader.Read(p) + if err == io.EOF { + c.reader = nil + err = nil + } + return n, err +} + +type websocketAddr struct { +} + +func (a websocketAddr) Network() string { + return "websocket" +} + +func (a websocketAddr) String() string { + return "websocket/unknown-addr" +} + +func (c *netConn) RemoteAddr() net.Addr { + return websocketAddr{} +} + +func (c *netConn) LocalAddr() net.Addr { + return websocketAddr{} +} + +func (c *netConn) SetDeadline(t time.Time) error { + c.SetWriteDeadline(t) + c.SetReadDeadline(t) + return nil +} + +func (c *netConn) SetWriteDeadline(t time.Time) error { + if t.IsZero() { + c.writeTimer.Stop() + } else { + c.writeTimer.Reset(t.Sub(time.Now())) + } + return nil +} + +func (c *netConn) SetReadDeadline(t time.Time) error { + if t.IsZero() { + c.readTimer.Stop() + } else { + c.readTimer.Reset(t.Sub(time.Now())) + } + return nil +} diff --git a/vendor/nhooyr.io/websocket/read.go b/vendor/nhooyr.io/websocket/read.go new file mode 100644 index 0000000000..ae05cf93ed --- /dev/null +++ b/vendor/nhooyr.io/websocket/read.go @@ -0,0 +1,474 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "strings" + "time" + + "nhooyr.io/websocket/internal/errd" + "nhooyr.io/websocket/internal/xsync" +) + +// Reader reads from the connection until until there is a WebSocket +// data message to be read. It will handle ping, pong and close frames as appropriate. +// +// It returns the type of the message and an io.Reader to read it. +// The passed context will also bound the reader. +// Ensure you read to EOF otherwise the connection will hang. +// +// Call CloseRead if you do not expect any data messages from the peer. +// +// Only one Reader may be open at a time. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + return c.reader(ctx) +} + +// Read is a convenience method around Reader to read a single message +// from the connection. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + typ, r, err := c.Reader(ctx) + if err != nil { + return 0, nil, err + } + + b, err := ioutil.ReadAll(r) + return typ, b, err +} + +// CloseRead starts a goroutine to read from the connection until it is closed +// or a data message is received. +// +// Once CloseRead is called you cannot read any messages from the connection. +// The returned context will be cancelled when the connection is closed. +// +// If a data message is received, the connection will be closed with StatusPolicyViolation. +// +// Call CloseRead when you do not expect to read any more messages. +// Since it actively reads from the connection, it will ensure that ping, pong and close +// frames are responded to. This means c.Ping and c.Close will still work as expected. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.Reader(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit sets the max number of bytes to read for a single message. +// It applies to the Reader and Read methods. +// +// By default, the connection has a message read limit of 32768 bytes. +// +// When the limit is hit, the connection will be closed with StatusMessageTooBig. +func (c *Conn) SetReadLimit(n int64) { + // We add read one more byte than the limit in case + // there is a fin frame that needs to be read. + c.msgReader.limitReader.limit.Store(n + 1) +} + +const defaultReadLimit = 32768 + +func newMsgReader(c *Conn) *msgReader { + mr := &msgReader{ + c: c, + fin: true, + } + mr.readFunc = mr.read + + mr.limitReader = newLimitReader(c, mr.readFunc, defaultReadLimit+1) + return mr +} + +func (mr *msgReader) resetFlate() { + if mr.flateContextTakeover() { + mr.dict.init(32768) + } + if mr.flateBufio == nil { + mr.flateBufio = getBufioReader(mr.readFunc) + } + + mr.flateReader = getFlateReader(mr.flateBufio, mr.dict.buf) + mr.limitReader.r = mr.flateReader + mr.flateTail.Reset(deflateMessageTail) +} + +func (mr *msgReader) putFlateReader() { + if mr.flateReader != nil { + putFlateReader(mr.flateReader) + mr.flateReader = nil + } +} + +func (mr *msgReader) close() { + mr.c.readMu.forceLock() + mr.putFlateReader() + mr.dict.close() + if mr.flateBufio != nil { + putBufioReader(mr.flateBufio) + } + + if mr.c.client { + putBufioReader(mr.c.br) + mr.c.br = nil + } +} + +func (mr *msgReader) flateContextTakeover() bool { + if mr.c.client { + return !mr.c.copts.serverNoContextTakeover + } + return !mr.c.copts.clientNoContextTakeover +} + +func (c *Conn) readRSV1Illegal(h header) bool { + // If compression is disabled, rsv1 is illegal. + if !c.flate() { + return true + } + // rsv1 is only allowed on data frames beginning messages. + if h.opcode != opText && h.opcode != opBinary { + return true + } + return false +} + +func (c *Conn) readLoop(ctx context.Context) (header, error) { + for { + h, err := c.readFrameHeader(ctx) + if err != nil { + return header{}, err + } + + if h.rsv1 && c.readRSV1Illegal(h) || h.rsv2 || h.rsv3 { + err := fmt.Errorf("received header with unexpected rsv bits set: %v:%v:%v", h.rsv1, h.rsv2, h.rsv3) + c.writeError(StatusProtocolError, err) + return header{}, err + } + + if !c.client && !h.masked { + return header{}, errors.New("received unmasked frame from client") + } + + switch h.opcode { + case opClose, opPing, opPong: + err = c.handleControl(ctx, h) + if err != nil { + // Pass through CloseErrors when receiving a close frame. + if h.opcode == opClose && CloseStatus(err) != -1 { + return header{}, err + } + return header{}, fmt.Errorf("failed to handle control frame %v: %w", h.opcode, err) + } + case opContinuation, opText, opBinary: + return h, nil + default: + err := fmt.Errorf("received unknown opcode %v", h.opcode) + c.writeError(StatusProtocolError, err) + return header{}, err + } + } +} + +func (c *Conn) readFrameHeader(ctx context.Context) (header, error) { + select { + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- ctx: + } + + h, err := readFrameHeader(c.br, c.readHeaderBuf[:]) + if err != nil { + select { + case <-c.closed: + return header{}, c.closeErr + case <-ctx.Done(): + return header{}, ctx.Err() + default: + c.close(err) + return header{}, err + } + } + + select { + case <-c.closed: + return header{}, c.closeErr + case c.readTimeout <- context.Background(): + } + + return h, nil +} + +func (c *Conn) readFramePayload(ctx context.Context, p []byte) (int, error) { + select { + case <-c.closed: + return 0, c.closeErr + case c.readTimeout <- ctx: + } + + n, err := io.ReadFull(c.br, p) + if err != nil { + select { + case <-c.closed: + return n, c.closeErr + case <-ctx.Done(): + return n, ctx.Err() + default: + err = fmt.Errorf("failed to read frame payload: %w", err) + c.close(err) + return n, err + } + } + + select { + case <-c.closed: + return n, c.closeErr + case c.readTimeout <- context.Background(): + } + + return n, err +} + +func (c *Conn) handleControl(ctx context.Context, h header) (err error) { + if h.payloadLength < 0 || h.payloadLength > maxControlPayload { + err := fmt.Errorf("received control frame payload with invalid length: %d", h.payloadLength) + c.writeError(StatusProtocolError, err) + return err + } + + if !h.fin { + err := errors.New("received fragmented control frame") + c.writeError(StatusProtocolError, err) + return err + } + + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + b := c.readControlBuf[:h.payloadLength] + _, err = c.readFramePayload(ctx, b) + if err != nil { + return err + } + + if h.masked { + mask(h.maskKey, b) + } + + switch h.opcode { + case opPing: + return c.writeControl(ctx, opPong, b) + case opPong: + c.activePingsMu.Lock() + pong, ok := c.activePings[string(b)] + c.activePingsMu.Unlock() + if ok { + select { + case pong <- struct{}{}: + default: + } + } + return nil + } + + defer func() { + c.readCloseFrameErr = err + }() + + ce, err := parseClosePayload(b) + if err != nil { + err = fmt.Errorf("received invalid close payload: %w", err) + c.writeError(StatusProtocolError, err) + return err + } + + err = fmt.Errorf("received close frame: %w", ce) + c.setCloseErr(err) + c.writeClose(ce.Code, ce.Reason) + c.close(err) + return err +} + +func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err error) { + defer errd.Wrap(&err, "failed to get reader") + + err = c.readMu.lock(ctx) + if err != nil { + return 0, nil, err + } + defer c.readMu.unlock() + + if !c.msgReader.fin { + err = errors.New("previous message not read to completion") + c.close(fmt.Errorf("failed to get reader: %w", err)) + return 0, nil, err + } + + h, err := c.readLoop(ctx) + if err != nil { + return 0, nil, err + } + + if h.opcode == opContinuation { + err := errors.New("received continuation frame without text or binary frame") + c.writeError(StatusProtocolError, err) + return 0, nil, err + } + + c.msgReader.reset(ctx, h) + + return MessageType(h.opcode), c.msgReader, nil +} + +type msgReader struct { + c *Conn + + ctx context.Context + flate bool + flateReader io.Reader + flateBufio *bufio.Reader + flateTail strings.Reader + limitReader *limitReader + dict slidingWindow + + fin bool + payloadLength int64 + maskKey uint32 + + // readerFunc(mr.Read) to avoid continuous allocations. + readFunc readerFunc +} + +func (mr *msgReader) reset(ctx context.Context, h header) { + mr.ctx = ctx + mr.flate = h.rsv1 + mr.limitReader.reset(mr.readFunc) + + if mr.flate { + mr.resetFlate() + } + + mr.setFrame(h) +} + +func (mr *msgReader) setFrame(h header) { + mr.fin = h.fin + mr.payloadLength = h.payloadLength + mr.maskKey = h.maskKey +} + +func (mr *msgReader) Read(p []byte) (n int, err error) { + err = mr.c.readMu.lock(mr.ctx) + if err != nil { + return 0, fmt.Errorf("failed to read: %w", err) + } + defer mr.c.readMu.unlock() + + n, err = mr.limitReader.Read(p) + if mr.flate && mr.flateContextTakeover() { + p = p[:n] + mr.dict.write(p) + } + if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate { + mr.putFlateReader() + return n, io.EOF + } + if err != nil { + err = fmt.Errorf("failed to read: %w", err) + mr.c.close(err) + } + return n, err +} + +func (mr *msgReader) read(p []byte) (int, error) { + for { + if mr.payloadLength == 0 { + if mr.fin { + if mr.flate { + return mr.flateTail.Read(p) + } + return 0, io.EOF + } + + h, err := mr.c.readLoop(mr.ctx) + if err != nil { + return 0, err + } + if h.opcode != opContinuation { + err := errors.New("received new data message without finishing the previous message") + mr.c.writeError(StatusProtocolError, err) + return 0, err + } + mr.setFrame(h) + + continue + } + + if int64(len(p)) > mr.payloadLength { + p = p[:mr.payloadLength] + } + + n, err := mr.c.readFramePayload(mr.ctx, p) + if err != nil { + return n, err + } + + mr.payloadLength -= int64(n) + + if !mr.c.client { + mr.maskKey = mask(mr.maskKey, p) + } + + return n, nil + } +} + +type limitReader struct { + c *Conn + r io.Reader + limit xsync.Int64 + n int64 +} + +func newLimitReader(c *Conn, r io.Reader, limit int64) *limitReader { + lr := &limitReader{ + c: c, + } + lr.limit.Store(limit) + lr.reset(r) + return lr +} + +func (lr *limitReader) reset(r io.Reader) { + lr.n = lr.limit.Load() + lr.r = r +} + +func (lr *limitReader) Read(p []byte) (int, error) { + if lr.n <= 0 { + err := fmt.Errorf("read limited at %v bytes", lr.limit.Load()) + lr.c.writeError(StatusMessageTooBig, err) + return 0, err + } + + if int64(len(p)) > lr.n { + p = p[:lr.n] + } + n, err := lr.r.Read(p) + lr.n -= int64(n) + return n, err +} + +type readerFunc func(p []byte) (int, error) + +func (f readerFunc) Read(p []byte) (int, error) { + return f(p) +} diff --git a/vendor/nhooyr.io/websocket/stringer.go b/vendor/nhooyr.io/websocket/stringer.go new file mode 100644 index 0000000000..5a66ba2907 --- /dev/null +++ b/vendor/nhooyr.io/websocket/stringer.go @@ -0,0 +1,91 @@ +// Code generated by "stringer -type=opcode,MessageType,StatusCode -output=stringer.go"; DO NOT EDIT. + +package websocket + +import "strconv" + +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[opContinuation-0] + _ = x[opText-1] + _ = x[opBinary-2] + _ = x[opClose-8] + _ = x[opPing-9] + _ = x[opPong-10] +} + +const ( + _opcode_name_0 = "opContinuationopTextopBinary" + _opcode_name_1 = "opCloseopPingopPong" +) + +var ( + _opcode_index_0 = [...]uint8{0, 14, 20, 28} + _opcode_index_1 = [...]uint8{0, 7, 13, 19} +) + +func (i opcode) String() string { + switch { + case 0 <= i && i <= 2: + return _opcode_name_0[_opcode_index_0[i]:_opcode_index_0[i+1]] + case 8 <= i && i <= 10: + i -= 8 + return _opcode_name_1[_opcode_index_1[i]:_opcode_index_1[i+1]] + default: + return "opcode(" + strconv.FormatInt(int64(i), 10) + ")" + } +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[MessageText-1] + _ = x[MessageBinary-2] +} + +const _MessageType_name = "MessageTextMessageBinary" + +var _MessageType_index = [...]uint8{0, 11, 24} + +func (i MessageType) String() string { + i -= 1 + if i < 0 || i >= MessageType(len(_MessageType_index)-1) { + return "MessageType(" + strconv.FormatInt(int64(i+1), 10) + ")" + } + return _MessageType_name[_MessageType_index[i]:_MessageType_index[i+1]] +} +func _() { + // An "invalid array index" compiler error signifies that the constant values have changed. + // Re-run the stringer command to generate them again. + var x [1]struct{} + _ = x[StatusNormalClosure-1000] + _ = x[StatusGoingAway-1001] + _ = x[StatusProtocolError-1002] + _ = x[StatusUnsupportedData-1003] + _ = x[statusReserved-1004] + _ = x[StatusNoStatusRcvd-1005] + _ = x[StatusAbnormalClosure-1006] + _ = x[StatusInvalidFramePayloadData-1007] + _ = x[StatusPolicyViolation-1008] + _ = x[StatusMessageTooBig-1009] + _ = x[StatusMandatoryExtension-1010] + _ = x[StatusInternalError-1011] + _ = x[StatusServiceRestart-1012] + _ = x[StatusTryAgainLater-1013] + _ = x[StatusBadGateway-1014] + _ = x[StatusTLSHandshake-1015] +} + +const _StatusCode_name = "StatusNormalClosureStatusGoingAwayStatusProtocolErrorStatusUnsupportedDatastatusReservedStatusNoStatusRcvdStatusAbnormalClosureStatusInvalidFramePayloadDataStatusPolicyViolationStatusMessageTooBigStatusMandatoryExtensionStatusInternalErrorStatusServiceRestartStatusTryAgainLaterStatusBadGatewayStatusTLSHandshake" + +var _StatusCode_index = [...]uint16{0, 19, 34, 53, 74, 88, 106, 127, 156, 177, 196, 220, 239, 259, 278, 294, 312} + +func (i StatusCode) String() string { + i -= 1000 + if i < 0 || i >= StatusCode(len(_StatusCode_index)-1) { + return "StatusCode(" + strconv.FormatInt(int64(i+1000), 10) + ")" + } + return _StatusCode_name[_StatusCode_index[i]:_StatusCode_index[i+1]] +} diff --git a/vendor/nhooyr.io/websocket/write.go b/vendor/nhooyr.io/websocket/write.go new file mode 100644 index 0000000000..2210cf817a --- /dev/null +++ b/vendor/nhooyr.io/websocket/write.go @@ -0,0 +1,397 @@ +// +build !js + +package websocket + +import ( + "bufio" + "context" + "crypto/rand" + "encoding/binary" + "errors" + "fmt" + "io" + "time" + + "github.com/klauspost/compress/flate" + + "nhooyr.io/websocket/internal/errd" +) + +// Writer returns a writer bounded by the context that will write +// a WebSocket message of type dataType to the connection. +// +// You must close the writer once you have written the entire message. +// +// Only one writer can be open at a time, multiple calls will block until the previous writer +// is closed. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + w, err := c.writer(ctx, typ) + if err != nil { + return nil, fmt.Errorf("failed to get writer: %w", err) + } + return w, nil +} + +// Write writes a message to the connection. +// +// See the Writer method if you want to stream a message. +// +// If compression is disabled or the threshold is not met, then it +// will write the message in a single frame. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + _, err := c.write(ctx, typ, p) + if err != nil { + return fmt.Errorf("failed to write msg: %w", err) + } + return nil +} + +type msgWriter struct { + mw *msgWriterState + closed bool +} + +func (mw *msgWriter) Write(p []byte) (int, error) { + if mw.closed { + return 0, errors.New("cannot use closed writer") + } + return mw.mw.Write(p) +} + +func (mw *msgWriter) Close() error { + if mw.closed { + return errors.New("cannot use closed writer") + } + mw.closed = true + return mw.mw.Close() +} + +type msgWriterState struct { + c *Conn + + mu *mu + writeMu *mu + + ctx context.Context + opcode opcode + flate bool + + trimWriter *trimLastFourBytesWriter + dict slidingWindow +} + +func newMsgWriterState(c *Conn) *msgWriterState { + mw := &msgWriterState{ + c: c, + mu: newMu(c), + writeMu: newMu(c), + } + return mw +} + +func (mw *msgWriterState) ensureFlate() { + if mw.trimWriter == nil { + mw.trimWriter = &trimLastFourBytesWriter{ + w: writerFunc(mw.write), + } + } + + mw.dict.init(8192) + mw.flate = true +} + +func (mw *msgWriterState) flateContextTakeover() bool { + if mw.c.client { + return !mw.c.copts.clientNoContextTakeover + } + return !mw.c.copts.serverNoContextTakeover +} + +func (c *Conn) writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + err := c.msgWriterState.reset(ctx, typ) + if err != nil { + return nil, err + } + return &msgWriter{ + mw: c.msgWriterState, + closed: false, + }, nil +} + +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) (int, error) { + mw, err := c.writer(ctx, typ) + if err != nil { + return 0, err + } + + if !c.flate() { + defer c.msgWriterState.mu.unlock() + return c.writeFrame(ctx, true, false, c.msgWriterState.opcode, p) + } + + n, err := mw.Write(p) + if err != nil { + return n, err + } + + err = mw.Close() + return n, err +} + +func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error { + err := mw.mu.lock(ctx) + if err != nil { + return err + } + + mw.ctx = ctx + mw.opcode = opcode(typ) + mw.flate = false + + mw.trimWriter.reset() + + return nil +} + +// Write writes the given bytes to the WebSocket connection. +func (mw *msgWriterState) Write(p []byte) (_ int, err error) { + err = mw.writeMu.lock(mw.ctx) + if err != nil { + return 0, fmt.Errorf("failed to write: %w", err) + } + defer mw.writeMu.unlock() + + defer func() { + if err != nil { + err = fmt.Errorf("failed to write: %w", err) + mw.c.close(err) + } + }() + + if mw.c.flate() { + // Only enables flate if the length crosses the + // threshold on the first frame + if mw.opcode != opContinuation && len(p) >= mw.c.flateThreshold { + mw.ensureFlate() + } + } + + if mw.flate { + err = flate.StatelessDeflate(mw.trimWriter, p, false, mw.dict.buf) + if err != nil { + return 0, err + } + mw.dict.write(p) + return len(p), nil + } + + return mw.write(p) +} + +func (mw *msgWriterState) write(p []byte) (int, error) { + n, err := mw.c.writeFrame(mw.ctx, false, mw.flate, mw.opcode, p) + if err != nil { + return n, fmt.Errorf("failed to write data frame: %w", err) + } + mw.opcode = opContinuation + return n, nil +} + +// Close flushes the frame to the connection. +func (mw *msgWriterState) Close() (err error) { + defer errd.Wrap(&err, "failed to close writer") + + err = mw.writeMu.lock(mw.ctx) + if err != nil { + return err + } + defer mw.writeMu.unlock() + + _, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil) + if err != nil { + return fmt.Errorf("failed to write fin frame: %w", err) + } + + if mw.flate && !mw.flateContextTakeover() { + mw.dict.close() + } + mw.mu.unlock() + return nil +} + +func (mw *msgWriterState) close() { + if mw.c.client { + mw.c.writeFrameMu.forceLock() + putBufioWriter(mw.c.bw) + } + + mw.writeMu.forceLock() + mw.dict.close() +} + +func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error { + ctx, cancel := context.WithTimeout(ctx, time.Second*5) + defer cancel() + + _, err := c.writeFrame(ctx, true, false, opcode, p) + if err != nil { + return fmt.Errorf("failed to write control frame %v: %w", opcode, err) + } + return nil +} + +// frame handles all writes to the connection. +func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) { + err = c.writeFrameMu.lock(ctx) + if err != nil { + return 0, err + } + defer c.writeFrameMu.unlock() + + // If the state says a close has already been written, we wait until + // the connection is closed and return that error. + // + // However, if the frame being written is a close, that means its the close from + // the state being set so we let it go through. + c.closeMu.Lock() + wroteClose := c.wroteClose + c.closeMu.Unlock() + if wroteClose && opcode != opClose { + select { + case <-ctx.Done(): + return 0, ctx.Err() + case <-c.closed: + return 0, c.closeErr + } + } + + select { + case <-c.closed: + return 0, c.closeErr + case c.writeTimeout <- ctx: + } + + defer func() { + if err != nil { + select { + case <-c.closed: + err = c.closeErr + case <-ctx.Done(): + err = ctx.Err() + } + c.close(err) + err = fmt.Errorf("failed to write frame: %w", err) + } + }() + + c.writeHeader.fin = fin + c.writeHeader.opcode = opcode + c.writeHeader.payloadLength = int64(len(p)) + + if c.client { + c.writeHeader.masked = true + _, err = io.ReadFull(rand.Reader, c.writeHeaderBuf[:4]) + if err != nil { + return 0, fmt.Errorf("failed to generate masking key: %w", err) + } + c.writeHeader.maskKey = binary.LittleEndian.Uint32(c.writeHeaderBuf[:]) + } + + c.writeHeader.rsv1 = false + if flate && (opcode == opText || opcode == opBinary) { + c.writeHeader.rsv1 = true + } + + err = writeFrameHeader(c.writeHeader, c.bw, c.writeHeaderBuf[:]) + if err != nil { + return 0, err + } + + n, err := c.writeFramePayload(p) + if err != nil { + return n, err + } + + if c.writeHeader.fin { + err = c.bw.Flush() + if err != nil { + return n, fmt.Errorf("failed to flush: %w", err) + } + } + + select { + case <-c.closed: + return n, c.closeErr + case c.writeTimeout <- context.Background(): + } + + return n, nil +} + +func (c *Conn) writeFramePayload(p []byte) (n int, err error) { + defer errd.Wrap(&err, "failed to write frame payload") + + if !c.writeHeader.masked { + return c.bw.Write(p) + } + + maskKey := c.writeHeader.maskKey + for len(p) > 0 { + // If the buffer is full, we need to flush. + if c.bw.Available() == 0 { + err = c.bw.Flush() + if err != nil { + return n, err + } + } + + // Start of next write in the buffer. + i := c.bw.Buffered() + + j := len(p) + if j > c.bw.Available() { + j = c.bw.Available() + } + + _, err := c.bw.Write(p[:j]) + if err != nil { + return n, err + } + + maskKey = mask(maskKey, c.writeBuf[i:c.bw.Buffered()]) + + p = p[j:] + n += j + } + + return n, nil +} + +type writerFunc func(p []byte) (int, error) + +func (f writerFunc) Write(p []byte) (int, error) { + return f(p) +} + +// extractBufioWriterBuf grabs the []byte backing a *bufio.Writer +// and returns it. +func extractBufioWriterBuf(bw *bufio.Writer, w io.Writer) []byte { + var writeBuf []byte + bw.Reset(writerFunc(func(p2 []byte) (int, error) { + writeBuf = p2[:cap(p2)] + return len(p2), nil + })) + + bw.WriteByte(0) + bw.Flush() + + bw.Reset(w) + + return writeBuf +} + +func (c *Conn) writeError(code StatusCode, err error) { + c.setCloseErr(err) + c.writeClose(code, err.Error()) + c.close(nil) +} diff --git a/vendor/nhooyr.io/websocket/ws_js.go b/vendor/nhooyr.io/websocket/ws_js.go new file mode 100644 index 0000000000..b87e32cdaf --- /dev/null +++ b/vendor/nhooyr.io/websocket/ws_js.go @@ -0,0 +1,379 @@ +package websocket // import "nhooyr.io/websocket" + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net/http" + "reflect" + "runtime" + "strings" + "sync" + "syscall/js" + + "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/wsjs" + "nhooyr.io/websocket/internal/xsync" +) + +// Conn provides a wrapper around the browser WebSocket API. +type Conn struct { + ws wsjs.WebSocket + + // read limit for a message in bytes. + msgReadLimit xsync.Int64 + + closingMu sync.Mutex + isReadClosed xsync.Int64 + closeOnce sync.Once + closed chan struct{} + closeErrOnce sync.Once + closeErr error + closeWasClean bool + + releaseOnClose func() + releaseOnMessage func() + + readSignal chan struct{} + readBufMu sync.Mutex + readBuf []wsjs.MessageEvent +} + +func (c *Conn) close(err error, wasClean bool) { + c.closeOnce.Do(func() { + runtime.SetFinalizer(c, nil) + + if !wasClean { + err = fmt.Errorf("unclean connection close: %w", err) + } + c.setCloseErr(err) + c.closeWasClean = wasClean + close(c.closed) + }) +} + +func (c *Conn) init() { + c.closed = make(chan struct{}) + c.readSignal = make(chan struct{}, 1) + + c.msgReadLimit.Store(32768) + + c.releaseOnClose = c.ws.OnClose(func(e wsjs.CloseEvent) { + err := CloseError{ + Code: StatusCode(e.Code), + Reason: e.Reason, + } + // We do not know if we sent or received this close as + // its possible the browser triggered it without us + // explicitly sending it. + c.close(err, e.WasClean) + + c.releaseOnClose() + c.releaseOnMessage() + }) + + c.releaseOnMessage = c.ws.OnMessage(func(e wsjs.MessageEvent) { + c.readBufMu.Lock() + defer c.readBufMu.Unlock() + + c.readBuf = append(c.readBuf, e) + + // Lets the read goroutine know there is definitely something in readBuf. + select { + case c.readSignal <- struct{}{}: + default: + } + }) + + runtime.SetFinalizer(c, func(c *Conn) { + c.setCloseErr(errors.New("connection garbage collected")) + c.closeWithInternal() + }) +} + +func (c *Conn) closeWithInternal() { + c.Close(StatusInternalError, "something went wrong") +} + +// Read attempts to read a message from the connection. +// The maximum time spent waiting is bounded by the context. +func (c *Conn) Read(ctx context.Context) (MessageType, []byte, error) { + if c.isReadClosed.Load() == 1 { + return 0, nil, errors.New("WebSocket connection read closed") + } + + typ, p, err := c.read(ctx) + if err != nil { + return 0, nil, fmt.Errorf("failed to read: %w", err) + } + if int64(len(p)) > c.msgReadLimit.Load() { + err := fmt.Errorf("read limited at %v bytes", c.msgReadLimit.Load()) + c.Close(StatusMessageTooBig, err.Error()) + return 0, nil, err + } + return typ, p, nil +} + +func (c *Conn) read(ctx context.Context) (MessageType, []byte, error) { + select { + case <-ctx.Done(): + c.Close(StatusPolicyViolation, "read timed out") + return 0, nil, ctx.Err() + case <-c.readSignal: + case <-c.closed: + return 0, nil, c.closeErr + } + + c.readBufMu.Lock() + defer c.readBufMu.Unlock() + + me := c.readBuf[0] + // We copy the messages forward and decrease the size + // of the slice to avoid reallocating. + copy(c.readBuf, c.readBuf[1:]) + c.readBuf = c.readBuf[:len(c.readBuf)-1] + + if len(c.readBuf) > 0 { + // Next time we read, we'll grab the message. + select { + case c.readSignal <- struct{}{}: + default: + } + } + + switch p := me.Data.(type) { + case string: + return MessageText, []byte(p), nil + case []byte: + return MessageBinary, p, nil + default: + panic("websocket: unexpected data type from wsjs OnMessage: " + reflect.TypeOf(me.Data).String()) + } +} + +// Ping is mocked out for Wasm. +func (c *Conn) Ping(ctx context.Context) error { + return nil +} + +// Write writes a message of the given type to the connection. +// Always non blocking. +func (c *Conn) Write(ctx context.Context, typ MessageType, p []byte) error { + err := c.write(ctx, typ, p) + if err != nil { + // Have to ensure the WebSocket is closed after a write error + // to match the Go API. It can only error if the message type + // is unexpected or the passed bytes contain invalid UTF-8 for + // MessageText. + err := fmt.Errorf("failed to write: %w", err) + c.setCloseErr(err) + c.closeWithInternal() + return err + } + return nil +} + +func (c *Conn) write(ctx context.Context, typ MessageType, p []byte) error { + if c.isClosed() { + return c.closeErr + } + switch typ { + case MessageBinary: + return c.ws.SendBytes(p) + case MessageText: + return c.ws.SendText(string(p)) + default: + return fmt.Errorf("unexpected message type: %v", typ) + } +} + +// Close closes the WebSocket with the given code and reason. +// It will wait until the peer responds with a close frame +// or the connection is closed. +// It thus performs the full WebSocket close handshake. +func (c *Conn) Close(code StatusCode, reason string) error { + err := c.exportedClose(code, reason) + if err != nil { + return fmt.Errorf("failed to close WebSocket: %w", err) + } + return nil +} + +func (c *Conn) exportedClose(code StatusCode, reason string) error { + c.closingMu.Lock() + defer c.closingMu.Unlock() + + ce := fmt.Errorf("sent close: %w", CloseError{ + Code: code, + Reason: reason, + }) + + if c.isClosed() { + return fmt.Errorf("tried to close with %q but connection already closed: %w", ce, c.closeErr) + } + + c.setCloseErr(ce) + err := c.ws.Close(int(code), reason) + if err != nil { + return err + } + + <-c.closed + if !c.closeWasClean { + return c.closeErr + } + return nil +} + +// Subprotocol returns the negotiated subprotocol. +// An empty string means the default protocol. +func (c *Conn) Subprotocol() string { + return c.ws.Subprotocol() +} + +// DialOptions represents the options available to pass to Dial. +type DialOptions struct { + // Subprotocols lists the subprotocols to negotiate with the server. + Subprotocols []string +} + +// Dial creates a new WebSocket connection to the given url with the given options. +// The passed context bounds the maximum time spent waiting for the connection to open. +// The returned *http.Response is always nil or a mock. It's only in the signature +// to match the core API. +func Dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + c, resp, err := dial(ctx, url, opts) + if err != nil { + return nil, nil, fmt.Errorf("failed to WebSocket dial %q: %w", url, err) + } + return c, resp, nil +} + +func dial(ctx context.Context, url string, opts *DialOptions) (*Conn, *http.Response, error) { + if opts == nil { + opts = &DialOptions{} + } + + url = strings.Replace(url, "http://", "ws://", 1) + url = strings.Replace(url, "https://", "wss://", 1) + + ws, err := wsjs.New(url, opts.Subprotocols) + if err != nil { + return nil, nil, err + } + + c := &Conn{ + ws: ws, + } + c.init() + + opench := make(chan struct{}) + releaseOpen := ws.OnOpen(func(e js.Value) { + close(opench) + }) + defer releaseOpen() + + select { + case <-ctx.Done(): + c.Close(StatusPolicyViolation, "dial timed out") + return nil, nil, ctx.Err() + case <-opench: + return c, &http.Response{ + StatusCode: http.StatusSwitchingProtocols, + }, nil + case <-c.closed: + return nil, nil, c.closeErr + } +} + +// Reader attempts to read a message from the connection. +// The maximum time spent waiting is bounded by the context. +func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) { + typ, p, err := c.Read(ctx) + if err != nil { + return 0, nil, err + } + return typ, bytes.NewReader(p), nil +} + +// Writer returns a writer to write a WebSocket data message to the connection. +// It buffers the entire message in memory and then sends it when the writer +// is closed. +func (c *Conn) Writer(ctx context.Context, typ MessageType) (io.WriteCloser, error) { + return writer{ + c: c, + ctx: ctx, + typ: typ, + b: bpool.Get(), + }, nil +} + +type writer struct { + closed bool + + c *Conn + ctx context.Context + typ MessageType + + b *bytes.Buffer +} + +func (w writer) Write(p []byte) (int, error) { + if w.closed { + return 0, errors.New("cannot write to closed writer") + } + n, err := w.b.Write(p) + if err != nil { + return n, fmt.Errorf("failed to write message: %w", err) + } + return n, nil +} + +func (w writer) Close() error { + if w.closed { + return errors.New("cannot close closed writer") + } + w.closed = true + defer bpool.Put(w.b) + + err := w.c.Write(w.ctx, w.typ, w.b.Bytes()) + if err != nil { + return fmt.Errorf("failed to close writer: %w", err) + } + return nil +} + +// CloseRead implements *Conn.CloseRead for wasm. +func (c *Conn) CloseRead(ctx context.Context) context.Context { + c.isReadClosed.Store(1) + + ctx, cancel := context.WithCancel(ctx) + go func() { + defer cancel() + c.read(ctx) + c.Close(StatusPolicyViolation, "unexpected data message") + }() + return ctx +} + +// SetReadLimit implements *Conn.SetReadLimit for wasm. +func (c *Conn) SetReadLimit(n int64) { + c.msgReadLimit.Store(n) +} + +func (c *Conn) setCloseErr(err error) { + c.closeErrOnce.Do(func() { + c.closeErr = fmt.Errorf("WebSocket closed: %w", err) + }) +} + +func (c *Conn) isClosed() bool { + select { + case <-c.closed: + return true + default: + return false + } +} diff --git a/vendor/nhooyr.io/websocket/wsjson/wsjson.go b/vendor/nhooyr.io/websocket/wsjson/wsjson.go new file mode 100644 index 0000000000..2000a77af8 --- /dev/null +++ b/vendor/nhooyr.io/websocket/wsjson/wsjson.go @@ -0,0 +1,67 @@ +// Package wsjson provides helpers for reading and writing JSON messages. +package wsjson // import "nhooyr.io/websocket/wsjson" + +import ( + "context" + "encoding/json" + "fmt" + + "nhooyr.io/websocket" + "nhooyr.io/websocket/internal/bpool" + "nhooyr.io/websocket/internal/errd" +) + +// Read reads a JSON message from c into v. +// It will reuse buffers in between calls to avoid allocations. +func Read(ctx context.Context, c *websocket.Conn, v interface{}) error { + return read(ctx, c, v) +} + +func read(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { + defer errd.Wrap(&err, "failed to read JSON message") + + _, r, err := c.Reader(ctx) + if err != nil { + return err + } + + b := bpool.Get() + defer bpool.Put(b) + + _, err = b.ReadFrom(r) + if err != nil { + return err + } + + err = json.Unmarshal(b.Bytes(), v) + if err != nil { + c.Close(websocket.StatusInvalidFramePayloadData, "failed to unmarshal JSON") + return fmt.Errorf("failed to unmarshal JSON: %w", err) + } + + return nil +} + +// Write writes the JSON message v to c. +// It will reuse buffers in between calls to avoid allocations. +func Write(ctx context.Context, c *websocket.Conn, v interface{}) error { + return write(ctx, c, v) +} + +func write(ctx context.Context, c *websocket.Conn, v interface{}) (err error) { + defer errd.Wrap(&err, "failed to write JSON message") + + w, err := c.Writer(ctx, websocket.MessageText) + if err != nil { + return err + } + + // json.Marshal cannot reuse buffers between calls as it has to return + // a copy of the byte slice but Encoder does as it directly writes to w. + err = json.NewEncoder(w).Encode(v) + if err != nil { + return fmt.Errorf("failed to marshal JSON: %w", err) + } + + return w.Close() +}