]> Cypherpunks.ru repositories - gostls13.git/commitdiff
archive/zip: add File.OpenRaw, Writer.CreateRaw, Writer.Copy
authorEddie Scholtz <escholtz@google.com>
Wed, 21 Apr 2021 17:11:02 +0000 (11:11 -0600)
committerIan Lance Taylor <iant@golang.org>
Mon, 3 May 2021 21:11:47 +0000 (21:11 +0000)
These new methods provide support for cases where performance is a
primary concern. For example, copying files from an existing zip to a
new zip without incurring the decompression and compression overhead.
Using an optimized, external compression method and writing the output
to a zip archive. And compressing file contents in parallel and then
sequentially writing the compressed bytes to a zip archive.

TestWriterCopy is copied verbatim from https://github.com/rsc/zipmerge

Fixes #34974

Change-Id: Iade5bc245ba34cdbb86364bf59f79f38bb9e2eb6
Reviewed-on: https://go-review.googlesource.com/c/go/+/312310
Run-TryBot: Ian Lance Taylor <iant@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Reviewed-by: Ian Lance Taylor <iant@golang.org>
Trust: Carlos Amedee <carlos@golang.org>

src/archive/zip/reader.go
src/archive/zip/reader_test.go
src/archive/zip/struct.go
src/archive/zip/writer.go
src/archive/zip/writer_test.go

index 18f9833db0a08e3ad0040e8d69847e0060792e60..808cf274caddb3883f77863288976c383b710553 100644 (file)
@@ -52,12 +52,9 @@ type File struct {
        FileHeader
        zip          *Reader
        zipr         io.ReaderAt
-       zipsize      int64
        headerOffset int64
-}
-
-func (f *File) hasDataDescriptor() bool {
-       return f.Flags&0x8 != 0
+       zip64        bool  // zip64 extended information extra field presence
+       descErr      error // error reading the data descriptor during init
 }
 
 // OpenReader will open the Zip file specified by name and return a ReadCloser.
@@ -112,7 +109,7 @@ func (z *Reader) init(r io.ReaderAt, size int64) error {
        // a bad one, and then only report an ErrFormat or UnexpectedEOF if
        // the file count modulo 65536 is incorrect.
        for {
-               f := &File{zip: z, zipr: r, zipsize: size}
+               f := &File{zip: z, zipr: r}
                err = readDirectoryHeader(f, buf)
                if err == ErrFormat || err == io.ErrUnexpectedEOF {
                        break
@@ -120,6 +117,7 @@ func (z *Reader) init(r io.ReaderAt, size int64) error {
                if err != nil {
                        return err
                }
+               f.readDataDescriptor()
                z.File = append(z.File, f)
        }
        if uint16(len(z.File)) != uint16(end.directoryRecords) { // only compare 16 bits here
@@ -180,26 +178,68 @@ func (f *File) Open() (io.ReadCloser, error) {
                return nil, ErrAlgorithm
        }
        var rc io.ReadCloser = dcomp(r)
-       var desr io.Reader
-       if f.hasDataDescriptor() {
-               desr = io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, dataDescriptorLen)
-       }
        rc = &checksumReader{
                rc:   rc,
                hash: crc32.NewIEEE(),
                f:    f,
-               desr: desr,
        }
        return rc, nil
 }
 
+// OpenRaw returns a Reader that provides access to the File's contents without
+// decompression.
+func (f *File) OpenRaw() (io.Reader, error) {
+       bodyOffset, err := f.findBodyOffset()
+       if err != nil {
+               return nil, err
+       }
+       r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset, int64(f.CompressedSize64))
+       return r, nil
+}
+
+func (f *File) readDataDescriptor() {
+       if !f.hasDataDescriptor() {
+               return
+       }
+
+       bodyOffset, err := f.findBodyOffset()
+       if err != nil {
+               f.descErr = err
+               return
+       }
+
+       // In section 4.3.9.2 of the spec: "However ZIP64 format MAY be used
+       // regardless of the size of a file.  When extracting, if the zip64
+       // extended information extra field is present for the file the
+       // compressed and uncompressed sizes will be 8 byte values."
+       //
+       // Historically, this package has used the compressed and uncompressed
+       // sizes from the central directory to determine if the package is
+       // zip64.
+       //
+       // For this case we allow either the extra field or sizes to determine
+       // the data descriptor length.
+       zip64 := f.zip64 || f.isZip64()
+       n := int64(dataDescriptorLen)
+       if zip64 {
+               n = dataDescriptor64Len
+       }
+       size := int64(f.CompressedSize64)
+       r := io.NewSectionReader(f.zipr, f.headerOffset+bodyOffset+size, n)
+       dd, err := readDataDescriptor(r, zip64)
+       if err != nil {
+               f.descErr = err
+               return
+       }
+       f.CRC32 = dd.crc32
+}
+
 type checksumReader struct {
        rc    io.ReadCloser
        hash  hash.Hash32
        nread uint64 // number of bytes read so far
        f     *File
-       desr  io.Reader // if non-nil, where to read the data descriptor
-       err   error     // sticky error
+       err   error // sticky error
 }
 
 func (r *checksumReader) Stat() (fs.FileInfo, error) {
@@ -220,12 +260,12 @@ func (r *checksumReader) Read(b []byte) (n int, err error) {
                if r.nread != r.f.UncompressedSize64 {
                        return 0, io.ErrUnexpectedEOF
                }
-               if r.desr != nil {
-                       if err1 := readDataDescriptor(r.desr, r.f); err1 != nil {
-                               if err1 == io.EOF {
+               if r.f.hasDataDescriptor() {
+                       if r.f.descErr != nil {
+                               if r.f.descErr == io.EOF {
                                        err = io.ErrUnexpectedEOF
                                } else {
-                                       err = err1
+                                       err = r.f.descErr
                                }
                        } else if r.hash.Sum32() != r.f.CRC32 {
                                err = ErrChecksum
@@ -336,6 +376,8 @@ parseExtras:
 
                switch fieldTag {
                case zip64ExtraID:
+                       f.zip64 = true
+
                        // update directory values from the zip64 extra block.
                        // They should only be consulted if the sizes read earlier
                        // are maxed out.
@@ -435,8 +477,9 @@ parseExtras:
        return nil
 }
 
-func readDataDescriptor(r io.Reader, f *File) error {
-       var buf [dataDescriptorLen]byte
+func readDataDescriptor(r io.Reader, zip64 bool) (*dataDescriptor, error) {
+       // Create enough space for the largest possible size
+       var buf [dataDescriptor64Len]byte
 
        // The spec says: "Although not originally assigned a
        // signature, the value 0x08074b50 has commonly been adopted
@@ -446,10 +489,9 @@ func readDataDescriptor(r io.Reader, f *File) error {
        // descriptors and should account for either case when reading
        // ZIP files to ensure compatibility."
        //
-       // dataDescriptorLen includes the size of the signature but
-       // first read just those 4 bytes to see if it exists.
+       // First read just those 4 bytes to see if the signature exists.
        if _, err := io.ReadFull(r, buf[:4]); err != nil {
-               return err
+               return nil, err
        }
        off := 0
        maybeSig := readBuf(buf[:4])
@@ -458,21 +500,28 @@ func readDataDescriptor(r io.Reader, f *File) error {
                // bytes.
                off += 4
        }
-       if _, err := io.ReadFull(r, buf[off:12]); err != nil {
-               return err
+
+       end := dataDescriptorLen - 4
+       if zip64 {
+               end = dataDescriptor64Len - 4
        }
-       b := readBuf(buf[:12])
-       if b.uint32() != f.CRC32 {
-               return ErrChecksum
+       if _, err := io.ReadFull(r, buf[off:end]); err != nil {
+               return nil, err
        }
+       b := readBuf(buf[:end])
 
-       // The two sizes that follow here can be either 32 bits or 64 bits
-       // but the spec is not very clear on this and different
-       // interpretations has been made causing incompatibilities. We
-       // already have the sizes from the central directory so we can
-       // just ignore these.
+       out := &dataDescriptor{
+               crc32: b.uint32(),
+       }
 
-       return nil
+       if zip64 {
+               out.compressedSize = b.uint64()
+               out.uncompressedSize = b.uint64()
+       } else {
+               out.compressedSize = uint64(b.uint32())
+               out.uncompressedSize = uint64(b.uint32())
+       }
+       return out, nil
 }
 
 func readDirectoryEnd(r io.ReaderAt, size int64) (dir *directoryEnd, err error) {
index 6ee62ef997414c5126c48fa96d9e6da116ce9293..35e681ec6995e844540df7d5b30475e62a64e350 100644 (file)
@@ -499,9 +499,15 @@ func TestReader(t *testing.T) {
 func readTestZip(t *testing.T, zt ZipTest) {
        var z *Reader
        var err error
+       var raw []byte
        if zt.Source != nil {
                rat, size := zt.Source()
                z, err = NewReader(rat, size)
+               raw = make([]byte, size)
+               if _, err := rat.ReadAt(raw, 0); err != nil {
+                       t.Errorf("ReadAt error=%v", err)
+                       return
+               }
        } else {
                path := filepath.Join("testdata", zt.Name)
                if zt.Obscured {
@@ -519,6 +525,12 @@ func readTestZip(t *testing.T, zt ZipTest) {
                        defer rc.Close()
                        z = &rc.Reader
                }
+               var err2 error
+               raw, err2 = os.ReadFile(path)
+               if err2 != nil {
+                       t.Errorf("ReadFile(%s) error=%v", path, err2)
+                       return
+               }
        }
        if err != zt.Error {
                t.Errorf("error=%v, want %v", err, zt.Error)
@@ -545,7 +557,7 @@ func readTestZip(t *testing.T, zt ZipTest) {
 
        // test read of each file
        for i, ft := range zt.File {
-               readTestFile(t, zt, ft, z.File[i])
+               readTestFile(t, zt, ft, z.File[i], raw)
        }
        if t.Failed() {
                return
@@ -557,7 +569,7 @@ func readTestZip(t *testing.T, zt ZipTest) {
        for i := 0; i < 5; i++ {
                for j, ft := range zt.File {
                        go func(j int, ft ZipTestFile) {
-                               readTestFile(t, zt, ft, z.File[j])
+                               readTestFile(t, zt, ft, z.File[j], raw)
                                done <- true
                        }(j, ft)
                        n++
@@ -574,7 +586,7 @@ func equalTimeAndZone(t1, t2 time.Time) bool {
        return t1.Equal(t2) && name1 == name2 && offset1 == offset2
 }
 
-func readTestFile(t *testing.T, zt ZipTest, ft ZipTestFile, f *File) {
+func readTestFile(t *testing.T, zt ZipTest, ft ZipTestFile, f *File, raw []byte) {
        if f.Name != ft.Name {
                t.Errorf("name=%q, want %q", f.Name, ft.Name)
        }
@@ -594,6 +606,31 @@ func readTestFile(t *testing.T, zt ZipTest, ft ZipTestFile, f *File) {
                t.Errorf("%v: UncompressedSize=%#x does not match UncompressedSize64=%#x", f.Name, size, f.UncompressedSize64)
        }
 
+       // Check that OpenRaw returns the correct byte segment
+       rw, err := f.OpenRaw()
+       if err != nil {
+               t.Errorf("%v: OpenRaw error=%v", f.Name, err)
+               return
+       }
+       start, err := f.DataOffset()
+       if err != nil {
+               t.Errorf("%v: DataOffset error=%v", f.Name, err)
+               return
+       }
+       got, err := io.ReadAll(rw)
+       if err != nil {
+               t.Errorf("%v: OpenRaw ReadAll error=%v", f.Name, err)
+               return
+       }
+       end := uint64(start) + f.CompressedSize64
+       want := raw[start:end]
+       if !bytes.Equal(got, want) {
+               t.Logf("got %q", got)
+               t.Logf("want %q", want)
+               t.Errorf("%v: OpenRaw returned unexpected bytes", f.Name)
+               return
+       }
+
        r, err := f.Open()
        if err != nil {
                t.Errorf("%v", err)
@@ -1166,3 +1203,125 @@ func TestCVE202127919(t *testing.T) {
                t.Errorf("Error reading file: %v", err)
        }
 }
+
+func TestReadDataDescriptor(t *testing.T) {
+       tests := []struct {
+               desc    string
+               in      []byte
+               zip64   bool
+               want    *dataDescriptor
+               wantErr error
+       }{{
+               desc: "valid 32 bit with signature",
+               in: []byte{
+                       0x50, 0x4b, 0x07, 0x08, // signature
+                       0x00, 0x01, 0x02, 0x03, // crc32
+                       0x04, 0x05, 0x06, 0x07, // compressed size
+                       0x08, 0x09, 0x0a, 0x0b, // uncompressed size
+               },
+               want: &dataDescriptor{
+                       crc32:            0x03020100,
+                       compressedSize:   0x07060504,
+                       uncompressedSize: 0x0b0a0908,
+               },
+       }, {
+               desc: "valid 32 bit without signature",
+               in: []byte{
+                       0x00, 0x01, 0x02, 0x03, // crc32
+                       0x04, 0x05, 0x06, 0x07, // compressed size
+                       0x08, 0x09, 0x0a, 0x0b, // uncompressed size
+               },
+               want: &dataDescriptor{
+                       crc32:            0x03020100,
+                       compressedSize:   0x07060504,
+                       uncompressedSize: 0x0b0a0908,
+               },
+       }, {
+               desc: "valid 64 bit with signature",
+               in: []byte{
+                       0x50, 0x4b, 0x07, 0x08, // signature
+                       0x00, 0x01, 0x02, 0x03, // crc32
+                       0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, // compressed size
+                       0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, // uncompressed size
+               },
+               zip64: true,
+               want: &dataDescriptor{
+                       crc32:            0x03020100,
+                       compressedSize:   0x0b0a090807060504,
+                       uncompressedSize: 0x131211100f0e0d0c,
+               },
+       }, {
+               desc: "valid 64 bit without signature",
+               in: []byte{
+                       0x00, 0x01, 0x02, 0x03, // crc32
+                       0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, // compressed size
+                       0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, 0x12, 0x13, // uncompressed size
+               },
+               zip64: true,
+               want: &dataDescriptor{
+                       crc32:            0x03020100,
+                       compressedSize:   0x0b0a090807060504,
+                       uncompressedSize: 0x131211100f0e0d0c,
+               },
+       }, {
+               desc: "invalid 32 bit with signature",
+               in: []byte{
+                       0x50, 0x4b, 0x07, 0x08, // signature
+                       0x00, 0x01, 0x02, 0x03, // crc32
+                       0x04, 0x05, // unexpected end
+               },
+               wantErr: io.ErrUnexpectedEOF,
+       }, {
+               desc: "invalid 32 bit without signature",
+               in: []byte{
+                       0x00, 0x01, 0x02, 0x03, // crc32
+                       0x04, 0x05, // unexpected end
+               },
+               wantErr: io.ErrUnexpectedEOF,
+       }, {
+               desc: "invalid 64 bit with signature",
+               in: []byte{
+                       0x50, 0x4b, 0x07, 0x08, // signature
+                       0x00, 0x01, 0x02, 0x03, // crc32
+                       0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, // compressed size
+                       0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, // unexpected end
+               },
+               zip64:   true,
+               wantErr: io.ErrUnexpectedEOF,
+       }, {
+               desc: "invalid 64 bit without signature",
+               in: []byte{
+                       0x00, 0x01, 0x02, 0x03, // crc32
+                       0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, // compressed size
+                       0x0c, 0x0d, 0x0e, 0x0f, 0x10, 0x11, // unexpected end
+               },
+               zip64:   true,
+               wantErr: io.ErrUnexpectedEOF,
+       }}
+
+       for _, test := range tests {
+               t.Run(test.desc, func(t *testing.T) {
+                       r := bytes.NewReader(test.in)
+
+                       desc, err := readDataDescriptor(r, test.zip64)
+                       if err != test.wantErr {
+                               t.Fatalf("got err %v; want nil", err)
+                       }
+                       if test.want == nil {
+                               return
+                       }
+                       if desc == nil {
+                               t.Fatalf("got nil DataDescriptor; want non-nil")
+                       }
+                       if desc.crc32 != test.want.crc32 {
+                               t.Errorf("got CRC32 %#x; want %#x", desc.crc32, test.want.crc32)
+                       }
+                       if desc.compressedSize != test.want.compressedSize {
+                               t.Errorf("got CompressedSize %#x; want %#x", desc.compressedSize, test.want.compressedSize)
+                       }
+                       if desc.uncompressedSize != test.want.uncompressedSize {
+                               t.Errorf("got UncompressedSize %#x; want %#x", desc.uncompressedSize, test.want.uncompressedSize)
+                       }
+               })
+       }
+}
index 3dc0c50122a83426cf23d39829b283b486365c02..ff9f605eb697eec319602d455917953b70b906e6 100644 (file)
@@ -42,7 +42,7 @@ const (
        directoryHeaderLen       = 46         // + filename + extra + comment
        directoryEndLen          = 22         // + comment
        dataDescriptorLen        = 16         // four uint32: descriptor signature, crc32, compressed size, size
-       dataDescriptor64Len      = 24         // descriptor with 8 byte sizes
+       dataDescriptor64Len      = 24         // two uint32: signature, crc32 | two uint64: compressed size, size
        directory64LocLen        = 20         //
        directory64EndLen        = 56         // + extra
 
@@ -315,6 +315,10 @@ func (h *FileHeader) isZip64() bool {
        return h.CompressedSize64 >= uint32max || h.UncompressedSize64 >= uint32max
 }
 
+func (f *FileHeader) hasDataDescriptor() bool {
+       return f.Flags&0x8 != 0
+}
+
 func msdosModeToFileMode(m uint32) (mode fs.FileMode) {
        if m&msdosDir != 0 {
                mode = fs.ModeDir | 0777
@@ -386,3 +390,11 @@ func unixModeToFileMode(m uint32) fs.FileMode {
        }
        return mode
 }
+
+// dataDescriptor holds the data descriptor that optionally follows the file
+// contents in the zip file.
+type dataDescriptor struct {
+       crc32            uint32
+       compressedSize   uint64
+       uncompressedSize uint64
+}
index cdc534eaf01922df345eb4a44b372ef4c8fbc1bc..3b23cc3391d9e676a6735266fa8bb76a73ace755 100644 (file)
@@ -37,6 +37,7 @@ type Writer struct {
 type header struct {
        *FileHeader
        offset uint64
+       raw    bool
 }
 
 // NewWriter returns a new Writer writing a zip file to w.
@@ -245,22 +246,31 @@ func detectUTF8(s string) (valid, require bool) {
        return true, require
 }
 
+// prepare performs the bookkeeping operations required at the start of
+// CreateHeader and CreateRaw.
+func (w *Writer) prepare(fh *FileHeader) error {
+       if w.last != nil && !w.last.closed {
+               if err := w.last.close(); err != nil {
+                       return err
+               }
+       }
+       if len(w.dir) > 0 && w.dir[len(w.dir)-1].FileHeader == fh {
+               // See https://golang.org/issue/11144 confusion.
+               return errors.New("archive/zip: invalid duplicate FileHeader")
+       }
+       return nil
+}
+
 // CreateHeader adds a file to the zip archive using the provided FileHeader
 // for the file metadata. Writer takes ownership of fh and may mutate
 // its fields. The caller must not modify fh after calling CreateHeader.
 //
 // This returns a Writer to which the file contents should be written.
 // The file's contents must be written to the io.Writer before the next
-// call to Create, CreateHeader, or Close.
+// call to Create, CreateHeader, CreateRaw, or Close.
 func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) {
-       if w.last != nil && !w.last.closed {
-               if err := w.last.close(); err != nil {
-                       return nil, err
-               }
-       }
-       if len(w.dir) > 0 && w.dir[len(w.dir)-1].FileHeader == fh {
-               // See https://golang.org/issue/11144 confusion.
-               return nil, errors.New("archive/zip: invalid duplicate FileHeader")
+       if err := w.prepare(fh); err != nil {
+               return nil, err
        }
 
        // The ZIP format has a sad state of affairs regarding character encoding.
@@ -365,7 +375,7 @@ func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) {
                ow = fw
        }
        w.dir = append(w.dir, h)
-       if err := writeHeader(w.cw, fh); err != nil {
+       if err := writeHeader(w.cw, h); err != nil {
                return nil, err
        }
        // If we're creating a directory, fw is nil.
@@ -373,7 +383,7 @@ func (w *Writer) CreateHeader(fh *FileHeader) (io.Writer, error) {
        return ow, nil
 }
 
-func writeHeader(w io.Writer, h *FileHeader) error {
+func writeHeader(w io.Writer, h *header) error {
        const maxUint16 = 1<<16 - 1
        if len(h.Name) > maxUint16 {
                return errLongName
@@ -390,9 +400,20 @@ func writeHeader(w io.Writer, h *FileHeader) error {
        b.uint16(h.Method)
        b.uint16(h.ModifiedTime)
        b.uint16(h.ModifiedDate)
-       b.uint32(0) // since we are writing a data descriptor crc32,
-       b.uint32(0) // compressed size,
-       b.uint32(0) // and uncompressed size should be zero
+       // In raw mode (caller does the compression), the values are either
+       // written here or in the trailing data descriptor based on the header
+       // flags.
+       if h.raw && !h.hasDataDescriptor() {
+               b.uint32(h.CRC32)
+               b.uint32(uint32(min64(h.CompressedSize64, uint32max)))
+               b.uint32(uint32(min64(h.UncompressedSize64, uint32max)))
+       } else {
+               // When this package handle the compression, these values are
+               // always written to the trailing data descriptor.
+               b.uint32(0) // crc32
+               b.uint32(0) // compressed size
+               b.uint32(0) // uncompressed size
+       }
        b.uint16(uint16(len(h.Name)))
        b.uint16(uint16(len(h.Extra)))
        if _, err := w.Write(buf[:]); err != nil {
@@ -405,6 +426,65 @@ func writeHeader(w io.Writer, h *FileHeader) error {
        return err
 }
 
+func min64(x, y uint64) uint64 {
+       if x < y {
+               return x
+       }
+       return y
+}
+
+// CreateRaw adds a file to the zip archive using the provided FileHeader and
+// returns a Writer to which the file contents should be written. The file's
+// contents must be written to the io.Writer before the next call to Create,
+// CreateHeader, CreateRaw, or Close.
+//
+// In contrast to CreateHeader, the bytes passed to Writer are not compressed.
+func (w *Writer) CreateRaw(fh *FileHeader) (io.Writer, error) {
+       if err := w.prepare(fh); err != nil {
+               return nil, err
+       }
+
+       fh.CompressedSize = uint32(min64(fh.CompressedSize64, uint32max))
+       fh.UncompressedSize = uint32(min64(fh.UncompressedSize64, uint32max))
+
+       h := &header{
+               FileHeader: fh,
+               offset:     uint64(w.cw.count),
+               raw:        true,
+       }
+       w.dir = append(w.dir, h)
+       if err := writeHeader(w.cw, h); err != nil {
+               return nil, err
+       }
+
+       if strings.HasSuffix(fh.Name, "/") {
+               w.last = nil
+               return dirWriter{}, nil
+       }
+
+       fw := &fileWriter{
+               header: h,
+               zipw:   w.cw,
+       }
+       w.last = fw
+       return fw, nil
+}
+
+// Copy copies the file f (obtained from a Reader) into w. It copies the raw
+// form directly bypassing decompression, compression, and validation.
+func (w *Writer) Copy(f *File) error {
+       r, err := f.OpenRaw()
+       if err != nil {
+               return err
+       }
+       fw, err := w.CreateRaw(&f.FileHeader)
+       if err != nil {
+               return err
+       }
+       _, err = io.Copy(fw, r)
+       return err
+}
+
 // RegisterCompressor registers or overrides a custom compressor for a specific
 // method ID. If a compressor for a given method is not found, Writer will
 // default to looking up the compressor at the package level.
@@ -446,6 +526,9 @@ func (w *fileWriter) Write(p []byte) (int, error) {
        if w.closed {
                return 0, errors.New("zip: write to closed file")
        }
+       if w.raw {
+               return w.zipw.Write(p)
+       }
        w.crc32.Write(p)
        return w.rawCount.Write(p)
 }
@@ -455,6 +538,9 @@ func (w *fileWriter) close() error {
                return errors.New("zip: file closed twice")
        }
        w.closed = true
+       if w.raw {
+               return w.writeDataDescriptor()
+       }
        if err := w.comp.Close(); err != nil {
                return err
        }
@@ -474,26 +560,33 @@ func (w *fileWriter) close() error {
                fh.UncompressedSize = uint32(fh.UncompressedSize64)
        }
 
+       return w.writeDataDescriptor()
+}
+
+func (w *fileWriter) writeDataDescriptor() error {
+       if !w.hasDataDescriptor() {
+               return nil
+       }
        // Write data descriptor. This is more complicated than one would
        // think, see e.g. comments in zipfile.c:putextended() and
        // http://bugs.sun.com/bugdatabase/view_bug.do?bug_id=7073588.
        // The approach here is to write 8 byte sizes if needed without
        // adding a zip64 extra in the local header (too late anyway).
        var buf []byte
-       if fh.isZip64() {
+       if w.isZip64() {
                buf = make([]byte, dataDescriptor64Len)
        } else {
                buf = make([]byte, dataDescriptorLen)
        }
        b := writeBuf(buf)
        b.uint32(dataDescriptorSignature) // de-facto standard, required by OS X
-       b.uint32(fh.CRC32)
-       if fh.isZip64() {
-               b.uint64(fh.CompressedSize64)
-               b.uint64(fh.UncompressedSize64)
+       b.uint32(w.CRC32)
+       if w.isZip64() {
+               b.uint64(w.CompressedSize64)
+               b.uint64(w.UncompressedSize64)
        } else {
-               b.uint32(fh.CompressedSize)
-               b.uint32(fh.UncompressedSize)
+               b.uint32(w.CompressedSize)
+               b.uint32(w.UncompressedSize)
        }
        _, err := w.zipw.Write(buf)
        return err
index 3fa8bef0553c4de1a3135ec335c90668042aa820..97c6c5297994684ca7fcfc5ccc492025709d795c 100644 (file)
@@ -6,8 +6,10 @@ package zip
 
 import (
        "bytes"
+       "compress/flate"
        "encoding/binary"
        "fmt"
+       "hash/crc32"
        "io"
        "io/fs"
        "math/rand"
@@ -365,6 +367,171 @@ func TestWriterDirAttributes(t *testing.T) {
        }
 }
 
+func TestWriterCopy(t *testing.T) {
+       // make a zip file
+       buf := new(bytes.Buffer)
+       w := NewWriter(buf)
+       for _, wt := range writeTests {
+               testCreate(t, w, &wt)
+       }
+       if err := w.Close(); err != nil {
+               t.Fatal(err)
+       }
+
+       // read it back
+       src, err := NewReader(bytes.NewReader(buf.Bytes()), int64(buf.Len()))
+       if err != nil {
+               t.Fatal(err)
+       }
+       for i, wt := range writeTests {
+               testReadFile(t, src.File[i], &wt)
+       }
+
+       // make a new zip file copying the old compressed data.
+       buf2 := new(bytes.Buffer)
+       dst := NewWriter(buf2)
+       for _, f := range src.File {
+               if err := dst.Copy(f); err != nil {
+                       t.Fatal(err)
+               }
+       }
+       if err := dst.Close(); err != nil {
+               t.Fatal(err)
+       }
+
+       // read the new one back
+       r, err := NewReader(bytes.NewReader(buf2.Bytes()), int64(buf2.Len()))
+       if err != nil {
+               t.Fatal(err)
+       }
+       for i, wt := range writeTests {
+               testReadFile(t, r.File[i], &wt)
+       }
+}
+
+func TestWriterCreateRaw(t *testing.T) {
+       files := []struct {
+               name             string
+               content          []byte
+               method           uint16
+               flags            uint16
+               crc32            uint32
+               uncompressedSize uint64
+               compressedSize   uint64
+       }{
+               {
+                       name:    "small store w desc",
+                       content: []byte("gophers"),
+                       method:  Store,
+                       flags:   0x8,
+               },
+               {
+                       name:    "small deflate wo desc",
+                       content: bytes.Repeat([]byte("abcdefg"), 2048),
+                       method:  Deflate,
+               },
+       }
+
+       // write a zip file
+       archive := new(bytes.Buffer)
+       w := NewWriter(archive)
+
+       for i := range files {
+               f := &files[i]
+               f.crc32 = crc32.ChecksumIEEE(f.content)
+               size := uint64(len(f.content))
+               f.uncompressedSize = size
+               f.compressedSize = size
+
+               var compressedContent []byte
+               if f.method == Deflate {
+                       var buf bytes.Buffer
+                       w, err := flate.NewWriter(&buf, flate.BestSpeed)
+                       if err != nil {
+                               t.Fatalf("flate.NewWriter err = %v", err)
+                       }
+                       _, err = w.Write(f.content)
+                       if err != nil {
+                               t.Fatalf("flate Write err = %v", err)
+                       }
+                       err = w.Close()
+                       if err != nil {
+                               t.Fatalf("flate Writer.Close err = %v", err)
+                       }
+                       compressedContent = buf.Bytes()
+                       f.compressedSize = uint64(len(compressedContent))
+               }
+
+               h := &FileHeader{
+                       Name:               f.name,
+                       Method:             f.method,
+                       Flags:              f.flags,
+                       CRC32:              f.crc32,
+                       CompressedSize64:   f.compressedSize,
+                       UncompressedSize64: f.uncompressedSize,
+               }
+               w, err := w.CreateRaw(h)
+               if err != nil {
+                       t.Fatal(err)
+               }
+               if compressedContent != nil {
+                       _, err = w.Write(compressedContent)
+               } else {
+                       _, err = w.Write(f.content)
+               }
+               if err != nil {
+                       t.Fatalf("%s Write got %v; want nil", f.name, err)
+               }
+       }
+
+       if err := w.Close(); err != nil {
+               t.Fatal(err)
+       }
+
+       // read it back
+       r, err := NewReader(bytes.NewReader(archive.Bytes()), int64(archive.Len()))
+       if err != nil {
+               t.Fatal(err)
+       }
+       for i, want := range files {
+               got := r.File[i]
+               if got.Name != want.name {
+                       t.Errorf("got Name %s; want %s", got.Name, want.name)
+               }
+               if got.Method != want.method {
+                       t.Errorf("%s: got Method %#x; want %#x", want.name, got.Method, want.method)
+               }
+               if got.Flags != want.flags {
+                       t.Errorf("%s: got Flags %#x; want %#x", want.name, got.Flags, want.flags)
+               }
+               if got.CRC32 != want.crc32 {
+                       t.Errorf("%s: got CRC32 %#x; want %#x", want.name, got.CRC32, want.crc32)
+               }
+               if got.CompressedSize64 != want.compressedSize {
+                       t.Errorf("%s: got CompressedSize64 %d; want %d", want.name, got.CompressedSize64, want.compressedSize)
+               }
+               if got.UncompressedSize64 != want.uncompressedSize {
+                       t.Errorf("%s: got UncompressedSize64 %d; want %d", want.name, got.UncompressedSize64, want.uncompressedSize)
+               }
+
+               r, err := got.Open()
+               if err != nil {
+                       t.Errorf("%s: Open err = %v", got.Name, err)
+                       continue
+               }
+
+               buf, err := io.ReadAll(r)
+               if err != nil {
+                       t.Errorf("%s: ReadAll err = %v", got.Name, err)
+                       continue
+               }
+
+               if !bytes.Equal(buf, want.content) {
+                       t.Errorf("%v: ReadAll returned unexpected bytes", got.Name)
+               }
+       }
+}
+
 func testCreate(t *testing.T, w *Writer, wt *WriteTest) {
        header := &FileHeader{
                Name:   wt.Name,
@@ -390,15 +557,15 @@ func testReadFile(t *testing.T, f *File, wt *WriteTest) {
        testFileMode(t, f, wt.Mode)
        rc, err := f.Open()
        if err != nil {
-               t.Fatal("opening:", err)
+               t.Fatalf("opening %s: %v", f.Name, err)
        }
        b, err := io.ReadAll(rc)
        if err != nil {
-               t.Fatal("reading:", err)
+               t.Fatalf("reading %s: %v", f.Name, err)
        }
        err = rc.Close()
        if err != nil {
-               t.Fatal("closing:", err)
+               t.Fatalf("closing %s: %v", f.Name, err)
        }
        if !bytes.Equal(b, wt.Data) {
                t.Errorf("File contents %q, want %q", b, wt.Data)