L2CAP ERTM: Drop invalid packet on reassembly

Store the expected SDU size on a START I-Frame.

If the reassembled SDU size is not equal to expected SDU size, drop the
SDU. Also don't allow receiving UNSEGMENTED PDU in the middle of
reassembly.

In next iteration, close the channel on receiving some invalid frame.

Bug: 144955385
Test: bluetooth_test_gd
Change-Id: Id36aa242802b6991c59c9f9f45e8fac064b26675
diff --git a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
index 2a1b930..97f9702 100644
--- a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
+++ b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
@@ -168,20 +168,20 @@
     }
   }
 
-  void recv_i_frame(Final f, uint8_t tx_seq, uint8_t req_seq, SegmentationAndReassembly sar,
+  void recv_i_frame(Final f, uint8_t tx_seq, uint8_t req_seq, SegmentationAndReassembly sar, uint16_t sdu_size,
                     const packet::PacketView<true>& payload) {
     if (rx_state_ == RxState::RECV) {
       if (f == Final::NOT_SET && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) && with_valid_f_bit(f) &&
           !local_busy()) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         send_ack(Final::NOT_SET);
       } else if (f == Final::POLL_RESPONSE && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) &&
                  with_valid_f_bit(f) && !local_busy()) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         if (!rej_actioned_) {
           retransmit_i_frames(req_seq);
           send_pending_i_frames();
@@ -216,14 +216,14 @@
       if (f == Final::NOT_SET && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) && with_valid_f_bit(f)) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         send_ack(Final::NOT_SET);
         rx_state_ = RxState::RECV;
       } else if (f == Final::POLL_RESPONSE && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) &&
                  with_valid_f_bit(f)) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         if (!rej_actioned_) {
           retransmit_i_frames(req_seq);
           send_pending_i_frames();
@@ -682,8 +682,8 @@
     recv_f_bit(f);
   }
 
-  void data_indication(SegmentationAndReassembly sar, const packet::PacketView<true>& segment) {
-    controller_->stage_for_reassembly(sar, segment);
+  void data_indication(SegmentationAndReassembly sar, uint16_t sdu_size, const packet::PacketView<true>& segment) {
+    controller_->stage_for_reassembly(sar, sdu_size, segment);
     buffer_seq_ = (buffer_seq_ + 1) % kMaxTxWin;
   }
 
@@ -825,8 +825,21 @@
       LOG_WARN("Received invalid frame");
       return;
     }
-    pimpl_->recv_i_frame(i_frame_view.GetF(), i_frame_view.GetTxSeq(), i_frame_view.GetReqSeq(), i_frame_view.GetSar(),
-                         i_frame_view.GetPayload());
+    Final f = i_frame_view.GetF();
+    uint8_t tx_seq = i_frame_view.GetTxSeq();
+    uint8_t req_seq = i_frame_view.GetReqSeq();
+    auto sar = i_frame_view.GetSar();
+    if (sar == SegmentationAndReassembly::START) {
+      auto i_frame_start_view = EnhancedInformationStartFrameView::Create(i_frame_view);
+      if (!i_frame_start_view.IsValid()) {
+        LOG_WARN("Received invalid I-Frame START");
+        return;
+      }
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, i_frame_start_view.GetL2capSduLength(),
+                           i_frame_start_view.GetPayload());
+    } else {
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, 0, i_frame_view.GetPayload());
+    }
   } else if (type == FrameType::S_FRAME) {
     auto s_frame_view = EnhancedSupervisoryFrameView::Create(standard_frame_view);
     if (!s_frame_view.IsValid()) {
@@ -872,8 +885,21 @@
       LOG_WARN("Received invalid frame");
       return;
     }
-    pimpl_->recv_i_frame(i_frame_view.GetF(), i_frame_view.GetTxSeq(), i_frame_view.GetReqSeq(), i_frame_view.GetSar(),
-                         i_frame_view.GetPayload());
+    Final f = i_frame_view.GetF();
+    uint8_t tx_seq = i_frame_view.GetTxSeq();
+    uint8_t req_seq = i_frame_view.GetReqSeq();
+    auto sar = i_frame_view.GetSar();
+    if (sar == SegmentationAndReassembly::START) {
+      auto i_frame_start_view = EnhancedInformationStartFrameWithFcsView::Create(i_frame_view);
+      if (!i_frame_start_view.IsValid()) {
+        LOG_WARN("Received invalid I-Frame START");
+        return;
+      }
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, i_frame_start_view.GetL2capSduLength(),
+                           i_frame_start_view.GetPayload());
+    } else {
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, 0, i_frame_view.GetPayload());
+    }
   } else if (type == FrameType::S_FRAME) {
     auto s_frame_view = EnhancedSupervisoryFrameWithFcsView::Create(standard_frame_view);
     if (!s_frame_view.IsValid()) {
@@ -908,10 +934,16 @@
   return next;
 }
 
-void ErtmController::stage_for_reassembly(SegmentationAndReassembly sar,
+void ErtmController::stage_for_reassembly(SegmentationAndReassembly sar, uint16_t sdu_size,
                                           const packet::PacketView<kLittleEndian>& payload) {
   switch (sar) {
     case SegmentationAndReassembly::UNSEGMENTED:
+      if (sar_state_ != SegmentationAndReassembly::END) {
+        LOG_WARN("Received invalid SAR");
+        close_channel();
+        return;
+      }
+      // TODO: Enforce MTU
       enqueue_buffer_.Enqueue(std::make_unique<packet::PacketView<kLittleEndian>>(payload), handler_);
       break;
     case SegmentationAndReassembly::START:
@@ -920,8 +952,10 @@
         close_channel();
         return;
       }
+      // TODO: Enforce MTU
       sar_state_ = SegmentationAndReassembly::START;
       reassembly_stage_ = payload;
+      remaining_sdu_continuation_packet_size_ = sdu_size - payload.size();
       break;
     case SegmentationAndReassembly::CONTINUATION:
       if (sar_state_ == SegmentationAndReassembly::END) {
@@ -930,6 +964,7 @@
         return;
       }
       reassembly_stage_.AppendPacketView(payload);
+      remaining_sdu_continuation_packet_size_ -= payload.size();
       break;
     case SegmentationAndReassembly::END:
       if (sar_state_ == SegmentationAndReassembly::END) {
@@ -937,9 +972,17 @@
         close_channel();
         return;
       }
+      sar_state_ = SegmentationAndReassembly::END;
+      remaining_sdu_continuation_packet_size_ -= payload.size();
+      if (remaining_sdu_continuation_packet_size_ != 0) {
+        LOG_WARN("Received invalid END I-Frame");
+        reassembly_stage_ = PacketViewForReassembly(std::make_shared<std::vector<uint8_t>>());
+        remaining_sdu_continuation_packet_size_ = 0;
+        close_channel();
+        return;
+      }
       reassembly_stage_.AppendPacketView(payload);
       enqueue_buffer_.Enqueue(std::make_unique<packet::PacketView<kLittleEndian>>(reassembly_stage_), handler_);
-      sar_state_ = SegmentationAndReassembly::END;
       break;
   }
 }