L2CAP ERTM: Drop invalid packet on reassembly

Store the expected SDU size on a START I-Frame.

If the reassembled SDU size is not equal to expected SDU size, drop the
SDU. Also don't allow receiving UNSEGMENTED PDU in the middle of
reassembly.

In next iteration, close the channel on receiving some invalid frame.

Bug: 144955385
Test: bluetooth_test_gd
Change-Id: Id36aa242802b6991c59c9f9f45e8fac064b26675
diff --git a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
index 2a1b930..97f9702 100644
--- a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
+++ b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.cc
@@ -168,20 +168,20 @@
     }
   }
 
-  void recv_i_frame(Final f, uint8_t tx_seq, uint8_t req_seq, SegmentationAndReassembly sar,
+  void recv_i_frame(Final f, uint8_t tx_seq, uint8_t req_seq, SegmentationAndReassembly sar, uint16_t sdu_size,
                     const packet::PacketView<true>& payload) {
     if (rx_state_ == RxState::RECV) {
       if (f == Final::NOT_SET && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) && with_valid_f_bit(f) &&
           !local_busy()) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         send_ack(Final::NOT_SET);
       } else if (f == Final::POLL_RESPONSE && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) &&
                  with_valid_f_bit(f) && !local_busy()) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         if (!rej_actioned_) {
           retransmit_i_frames(req_seq);
           send_pending_i_frames();
@@ -216,14 +216,14 @@
       if (f == Final::NOT_SET && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) && with_valid_f_bit(f)) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         send_ack(Final::NOT_SET);
         rx_state_ = RxState::RECV;
       } else if (f == Final::POLL_RESPONSE && with_expected_tx_seq(tx_seq) && with_valid_req_seq(req_seq) &&
                  with_valid_f_bit(f)) {
         increment_expected_tx_seq();
         pass_to_tx(req_seq, f);
-        data_indication(sar, payload);
+        data_indication(sar, sdu_size, payload);
         if (!rej_actioned_) {
           retransmit_i_frames(req_seq);
           send_pending_i_frames();
@@ -682,8 +682,8 @@
     recv_f_bit(f);
   }
 
-  void data_indication(SegmentationAndReassembly sar, const packet::PacketView<true>& segment) {
-    controller_->stage_for_reassembly(sar, segment);
+  void data_indication(SegmentationAndReassembly sar, uint16_t sdu_size, const packet::PacketView<true>& segment) {
+    controller_->stage_for_reassembly(sar, sdu_size, segment);
     buffer_seq_ = (buffer_seq_ + 1) % kMaxTxWin;
   }
 
@@ -825,8 +825,21 @@
       LOG_WARN("Received invalid frame");
       return;
     }
-    pimpl_->recv_i_frame(i_frame_view.GetF(), i_frame_view.GetTxSeq(), i_frame_view.GetReqSeq(), i_frame_view.GetSar(),
-                         i_frame_view.GetPayload());
+    Final f = i_frame_view.GetF();
+    uint8_t tx_seq = i_frame_view.GetTxSeq();
+    uint8_t req_seq = i_frame_view.GetReqSeq();
+    auto sar = i_frame_view.GetSar();
+    if (sar == SegmentationAndReassembly::START) {
+      auto i_frame_start_view = EnhancedInformationStartFrameView::Create(i_frame_view);
+      if (!i_frame_start_view.IsValid()) {
+        LOG_WARN("Received invalid I-Frame START");
+        return;
+      }
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, i_frame_start_view.GetL2capSduLength(),
+                           i_frame_start_view.GetPayload());
+    } else {
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, 0, i_frame_view.GetPayload());
+    }
   } else if (type == FrameType::S_FRAME) {
     auto s_frame_view = EnhancedSupervisoryFrameView::Create(standard_frame_view);
     if (!s_frame_view.IsValid()) {
@@ -872,8 +885,21 @@
       LOG_WARN("Received invalid frame");
       return;
     }
-    pimpl_->recv_i_frame(i_frame_view.GetF(), i_frame_view.GetTxSeq(), i_frame_view.GetReqSeq(), i_frame_view.GetSar(),
-                         i_frame_view.GetPayload());
+    Final f = i_frame_view.GetF();
+    uint8_t tx_seq = i_frame_view.GetTxSeq();
+    uint8_t req_seq = i_frame_view.GetReqSeq();
+    auto sar = i_frame_view.GetSar();
+    if (sar == SegmentationAndReassembly::START) {
+      auto i_frame_start_view = EnhancedInformationStartFrameWithFcsView::Create(i_frame_view);
+      if (!i_frame_start_view.IsValid()) {
+        LOG_WARN("Received invalid I-Frame START");
+        return;
+      }
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, i_frame_start_view.GetL2capSduLength(),
+                           i_frame_start_view.GetPayload());
+    } else {
+      pimpl_->recv_i_frame(f, tx_seq, req_seq, sar, 0, i_frame_view.GetPayload());
+    }
   } else if (type == FrameType::S_FRAME) {
     auto s_frame_view = EnhancedSupervisoryFrameWithFcsView::Create(standard_frame_view);
     if (!s_frame_view.IsValid()) {
@@ -908,10 +934,16 @@
   return next;
 }
 
-void ErtmController::stage_for_reassembly(SegmentationAndReassembly sar,
+void ErtmController::stage_for_reassembly(SegmentationAndReassembly sar, uint16_t sdu_size,
                                           const packet::PacketView<kLittleEndian>& payload) {
   switch (sar) {
     case SegmentationAndReassembly::UNSEGMENTED:
+      if (sar_state_ != SegmentationAndReassembly::END) {
+        LOG_WARN("Received invalid SAR");
+        close_channel();
+        return;
+      }
+      // TODO: Enforce MTU
       enqueue_buffer_.Enqueue(std::make_unique<packet::PacketView<kLittleEndian>>(payload), handler_);
       break;
     case SegmentationAndReassembly::START:
@@ -920,8 +952,10 @@
         close_channel();
         return;
       }
+      // TODO: Enforce MTU
       sar_state_ = SegmentationAndReassembly::START;
       reassembly_stage_ = payload;
+      remaining_sdu_continuation_packet_size_ = sdu_size - payload.size();
       break;
     case SegmentationAndReassembly::CONTINUATION:
       if (sar_state_ == SegmentationAndReassembly::END) {
@@ -930,6 +964,7 @@
         return;
       }
       reassembly_stage_.AppendPacketView(payload);
+      remaining_sdu_continuation_packet_size_ -= payload.size();
       break;
     case SegmentationAndReassembly::END:
       if (sar_state_ == SegmentationAndReassembly::END) {
@@ -937,9 +972,17 @@
         close_channel();
         return;
       }
+      sar_state_ = SegmentationAndReassembly::END;
+      remaining_sdu_continuation_packet_size_ -= payload.size();
+      if (remaining_sdu_continuation_packet_size_ != 0) {
+        LOG_WARN("Received invalid END I-Frame");
+        reassembly_stage_ = PacketViewForReassembly(std::make_shared<std::vector<uint8_t>>());
+        remaining_sdu_continuation_packet_size_ = 0;
+        close_channel();
+        return;
+      }
       reassembly_stage_.AppendPacketView(payload);
       enqueue_buffer_.Enqueue(std::make_unique<packet::PacketView<kLittleEndian>>(reassembly_stage_), handler_);
-      sar_state_ = SegmentationAndReassembly::END;
       break;
   }
 }
diff --git a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.h b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.h
index 14deea3..dd1ca64 100644
--- a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.h
+++ b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller.h
@@ -59,7 +59,19 @@
   os::Handler* handler_;
   std::queue<std::unique_ptr<packet::BasePacketBuilder>> pdu_queue_;
   Scheduler* scheduler_;
+
+  // Configuration options
   bool fcs_enabled_ = false;
+  uint16_t local_tx_window_ = 10;
+  uint16_t local_max_transmit_ = 20;
+  uint16_t local_retransmit_timeout_ms_ = 2000;
+  uint16_t local_monitor_timeout_ms_ = 12000;
+
+  uint16_t remote_tx_window_ = 10;
+  uint16_t remote_mps_ = 1010;
+
+  uint16_t size_each_packet_ =
+      (remote_mps_ - 4 /* basic L2CAP header */ - 2 /* SDU length */ - 2 /* Extended control */ - 2 /* FCS */);
 
   class PacketViewForReassembly : public packet::PacketView<kLittleEndian> {
    public:
@@ -83,8 +95,10 @@
 
   PacketViewForReassembly reassembly_stage_{std::make_shared<std::vector<uint8_t>>()};
   SegmentationAndReassembly sar_state_ = SegmentationAndReassembly::END;
+  uint16_t remaining_sdu_continuation_packet_size_ = 0;
 
-  void stage_for_reassembly(SegmentationAndReassembly sar, const packet::PacketView<kLittleEndian>& payload);
+  void stage_for_reassembly(SegmentationAndReassembly sar, uint16_t sdu_size,
+                            const packet::PacketView<kLittleEndian>& payload);
   void send_pdu(std::unique_ptr<packet::BasePacketBuilder> pdu);
 
   void close_channel();
@@ -92,18 +106,6 @@
   void on_pdu_no_fcs(const packet::PacketView<true>& pdu);
   void on_pdu_fcs(const packet::PacketView<true>& pdu);
 
-  // Configuration options
-  uint16_t local_tx_window_ = 10;
-  uint16_t local_max_transmit_ = 20;
-  uint16_t local_retransmit_timeout_ms_ = 2000;
-  uint16_t local_monitor_timeout_ms_ = 12000;
-
-  uint16_t remote_tx_window_ = 10;
-  uint16_t remote_mps_ = 1010;
-
-  uint16_t size_each_packet_ =
-      (remote_mps_ - 4 /* basic L2CAP header */ - 2 /* SDU length */ - 2 /* Extended control */ - 2 /* FCS */);
-
   struct impl;
   std::unique_ptr<impl> pimpl_;
 };
diff --git a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller_test.cc b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller_test.cc
index 04ef82b..e1b5818 100644
--- a/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller_test.cc
+++ b/gd/l2cap/internal/enhanced_retransmission_mode_channel_data_controller_test.cc
@@ -108,6 +108,54 @@
   EXPECT_EQ(data, "abcd");
 }
 
+TEST_F(ErtmDataControllerTest, reassemble_valid_sdu) {
+  common::BidiQueue<Scheduler::UpperEnqueue, Scheduler::UpperDequeue> channel_queue{10};
+  testing::MockScheduler scheduler;
+  ErtmController controller{1, 1, channel_queue.GetDownEnd(), queue_handler_, &scheduler};
+  auto segment1 = CreateSdu({'a'});
+  auto segment2 = CreateSdu({'b', 'c'});
+  auto segment3 = CreateSdu({'d', 'e', 'f'});
+  auto builder1 = EnhancedInformationStartFrameBuilder::Create(1, 0, Final::NOT_SET, 0, 6, std::move(segment1));
+  auto base_view = GetPacketView(std::move(builder1));
+  controller.OnPdu(base_view);
+  auto builder2 = EnhancedInformationFrameBuilder::Create(1, 1, Final::NOT_SET, 0,
+                                                          SegmentationAndReassembly::CONTINUATION, std::move(segment2));
+  base_view = GetPacketView(std::move(builder2));
+  controller.OnPdu(base_view);
+  auto builder3 = EnhancedInformationFrameBuilder::Create(1, 2, Final::NOT_SET, 0, SegmentationAndReassembly::END,
+                                                          std::move(segment3));
+  base_view = GetPacketView(std::move(builder3));
+  controller.OnPdu(base_view);
+  sync_handler(queue_handler_);
+  auto payload = channel_queue.GetUpEnd()->TryDequeue();
+  EXPECT_NE(payload, nullptr);
+  std::string data = std::string(payload->begin(), payload->end());
+  EXPECT_EQ(data, "abcdef");
+}
+
+TEST_F(ErtmDataControllerTest, reassemble_invalid_sdu_size_in_start_frame) {
+  common::BidiQueue<Scheduler::UpperEnqueue, Scheduler::UpperDequeue> channel_queue{10};
+  testing::MockScheduler scheduler;
+  ErtmController controller{1, 1, channel_queue.GetDownEnd(), queue_handler_, &scheduler};
+  auto segment1 = CreateSdu({'a'});
+  auto segment2 = CreateSdu({'b', 'c'});
+  auto segment3 = CreateSdu({'d', 'e', 'f'});
+  auto builder1 = EnhancedInformationStartFrameBuilder::Create(1, 0, Final::NOT_SET, 0, 10, std::move(segment1));
+  auto base_view = GetPacketView(std::move(builder1));
+  controller.OnPdu(base_view);
+  auto builder2 = EnhancedInformationFrameBuilder::Create(1, 1, Final::NOT_SET, 0,
+                                                          SegmentationAndReassembly::CONTINUATION, std::move(segment2));
+  base_view = GetPacketView(std::move(builder2));
+  controller.OnPdu(base_view);
+  auto builder3 = EnhancedInformationFrameBuilder::Create(1, 2, Final::NOT_SET, 0, SegmentationAndReassembly::END,
+                                                          std::move(segment3));
+  base_view = GetPacketView(std::move(builder3));
+  controller.OnPdu(base_view);
+  sync_handler(queue_handler_);
+  auto payload = channel_queue.GetUpEnd()->TryDequeue();
+  EXPECT_EQ(payload, nullptr);
+}
+
 TEST_F(ErtmDataControllerTest, transmit_with_fcs) {
   common::BidiQueue<Scheduler::UpperEnqueue, Scheduler::UpperDequeue> channel_queue{10};
   testing::MockScheduler scheduler;