diff --git a/taggit/management/commands/deduplicate_tags.py b/taggit/management/commands/deduplicate_tags.py index 647d7873..2f1f512a 100644 --- a/taggit/management/commands/deduplicate_tags.py +++ b/taggit/management/commands/deduplicate_tags.py @@ -19,20 +19,46 @@ def handle(self, *args, **kwargs): tags = Tag.objects.all() tag_dict = {} - tagged_items_to_update = defaultdict(list) for tag in tags: lower_name = tag.name.lower() if lower_name in tag_dict: existing_tag = tag_dict[lower_name] - self._collect_tagged_items(tag, existing_tag, tagged_items_to_update) - tag.delete() + self._deduplicate_tags(existing_tag=existing_tag, tag_to_remove=tag) else: tag_dict[lower_name] = tag - self._remove_duplicates_and_update(tagged_items_to_update) self.stdout.write(self.style.SUCCESS("Tag deduplication complete.")) + @transaction.atomic + def _deduplicate_tags(self, existing_tag, tag_to_remove): + """ + Remove a tag by merging it into an existing tag + """ + # If this ends up very slow for you, please file a ticket! + # This isn't trying to be performant, in order to keep the code simple. + for item in TaggedItem.objects.filter(tag=tag_to_remove): + # if we already have the same association on the model + # (via the existing tag), then we can just remove the + # tagged item. + tag_exists_other = TaggedItem.objects.filter( + tag=existing_tag, + content_type_id=item.content_type_id, + object_id=item.object_id, + ).exists() + if tag_exists_other: + item.delete() + else: + item.tag = existing_tag + item.save() + + # this should never trigger, but can never be too sure + assert not TaggedItem.objects.filter( + tag=tag_to_remove + ).exists(), "Tags were not all cleaned up!" + + tag_to_remove.delete() + def _collect_tagged_items(self, tag, existing_tag, tagged_items_to_update): for item in TaggedItem.objects.filter(tag=tag): tagged_items_to_update[(item.content_type_id, item.object_id)].append( diff --git a/tests/test_deduplicate_tags.py b/tests/test_deduplicate_tags.py index b0f63feb..b7aa782f 100644 --- a/tests/test_deduplicate_tags.py +++ b/tests/test_deduplicate_tags.py @@ -21,10 +21,11 @@ def setUp(self): self.food_item.tags.add(self.tag1) self.pet_item.tags.add(self.tag2) + self.pet_item.tags.add(self.tag3) def test_deduplicate_tags(self): self.assertEqual(Tag.objects.count(), 3) - self.assertEqual(TaggedItem.objects.count(), 2) + self.assertEqual(TaggedItem.objects.count(), 3) out = StringIO() call_command("deduplicate_tags", stdout=out) @@ -41,7 +42,7 @@ def test_deduplicate_tags(self): def test_no_duplicates(self): self.assertEqual(Tag.objects.count(), 3) - self.assertEqual(TaggedItem.objects.count(), 2) + self.assertEqual(TaggedItem.objects.count(), 3) out = StringIO() call_command("deduplicate_tags", stdout=out) @@ -65,4 +66,4 @@ def test_taggit_case_insensitive_not_enabled(self): self.assertIn("TAGGIT_CASE_INSENSITIVE is not enabled.", out.getvalue()) self.assertEqual(Tag.objects.count(), 3) - self.assertEqual(TaggedItem.objects.count(), 2) + self.assertEqual(TaggedItem.objects.count(), 3)