@@ -592,6 +592,65 @@ class Blog(Document):
592592
593593 Blog .drop_collection ()
594594
595+ def test_update_array_filters (self ):
596+ """Ensure that updating by array_filters works."""
597+
598+ class Comment (EmbeddedDocument ):
599+ comment_tags = ListField (StringField ())
600+
601+ class Blog (Document ):
602+ tags = ListField (StringField ())
603+ comments = EmbeddedDocumentField (Comment )
604+
605+ Blog .drop_collection ()
606+
607+ # update one
608+ Blog .objects .create (tags = ["test1" , "test2" , "test3" ])
609+
610+ Blog .objects ().update_one (
611+ __raw__ = {"$set" : {"tags.$[element]" : "test11111" }},
612+ array_filters = [{"element" : {"$eq" : "test2" }}],
613+ )
614+ testc_blogs = Blog .objects (tags = "test11111" )
615+
616+ assert testc_blogs .count () == 1
617+
618+ Blog .drop_collection ()
619+
620+ # update one inner list
621+ comments = Comment (comment_tags = ["test1" , "test2" , "test3" ])
622+ Blog .objects .create (comments = comments )
623+
624+ Blog .objects ().update_one (
625+ __raw__ = {"$set" : {"comments.comment_tags.$[element]" : "test11111" }},
626+ array_filters = [{"element" : {"$eq" : "test2" }}],
627+ )
628+ testc_blogs = Blog .objects (comments__comment_tags = "test11111" )
629+
630+ assert testc_blogs .count () == 1
631+
632+ # update many
633+ Blog .drop_collection ()
634+
635+ Blog .objects .create (tags = ["test1" , "test2" , "test3" , "test_all" ])
636+ Blog .objects .create (tags = ["test4" , "test5" , "test6" , "test_all" ])
637+
638+ Blog .objects ().update (
639+ __raw__ = {"$set" : {"tags.$[element]" : "test11111" }},
640+ array_filters = [{"element" : {"$eq" : "test2" }}],
641+ )
642+ testc_blogs = Blog .objects (tags = "test11111" )
643+
644+ assert testc_blogs .count () == 1
645+
646+ Blog .objects ().update (
647+ __raw__ = {"$set" : {"tags.$[element]" : "test_all1234577" }},
648+ array_filters = [{"element" : {"$eq" : "test_all" }}],
649+ )
650+ testc_blogs = Blog .objects (tags = "test_all1234577" )
651+
652+ assert testc_blogs .count () == 2
653+
595654 def test_update_using_positional_operator (self ):
596655 """Ensure that the list fields can be updated using the positional
597656 operator."""
0 commit comments