Merge "Improvements in snapshot data integrity tests"
diff --git a/cinder_tempest_plugin/scenario/manager.py b/cinder_tempest_plugin/scenario/manager.py
index 3b25bb1..862432c 100644
--- a/cinder_tempest_plugin/scenario/manager.py
+++ b/cinder_tempest_plugin/scenario/manager.py
@@ -13,6 +13,8 @@
 #    License for the specific language governing permissions and limitations
 #    under the License.
 
+import contextlib
+
 from oslo_log import log
 
 from tempest.common import waiters
@@ -55,20 +57,25 @@
                        if item not in disks_list_before_attach][0]
         return volume_name
 
+    @contextlib.contextmanager
+    def mount_dev_path(self, ssh_client, dev_name, mount_path):
+        if dev_name is not None:
+            ssh_client.exec_command('sudo mount /dev/%s %s' % (dev_name,
+                                                               mount_path))
+            yield
+            ssh_client.exec_command('sudo umount %s' % mount_path)
+        else:
+            yield
+
     def _get_file_md5(self, ip_address, filename, dev_name=None,
                       mount_path='/mnt', private_key=None, server=None):
 
         ssh_client = self.get_remote_client(ip_address,
                                             private_key=private_key,
                                             server=server)
-        if dev_name is not None:
-            ssh_client.exec_command('sudo mount /dev/%s %s' % (dev_name,
-                                                               mount_path))
-
-        md5_sum = ssh_client.exec_command(
-            'sudo md5sum %s/%s|cut -c 1-32' % (mount_path, filename))
-        if dev_name is not None:
-            ssh_client.exec_command('sudo umount %s' % mount_path)
+        with self.mount_dev_path(ssh_client, dev_name, mount_path):
+            md5_sum = ssh_client.exec_command(
+                'sudo md5sum %s/%s|cut -c 1-32' % (mount_path, filename))
         return md5_sum
 
     def _count_files(self, ip_address, dev_name=None, mount_path='/mnt',
@@ -76,12 +83,9 @@
         ssh_client = self.get_remote_client(ip_address,
                                             private_key=private_key,
                                             server=server)
-        if dev_name is not None:
-            ssh_client.exec_command('sudo mount /dev/%s %s' % (dev_name,
-                                                               mount_path))
-        count = ssh_client.exec_command('sudo ls -l %s | wc -l' % mount_path)
-        if dev_name is not None:
-            ssh_client.exec_command('sudo umount %s' % mount_path)
+        with self.mount_dev_path(ssh_client, dev_name, mount_path):
+            count = ssh_client.exec_command(
+                'sudo ls -l %s | wc -l' % mount_path)
         # We subtract 2 from the count since `wc -l` also includes the count
         # of new line character and while creating the filesystem, a
         # lost+found folder is also created
@@ -100,17 +104,13 @@
                                             private_key=private_key,
                                             server=server)
 
-        if dev_name is not None:
-            ssh_client.exec_command('sudo mount /dev/%s %s' % (dev_name,
-                                                               mount_path))
-        ssh_client.exec_command(
-            'sudo dd bs=1024 count=100 if=/dev/urandom of=/%s/%s' %
-            (mount_path, filename))
-        md5 = ssh_client.exec_command(
-            'sudo md5sum -b %s/%s|cut -c 1-32' % (mount_path, filename))
-        ssh_client.exec_command('sudo sync')
-        if dev_name is not None:
-            ssh_client.exec_command('sudo umount %s' % mount_path)
+        with self.mount_dev_path(ssh_client, dev_name, mount_path):
+            ssh_client.exec_command(
+                'sudo dd bs=1024 count=100 if=/dev/urandom of=/%s/%s' %
+                (mount_path, filename))
+            md5 = ssh_client.exec_command(
+                'sudo md5sum -b %s/%s|cut -c 1-32' % (mount_path, filename))
+            ssh_client.exec_command('sudo sync')
         return md5
 
     def get_md5_from_file(self, instance, instance_ip, filename,
diff --git a/cinder_tempest_plugin/scenario/test_snapshots.py b/cinder_tempest_plugin/scenario/test_snapshots.py
index 5a9611f..99e1057 100644
--- a/cinder_tempest_plugin/scenario/test_snapshots.py
+++ b/cinder_tempest_plugin/scenario/test_snapshots.py
@@ -36,7 +36,7 @@
         1) Create an instance with ephemeral disk
         2) Create a volume, attach it to the instance and create a filesystem
            on it and mount it
-        3) Mount the volume, create a file and write data into it, Unmount it
+        3) Create a file and write data into it, Unmount it
         4) create snapshot
         5) repeat 3 and 4 two more times (simply creating 3 snapshots)
 
@@ -93,41 +93,21 @@
         # Detach the volume
         self.nova_volume_detach(server, volume)
 
-        # Create volume from snapshot, attach it to instance and check file
-        # and contents for snap1
-        volume_snap_1 = self.create_volume(snapshot_id=snapshot1['id'])
-        volume_device_name, __ = self._attach_and_get_volume_device_name(
-            server, volume_snap_1, instance_ip, self.keypair['private_key'])
-        count_snap_1, md5_file_1 = self.get_md5_from_file(
-            server, instance_ip, 'file1', dev_name=volume_device_name)
-        # Detach the volume
-        self.nova_volume_detach(server, volume_snap_1)
+        snap_map = {1: snapshot1, 2: snapshot2, 3: snapshot3}
+        file_map = {1: file1_md5, 2: file2_md5, 3: file3_md5}
 
-        self.assertEqual(count_snap_1, 1)
-        self.assertEqual(file1_md5, md5_file_1)
+        # Loop over 3 times to check the data integrity of all 3 snapshots
+        for i in range(1, 4):
+            # Create volume from snapshot, attach it to instance and check file
+            # and contents for snap
+            volume_snap = self.create_volume(snapshot_id=snap_map[i]['id'])
+            volume_device_name, __ = self._attach_and_get_volume_device_name(
+                server, volume_snap, instance_ip, self.keypair['private_key'])
+            count_snap, md5_file = self.get_md5_from_file(
+                server, instance_ip, 'file' + str(i),
+                dev_name=volume_device_name)
+            # Detach the volume
+            self.nova_volume_detach(server, volume_snap)
 
-        # Create volume from snapshot, attach it to instance and check file
-        # and contents for snap2
-        volume_snap_2 = self.create_volume(snapshot_id=snapshot2['id'])
-        volume_device_name, __ = self._attach_and_get_volume_device_name(
-            server, volume_snap_2, instance_ip, self.keypair['private_key'])
-        count_snap_2, md5_file_2 = self.get_md5_from_file(
-            server, instance_ip, 'file2', dev_name=volume_device_name)
-        # Detach the volume
-        self.nova_volume_detach(server, volume_snap_2)
-
-        self.assertEqual(count_snap_2, 2)
-        self.assertEqual(file2_md5, md5_file_2)
-
-        # Create volume from snapshot, attach it to instance and check file
-        # and contents for snap3
-        volume_snap_3 = self.create_volume(snapshot_id=snapshot3['id'])
-        volume_device_name, __ = self._attach_and_get_volume_device_name(
-            server, volume_snap_3, instance_ip, self.keypair['private_key'])
-        count_snap_3, md5_file_3 = self.get_md5_from_file(
-            server, instance_ip, 'file3', dev_name=volume_device_name)
-        # Detach the volume
-        self.nova_volume_detach(server, volume_snap_3)
-
-        self.assertEqual(count_snap_3, 3)
-        self.assertEqual(file3_md5, md5_file_3)
+            self.assertEqual(count_snap, i)
+            self.assertEqual(file_map[i], md5_file)